Skip to content

Commit fae5d1b

Browse files
Scan support (#16028)
Add support for higher order ops scan. Its inefficient today because we are manually deep copying from output to input for every carry. We could do better by shallow swapping the pointers but Ill do that in a follow up if needed. Test plan: Unit tests and internal verification against harder patterns
1 parent 56c9b2c commit fae5d1b

File tree

10 files changed

+775
-68
lines changed

10 files changed

+775
-68
lines changed

exir/emit/_emitter.py

Lines changed: 261 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -944,22 +944,279 @@ def forward(self, x,y):
944944

945945
return subemitter_binding_output_values
946946

947+
def _emit_scan(
948+
self,
949+
args: Tuple[_Argument, ...],
950+
subemitter_binding_output_values: List[_AbstractValue],
951+
) -> List[_AbstractValue]:
952+
"""Emits torch.scan.
953+
954+
Converts the higher order scan op into a loop constructed from jump instructions
955+
and primitive operations. Scan differs from map in that it maintains a carry state
956+
that evolves across iterations.
957+
958+
Scan signature: scan(combine_fn, init, xs, additional_inputs)
959+
- combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
960+
and returns (next_carry, y_slice)
961+
- init: Initial carry state (list of tensors)
962+
- xs: Input tensors to scan over (list of tensors, scanned along dim 0)
963+
- additional_inputs: Additional arguments passed to combine_fn
964+
965+
Output: (final_carry, stacked_ys)
966+
- final_carry: The carry state after the last iteration
967+
- stacked_ys: All y outputs stacked along dim 0
968+
969+
Memory Layout:
970+
- carry_outputs (subemitter_binding_output_values[:num_carry]):
971+
Working carry buffers, initialized from init, updated each iteration
972+
- y_outputs (subemitter_binding_output_values[num_carry:]):
973+
Pre-allocated stacked output buffers, filled via et_copy_index
974+
975+
The combine_fn writes to its own temporary output buffers (concrete_output_ids).
976+
After each iteration:
977+
1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
978+
2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
979+
980+
This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
981+
"""
982+
combine_fn, init, xs, additional_inputs = args
983+
984+
assert isinstance(subemitter_binding_output_values, (list, tuple)), (
985+
f"Expected list for subemitter_binding_output_values. "
986+
f"Got {type(subemitter_binding_output_values).__name__}: "
987+
f"{subemitter_binding_output_values}."
988+
)
989+
990+
assert isinstance(combine_fn, torch.fx.GraphModule)
991+
assert isinstance(init, (list, tuple))
992+
assert isinstance(xs, (list, tuple))
993+
assert isinstance(additional_inputs, (list, tuple))
994+
995+
num_carry = len(init)
996+
num_xs = len(xs)
997+
998+
carry_outputs = list(subemitter_binding_output_values[:num_carry])
999+
y_outputs = list(subemitter_binding_output_values[num_carry:])
1000+
1001+
if num_xs < 1:
1002+
raise RuntimeError(
1003+
f"Scan requires at least one xs tensor to scan over but got {num_xs}"
1004+
)
1005+
1006+
iter_idx = self._emit_evalue(EValue(Int(0)))
1007+
1008+
op_index, op = self._get_operator(
1009+
name="aten::sym_size",
1010+
overload="int",
1011+
)
1012+
sym_size = self._emit_evalue(EValue(Int(0)))
1013+
kernel = Instruction(
1014+
KernelCall(
1015+
op_index=op_index,
1016+
args=[xs[0].id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
1017+
)
1018+
)
1019+
self.chain.instructions.append(kernel)
1020+
1021+
# Initialize carry_outputs from init
1022+
op_index_copy, _ = self._get_operator(name="aten::copy_")
1023+
for init_val, carry_out in zip(init, carry_outputs):
1024+
kernel = Instruction(
1025+
KernelCall(
1026+
op_index=op_index_copy,
1027+
args=[
1028+
carry_out.id,
1029+
init_val.id,
1030+
self._emit_evalue(EValue(Bool(False))).id,
1031+
carry_out.id,
1032+
],
1033+
)
1034+
)
1035+
self.chain.instructions.append(kernel)
1036+
1037+
# Slice each xs tensor for the current iteration
1038+
op_index_select, _ = self._get_operator(
1039+
name="aten::select_copy",
1040+
overload="int_out",
1041+
)
1042+
xs_slice_instructions = []
1043+
for x in xs:
1044+
kernel = Instruction(
1045+
KernelCall(
1046+
op_index=op_index_select,
1047+
args=[
1048+
x.id,
1049+
self._emit_evalue(EValue(Int(0))).id,
1050+
iter_idx.id,
1051+
-1,
1052+
-1,
1053+
],
1054+
)
1055+
)
1056+
xs_slice_instructions.append(kernel)
1057+
1058+
jump_to_instruction = self.instruction_start_offset + len(
1059+
self.chain.instructions
1060+
)
1061+
1062+
for kernel in xs_slice_instructions:
1063+
self.chain.instructions.append(kernel)
1064+
1065+
# Emit combine_fn submodule
1066+
binding_input_values: List[Any] = []
1067+
binding_input_values.extend(carry_outputs)
1068+
binding_input_values.extend([-1] * num_xs)
1069+
binding_input_values.extend(additional_inputs)
1070+
1071+
scan_emitter = _Emitter(
1072+
combine_fn,
1073+
self.emitter_state,
1074+
self.program_state,
1075+
instruction_start_offset=self.instruction_start_offset
1076+
+ len(self.chain.instructions),
1077+
binding_input_values=binding_input_values,
1078+
binding_output_values=None,
1079+
)
1080+
scan_emitter.run()
1081+
1082+
self._merge_chain(scan_emitter.chain)
1083+
1084+
for i, kernel in enumerate(xs_slice_instructions):
1085+
xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id
1086+
kernel.instr_args.args[-1] = xs_placeholder_id
1087+
kernel.instr_args.args[-2] = xs_placeholder_id
1088+
1089+
concrete_outputs = scan_emitter.concrete_output_ids
1090+
carry_temp = concrete_outputs[:num_carry]
1091+
y_temp = concrete_outputs[num_carry:]
1092+
1093+
self._internal_assert_emitter(
1094+
len(carry_temp) == num_carry,
1095+
self.node,
1096+
f"Scan combine_fn should output {num_carry} carry values, got {len(carry_temp)}",
1097+
)
1098+
self._internal_assert_emitter(
1099+
len(y_temp) == len(y_outputs),
1100+
self.node,
1101+
f"Scan combine_fn should output {len(y_outputs)} y values, got {len(y_temp)}",
1102+
)
1103+
1104+
# Copy carry_temp -> carry_outputs for next iteration
1105+
for carry_t, carry_out in zip(carry_temp, carry_outputs):
1106+
kernel = Instruction(
1107+
KernelCall(
1108+
op_index=op_index_copy,
1109+
args=[
1110+
carry_out.id,
1111+
carry_t.id,
1112+
self._emit_evalue(EValue(Bool(False))).id,
1113+
carry_out.id,
1114+
],
1115+
)
1116+
)
1117+
self.chain.instructions.append(kernel)
1118+
1119+
# Copy y_temp to stacked y_outputs
1120+
op_index_copy_index, _ = self._get_operator(
1121+
name="executorch_prim::et_copy_index",
1122+
overload="tensor",
1123+
)
1124+
for y_t, y_out in zip(y_temp, y_outputs):
1125+
kernel = Instruction(
1126+
KernelCall(
1127+
op_index=op_index_copy_index,
1128+
args=[y_out.id, y_t.id, iter_idx.id],
1129+
)
1130+
)
1131+
self.chain.instructions.append(kernel)
1132+
1133+
# Increment iter_idx
1134+
op_index_add, _ = self._get_operator(
1135+
name="executorch_prim::add",
1136+
overload="Scalar",
1137+
)
1138+
kernel = Instruction(
1139+
KernelCall(
1140+
op_index=op_index_add,
1141+
args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
1142+
)
1143+
)
1144+
self.chain.instructions.append(kernel)
1145+
1146+
# Check if iteration is complete
1147+
jump_bool_value = self._emit_evalue(EValue(Bool(False)))
1148+
op_index_eq, _ = self._get_operator(
1149+
name="executorch_prim::eq",
1150+
overload="Scalar",
1151+
)
1152+
kernel = Instruction(
1153+
KernelCall(
1154+
op_index=op_index_eq,
1155+
args=[iter_idx.id, sym_size.id, jump_bool_value.id],
1156+
)
1157+
)
1158+
self.chain.instructions.append(kernel)
1159+
1160+
jf_beginning_loop = Instruction(
1161+
JumpFalseCall(
1162+
cond_value_index=jump_bool_value.id,
1163+
destination_instruction=jump_to_instruction,
1164+
)
1165+
)
1166+
self.chain.instructions.append(jf_beginning_loop)
1167+
1168+
# Reset iter_idx for potential re-runs
1169+
op_index_sub, _ = self._get_operator(
1170+
name="executorch_prim::sub",
1171+
overload="Scalar",
1172+
)
1173+
kernel = Instruction(
1174+
KernelCall(
1175+
op_index=op_index_sub,
1176+
args=[iter_idx.id, sym_size.id, iter_idx.id],
1177+
)
1178+
)
1179+
self.chain.instructions.append(kernel)
1180+
1181+
return subemitter_binding_output_values
1182+
9471183
def _emit_control_flow(
9481184
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
9491185
) -> _EmitterValue:
9501186
"""Wraps common logic for emitting all control flow operations.
9511187
952-
See the more specific emission functions for more details on how cond or map get emitted.
1188+
See the more specific emission functions for more details on how cond, map, or scan get emitted.
9531189
"""
1190+
specs = self.node.meta["spec"]
1191+
1192+
# For scan/map, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
1193+
# BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
1194+
# stack memory behavior, so we need to be able to update the shape +1 for each iteration
1195+
# which we can't do for tensors marked static.
1196+
if target is torch.ops.higher_order.scan:
1197+
combine_fn, init, xs, additional_inputs = args
1198+
num_carry = len(init)
1199+
if isinstance(specs, (list, tuple)):
1200+
y_specs = specs[num_carry:]
1201+
for y_spec in y_specs:
1202+
if isinstance(y_spec, TensorSpec):
1203+
y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
1204+
elif target is torch.ops.higher_order.map_impl:
1205+
assert len(specs) == 1
1206+
assert isinstance(specs[0], TensorSpec)
1207+
specs[0].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
1208+
9541209
subemitter_binding_output_values = pytree.tree_map(
9551210
lambda spec: self._emit_spec(spec),
956-
self.node.meta["spec"],
1211+
specs,
9571212
)
9581213

9591214
if target is torch.ops.higher_order.cond:
9601215
return self._emit_cond(args, subemitter_binding_output_values)
9611216
elif target is torch.ops.higher_order.map_impl:
9621217
return self._emit_map(args, subemitter_binding_output_values)
1218+
elif target is torch.ops.higher_order.scan:
1219+
return self._emit_scan(args, subemitter_binding_output_values)
9631220
else:
9641221
raise InternalError(
9651222
self._emit_node_specific_error(
@@ -1190,7 +1447,7 @@ def _emit_delegate(
11901447

11911448
return delegate_ret
11921449

1193-
def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]:
1450+
def _get_operator(self, name: str, overload: str = "") -> Tuple[int, Operator]:
11941451
"""Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it
11951452
if it is not already present"""
11961453
key = (name, overload)
@@ -1511,6 +1768,7 @@ def call_function( # pyre-fixme[14]
15111768
torch.ops.higher_order.cond,
15121769
torch.ops.higher_order.map_impl,
15131770
torch.ops.higher_order.while_loop,
1771+
torch.ops.higher_order.scan,
15141772
):
15151773
return self._emit_control_flow(target, args, kwargs)
15161774

0 commit comments

Comments
 (0)