Skip to content

Commit 89279ba

Browse files
committed
scan support
1 parent 5e96b43 commit 89279ba

File tree

8 files changed

+731
-4
lines changed

8 files changed

+731
-4
lines changed

exir/capture/_unlift.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,49 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
9797
_unlift(
9898
body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
9999
)
100+
if node.op == "call_function" and node.target.__name__ == "scan":
101+
# scan signature: scan(combine_fn, init, xs, additional_inputs)
102+
# - combine_fn: GraphModule for the scan body
103+
# - init: list of initial carry tensors
104+
# - xs: list of input tensors to scan over
105+
# - additional_inputs: tuple of additional arguments (may contain lifted params/buffers)
106+
combine_fn, init, xs, additional_inputs = node.args
107+
combine_gm = getattr(gm, combine_fn.name)
108+
inp_pos_to_buffer_name_for_submod = {}
109+
real_additional_inputs = []
110+
111+
# additional_inputs may contain lifted parameters/buffers that need to be
112+
# registered in the combine_fn submodule
113+
for ix, operand in enumerate(additional_inputs):
114+
if (
115+
hasattr(operand, "target")
116+
and operand.target in inp_pos_to_param_buffer_name.values()
117+
):
118+
# This is a lifted param/buffer, register it in the submodule
119+
# The index needs to account for init and xs inputs to combine_fn
120+
# combine_fn inputs: (*init, *xs_slice, *additional_inputs)
121+
num_init = len(init) if isinstance(init, (list, tuple)) else 1
122+
num_xs = len(xs) if isinstance(xs, (list, tuple)) else 1
123+
adjusted_ix = num_init + num_xs + ix
124+
inp_pos_to_buffer_name_for_submod[adjusted_ix] = operand.target
125+
combine_gm.register_buffer(
126+
operand.target, state_dict[operand.target]
127+
)
128+
else:
129+
real_additional_inputs.append(operand)
130+
131+
# Update node args with the filtered additional_inputs
132+
node.args = (combine_fn, init, xs, tuple(real_additional_inputs))
133+
134+
_, in_spec = pytree.tree_flatten((init, xs, tuple(real_additional_inputs)))
135+
136+
_unlift(
137+
combine_gm,
138+
inp_pos_to_buffer_name_for_submod,
139+
in_spec,
140+
None,
141+
state_dict,
142+
)
100143
gm.graph.lint()
101144
gm.graph.eliminate_dead_code()
102145
gm.recompile()

exir/emit/_emitter.py

Lines changed: 267 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,12 +944,275 @@ def forward(self, x,y):
944944

945945
return subemitter_binding_output_values
946946

947+
def _emit_scan(
948+
self,
949+
args: Tuple[_Argument, ...],
950+
subemitter_binding_output_values: List[_AbstractValue],
951+
) -> List[_AbstractValue]:
952+
"""Emits torch.scan.
953+
954+
Converts the higher order scan op into a loop constructed from jump instructions
955+
and primitive operations. Scan differs from map in that it maintains a carry state
956+
that evolves across iterations.
957+
958+
Scan signature: scan(combine_fn, init, xs, additional_inputs)
959+
- combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
960+
and returns (next_carry, y_slice)
961+
- init: Initial carry state (list of tensors)
962+
- xs: Input tensors to scan over (list of tensors, scanned along dim 0)
963+
- additional_inputs: Additional arguments passed to combine_fn
964+
965+
Output: (final_carry, stacked_ys)
966+
- final_carry: The carry state after the last iteration
967+
- stacked_ys: All y outputs stacked along dim 0
968+
969+
Memory Layout:
970+
- carry_outputs (subemitter_binding_output_values[:num_carry]):
971+
Working carry buffers, initialized from init, updated each iteration
972+
- y_outputs (subemitter_binding_output_values[num_carry:]):
973+
Pre-allocated stacked output buffers, filled via et_copy_index
974+
975+
The combine_fn writes to its own temporary output buffers (concrete_output_ids).
976+
After each iteration:
977+
1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
978+
2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
979+
980+
This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
981+
"""
982+
combine_fn, init, xs, additional_inputs = args
983+
984+
assert isinstance(
985+
subemitter_binding_output_values, (list, tuple)
986+
), f"Expected list for subemitter_binding_output_values. Got {subemitter_binding_output_values}."
987+
988+
assert isinstance(combine_fn, torch.fx.GraphModule)
989+
assert isinstance(init, (list, tuple))
990+
assert isinstance(xs, (list, tuple))
991+
assert isinstance(additional_inputs, (list, tuple))
992+
993+
num_carry = len(init)
994+
num_xs = len(xs)
995+
num_additional = len(additional_inputs)
996+
997+
# Split output values into carry outputs and y outputs
998+
carry_outputs = list(subemitter_binding_output_values[:num_carry])
999+
y_outputs = list(subemitter_binding_output_values[num_carry:])
1000+
1001+
if num_xs < 1:
1002+
raise RuntimeError("Scan requires at least one xs tensor to scan over.")
1003+
1004+
# === INITIALIZATION ===
1005+
1006+
# Generate iterator index EValue
1007+
iter_idx = self._emit_evalue(EValue(Int(0)))
1008+
1009+
# Get scan length from first xs tensor
1010+
op_index, op = self._get_operator(
1011+
name="aten::sym_size",
1012+
overload="int",
1013+
)
1014+
sym_size = self._emit_evalue(EValue(Int(0)))
1015+
kernel = Instruction(
1016+
KernelCall(
1017+
op_index=op_index,
1018+
args=[xs[0].id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
1019+
)
1020+
)
1021+
self.chain.instructions.append(kernel)
1022+
1023+
# Initialize carry_outputs from init by copying init -> carry_outputs
1024+
# This is necessary because we shouldn't mutate the original init tensors
1025+
op_index_copy, _ = self._get_operator(
1026+
name="aten::copy_",
1027+
overload="default",
1028+
)
1029+
for i, (init_val, carry_out) in enumerate(zip(init, carry_outputs)):
1030+
kernel = Instruction(
1031+
KernelCall(
1032+
op_index=op_index_copy,
1033+
args=[carry_out.id, init_val.id],
1034+
)
1035+
)
1036+
self.chain.instructions.append(kernel)
1037+
1038+
# === LOOP START ===
1039+
1040+
# Slice each xs tensor for the current iteration
1041+
# We use -1 as placeholder for the output tensor id, which will be filled
1042+
# after the scan_emitter runs and allocates the input placeholder EValues
1043+
op_index_select, _ = self._get_operator(
1044+
name="aten::select_copy",
1045+
overload="int_out",
1046+
)
1047+
xs_slice_instructions = []
1048+
for i, x in enumerate(xs):
1049+
kernel = Instruction(
1050+
KernelCall(
1051+
op_index=op_index_select,
1052+
args=[
1053+
x.id,
1054+
self._emit_evalue(EValue(Int(0))).id, # dim=0
1055+
iter_idx.id,
1056+
-1, # placeholder for output tensor id
1057+
-1, # placeholder (repeated for out variant)
1058+
],
1059+
)
1060+
)
1061+
xs_slice_instructions.append(kernel)
1062+
1063+
# Store jump target - this is where we jump back to after each iteration
1064+
jump_to_instruction = self.instruction_start_offset + len(
1065+
self.chain.instructions
1066+
)
1067+
1068+
# Add all xs slice instructions
1069+
for kernel in xs_slice_instructions:
1070+
self.chain.instructions.append(kernel)
1071+
1072+
# === EMIT COMBINE_FN SUBMODULE ===
1073+
1074+
# combine_fn inputs: (*carry, *xs_slice, *additional_inputs)
1075+
# We bind carry inputs to carry_outputs (the working carry buffers)
1076+
# xs_slice inputs will be filled in after emitter runs (using -1 placeholder)
1077+
# additional_inputs are passed through directly
1078+
binding_input_values: List[Any] = []
1079+
binding_input_values.extend(
1080+
carry_outputs
1081+
) # Carry inputs bound to carry_outputs
1082+
binding_input_values.extend([-1] * num_xs) # Placeholders for xs slices
1083+
binding_input_values.extend(additional_inputs) # Additional inputs
1084+
1085+
# combine_fn outputs: (*next_carry, *y_slice)
1086+
# We don't bind outputs to the final destinations directly because we need
1087+
# to copy them explicitly (in-place is unsafe)
1088+
scan_emitter = _Emitter(
1089+
combine_fn,
1090+
self.emitter_state,
1091+
self.program_state,
1092+
instruction_start_offset=self.instruction_start_offset
1093+
+ len(self.chain.instructions),
1094+
binding_input_values=binding_input_values,
1095+
binding_output_values=None, # Let combine_fn use its own output buffers
1096+
)
1097+
scan_emitter.run()
1098+
1099+
# Merge combine_fn instructions
1100+
self._merge_chain(scan_emitter.chain)
1101+
# Remove the return instruction from combine_fn
1102+
self.chain.instructions.pop()
1103+
1104+
# Update xs_slice instructions with the actual placeholder EValue ids
1105+
# The xs placeholders start after the carry inputs in combine_fn
1106+
for i, kernel in enumerate(xs_slice_instructions):
1107+
xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id
1108+
kernel.instr_args.args[-1] = xs_placeholder_id
1109+
kernel.instr_args.args[-2] = xs_placeholder_id
1110+
1111+
# === COPY OUTPUTS ===
1112+
1113+
# Get combine_fn's actual output EValues
1114+
# concrete_output_ids contains: (*carry_temp, *y_temp)
1115+
concrete_outputs = scan_emitter.concrete_output_ids
1116+
carry_temp = concrete_outputs[:num_carry]
1117+
y_temp = concrete_outputs[num_carry:]
1118+
1119+
self._internal_assert_emitter(
1120+
len(carry_temp) == num_carry,
1121+
self.node,
1122+
f"Scan combine_fn should output {num_carry} carry values, got {len(carry_temp)}",
1123+
)
1124+
self._internal_assert_emitter(
1125+
len(y_temp) == len(y_outputs),
1126+
self.node,
1127+
f"Scan combine_fn should output {len(y_outputs)} y values, got {len(y_temp)}",
1128+
)
1129+
1130+
# Copy carry_temp -> carry_outputs for next iteration
1131+
# This explicit copy is required because in-place op.out(x, out=x) is unsafe
1132+
for carry_t, carry_out in zip(carry_temp, carry_outputs):
1133+
kernel = Instruction(
1134+
KernelCall(
1135+
op_index=op_index_copy,
1136+
args=[carry_out.id, carry_t.id],
1137+
)
1138+
)
1139+
self.chain.instructions.append(kernel)
1140+
1141+
# Copy y_temp to stacked y_outputs using et_copy_index
1142+
op_index_copy_index, _ = self._get_operator(
1143+
name="executorch_prim::et_copy_index",
1144+
overload="tensor",
1145+
)
1146+
for y_t, y_out in zip(y_temp, y_outputs):
1147+
kernel = Instruction(
1148+
KernelCall(
1149+
op_index=op_index_copy_index,
1150+
args=[y_out.id, y_t.id, iter_idx.id],
1151+
)
1152+
)
1153+
self.chain.instructions.append(kernel)
1154+
1155+
# === LOOP CONTROL ===
1156+
1157+
# Increment iter_idx
1158+
op_index_add, _ = self._get_operator(
1159+
name="executorch_prim::add",
1160+
overload="Scalar",
1161+
)
1162+
kernel = Instruction(
1163+
KernelCall(
1164+
op_index=op_index_add,
1165+
args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
1166+
)
1167+
)
1168+
self.chain.instructions.append(kernel)
1169+
1170+
# Check if iteration is complete
1171+
jump_bool_value = self._emit_evalue(EValue(Bool(False)))
1172+
op_index_eq, _ = self._get_operator(
1173+
name="executorch_prim::eq",
1174+
overload="Scalar",
1175+
)
1176+
kernel = Instruction(
1177+
KernelCall(
1178+
op_index=op_index_eq,
1179+
args=[iter_idx.id, sym_size.id, jump_bool_value.id],
1180+
)
1181+
)
1182+
self.chain.instructions.append(kernel)
1183+
1184+
# Jump back to loop start if not done
1185+
jf_beginning_loop = Instruction(
1186+
JumpFalseCall(
1187+
cond_value_index=jump_bool_value.id,
1188+
destination_instruction=jump_to_instruction,
1189+
)
1190+
)
1191+
self.chain.instructions.append(jf_beginning_loop)
1192+
1193+
# === CLEANUP ===
1194+
1195+
# Reset iter_idx for potential re-runs of the model
1196+
op_index_sub, _ = self._get_operator(
1197+
name="executorch_prim::sub",
1198+
overload="Scalar",
1199+
)
1200+
kernel = Instruction(
1201+
KernelCall(
1202+
op_index=op_index_sub,
1203+
args=[iter_idx.id, sym_size.id, iter_idx.id],
1204+
)
1205+
)
1206+
self.chain.instructions.append(kernel)
1207+
1208+
return subemitter_binding_output_values
1209+
9471210
def _emit_control_flow(
9481211
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
9491212
) -> _EmitterValue:
9501213
"""Wraps common logic for emitting all control flow operations.
9511214
952-
See the more specific emission functions for more details on how cond or map get emitted.
1215+
See the more specific emission functions for more details on how cond, map, or scan get emitted.
9531216
"""
9541217
subemitter_binding_output_values = pytree.tree_map(
9551218
lambda spec: self._emit_spec(spec),
@@ -960,6 +1223,8 @@ def _emit_control_flow(
9601223
return self._emit_cond(args, subemitter_binding_output_values)
9611224
elif target is torch.ops.higher_order.map_impl:
9621225
return self._emit_map(args, subemitter_binding_output_values)
1226+
elif target is torch.ops.higher_order.scan:
1227+
return self._emit_scan(args, subemitter_binding_output_values)
9631228
else:
9641229
raise InternalError(
9651230
self._emit_node_specific_error(
@@ -1511,6 +1776,7 @@ def call_function( # pyre-fixme[14]
15111776
torch.ops.higher_order.cond,
15121777
torch.ops.higher_order.map_impl,
15131778
torch.ops.higher_order.while_loop,
1779+
torch.ops.higher_order.scan,
15141780
):
15151781
return self._emit_control_flow(target, args, kwargs)
15161782

0 commit comments

Comments
 (0)