Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 257 additions & 3 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,22 +944,275 @@ def forward(self, x,y):

return subemitter_binding_output_values

def _emit_scan(
self,
args: Tuple[_Argument, ...],
subemitter_binding_output_values: List[_AbstractValue],
) -> List[_AbstractValue]:
"""Emits torch.scan.

Converts the higher order scan op into a loop constructed from jump instructions
and primitive operations. Scan differs from map in that it maintains a carry state
that evolves across iterations.

Scan signature: scan(combine_fn, init, xs, additional_inputs)
- combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
and returns (next_carry, y_slice)
- init: Initial carry state (list of tensors)
- xs: Input tensors to scan over (list of tensors, scanned along dim 0)
- additional_inputs: Additional arguments passed to combine_fn

Output: (final_carry, stacked_ys)
- final_carry: The carry state after the last iteration
- stacked_ys: All y outputs stacked along dim 0

Memory Layout:
- carry_outputs (subemitter_binding_output_values[:num_carry]):
Working carry buffers, initialized from init, updated each iteration
- y_outputs (subemitter_binding_output_values[num_carry:]):
Pre-allocated stacked output buffers, filled via et_copy_index

The combine_fn writes to its own temporary output buffers (concrete_output_ids).
After each iteration:
1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)

This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
Comment on lines +978 to +980
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that this might be fine. We basically emit scan at the very end of the lowering process and I'm not convinced we still require the graph to be functional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the problem isnt being functional its that aten (and ET ops) are not guaranteed to work when in and out alias the same memory.

You could very easily write before read over sections of the tensor.

"""
combine_fn, init, xs, additional_inputs = args

assert isinstance(subemitter_binding_output_values, (list, tuple)), (
f"Expected list for subemitter_binding_output_values. "
f"Got {type(subemitter_binding_output_values).__name__}: "
f"{subemitter_binding_output_values}."
)

assert isinstance(combine_fn, torch.fx.GraphModule)
assert isinstance(init, (list, tuple))
assert isinstance(xs, (list, tuple))
assert isinstance(additional_inputs, (list, tuple))

num_carry = len(init)
num_xs = len(xs)

carry_outputs = list(subemitter_binding_output_values[:num_carry])
y_outputs = list(subemitter_binding_output_values[num_carry:])

if num_xs < 1:
raise RuntimeError(
f"Scan requires at least one xs tensor to scan over but got {num_xs}"
)

iter_idx = self._emit_evalue(EValue(Int(0)))

op_index, op = self._get_operator(
name="aten::sym_size",
overload="int",
)
sym_size = self._emit_evalue(EValue(Int(0)))
kernel = Instruction(
KernelCall(
op_index=op_index,
args=[xs[0].id, self._emit_evalue(EValue(Int(0))).id, sym_size.id],
)
)
self.chain.instructions.append(kernel)

# Initialize carry_outputs from init
op_index_copy, _ = self._get_operator(name="aten::copy_")
for init_val, carry_out in zip(init, carry_outputs):
kernel = Instruction(
KernelCall(
op_index=op_index_copy,
args=[
carry_out.id,
init_val.id,
self._emit_evalue(EValue(Bool(False))).id,
carry_out.id,
],
)
)
self.chain.instructions.append(kernel)

# Slice each xs tensor for the current iteration
op_index_select, _ = self._get_operator(
name="aten::select_copy",
overload="int_out",
)
xs_slice_instructions = []
for x in xs:
kernel = Instruction(
KernelCall(
op_index=op_index_select,
args=[
x.id,
self._emit_evalue(EValue(Int(0))).id,
iter_idx.id,
-1,
-1,
],
)
)
xs_slice_instructions.append(kernel)

jump_to_instruction = self.instruction_start_offset + len(
self.chain.instructions
)

for kernel in xs_slice_instructions:
self.chain.instructions.append(kernel)

# Emit combine_fn submodule
binding_input_values: List[Any] = []
binding_input_values.extend(carry_outputs)
binding_input_values.extend([-1] * num_xs)
binding_input_values.extend(additional_inputs)

scan_emitter = _Emitter(
combine_fn,
self.emitter_state,
self.program_state,
instruction_start_offset=self.instruction_start_offset
+ len(self.chain.instructions),
binding_input_values=binding_input_values,
binding_output_values=None,
)
scan_emitter.run()

self._merge_chain(scan_emitter.chain)

for i, kernel in enumerate(xs_slice_instructions):
xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id
kernel.instr_args.args[-1] = xs_placeholder_id
kernel.instr_args.args[-2] = xs_placeholder_id

concrete_outputs = scan_emitter.concrete_output_ids
carry_temp = concrete_outputs[:num_carry]
y_temp = concrete_outputs[num_carry:]

self._internal_assert_emitter(
len(carry_temp) == num_carry,
self.node,
f"Scan combine_fn should output {num_carry} carry values, got {len(carry_temp)}",
)
self._internal_assert_emitter(
len(y_temp) == len(y_outputs),
self.node,
f"Scan combine_fn should output {len(y_outputs)} y values, got {len(y_temp)}",
)

# Copy carry_temp -> carry_outputs for next iteration
for carry_t, carry_out in zip(carry_temp, carry_outputs):
kernel = Instruction(
KernelCall(
op_index=op_index_copy,
args=[
carry_out.id,
carry_t.id,
self._emit_evalue(EValue(Bool(False))).id,
carry_out.id,
],
)
)
self.chain.instructions.append(kernel)

# Copy y_temp to stacked y_outputs
op_index_copy_index, _ = self._get_operator(
name="executorch_prim::et_copy_index",
overload="tensor",
)
for y_t, y_out in zip(y_temp, y_outputs):
kernel = Instruction(
KernelCall(
op_index=op_index_copy_index,
args=[y_out.id, y_t.id, iter_idx.id],
)
)
self.chain.instructions.append(kernel)

# Increment iter_idx
op_index_add, _ = self._get_operator(
name="executorch_prim::add",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index=op_index_add,
args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id],
)
)
self.chain.instructions.append(kernel)

# Check if iteration is complete
jump_bool_value = self._emit_evalue(EValue(Bool(False)))
op_index_eq, _ = self._get_operator(
name="executorch_prim::eq",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index=op_index_eq,
args=[iter_idx.id, sym_size.id, jump_bool_value.id],
)
)
self.chain.instructions.append(kernel)

jf_beginning_loop = Instruction(
JumpFalseCall(
cond_value_index=jump_bool_value.id,
destination_instruction=jump_to_instruction,
)
)
self.chain.instructions.append(jf_beginning_loop)

# Reset iter_idx for potential re-runs
op_index_sub, _ = self._get_operator(
name="executorch_prim::sub",
overload="Scalar",
)
kernel = Instruction(
KernelCall(
op_index=op_index_sub,
args=[iter_idx.id, sym_size.id, iter_idx.id],
)
)
self.chain.instructions.append(kernel)

return subemitter_binding_output_values

def _emit_control_flow(
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _EmitterValue:
"""Wraps common logic for emitting all control flow operations.

See the more specific emission functions for more details on how cond or map get emitted.
See the more specific emission functions for more details on how cond, map, or scan get emitted.
"""
specs = self.node.meta["spec"]

# For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
# BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
# stack memory behavior, so we need to be able to update the shape +1 for each iteration
# which we can't do for tensors marked static.
if target is torch.ops.higher_order.scan:
combine_fn, init, xs, additional_inputs = args
num_carry = len(init)
if isinstance(specs, (list, tuple)):
y_specs = specs[num_carry:]
for y_spec in y_specs:
if isinstance(y_spec, TensorSpec):
y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND

subemitter_binding_output_values = pytree.tree_map(
lambda spec: self._emit_spec(spec),
self.node.meta["spec"],
specs,
)

if target is torch.ops.higher_order.cond:
return self._emit_cond(args, subemitter_binding_output_values)
elif target is torch.ops.higher_order.map_impl:
return self._emit_map(args, subemitter_binding_output_values)
elif target is torch.ops.higher_order.scan:
return self._emit_scan(args, subemitter_binding_output_values)
else:
raise InternalError(
self._emit_node_specific_error(
Expand Down Expand Up @@ -1190,7 +1443,7 @@ def _emit_delegate(

return delegate_ret

def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]:
def _get_operator(self, name: str, overload: str = "") -> Tuple[int, Operator]:
"""Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it
if it is not already present"""
key = (name, overload)
Expand Down Expand Up @@ -1511,6 +1764,7 @@ def call_function( # pyre-fixme[14]
torch.ops.higher_order.cond,
torch.ops.higher_order.map_impl,
torch.ops.higher_order.while_loop,
torch.ops.higher_order.scan,
):
return self._emit_control_flow(target, args, kwargs)

Expand Down
Loading
Loading