From 89279ba0f858b3994e349c4c7cabfdb086d99f5d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 1 Dec 2025 13:16:53 -0800 Subject: [PATCH 01/19] scan support --- exir/capture/_unlift.py | 43 +++++ exir/emit/_emitter.py | 268 +++++++++++++++++++++++++++++- exir/emit/test/test_emit.py | 192 +++++++++++++++++++++ exir/graph_module.py | 29 +++- exir/memory_planning.py | 22 ++- exir/pass_base.py | 49 ++++++ exir/passes/spec_prop_pass.py | 68 ++++++++ exir/tests/control_flow_models.py | 64 +++++++ 8 files changed, 731 insertions(+), 4 deletions(-) diff --git a/exir/capture/_unlift.py b/exir/capture/_unlift.py index 24c84494883..502dcc4ba3e 100644 --- a/exir/capture/_unlift.py +++ b/exir/capture/_unlift.py @@ -97,6 +97,49 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict): _unlift( body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict ) + if node.op == "call_function" and node.target.__name__ == "scan": + # scan signature: scan(combine_fn, init, xs, additional_inputs) + # - combine_fn: GraphModule for the scan body + # - init: list of initial carry tensors + # - xs: list of input tensors to scan over + # - additional_inputs: tuple of additional arguments (may contain lifted params/buffers) + combine_fn, init, xs, additional_inputs = node.args + combine_gm = getattr(gm, combine_fn.name) + inp_pos_to_buffer_name_for_submod = {} + real_additional_inputs = [] + + # additional_inputs may contain lifted parameters/buffers that need to be + # registered in the combine_fn submodule + for ix, operand in enumerate(additional_inputs): + if ( + hasattr(operand, "target") + and operand.target in inp_pos_to_param_buffer_name.values() + ): + # This is a lifted param/buffer, register it in the submodule + # The index needs to account for init and xs inputs to combine_fn + # combine_fn inputs: (*init, *xs_slice, *additional_inputs) + num_init = len(init) if isinstance(init, (list, tuple)) else 1 + num_xs = len(xs) if isinstance(xs, (list, tuple)) else 1 + adjusted_ix = num_init + num_xs + ix + inp_pos_to_buffer_name_for_submod[adjusted_ix] = operand.target + combine_gm.register_buffer( + operand.target, state_dict[operand.target] + ) + else: + real_additional_inputs.append(operand) + + # Update node args with the filtered additional_inputs + node.args = (combine_fn, init, xs, tuple(real_additional_inputs)) + + _, in_spec = pytree.tree_flatten((init, xs, tuple(real_additional_inputs))) + + _unlift( + combine_gm, + inp_pos_to_buffer_name_for_submod, + in_spec, + None, + state_dict, + ) gm.graph.lint() gm.graph.eliminate_dead_code() gm.recompile() diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 5c24f08c732..b9bc431ccfd 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -944,12 +944,275 @@ def forward(self, x,y): return subemitter_binding_output_values + def _emit_scan( + self, + args: Tuple[_Argument, ...], + subemitter_binding_output_values: List[_AbstractValue], + ) -> List[_AbstractValue]: + """Emits torch.scan. + + Converts the higher order scan op into a loop constructed from jump instructions + and primitive operations. Scan differs from map in that it maintains a carry state + that evolves across iterations. + + Scan signature: scan(combine_fn, init, xs, additional_inputs) + - combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs) + and returns (next_carry, y_slice) + - init: Initial carry state (list of tensors) + - xs: Input tensors to scan over (list of tensors, scanned along dim 0) + - additional_inputs: Additional arguments passed to combine_fn + + Output: (final_carry, stacked_ys) + - final_carry: The carry state after the last iteration + - stacked_ys: All y outputs stacked along dim 0 + + Memory Layout: + - carry_outputs (subemitter_binding_output_values[:num_carry]): + Working carry buffers, initialized from init, updated each iteration + - y_outputs (subemitter_binding_output_values[num_carry:]): + Pre-allocated stacked output buffers, filled via et_copy_index + + The combine_fn writes to its own temporary output buffers (concrete_output_ids). + After each iteration: + 1. Copy combine_fn's carry output -> carry_outputs (for next iteration) + 2. et_copy_index(y_outputs, combine_fn's y output, iter_idx) + + This explicit copy approach is used because in-place op.out(x, out=x) is unsafe. + """ + combine_fn, init, xs, additional_inputs = args + + assert isinstance( + subemitter_binding_output_values, (list, tuple) + ), f"Expected list for subemitter_binding_output_values. Got {subemitter_binding_output_values}." + + assert isinstance(combine_fn, torch.fx.GraphModule) + assert isinstance(init, (list, tuple)) + assert isinstance(xs, (list, tuple)) + assert isinstance(additional_inputs, (list, tuple)) + + num_carry = len(init) + num_xs = len(xs) + num_additional = len(additional_inputs) + + # Split output values into carry outputs and y outputs + carry_outputs = list(subemitter_binding_output_values[:num_carry]) + y_outputs = list(subemitter_binding_output_values[num_carry:]) + + if num_xs < 1: + raise RuntimeError("Scan requires at least one xs tensor to scan over.") + + # === INITIALIZATION === + + # Generate iterator index EValue + iter_idx = self._emit_evalue(EValue(Int(0))) + + # Get scan length from first xs tensor + op_index, op = self._get_operator( + name="aten::sym_size", + overload="int", + ) + sym_size = self._emit_evalue(EValue(Int(0))) + kernel = Instruction( + KernelCall( + op_index=op_index, + args=[xs[0].id, self._emit_evalue(EValue(Int(0))).id, sym_size.id], + ) + ) + self.chain.instructions.append(kernel) + + # Initialize carry_outputs from init by copying init -> carry_outputs + # This is necessary because we shouldn't mutate the original init tensors + op_index_copy, _ = self._get_operator( + name="aten::copy_", + overload="default", + ) + for i, (init_val, carry_out) in enumerate(zip(init, carry_outputs)): + kernel = Instruction( + KernelCall( + op_index=op_index_copy, + args=[carry_out.id, init_val.id], + ) + ) + self.chain.instructions.append(kernel) + + # === LOOP START === + + # Slice each xs tensor for the current iteration + # We use -1 as placeholder for the output tensor id, which will be filled + # after the scan_emitter runs and allocates the input placeholder EValues + op_index_select, _ = self._get_operator( + name="aten::select_copy", + overload="int_out", + ) + xs_slice_instructions = [] + for i, x in enumerate(xs): + kernel = Instruction( + KernelCall( + op_index=op_index_select, + args=[ + x.id, + self._emit_evalue(EValue(Int(0))).id, # dim=0 + iter_idx.id, + -1, # placeholder for output tensor id + -1, # placeholder (repeated for out variant) + ], + ) + ) + xs_slice_instructions.append(kernel) + + # Store jump target - this is where we jump back to after each iteration + jump_to_instruction = self.instruction_start_offset + len( + self.chain.instructions + ) + + # Add all xs slice instructions + for kernel in xs_slice_instructions: + self.chain.instructions.append(kernel) + + # === EMIT COMBINE_FN SUBMODULE === + + # combine_fn inputs: (*carry, *xs_slice, *additional_inputs) + # We bind carry inputs to carry_outputs (the working carry buffers) + # xs_slice inputs will be filled in after emitter runs (using -1 placeholder) + # additional_inputs are passed through directly + binding_input_values: List[Any] = [] + binding_input_values.extend( + carry_outputs + ) # Carry inputs bound to carry_outputs + binding_input_values.extend([-1] * num_xs) # Placeholders for xs slices + binding_input_values.extend(additional_inputs) # Additional inputs + + # combine_fn outputs: (*next_carry, *y_slice) + # We don't bind outputs to the final destinations directly because we need + # to copy them explicitly (in-place is unsafe) + scan_emitter = _Emitter( + combine_fn, + self.emitter_state, + self.program_state, + instruction_start_offset=self.instruction_start_offset + + len(self.chain.instructions), + binding_input_values=binding_input_values, + binding_output_values=None, # Let combine_fn use its own output buffers + ) + scan_emitter.run() + + # Merge combine_fn instructions + self._merge_chain(scan_emitter.chain) + # Remove the return instruction from combine_fn + self.chain.instructions.pop() + + # Update xs_slice instructions with the actual placeholder EValue ids + # The xs placeholders start after the carry inputs in combine_fn + for i, kernel in enumerate(xs_slice_instructions): + xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id + kernel.instr_args.args[-1] = xs_placeholder_id + kernel.instr_args.args[-2] = xs_placeholder_id + + # === COPY OUTPUTS === + + # Get combine_fn's actual output EValues + # concrete_output_ids contains: (*carry_temp, *y_temp) + concrete_outputs = scan_emitter.concrete_output_ids + carry_temp = concrete_outputs[:num_carry] + y_temp = concrete_outputs[num_carry:] + + self._internal_assert_emitter( + len(carry_temp) == num_carry, + self.node, + f"Scan combine_fn should output {num_carry} carry values, got {len(carry_temp)}", + ) + self._internal_assert_emitter( + len(y_temp) == len(y_outputs), + self.node, + f"Scan combine_fn should output {len(y_outputs)} y values, got {len(y_temp)}", + ) + + # Copy carry_temp -> carry_outputs for next iteration + # This explicit copy is required because in-place op.out(x, out=x) is unsafe + for carry_t, carry_out in zip(carry_temp, carry_outputs): + kernel = Instruction( + KernelCall( + op_index=op_index_copy, + args=[carry_out.id, carry_t.id], + ) + ) + self.chain.instructions.append(kernel) + + # Copy y_temp to stacked y_outputs using et_copy_index + op_index_copy_index, _ = self._get_operator( + name="executorch_prim::et_copy_index", + overload="tensor", + ) + for y_t, y_out in zip(y_temp, y_outputs): + kernel = Instruction( + KernelCall( + op_index=op_index_copy_index, + args=[y_out.id, y_t.id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + # === LOOP CONTROL === + + # Increment iter_idx + op_index_add, _ = self._get_operator( + name="executorch_prim::add", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_add, + args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + # Check if iteration is complete + jump_bool_value = self._emit_evalue(EValue(Bool(False))) + op_index_eq, _ = self._get_operator( + name="executorch_prim::eq", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_eq, + args=[iter_idx.id, sym_size.id, jump_bool_value.id], + ) + ) + self.chain.instructions.append(kernel) + + # Jump back to loop start if not done + jf_beginning_loop = Instruction( + JumpFalseCall( + cond_value_index=jump_bool_value.id, + destination_instruction=jump_to_instruction, + ) + ) + self.chain.instructions.append(jf_beginning_loop) + + # === CLEANUP === + + # Reset iter_idx for potential re-runs of the model + op_index_sub, _ = self._get_operator( + name="executorch_prim::sub", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_sub, + args=[iter_idx.id, sym_size.id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + return subemitter_binding_output_values + def _emit_control_flow( self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _EmitterValue: """Wraps common logic for emitting all control flow operations. - See the more specific emission functions for more details on how cond or map get emitted. + See the more specific emission functions for more details on how cond, map, or scan get emitted. """ subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), @@ -960,6 +1223,8 @@ def _emit_control_flow( return self._emit_cond(args, subemitter_binding_output_values) elif target is torch.ops.higher_order.map_impl: return self._emit_map(args, subemitter_binding_output_values) + elif target is torch.ops.higher_order.scan: + return self._emit_scan(args, subemitter_binding_output_values) else: raise InternalError( self._emit_node_specific_error( @@ -1511,6 +1776,7 @@ def call_function( # pyre-fixme[14] torch.ops.higher_order.cond, torch.ops.higher_order.map_impl, torch.ops.higher_order.while_loop, + torch.ops.higher_order.scan, ): return self._emit_control_flow(target, args, kwargs) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 165bc2951f7..52026ab612b 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -812,6 +812,198 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: outputs = loaded_model(inputs)[0] torch.allclose(outputs, f(*inputs)) + def test_emit_scan_basic(self) -> None: + """Test basic scan emission: verifies instruction structure for cumulative sum.""" + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + f = ScanCumSum() + inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + program = module.to_executorch().executorch_program + + op_table = program.execution_plan[0].operators + instructions = program.execution_plan[0].chains[0].instructions + + # Verify the instruction structure for scan: + # 1. First instruction should be sym_size to get scan length + self.assertEqual( + op_table[instructions[0].instr_args.op_index].name, + "aten::sym_size", + ) + + # 2. Should have copy_ instructions to initialize carry from init + copy_found = False + for instr in instructions: + if hasattr(instr.instr_args, "op_index"): + op_name = op_table[instr.instr_args.op_index].name + if op_name == "aten::copy_": + copy_found = True + break + self.assertTrue(copy_found, "Should have aten::copy_ for carry initialization") + + # 3. Should have select_copy to slice xs + select_copy_found = False + for instr in instructions: + if hasattr(instr.instr_args, "op_index"): + op_name = op_table[instr.instr_args.op_index].name + if op_name == "aten::select_copy": + select_copy_found = True + break + self.assertTrue(select_copy_found, "Should have select_copy for xs slicing") + + # 4. Should have et_copy_index to accumulate y outputs + et_copy_index_found = False + for instr in instructions: + if hasattr(instr.instr_args, "op_index"): + op_name = op_table[instr.instr_args.op_index].name + if op_name == "executorch_prim::et_copy_index": + et_copy_index_found = True + break + self.assertTrue( + et_copy_index_found, "Should have et_copy_index for y accumulation" + ) + + # 5. Loop control: should have add, eq for iteration control + add_found = False + eq_found = False + for instr in instructions: + if hasattr(instr.instr_args, "op_index"): + op_name = op_table[instr.instr_args.op_index].name + if op_name == "executorch_prim::add": + add_found = True + if op_name == "executorch_prim::eq": + eq_found = True + self.assertTrue( + add_found, "Should have executorch_prim::add for iter increment" + ) + self.assertTrue( + eq_found, "Should have executorch_prim::eq for completion check" + ) + + # 6. Should have JumpFalseCall for loop back + jump_false_found = False + for instr in instructions: + if isinstance(instr.instr_args, JumpFalseCall): + jump_false_found = True + break + self.assertTrue(jump_false_found, "Should have JumpFalseCall for loop control") + + # 7. Last instruction should be sub to reset iter_idx + self.assertEqual( + op_table[instructions[-1].instr_args.op_index].name, + "executorch_prim::sub", + ) + + def test_load_emit_scan(self) -> None: + """Test that scan program can be loaded by the runtime.""" + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + f = ScanCumSum() + inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + # This should not raise - verifies the program is loadable + _load_for_executorch_from_buffer(module.to_executorch().buffer) + + def test_run_emit_scan_cumsum(self) -> None: + """Test scan execution correctness: cumulative sum.""" + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + f = ScanCumSum() + inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + buffer = module.to_executorch().buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + # Run through executorch + et_outputs = loaded_model(inputs) + + # Run through eager PyTorch + eager_outputs = f(*inputs) + + # Compare final carry + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0]), + f"Final carry mismatch: {et_outputs[0]} vs {eager_outputs[0]}", + ) + + # Compare stacked outputs + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1]), + f"Stacked outputs mismatch: {et_outputs[1]} vs {eager_outputs[1]}", + ) + + def test_emit_scan_add_mul(self) -> None: + """Test scan with add operation in combine_fn.""" + + class ScanAddMul(torch.nn.Module): + def forward( + self, xs: torch.Tensor, y: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + # Multiply carry by 2 and add x + new_carry = carry * 2 + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + f = ScanAddMul() + inputs = (torch.ones(4, 3), torch.ones(3)) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + program = module.to_executorch().executorch_program + + # Verify we have the expected operators + op_names = [op.name for op in program.execution_plan[0].operators] + + # Should have mul and add operations from the combine_fn body + self.assertIn("aten::mul", op_names) + self.assertIn("aten::add", op_names) + + # Should have scan control flow operators + self.assertIn("aten::sym_size", op_names) + self.assertIn("aten::select_copy", op_names) + self.assertIn("executorch_prim::et_copy_index", op_names) + def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/exir/graph_module.py b/exir/graph_module.py index 2adf62ab0b8..cc85a596bfb 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -78,14 +78,18 @@ def get_control_flow_submodules( ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: """ Returns a list of submodules used for control flow operations - (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + (torch.ops.higher_order.cond/map/scan) that are in the given toplevel graph (does not look into submodules). Specifically, the returned value is a list containing tuples of (name of the submodule that's stored in the graph module, the submodule itself, and the fx node that uses this submodule). """ return _get_control_flow_submodules( graph_module, - {torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]}, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.map_impl: [0], + torch.ops.higher_order.scan: [0], # combine_fn is at arg index 0 + }, ) @@ -108,6 +112,27 @@ def get_cond_while_submodules( ) +def get_scan_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for scan operations + (torch.ops.higher_order.scan) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + + For scan, the combine_fn submodule is at argument index 0. + The scan operator signature is: scan(combine_fn, init, xs, additional_inputs) + """ + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.scan: [0], + }, + ) + + def bfs_trace_with_node_process( gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] ) -> None: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 0394ed9c529..a413e5f1b60 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -1023,6 +1023,17 @@ def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: yield nd +def get_scan_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: + """Get all scan nodes in the graph module. + + Scan nodes have the signature: scan(combine_fn, init, xs, additional_inputs) + where combine_fn is a submodule at args[0]. + """ + for nd in graph_module.graph.nodes: + if nd.target is torch.ops.higher_order.scan: + yield nd + + def get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: return_specs = set() nodes = graph_module.graph.nodes @@ -1125,7 +1136,7 @@ def _apply_algo_to_submodules( alignment: int, graph_signature: Optional[ExportGraphSignature] = None, ) -> list[int]: - """Apply algo to map/cond/while nodes in the graph module. + """Apply algo to map/cond/while/scan nodes in the graph module. This method will popuate graph_module.meta["non_const_buffer_sizes"] for all submodules and return a bufsizes list that is the maximum size of all @@ -1161,6 +1172,15 @@ def _handle( for map_node in get_map_nodes(graph_module): _handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True) + # Handle scan nodes + # Scan signature: scan(combine_fn, init, xs, additional_inputs) + # combine_fn is at args[0] + # Like map, scan needs alloc_graph_input=True because the runtime slices + # xs tensors during each iteration, requiring allocated input buffers. + # Additionally, scan has carry state that flows between iterations. + for scan_node in get_scan_nodes(graph_module): + _handle(cast(torch.fx.Node, scan_node.args[0]), alloc_graph_input=True) + # TODO: We can handle delegates the same way as map/cond/while. # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates. diff --git a/exir/pass_base.py b/exir/pass_base.py index 497970fae34..bbf6bbd0368 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -348,6 +348,11 @@ def call_function( elif target == torch.ops.higher_order.map_impl: f, mapped_args, operands = args # type: ignore[assignment] return self.callback.call_map(f, mapped_args, operands, meta) + elif target == torch.ops.higher_order.scan: + combine_fn, init, xs, additional_inputs = args # type: ignore[assignment] + return self.callback.call_scan( + combine_fn, init, xs, additional_inputs, meta + ) # For other unregistered HigherOrderOps, just interpret them blindly elif isinstance(target, torch._ops.HigherOrderOperator): return self.callback._fx( @@ -545,6 +550,50 @@ def call_map( meta, ) + def call_scan( + self, + combine_fn: torch.fx.GraphModule, + init: List[ProxyValue], + xs: List[ProxyValue], + additional_inputs: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + """ + Process a scan higher-order operation. + + Scan applies combine_fn iteratively, carrying state across iterations: + combine_fn(carry, x_slice) -> (next_carry, y_slice) + + Args: + combine_fn: GraphModule implementing the scan body + init: Initial carry state values + xs: Input tensors to scan over (along dim 0) + additional_inputs: Additional arguments passed to combine_fn + meta: Node metadata + + Returns: + ProxyValue containing (final_carry, stacked_outputs) + """ + # Get the first slice of xs to determine input shapes for combine_fn + # combine_fn inputs: (*init, *xs_slice, *additional_inputs) + xs_first_slice = _unstack_pytree([arg.data for arg in xs])[0] + init_data = [arg.data for arg in init] + additional_data = [arg.data for arg in additional_inputs] + + # Call submodule with representative inputs + combine_fn_result = self.call_submodule( + combine_fn, tuple(init_data + xs_first_slice + additional_data) + ) + assert combine_fn_result is not None + + return self._fx( + "call_function", + torch.ops.higher_order.scan, + (combine_fn_result.graph_module, init, xs, additional_inputs), + {}, + meta, + ) + def call_getitem( self, value: ProxyValue, key: int, meta: NodeMetadata ) -> ProxyValue: diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index ab5367d1b20..33e24af6532 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -142,6 +142,74 @@ def call_map( meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) return super().call_map(f, mapped_args, operands, meta) + def call_scan( + self, + combine_fn: torch.fx.GraphModule, + init: List[ProxyValue], + xs: List[ProxyValue], + additional_inputs: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + """ + Propagate specs for scan higher-order operation. + + Scan returns (final_carry, stacked_outputs) where: + - final_carry: Same shape as init (NOT stacked, just the final carry state) + - stacked_outputs: Outputs stacked along dim 0 with scan_length + + The combine_fn signature is: + combine_fn(*init, *xs_slice, *additional_inputs) -> (*next_carry, *y_slice) + + So the combine_fn outputs are split into: + - First len(init) outputs: carry values (same shape as init) + - Remaining outputs: y values (to be stacked) + + Memory Layout Note: + The specs created here are for the FINAL outputs of the scan operation: + - carry specs: Working carry buffers that persist across iterations. + These are SEPARATE from combine_fn's output buffers. The emitter + must copy from combine_fn's temporary carry output to these buffers + after each iteration (in-place op.out(x, out=x) is unsafe). + - y specs: Pre-allocated stacked buffers filled via et_copy_index. + + The combine_fn's internal temporary buffers are allocated separately + via memory planning with alloc_graph_input=True, alloc_graph_output=True. + """ + # Get scan length from first xs tensor + scan_length = [arg.data for arg in xs][0].size(0) + + # Get the output node from combine_fn + *_, body_out_node = combine_fn.graph.nodes + body_out_fake = body_out_node.meta["val"] + + # The combine_fn outputs are: (*next_carry, *y_slice) + # Split them based on the number of init values + num_carry = len(init) + + # Flatten the outputs to handle them uniformly + flat_body_out, out_spec = pytree.tree_flatten(body_out_fake) + + # Split into carry outputs and y outputs + carry_out = flat_body_out[:num_carry] + y_out = flat_body_out[num_carry:] + + # Create specs: + # - Carry: same shape as combine_fn output (NOT stacked) + # These are working buffers that get updated each iteration + # - Y: stacked along dim 0 with scan_length + carry_fake = carry_out # Carry keeps same shape + + y_fake = [ + x.new_empty(scan_length, *x.shape) if isinstance(x, torch.Tensor) else x + for x in y_out + ] + + # Combine carry and stacked y outputs + combined_fake = carry_fake + y_fake + + meta["spec"] = pytree.tree_map(make_spec, combined_fake) + return super().call_scan(combine_fn, init, xs, additional_inputs, meta) + # pyre-ignore def call_delegate(self, lowered_module, args, kwargs, meta): args_data, kwargs_data = pytree.tree_map_only( diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index 3c0fd8badab..82bc5ffb30d 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -103,3 +103,67 @@ def get_upper_bound_inputs(self): def get_random_inputs(self): return torch.rand(2, 4), torch.rand(4) + + +class FTScanBasic(Module): + """Basic scan model that computes cumulative sum.""" + + def __init__(self): + super().__init__() + + def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + y = new_carry.clone() + return new_carry, y + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + def get_random_inputs(self): + return (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + + +class FTScanMultipleCarry(Module): + """Scan model with multiple carry values (sum and product).""" + + def __init__(self): + super().__init__() + + def forward( + self, xs: torch.Tensor + ) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + def combine_fn(carry, x): + sum_carry, prod_carry = carry + new_sum = sum_carry + x + new_prod = prod_carry * x + y = new_sum + new_prod + return (new_sum, new_prod), y.clone() + + init_sum = torch.zeros_like(xs[0]) + init_prod = torch.ones_like(xs[0]) + return torch.scan(combine_fn, (init_sum, init_prod), xs) + + def get_random_inputs(self): + return (torch.arange(1, 5).float().unsqueeze(1).expand(4, 2),) + + +class FTScanWithAdditionalInputs(Module): + """Scan model with additional inputs (closure-like behavior).""" + + def __init__(self): + super().__init__() + + def forward( + self, xs: torch.Tensor, scale: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + # Scale is captured from outer scope + new_carry = carry + x * scale + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return torch.scan(combine_fn, init, xs) + + def get_random_inputs(self): + return (torch.arange(5).float().unsqueeze(1).expand(5, 3), torch.tensor([2.0])) From 646d1ce224db35f44c12ca1c53f93094df59fe4d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 1 Dec 2025 15:22:45 -0800 Subject: [PATCH 02/19] make it work --- exir/emit/_emitter.py | 39 +++++++----- exir/emit/test/test_emit.py | 114 ++++++++++++++++-------------------- exir/passes/__init__.py | 5 ++ 3 files changed, 80 insertions(+), 78 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index b9bc431ccfd..3fbc3b968d6 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1022,15 +1022,18 @@ def _emit_scan( # Initialize carry_outputs from init by copying init -> carry_outputs # This is necessary because we shouldn't mutate the original init tensors - op_index_copy, _ = self._get_operator( - name="aten::copy_", - overload="default", - ) - for i, (init_val, carry_out) in enumerate(zip(init, carry_outputs)): + # Use aten::copy_.default which copies src to self in-place + op_index_copy, _ = self._get_operator(name="aten::copy_") + for init_val, carry_out in zip(init, carry_outputs): kernel = Instruction( KernelCall( op_index=op_index_copy, - args=[carry_out.id, init_val.id], + args=[ + carry_out.id, + init_val.id, + self._emit_evalue(EValue(Bool(False))).id, + carry_out.id, + ], ) ) self.chain.instructions.append(kernel) @@ -1083,8 +1086,9 @@ def _emit_scan( binding_input_values.extend(additional_inputs) # Additional inputs # combine_fn outputs: (*next_carry, *y_slice) - # We don't bind outputs to the final destinations directly because we need - # to copy them explicitly (in-place is unsafe) + # Pass binding_output_values=None so the combine_fn writes directly to its + # own output buffers (concrete_output_ids). We then copy from these directly + # to the final carry/y buffers, avoiding unnecessary temp buffers and MOVEs. scan_emitter = _Emitter( combine_fn, self.emitter_state, @@ -1092,14 +1096,14 @@ def _emit_scan( instruction_start_offset=self.instruction_start_offset + len(self.chain.instructions), binding_input_values=binding_input_values, - binding_output_values=None, # Let combine_fn use its own output buffers + binding_output_values=None, # Use concrete outputs directly ) scan_emitter.run() # Merge combine_fn instructions self._merge_chain(scan_emitter.chain) - # Remove the return instruction from combine_fn - self.chain.instructions.pop() + # NOTE: When binding_output_values=None, no return/move instruction is added + # by the output() method, so we don't need to pop anything. # Update xs_slice instructions with the actual placeholder EValue ids # The xs placeholders start after the carry inputs in combine_fn @@ -1111,7 +1115,8 @@ def _emit_scan( # === COPY OUTPUTS === # Get combine_fn's actual output EValues - # concrete_output_ids contains: (*carry_temp, *y_temp) + # concrete_output_ids contains the actual EValues that the combine_fn + # graph operations write to: (*carry_temp, *y_temp) concrete_outputs = scan_emitter.concrete_output_ids carry_temp = concrete_outputs[:num_carry] y_temp = concrete_outputs[num_carry:] @@ -1129,11 +1134,17 @@ def _emit_scan( # Copy carry_temp -> carry_outputs for next iteration # This explicit copy is required because in-place op.out(x, out=x) is unsafe + # aten::copy_ signature: (self, src, non_blocking, out) -> self for carry_t, carry_out in zip(carry_temp, carry_outputs): kernel = Instruction( KernelCall( op_index=op_index_copy, - args=[carry_out.id, carry_t.id], + args=[ + carry_out.id, + carry_t.id, + self._emit_evalue(EValue(Bool(False))).id, + carry_out.id, + ], ) ) self.chain.instructions.append(kernel) @@ -1455,7 +1466,7 @@ def _emit_delegate( return delegate_ret - def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]: + def _get_operator(self, name: str, overload: str = "") -> Tuple[int, Operator]: """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it if it is not already present""" key = (name, overload) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 52026ab612b..af1231d021e 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -814,6 +814,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def test_emit_scan_basic(self) -> None: """Test basic scan emission: verifies instruction structure for cumulative sum.""" + from torch._higher_order_ops.scan import scan class ScanCumSum(torch.nn.Module): def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -822,10 +823,11 @@ def combine_fn(carry, x): return new_carry, new_carry.clone() init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) + return scan(combine_fn, init, xs) f = ScanCumSum() - inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + # Use contiguous tensor to avoid stride=0 issue + inputs = (torch.arange(15).float().reshape(5, 3),) module = to_edge( export(f, inputs, strict=True), @@ -836,63 +838,44 @@ def combine_fn(carry, x): op_table = program.execution_plan[0].operators instructions = program.execution_plan[0].chains[0].instructions - # Verify the instruction structure for scan: - # 1. First instruction should be sym_size to get scan length - self.assertEqual( - op_table[instructions[0].instr_args.op_index].name, - "aten::sym_size", + # Collect all operator names in the program + op_names = [op.name for op in op_table] + + # Verify the key operators are present for scan: + # 1. sym_size - to get scan length + self.assertIn( + "aten::sym_size", op_names, "Should have sym_size for scan length" ) - # 2. Should have copy_ instructions to initialize carry from init - copy_found = False - for instr in instructions: - if hasattr(instr.instr_args, "op_index"): - op_name = op_table[instr.instr_args.op_index].name - if op_name == "aten::copy_": - copy_found = True - break - self.assertTrue(copy_found, "Should have aten::copy_ for carry initialization") - - # 3. Should have select_copy to slice xs - select_copy_found = False - for instr in instructions: - if hasattr(instr.instr_args, "op_index"): - op_name = op_table[instr.instr_args.op_index].name - if op_name == "aten::select_copy": - select_copy_found = True - break - self.assertTrue(select_copy_found, "Should have select_copy for xs slicing") - - # 4. Should have et_copy_index to accumulate y outputs - et_copy_index_found = False - for instr in instructions: - if hasattr(instr.instr_args, "op_index"): - op_name = op_table[instr.instr_args.op_index].name - if op_name == "executorch_prim::et_copy_index": - et_copy_index_found = True - break - self.assertTrue( - et_copy_index_found, "Should have et_copy_index for y accumulation" + # 2. copy_ - for carry initialization and carry updates + self.assertIn("aten::copy_", op_names, "Should have copy_ for carry handling") + + # 3. select_copy - to slice xs + self.assertIn( + "aten::select_copy", op_names, "Should have select_copy for xs slicing" ) - # 5. Loop control: should have add, eq for iteration control - add_found = False - eq_found = False - for instr in instructions: - if hasattr(instr.instr_args, "op_index"): - op_name = op_table[instr.instr_args.op_index].name - if op_name == "executorch_prim::add": - add_found = True - if op_name == "executorch_prim::eq": - eq_found = True - self.assertTrue( - add_found, "Should have executorch_prim::add for iter increment" + # 4. et_copy_index - to accumulate y outputs + self.assertIn( + "executorch_prim::et_copy_index", + op_names, + "Should have et_copy_index for y accumulation", ) - self.assertTrue( - eq_found, "Should have executorch_prim::eq for completion check" + + # 5. Loop control: add, eq for iteration control + self.assertIn( + "executorch_prim::add", op_names, "Should have add for iter increment" + ) + self.assertIn( + "executorch_prim::eq", op_names, "Should have eq for completion check" + ) + + # 6. sub - to reset iter_idx for re-runs + self.assertIn( + "executorch_prim::sub", op_names, "Should have sub to reset iterator" ) - # 6. Should have JumpFalseCall for loop back + # 7. Should have JumpFalseCall for loop back jump_false_found = False for instr in instructions: if isinstance(instr.instr_args, JumpFalseCall): @@ -900,14 +883,12 @@ def combine_fn(carry, x): break self.assertTrue(jump_false_found, "Should have JumpFalseCall for loop control") - # 7. Last instruction should be sub to reset iter_idx - self.assertEqual( - op_table[instructions[-1].instr_args.op_index].name, - "executorch_prim::sub", - ) + # 8. Verify we have the body operations (add from combine_fn) + self.assertIn("aten::add", op_names, "Should have add from combine_fn body") def test_load_emit_scan(self) -> None: """Test that scan program can be loaded by the runtime.""" + from torch._higher_order_ops.scan import scan class ScanCumSum(torch.nn.Module): def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -916,10 +897,11 @@ def combine_fn(carry, x): return new_carry, new_carry.clone() init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) + return scan(combine_fn, init, xs) f = ScanCumSum() - inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + # Use contiguous tensor to avoid stride=0 issue + inputs = (torch.arange(15).float().reshape(5, 3),) module = to_edge( export(f, inputs, strict=True), @@ -930,6 +912,7 @@ def combine_fn(carry, x): def test_run_emit_scan_cumsum(self) -> None: """Test scan execution correctness: cumulative sum.""" + from torch._higher_order_ops.scan import scan class ScanCumSum(torch.nn.Module): def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -938,16 +921,19 @@ def combine_fn(carry, x): return new_carry, new_carry.clone() init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) + return scan(combine_fn, init, xs) f = ScanCumSum() - inputs = (torch.arange(5).float().unsqueeze(1).expand(5, 3),) + # Use contiguous tensor to avoid stride=0 issue + inputs = (torch.arange(15).float().reshape(5, 3),) module = to_edge( export(f, inputs, strict=True), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) - buffer = module.to_executorch().buffer + et = module.to_executorch() + et.dump_executorch_program(False) + buffer = et.buffer loaded_model = _load_for_executorch_from_buffer(buffer) # Run through executorch @@ -970,6 +956,7 @@ def combine_fn(carry, x): def test_emit_scan_add_mul(self) -> None: """Test scan with add operation in combine_fn.""" + from torch._higher_order_ops.scan import scan class ScanAddMul(torch.nn.Module): def forward( @@ -981,7 +968,7 @@ def combine_fn(carry, x): return new_carry, new_carry.clone() init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) + return scan(combine_fn, init, xs) f = ScanAddMul() inputs = (torch.ones(4, 3), torch.ones(3)) @@ -1792,7 +1779,6 @@ def forward(self, x): model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True)) model = model.to_executorch() - model.dump_executorch_program(True) self.assertTrue( model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 5c6eb63db46..cdc1c22a9d5 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -344,6 +344,11 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue + elif target == torch.ops.higher_order.scan: + # scan(combine_fn, init, xs, additional_inputs) + # combine_fn is at args[0] + self.call(get_submodule(node.args[0])) + continue elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: From c1beb3d20487ab1d792eda4ebb99d3c113f9ff96 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 1 Dec 2025 15:24:10 -0800 Subject: [PATCH 03/19] lint --- exir/emit/_emitter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 3fbc3b968d6..4005fc073ee 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -992,7 +992,6 @@ def _emit_scan( num_carry = len(init) num_xs = len(xs) - num_additional = len(additional_inputs) # Split output values into carry outputs and y outputs carry_outputs = list(subemitter_binding_output_values[:num_carry]) @@ -1048,7 +1047,7 @@ def _emit_scan( overload="int_out", ) xs_slice_instructions = [] - for i, x in enumerate(xs): + for x in xs: kernel = Instruction( KernelCall( op_index=op_index_select, From d1473c7c73cf4f7c23380cbc8a861eeaa8256c8b Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 1 Dec 2025 15:25:01 -0800 Subject: [PATCH 04/19] remove useless test --- exir/emit/test/test_emit.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index af1231d021e..2f55091e0e3 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -886,30 +886,6 @@ def combine_fn(carry, x): # 8. Verify we have the body operations (add from combine_fn) self.assertIn("aten::add", op_names, "Should have add from combine_fn body") - def test_load_emit_scan(self) -> None: - """Test that scan program can be loaded by the runtime.""" - from torch._higher_order_ops.scan import scan - - class ScanCumSum(torch.nn.Module): - def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - def combine_fn(carry, x): - new_carry = carry + x - return new_carry, new_carry.clone() - - init = torch.zeros_like(xs[0]) - return scan(combine_fn, init, xs) - - f = ScanCumSum() - # Use contiguous tensor to avoid stride=0 issue - inputs = (torch.arange(15).float().reshape(5, 3),) - - module = to_edge( - export(f, inputs, strict=True), - compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), - ) - # This should not raise - verifies the program is loadable - _load_for_executorch_from_buffer(module.to_executorch().buffer) - def test_run_emit_scan_cumsum(self) -> None: """Test scan execution correctness: cumulative sum.""" from torch._higher_order_ops.scan import scan From d563028610127e41c1712d7340ae9bed86b4dd46 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 2 Dec 2025 11:06:23 -0800 Subject: [PATCH 05/19] test --- exir/emit/test/test_emit.py | 110 ++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 2f55091e0e3..6c05640269c 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -967,6 +967,116 @@ def combine_fn(carry, x): self.assertIn("aten::select_copy", op_names) self.assertIn("executorch_prim::et_copy_index", op_names) + def test_emit_scan_gru(self) -> None: + """Test scan with a simple GRU-like computation.""" + from torch._higher_order_ops.scan import scan + + class SimpleGRU(torch.nn.Module): + """Simple single-layer unidirectional GRU using scan.""" + + def __init__(self, input_size: int, hidden_size: int): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + # GRU gates: reset, update, new + self.weight_ih = torch.nn.Parameter( + torch.randn(3 * hidden_size, input_size), requires_grad=False + ) + self.weight_hh = torch.nn.Parameter( + torch.randn(3 * hidden_size, hidden_size), requires_grad=False + ) + self.bias_ih = torch.nn.Parameter( + torch.randn(3 * hidden_size), requires_grad=False + ) + self.bias_hh = torch.nn.Parameter( + torch.randn(3 * hidden_size), requires_grad=False + ) + + def forward( + self, x: torch.Tensor, h0: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: Input tensor of shape [seq_len, batch, input_size] + h0: Initial hidden state of shape [batch, hidden_size] + Returns: + output: Output tensor of shape [seq_len, batch, hidden_size] + h_n: Final hidden state of shape [batch, hidden_size] + """ + weight_ih = self.weight_ih + weight_hh = self.weight_hh + bias_ih = self.bias_ih + bias_hh = self.bias_hh + + def gru_cell( + h: torch.Tensor, x_t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute gates + gates_ih = torch.nn.functional.linear(x_t, weight_ih, bias_ih) + gates_hh = torch.nn.functional.linear(h, weight_hh, bias_hh) + + # Split into reset, update, new gates + r_ih, z_ih, n_ih = gates_ih.chunk(3, dim=-1) + r_hh, z_hh, n_hh = gates_hh.chunk(3, dim=-1) + + r = torch.sigmoid(r_ih + r_hh) + z = torch.sigmoid(z_ih + z_hh) + n = torch.tanh(n_ih + r * n_hh) + + h_new = (1 - z) * n + z * h + return h_new, h_new.clone() + + final_h, outputs = scan(gru_cell, h0, x) + return outputs, final_h + + # Create model and inputs + input_size = 4 + hidden_size = 8 + seq_len = 5 + batch_size = 2 + + model = SimpleGRU(input_size, hidden_size) + x = torch.randn(seq_len, batch_size, input_size) + h0 = torch.randn(batch_size, hidden_size) + inputs = (x, h0) + + # Run through eager PyTorch + eager_outputs = model(*inputs) + + # Export and convert to edge + module = to_edge( + export(model, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et = module.to_executorch() + program = et.executorch_program + + # Verify the program has expected operators + op_names = [op.name for op in program.execution_plan[0].operators] + + # Should have scan control flow operators + self.assertIn("aten::sym_size", op_names) + self.assertIn("aten::select_copy", op_names) + self.assertIn("executorch_prim::et_copy_index", op_names) + + # Verify we can load the program + buffer = et.buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + # Run through executorch + et_outputs = loaded_model(inputs) + + # Compare outputs (with tolerance for floating point) + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0], atol=1e-5), + f"Output mismatch: {et_outputs[0]} vs {eager_outputs[0]}", + ) + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1], atol=1e-5), + f"Final hidden state mismatch: {et_outputs[1]} vs {eager_outputs[1]}", + ) + def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: From bfd7f369f681ec129f45200341d63fcddb766cc3 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 5 Dec 2025 11:05:11 -0800 Subject: [PATCH 06/19] clean up --- exir/emit/_emitter.py | 65 ++++++------------------------- exir/memory_planning.py | 11 ------ exir/pass_base.py | 19 --------- exir/passes/spec_prop_pass.py | 40 +------------------ exir/tests/control_flow_models.py | 1 - 5 files changed, 13 insertions(+), 123 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 4005fc073ee..780c0d7e490 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -993,19 +993,14 @@ def _emit_scan( num_carry = len(init) num_xs = len(xs) - # Split output values into carry outputs and y outputs carry_outputs = list(subemitter_binding_output_values[:num_carry]) y_outputs = list(subemitter_binding_output_values[num_carry:]) if num_xs < 1: raise RuntimeError("Scan requires at least one xs tensor to scan over.") - # === INITIALIZATION === - - # Generate iterator index EValue iter_idx = self._emit_evalue(EValue(Int(0))) - # Get scan length from first xs tensor op_index, op = self._get_operator( name="aten::sym_size", overload="int", @@ -1019,9 +1014,7 @@ def _emit_scan( ) self.chain.instructions.append(kernel) - # Initialize carry_outputs from init by copying init -> carry_outputs - # This is necessary because we shouldn't mutate the original init tensors - # Use aten::copy_.default which copies src to self in-place + # Initialize carry_outputs from init op_index_copy, _ = self._get_operator(name="aten::copy_") for init_val, carry_out in zip(init, carry_outputs): kernel = Instruction( @@ -1037,11 +1030,7 @@ def _emit_scan( ) self.chain.instructions.append(kernel) - # === LOOP START === - # Slice each xs tensor for the current iteration - # We use -1 as placeholder for the output tensor id, which will be filled - # after the scan_emitter runs and allocates the input placeholder EValues op_index_select, _ = self._get_operator( name="aten::select_copy", overload="int_out", @@ -1053,41 +1042,28 @@ def _emit_scan( op_index=op_index_select, args=[ x.id, - self._emit_evalue(EValue(Int(0))).id, # dim=0 + self._emit_evalue(EValue(Int(0))).id, iter_idx.id, - -1, # placeholder for output tensor id - -1, # placeholder (repeated for out variant) + -1, + -1, ], ) ) xs_slice_instructions.append(kernel) - # Store jump target - this is where we jump back to after each iteration jump_to_instruction = self.instruction_start_offset + len( self.chain.instructions ) - # Add all xs slice instructions for kernel in xs_slice_instructions: self.chain.instructions.append(kernel) - # === EMIT COMBINE_FN SUBMODULE === - - # combine_fn inputs: (*carry, *xs_slice, *additional_inputs) - # We bind carry inputs to carry_outputs (the working carry buffers) - # xs_slice inputs will be filled in after emitter runs (using -1 placeholder) - # additional_inputs are passed through directly + # Emit combine_fn submodule binding_input_values: List[Any] = [] - binding_input_values.extend( - carry_outputs - ) # Carry inputs bound to carry_outputs - binding_input_values.extend([-1] * num_xs) # Placeholders for xs slices - binding_input_values.extend(additional_inputs) # Additional inputs - - # combine_fn outputs: (*next_carry, *y_slice) - # Pass binding_output_values=None so the combine_fn writes directly to its - # own output buffers (concrete_output_ids). We then copy from these directly - # to the final carry/y buffers, avoiding unnecessary temp buffers and MOVEs. + binding_input_values.extend(carry_outputs) + binding_input_values.extend([-1] * num_xs) + binding_input_values.extend(additional_inputs) + scan_emitter = _Emitter( combine_fn, self.emitter_state, @@ -1095,27 +1071,17 @@ def _emit_scan( instruction_start_offset=self.instruction_start_offset + len(self.chain.instructions), binding_input_values=binding_input_values, - binding_output_values=None, # Use concrete outputs directly + binding_output_values=None, ) scan_emitter.run() - # Merge combine_fn instructions self._merge_chain(scan_emitter.chain) - # NOTE: When binding_output_values=None, no return/move instruction is added - # by the output() method, so we don't need to pop anything. - # Update xs_slice instructions with the actual placeholder EValue ids - # The xs placeholders start after the carry inputs in combine_fn for i, kernel in enumerate(xs_slice_instructions): xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id kernel.instr_args.args[-1] = xs_placeholder_id kernel.instr_args.args[-2] = xs_placeholder_id - # === COPY OUTPUTS === - - # Get combine_fn's actual output EValues - # concrete_output_ids contains the actual EValues that the combine_fn - # graph operations write to: (*carry_temp, *y_temp) concrete_outputs = scan_emitter.concrete_output_ids carry_temp = concrete_outputs[:num_carry] y_temp = concrete_outputs[num_carry:] @@ -1132,8 +1098,6 @@ def _emit_scan( ) # Copy carry_temp -> carry_outputs for next iteration - # This explicit copy is required because in-place op.out(x, out=x) is unsafe - # aten::copy_ signature: (self, src, non_blocking, out) -> self for carry_t, carry_out in zip(carry_temp, carry_outputs): kernel = Instruction( KernelCall( @@ -1148,7 +1112,7 @@ def _emit_scan( ) self.chain.instructions.append(kernel) - # Copy y_temp to stacked y_outputs using et_copy_index + # Copy y_temp to stacked y_outputs op_index_copy_index, _ = self._get_operator( name="executorch_prim::et_copy_index", overload="tensor", @@ -1162,8 +1126,6 @@ def _emit_scan( ) self.chain.instructions.append(kernel) - # === LOOP CONTROL === - # Increment iter_idx op_index_add, _ = self._get_operator( name="executorch_prim::add", @@ -1191,7 +1153,6 @@ def _emit_scan( ) self.chain.instructions.append(kernel) - # Jump back to loop start if not done jf_beginning_loop = Instruction( JumpFalseCall( cond_value_index=jump_bool_value.id, @@ -1200,9 +1161,7 @@ def _emit_scan( ) self.chain.instructions.append(jf_beginning_loop) - # === CLEANUP === - - # Reset iter_idx for potential re-runs of the model + # Reset iter_idx for potential re-runs op_index_sub, _ = self._get_operator( name="executorch_prim::sub", overload="Scalar", diff --git a/exir/memory_planning.py b/exir/memory_planning.py index a413e5f1b60..ad6b01e4c9d 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -1024,11 +1024,6 @@ def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: def get_scan_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: - """Get all scan nodes in the graph module. - - Scan nodes have the signature: scan(combine_fn, init, xs, additional_inputs) - where combine_fn is a submodule at args[0]. - """ for nd in graph_module.graph.nodes: if nd.target is torch.ops.higher_order.scan: yield nd @@ -1172,12 +1167,6 @@ def _handle( for map_node in get_map_nodes(graph_module): _handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True) - # Handle scan nodes - # Scan signature: scan(combine_fn, init, xs, additional_inputs) - # combine_fn is at args[0] - # Like map, scan needs alloc_graph_input=True because the runtime slices - # xs tensors during each iteration, requiring allocated input buffers. - # Additionally, scan has carry state that flows between iterations. for scan_node in get_scan_nodes(graph_module): _handle(cast(torch.fx.Node, scan_node.args[0]), alloc_graph_input=True) diff --git a/exir/pass_base.py b/exir/pass_base.py index bbf6bbd0368..17feadd0a43 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -558,29 +558,10 @@ def call_scan( additional_inputs: List[ProxyValue], meta: NodeMetadata, ) -> ProxyValue: - """ - Process a scan higher-order operation. - - Scan applies combine_fn iteratively, carrying state across iterations: - combine_fn(carry, x_slice) -> (next_carry, y_slice) - - Args: - combine_fn: GraphModule implementing the scan body - init: Initial carry state values - xs: Input tensors to scan over (along dim 0) - additional_inputs: Additional arguments passed to combine_fn - meta: Node metadata - - Returns: - ProxyValue containing (final_carry, stacked_outputs) - """ - # Get the first slice of xs to determine input shapes for combine_fn - # combine_fn inputs: (*init, *xs_slice, *additional_inputs) xs_first_slice = _unstack_pytree([arg.data for arg in xs])[0] init_data = [arg.data for arg in init] additional_data = [arg.data for arg in additional_inputs] - # Call submodule with representative inputs combine_fn_result = self.call_submodule( combine_fn, tuple(init_data + xs_first_slice + additional_data) ) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 33e24af6532..efd2fe44a74 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -150,61 +150,23 @@ def call_scan( additional_inputs: List[ProxyValue], meta: NodeMetadata, ) -> ProxyValue: - """ - Propagate specs for scan higher-order operation. - - Scan returns (final_carry, stacked_outputs) where: - - final_carry: Same shape as init (NOT stacked, just the final carry state) - - stacked_outputs: Outputs stacked along dim 0 with scan_length - - The combine_fn signature is: - combine_fn(*init, *xs_slice, *additional_inputs) -> (*next_carry, *y_slice) - - So the combine_fn outputs are split into: - - First len(init) outputs: carry values (same shape as init) - - Remaining outputs: y values (to be stacked) - - Memory Layout Note: - The specs created here are for the FINAL outputs of the scan operation: - - carry specs: Working carry buffers that persist across iterations. - These are SEPARATE from combine_fn's output buffers. The emitter - must copy from combine_fn's temporary carry output to these buffers - after each iteration (in-place op.out(x, out=x) is unsafe). - - y specs: Pre-allocated stacked buffers filled via et_copy_index. - - The combine_fn's internal temporary buffers are allocated separately - via memory planning with alloc_graph_input=True, alloc_graph_output=True. - """ - # Get scan length from first xs tensor scan_length = [arg.data for arg in xs][0].size(0) - # Get the output node from combine_fn *_, body_out_node = combine_fn.graph.nodes body_out_fake = body_out_node.meta["val"] - # The combine_fn outputs are: (*next_carry, *y_slice) - # Split them based on the number of init values num_carry = len(init) - - # Flatten the outputs to handle them uniformly flat_body_out, out_spec = pytree.tree_flatten(body_out_fake) - # Split into carry outputs and y outputs carry_out = flat_body_out[:num_carry] y_out = flat_body_out[num_carry:] - # Create specs: - # - Carry: same shape as combine_fn output (NOT stacked) - # These are working buffers that get updated each iteration - # - Y: stacked along dim 0 with scan_length - carry_fake = carry_out # Carry keeps same shape - + carry_fake = carry_out y_fake = [ x.new_empty(scan_length, *x.shape) if isinstance(x, torch.Tensor) else x for x in y_out ] - # Combine carry and stacked y outputs combined_fake = carry_fake + y_fake meta["spec"] = pytree.tree_map(make_spec, combined_fake) diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index 82bc5ffb30d..edca61bd85b 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -158,7 +158,6 @@ def forward( self, xs: torch.Tensor, scale: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: def combine_fn(carry, x): - # Scale is captured from outer scope new_carry = carry + x * scale return new_carry, new_carry.clone() From 9b618dd569ac2c9d0fd50336cc6d90c5575bce48 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 5 Dec 2025 11:42:12 -0800 Subject: [PATCH 07/19] ditch useless unlift change --- exir/capture/_unlift.py | 43 ----------------------------------------- 1 file changed, 43 deletions(-) diff --git a/exir/capture/_unlift.py b/exir/capture/_unlift.py index 502dcc4ba3e..24c84494883 100644 --- a/exir/capture/_unlift.py +++ b/exir/capture/_unlift.py @@ -97,49 +97,6 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict): _unlift( body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict ) - if node.op == "call_function" and node.target.__name__ == "scan": - # scan signature: scan(combine_fn, init, xs, additional_inputs) - # - combine_fn: GraphModule for the scan body - # - init: list of initial carry tensors - # - xs: list of input tensors to scan over - # - additional_inputs: tuple of additional arguments (may contain lifted params/buffers) - combine_fn, init, xs, additional_inputs = node.args - combine_gm = getattr(gm, combine_fn.name) - inp_pos_to_buffer_name_for_submod = {} - real_additional_inputs = [] - - # additional_inputs may contain lifted parameters/buffers that need to be - # registered in the combine_fn submodule - for ix, operand in enumerate(additional_inputs): - if ( - hasattr(operand, "target") - and operand.target in inp_pos_to_param_buffer_name.values() - ): - # This is a lifted param/buffer, register it in the submodule - # The index needs to account for init and xs inputs to combine_fn - # combine_fn inputs: (*init, *xs_slice, *additional_inputs) - num_init = len(init) if isinstance(init, (list, tuple)) else 1 - num_xs = len(xs) if isinstance(xs, (list, tuple)) else 1 - adjusted_ix = num_init + num_xs + ix - inp_pos_to_buffer_name_for_submod[adjusted_ix] = operand.target - combine_gm.register_buffer( - operand.target, state_dict[operand.target] - ) - else: - real_additional_inputs.append(operand) - - # Update node args with the filtered additional_inputs - node.args = (combine_fn, init, xs, tuple(real_additional_inputs)) - - _, in_spec = pytree.tree_flatten((init, xs, tuple(real_additional_inputs))) - - _unlift( - combine_gm, - inp_pos_to_buffer_name_for_submod, - in_spec, - None, - state_dict, - ) gm.graph.lint() gm.graph.eliminate_dead_code() gm.recompile() From f96524493a267aa3cd7b0631d5d12da0b4bd9335 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 5 Dec 2025 14:28:23 -0800 Subject: [PATCH 08/19] fix dynamic shape --- exir/emit/_emitter.py | 27 ++++++++++++- exir/emit/test/test_emit.py | 64 +++++++++++++++++++++++++++--- exir/passes/spec_prop_pass.py | 46 +++++++++++++++++++-- kernels/prim_ops/et_copy_index.cpp | 26 ++++-------- 4 files changed, 134 insertions(+), 29 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 780c0d7e490..b8dbd03c1a0 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1183,9 +1183,34 @@ def _emit_control_flow( See the more specific emission functions for more details on how cond, map, or scan get emitted. """ + specs = self.node.meta["spec"] + + # For map and scan operations, the stacked output tensors are built up + # incrementally via et_copy_index. These tensors need to be marked as + # DYNAMIC_BOUND to allow resizing during execution, even if their shape + # appears static. + if target is torch.ops.higher_order.map_impl: + # Map stacks all outputs, so all output specs need DYNAMIC_BOUND + def mark_dynamic_bounded(spec: TensorSpec) -> TensorSpec: + spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + return spec + + specs = pytree.tree_map(mark_dynamic_bounded, specs) + + elif target is torch.ops.higher_order.scan: + # Scan has (carry_outputs, y_outputs). Only y_outputs are stacked + # and need DYNAMIC_BOUND. carry_outputs keep their original dynamism. + init = args[1] + num_carry = len(init) if isinstance(init, (list, tuple)) else 1 + + flat_specs, spec_tree = pytree.tree_flatten(specs) + for i in range(num_carry, len(flat_specs)): + flat_specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + specs = pytree.tree_unflatten(flat_specs, spec_tree) + subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), - self.node.meta["spec"], + specs, ) if target is torch.ops.higher_order.cond: diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 6c05640269c..3ed432c1872 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -62,10 +62,10 @@ _load_for_executorch_from_buffer, ) from executorch.runtime import Runtime - -from functorch.experimental import control_flow from torch import nn +from torch._higher_order_ops import cond as torch_cond, map as torch_map + from torch.export import Dim, export from torch.export.experimental import _export_forward_backward @@ -674,7 +674,7 @@ def true_fn(y: torch.Tensor) -> torch.Tensor: def false_fn(y: torch.Tensor) -> torch.Tensor: return torch.mm(y, y) - ret = control_flow.cond(pred, true_fn, false_fn, [x]) + ret = torch_cond(pred, true_fn, false_fn, [x]) return ret module = to_edge( @@ -708,7 +708,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -781,7 +781,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -798,7 +798,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -1077,6 +1077,58 @@ def gru_cell( f"Final hidden state mismatch: {et_outputs[1]} vs {eager_outputs[1]}", ) + def test_run_emit_scan_dynamic_shape(self) -> None: + """Test scan execution with dynamic sequence length.""" + from torch._higher_order_ops.scan import scan + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + f = ScanCumSum() + + upper_bound_inputs = (torch.arange(24).float().reshape(8, 3),) + + seq_dim = Dim("seq", max=8) + dynamic_shapes = {"xs": {0: seq_dim}} + + module = to_edge( + export(f, upper_bound_inputs, dynamic_shapes=dynamic_shapes, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et = module.to_executorch( + config=exir.ExecutorchBackendConfig( + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + et.dump_executorch_program(True) + buffer = et.buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + test_cases = [ + (torch.arange(15).float().reshape(5, 3),), # seq_len=5 + (torch.arange(9).float().reshape(3, 3),), # seq_len=3 + (torch.arange(21).float().reshape(7, 3),), # seq_len=7 + ] + + for test_inputs in test_cases: + et_outputs = loaded_model(test_inputs) + eager_outputs = f(*test_inputs) + + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0]), + f"Final carry mismatch for shape {test_inputs[0].shape}: {et_outputs[0]} vs {eager_outputs[0]}", + ) + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1]), + f"Stacked outputs mismatch for shape {test_inputs[0].shape}: {et_outputs[1]} vs {eager_outputs[1]}", + ) + def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index efd2fe44a74..78b1bcfc82e 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -11,6 +11,7 @@ import torch from executorch.exir.delegate import executorch_call_delegate from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.schema import TensorShapeDynamism from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature from torch.fx.node import Node @@ -134,9 +135,16 @@ def call_map( mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) *_, body_out_node = f.graph.nodes body_out_node_fake_tensor = body_out_node.meta["val"] + + # For dynamic shapes, initialize with size 0 in the mapped dimension. + # The et_copy_index op will resize as it writes to each index. + # Check if the mapped dimension is symbolic (dynamic). + is_dynamic = isinstance(mapped_dim_size, torch.SymInt) + init_size = 0 if is_dynamic else mapped_dim_size + map_fake_tensor = pytree.tree_map_only( torch.Tensor, - lambda x: x.new_empty(mapped_dim_size, *x.shape), + lambda x: x.new_empty(init_size, *x.shape), body_out_node_fake_tensor, ) meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) @@ -150,7 +158,9 @@ def call_scan( additional_inputs: List[ProxyValue], meta: NodeMetadata, ) -> ProxyValue: - scan_length = [arg.data for arg in xs][0].size(0) + # Get the scan length - this may be symbolic for dynamic shapes + xs_tensor = [arg.data for arg in xs][0] + scan_length = xs_tensor.size(0) *_, body_out_node = combine_fn.graph.nodes body_out_fake = body_out_node.meta["val"] @@ -161,15 +171,43 @@ def call_scan( carry_out = flat_body_out[:num_carry] y_out = flat_body_out[num_carry:] + # Check if the scan dimension is symbolic (dynamic) + is_dynamic = isinstance(scan_length, torch.SymInt) + + # For the y outputs, we need to use the upper bound size to allocate memory, + # but also mark the tensor spec as DYNAMIC_BOUND so it can be resized at runtime. + if is_dynamic: + # Get the upper bound by evaluating the symbolic int + # Using hint gives us the concrete upper bound value + upper_bound_size = scan_length.node.shape_env.size_hint( + scan_length.node.expr + ) + else: + upper_bound_size = scan_length + carry_fake = carry_out y_fake = [ - x.new_empty(scan_length, *x.shape) if isinstance(x, torch.Tensor) else x + ( + x.new_empty(upper_bound_size, *x.shape) + if isinstance(x, torch.Tensor) + else x + ) for x in y_out ] combined_fake = carry_fake + y_fake - meta["spec"] = pytree.tree_map(make_spec, combined_fake) + # Create specs from the fake tensors + specs = pytree.tree_map(make_spec, combined_fake) + + # For dynamic shapes, mark the y_output specs as DYNAMIC_BOUND + # so that et_copy_index can resize them at runtime + if is_dynamic and isinstance(specs, list): + for i in range(num_carry, len(specs)): + if isinstance(specs[i], TensorSpec): + specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + + meta["spec"] = specs return super().call_scan(combine_fn, init, xs, additional_inputs, meta) # pyre-ignore diff --git a/kernels/prim_ops/et_copy_index.cpp b/kernels/prim_ops/et_copy_index.cpp index 2ef076ad1a0..d81fd947ca4 100644 --- a/kernels/prim_ops/et_copy_index.cpp +++ b/kernels/prim_ops/et_copy_index.cpp @@ -114,27 +114,17 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { expected_output_size[i + 1] = copy_from.sizes()[i]; } - if (copy_to.sizes()[0] < expected_output_size[0]) { - // Resize `copy_to` to the expected output size. - const void* data_ptr = copy_to.const_data_ptr(); - Error err = - resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); - ET_CHECK(err == Error::Ok); - ET_KERNEL_CHECK_MSG( - context, - data_ptr == copy_to.const_data_ptr(), - InvalidState, - /* void */, - "Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors"); - } - - // After potential resize, verify that index is within bounds. + // Resize `copy_to` to the expected output size. This grows the tensor + // as we write to each index (0→1, 1→2, 2→3, etc.). + const void* data_ptr = copy_to.const_data_ptr(); + Error err = + resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); ET_KERNEL_CHECK_MSG( context, - index < copy_to.sizes()[0], - InvalidArgument, + err == Error::Ok, + InvalidState, /* void */, - "Index out of bounds"); + "Failed to resize copy_to tensor"); auto copy_to_ptr = copy_to.const_data_ptr(); auto copy_from_ptr = copy_from.const_data_ptr(); From 45eee873256fecd727af1b3eb02195cf0bdc491d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 9 Dec 2025 15:05:52 -0800 Subject: [PATCH 09/19] undo unneeded change --- exir/tests/control_flow_models.py | 63 ------------------------------- 1 file changed, 63 deletions(-) diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index edca61bd85b..3c0fd8badab 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -103,66 +103,3 @@ def get_upper_bound_inputs(self): def get_random_inputs(self): return torch.rand(2, 4), torch.rand(4) - - -class FTScanBasic(Module): - """Basic scan model that computes cumulative sum.""" - - def __init__(self): - super().__init__() - - def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - def combine_fn(carry, x): - new_carry = carry + x - y = new_carry.clone() - return new_carry, y - - init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) - - def get_random_inputs(self): - return (torch.arange(5).float().unsqueeze(1).expand(5, 3),) - - -class FTScanMultipleCarry(Module): - """Scan model with multiple carry values (sum and product).""" - - def __init__(self): - super().__init__() - - def forward( - self, xs: torch.Tensor - ) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: - def combine_fn(carry, x): - sum_carry, prod_carry = carry - new_sum = sum_carry + x - new_prod = prod_carry * x - y = new_sum + new_prod - return (new_sum, new_prod), y.clone() - - init_sum = torch.zeros_like(xs[0]) - init_prod = torch.ones_like(xs[0]) - return torch.scan(combine_fn, (init_sum, init_prod), xs) - - def get_random_inputs(self): - return (torch.arange(1, 5).float().unsqueeze(1).expand(4, 2),) - - -class FTScanWithAdditionalInputs(Module): - """Scan model with additional inputs (closure-like behavior).""" - - def __init__(self): - super().__init__() - - def forward( - self, xs: torch.Tensor, scale: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - def combine_fn(carry, x): - new_carry = carry + x * scale - return new_carry, new_carry.clone() - - init = torch.zeros_like(xs[0]) - return torch.scan(combine_fn, init, xs) - - def get_random_inputs(self): - return (torch.arange(5).float().unsqueeze(1).expand(5, 3), torch.tensor([2.0])) From 843b3d275663e0cb68f3529edb8c285c45833b81 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 9 Dec 2025 15:14:43 -0800 Subject: [PATCH 10/19] more clean up --- exir/emit/_emitter.py | 23 ---------------- kernels/prim_ops/test/prim_ops_test.cpp | 36 ------------------------- 2 files changed, 59 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index b8dbd03c1a0..86d6d369387 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1185,29 +1185,6 @@ def _emit_control_flow( """ specs = self.node.meta["spec"] - # For map and scan operations, the stacked output tensors are built up - # incrementally via et_copy_index. These tensors need to be marked as - # DYNAMIC_BOUND to allow resizing during execution, even if their shape - # appears static. - if target is torch.ops.higher_order.map_impl: - # Map stacks all outputs, so all output specs need DYNAMIC_BOUND - def mark_dynamic_bounded(spec: TensorSpec) -> TensorSpec: - spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND - return spec - - specs = pytree.tree_map(mark_dynamic_bounded, specs) - - elif target is torch.ops.higher_order.scan: - # Scan has (carry_outputs, y_outputs). Only y_outputs are stacked - # and need DYNAMIC_BOUND. carry_outputs keep their original dynamism. - init = args[1] - num_carry = len(init) if isinstance(init, (list, tuple)) else 1 - - flat_specs, spec_tree = pytree.tree_flatten(specs) - for i in range(num_carry, len(flat_specs)): - flat_specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND - specs = pytree.tree_unflatten(flat_specs, spec_tree) - subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), specs, diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index b46733045fb..ff0d633fd94 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -281,42 +281,6 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) { getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack)); } -TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) { - int64_t index = 1; - testing::TensorFactory tf; - - EValue values[3]; - EValue* stack[3]; - - // Test with static shape tensors. - const std::vector buf = {1, 2, 3, 4}; - auto copy_to = tf.make({2, 2}, buf); - auto to_copy = tf.make({2}, {5, 6}); - - values[0] = EValue(copy_to); - values[1] = EValue(to_copy); - values[2] = EValue(index); - - stack[0] = &values[0]; - stack[1] = &values[1]; - stack[2] = &values[2]; - - // Copy and replace at index 1. - getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack); - EXPECT_EQ(copy_to.sizes()[0], 2); - EXPECT_EQ(copy_to.sizes()[1], 2); - EXPECT_TENSOR_EQ(copy_to, tf.make({2, 2}, {1, 2, 5, 6})); - -#ifndef USE_ATEN_LIB - // Copy and replace at index 2. This should trigger an EXPECT - // in lean mode. - index = 2; - values[2] = EValue(index); - ET_EXPECT_DEATH( - getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), ""); -#endif -} - TEST_F(RegisterPrimOpsTest, TestBooleanOps) { EValue values[3]; double a = 3; From a2789d89730cf418710e73cd03876f6797d1306f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 10:59:20 -0800 Subject: [PATCH 11/19] missing type --- exir/passes/spec_prop_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index c39d23845a3..2ab113f5e2a 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -7,7 +7,7 @@ # pyre-strict import operator -from typing import Optional +from typing import List, Optional import torch from executorch.exir.delegate import executorch_call_delegate From 4497c27392f2d757ee64319a53d03dda3359ad3b Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 11:45:42 -0800 Subject: [PATCH 12/19] more missing type --- exir/passes/spec_prop_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 2ab113f5e2a..eb23171f1f4 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -11,7 +11,7 @@ import torch from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from executorch.exir.schema import TensorShapeDynamism from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature From 58bfc9c812a768744a8baa7ea9cdb52248754c17 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 11:51:51 -0800 Subject: [PATCH 13/19] dead code --- kernels/prim_ops/et_copy_index.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/kernels/prim_ops/et_copy_index.cpp b/kernels/prim_ops/et_copy_index.cpp index d81fd947ca4..2d46069fcde 100644 --- a/kernels/prim_ops/et_copy_index.cpp +++ b/kernels/prim_ops/et_copy_index.cpp @@ -116,7 +116,6 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { // Resize `copy_to` to the expected output size. This grows the tensor // as we write to each index (0→1, 1→2, 2→3, etc.). - const void* data_ptr = copy_to.const_data_ptr(); Error err = resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); ET_KERNEL_CHECK_MSG( From 6bb49079fbed504687f573d47ded0b290b92373b Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 12:05:42 -0800 Subject: [PATCH 14/19] improve logs --- exir/emit/_emitter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 86d6d369387..c25db311cec 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -983,7 +983,7 @@ def _emit_scan( assert isinstance( subemitter_binding_output_values, (list, tuple) - ), f"Expected list for subemitter_binding_output_values. Got {subemitter_binding_output_values}." + ), f"Expected list for subemitter_binding_output_values. Got {type(subemitter_binding_output_values).__name__}: {subemitter_binding_output_values}." assert isinstance(combine_fn, torch.fx.GraphModule) assert isinstance(init, (list, tuple)) @@ -997,7 +997,7 @@ def _emit_scan( y_outputs = list(subemitter_binding_output_values[num_carry:]) if num_xs < 1: - raise RuntimeError("Scan requires at least one xs tensor to scan over.") + raise RuntimeError(f"Scan requires at least one xs tensor to scan over but got {num_xs}") iter_idx = self._emit_evalue(EValue(Int(0))) From d18aa7839090836d5ca4a1adc649c2da7890ca44 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 12:06:50 -0800 Subject: [PATCH 15/19] lint --- exir/emit/_emitter.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index c25db311cec..b0ac3b15ff3 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -981,9 +981,11 @@ def _emit_scan( """ combine_fn, init, xs, additional_inputs = args - assert isinstance( - subemitter_binding_output_values, (list, tuple) - ), f"Expected list for subemitter_binding_output_values. Got {type(subemitter_binding_output_values).__name__}: {subemitter_binding_output_values}." + assert isinstance(subemitter_binding_output_values, (list, tuple)), ( + f"Expected list for subemitter_binding_output_values. " + f"Got {type(subemitter_binding_output_values).__name__}: " + f"{subemitter_binding_output_values}." + ) assert isinstance(combine_fn, torch.fx.GraphModule) assert isinstance(init, (list, tuple)) @@ -997,7 +999,9 @@ def _emit_scan( y_outputs = list(subemitter_binding_output_values[num_carry:]) if num_xs < 1: - raise RuntimeError(f"Scan requires at least one xs tensor to scan over but got {num_xs}") + raise RuntimeError( + f"Scan requires at least one xs tensor to scan over but got {num_xs}" + ) iter_idx = self._emit_evalue(EValue(Int(0))) From 3bbb02cdfafa0ba3bed31c6e58440f5ccd1a40a4 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 15:05:30 -0800 Subject: [PATCH 16/19] ok fix the spec issues at the source --- exir/pass_base.py | 19 +++-- exir/passes/spec_prop_pass.py | 153 ---------------------------------- exir/tests/test_passes.py | 103 +++++++++++++++++++++++ 3 files changed, 117 insertions(+), 158 deletions(-) diff --git a/exir/pass_base.py b/exir/pass_base.py index 17feadd0a43..500d57bf9b7 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -554,16 +554,25 @@ def call_scan( self, combine_fn: torch.fx.GraphModule, init: List[ProxyValue], - xs: List[ProxyValue], + xs: List[Argument], additional_inputs: List[ProxyValue], meta: NodeMetadata, ) -> ProxyValue: - xs_first_slice = _unstack_pytree([arg.data for arg in xs])[0] - init_data = [arg.data for arg in init] - additional_data = [arg.data for arg in additional_inputs] + # Get the expected x element shapes from the combine_fn's placeholders + # The combine_fn expects: (carry..., x_element..., additional_inputs...) + combine_fn_placeholders = [ + n for n in combine_fn.graph.nodes if n.op == "placeholder" + ] + num_init = len(init) + # The x_element placeholders are at indices [num_init : num_init + num_xs] + xs_element_data = [] + for i, x_proxy in enumerate(xs): + ph = combine_fn_placeholders[num_init + i] + # Use the placeholder's val which has the correct shape + xs_element_data.append(ph.meta["val"]) combine_fn_result = self.call_submodule( - combine_fn, tuple(init_data + xs_first_slice + additional_data) + combine_fn, (*init , *xs_element_data , *additional_inputs) ) assert combine_fn_result is not None diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index eb23171f1f4..0b7db0f75e9 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -12,7 +12,6 @@ import torch from executorch.exir.delegate import executorch_call_delegate from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from executorch.exir.schema import TensorShapeDynamism from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature from torch.fx.node import Node @@ -60,18 +59,15 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: res = ExportPass()(graph_module) assert res is not None gm = res.graph_module - def get_spec(x): if hasattr(x, "meta"): return x.meta.get("spec", None) else: return None - for module in gm.modules(): if isinstance(module, torch.fx.GraphModule): for node in module.graph.nodes: meta_val = node.meta.get("val", None) - if node.op == "output": node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) elif node.op == "call_function" and node.target == operator.getitem: @@ -123,152 +119,3 @@ def update_placeholder_tensor_specs( in exported_program.graph_signature.inputs_to_lifted_tensor_constants ): spec.const = True - - # pyre-ignore - def placeholder(self, name: str, arg, meta): - meta["spec"] = make_spec(arg) - return super().placeholder(name, arg, meta) - - # pyre-ignore - def call_operator(self, op, args, kwargs, meta): - args_data, kwargs_data = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) - return super().call_operator(op, args, kwargs, meta) - - # pyre-ignore - def call_getitem(self, value, key: int, meta): - meta["spec"] = value.node.meta["spec"][key] - return super().call_getitem(value, key, meta) - - # pyre-ignore - def call_cond(self, pred, true_fn, false_fn, inputs, meta): - # true_fn/false_fn return tensors of the same shape, so we can pick - # either one here. - *_, true_out_node = true_fn.graph.nodes - meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) - return super().call_cond(pred, true_fn, false_fn, inputs, meta) - - def call_while( - self, - cond_fn: torch.fx.GraphModule, - body_fn: torch.fx.GraphModule, - carried_inputs: List[ProxyValue], - additional_inputs: List[ProxyValue], - meta: NodeMetadata, - ): - meta["spec"] = pytree.tree_map(make_spec, carried_inputs) - return super().call_while( - cond_fn, body_fn, carried_inputs, additional_inputs, meta - ) - - def call_map( - self, - f: torch.fx.GraphModule, - mapped_args: List[ProxyValue], - operands: List[ProxyValue], - meta: NodeMetadata, - ) -> ProxyValue: - mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) - *_, body_out_node = f.graph.nodes - body_out_node_fake_tensor = body_out_node.meta["val"] - - # For dynamic shapes, initialize with size 0 in the mapped dimension. - # The et_copy_index op will resize as it writes to each index. - # Check if the mapped dimension is symbolic (dynamic). - is_dynamic = isinstance(mapped_dim_size, torch.SymInt) - init_size = 0 if is_dynamic else mapped_dim_size - - map_fake_tensor = pytree.tree_map_only( - torch.Tensor, - lambda x: x.new_empty(init_size, *x.shape), - body_out_node_fake_tensor, - ) - meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) - return super().call_map(f, mapped_args, operands, meta) - - def call_scan( - self, - combine_fn: torch.fx.GraphModule, - init: List[ProxyValue], - xs: List[ProxyValue], - additional_inputs: List[ProxyValue], - meta: NodeMetadata, - ) -> ProxyValue: - # Get the scan length - this may be symbolic for dynamic shapes - xs_tensor = [arg.data for arg in xs][0] - scan_length = xs_tensor.size(0) - - *_, body_out_node = combine_fn.graph.nodes - body_out_fake = body_out_node.meta["val"] - - num_carry = len(init) - flat_body_out, out_spec = pytree.tree_flatten(body_out_fake) - - carry_out = flat_body_out[:num_carry] - y_out = flat_body_out[num_carry:] - - # Check if the scan dimension is symbolic (dynamic) - is_dynamic = isinstance(scan_length, torch.SymInt) - - # For the y outputs, we need to use the upper bound size to allocate memory, - # but also mark the tensor spec as DYNAMIC_BOUND so it can be resized at runtime. - if is_dynamic: - # Get the upper bound by evaluating the symbolic int - # Using hint gives us the concrete upper bound value - upper_bound_size = scan_length.node.shape_env.size_hint( - scan_length.node.expr - ) - else: - upper_bound_size = scan_length - - carry_fake = carry_out - y_fake = [ - ( - x.new_empty(upper_bound_size, *x.shape) - if isinstance(x, torch.Tensor) - else x - ) - for x in y_out - ] - - combined_fake = carry_fake + y_fake - - # Create specs from the fake tensors - specs = pytree.tree_map(make_spec, combined_fake) - - # For dynamic shapes, mark the y_output specs as DYNAMIC_BOUND - # so that et_copy_index can resize them at runtime - if is_dynamic and isinstance(specs, list): - for i in range(num_carry, len(specs)): - if isinstance(specs[i], TensorSpec): - specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND - - meta["spec"] = specs - return super().call_scan(combine_fn, init, xs, additional_inputs, meta) - - # pyre-ignore - def call_delegate(self, lowered_module, args, kwargs, meta): - args_data, kwargs_data = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - # If spec is missing, re-genenrate it with args data - if "spec" not in meta: - meta["spec"] = pytree.tree_map( - make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) - return super().call_delegate(lowered_module, args, kwargs, meta) - - # pyre-ignore - def output(self, results, meta): - # pyre-ignore - def get_spec(x): - if isinstance(x, ProxyValue): - return x.node.meta["spec"] - else: - return make_spec(x) - - meta["spec"] = pytree.tree_map(get_spec, results) - return super().output(results, meta) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index d398b81ee8f..ff116cc3f88 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -745,6 +745,109 @@ def loop_body(i, acc): upper_bound = eval_upper_bound(spec[1].shape[0]) self.assertEqual(upper_bound, 10) + def test_spec_prop_pass_scan(self) -> None: + from torch._higher_order_ops.scan import scan + + class ModelWithScan(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + model = ModelWithScan() + inputs = (torch.arange(15).float().reshape(5, 3),) + + # Run the spec prop pass and sanity check the spec on the scan. + edge_program = to_edge( + export(model, inputs, strict=True), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the scan node. + scan_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + self.assertIsNotNone(scan_node) + + # Spec for the scan node should be a two-element tuple (carry, stacked_outputs) + spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 2) + + # Carry should have shape [3] (same as xs[0]) + self.assertEqual(list(spec[0].shape), [3]) + # Stacked outputs should have shape [5, 3] (same as xs) + self.assertEqual(list(spec[1].shape), [5, 3]) + + def test_spec_prop_pass_scan_dynamic_shape(self) -> None: + from torch._higher_order_ops.scan import scan + + class ModelWithScan(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + model = ModelWithScan() + inputs = (torch.arange(15).float().reshape(5, 3),) + dynamic_shapes = {"xs": {0: torch.export.Dim("seq_len", min=1, max=20)}} + + # First verify that export preserves symbolic shapes + exported = export(model, inputs, dynamic_shapes=dynamic_shapes, strict=True) + scan_node_after_export = next( + n + for n in exported.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + val_after_export = scan_node_after_export.meta.get("val") + self.assertIsNotNone(val_after_export) + # After export, the stacked output should have a symbolic first dimension + self.assertIsInstance(val_after_export[1].shape[0], torch.SymInt) + + # Run the spec prop pass and sanity check the spec on the scan. + edge_program = to_edge( + exported, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the scan node. + scan_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + self.assertIsNotNone(scan_node) + + # Spec for the scan node should be a two-element tuple (carry, stacked_outputs) + spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 2) + + # Carry should have static shape [3] + self.assertEqual(list(spec[0].shape), [3]) + self.assertEqual(spec[0].shape_dynamism, TensorShapeDynamism.STATIC) + + # Stacked outputs should have dynamic first dimension with upper bound 20 + self.assertEqual(len(spec[1].shape), 2) + upper_bound = eval_upper_bound(spec[1].shape[0]) + self.assertEqual(upper_bound, 20) + self.assertEqual(spec[1].shape_dynamism, TensorShapeDynamism.DYNAMIC_BOUND) + self.assertEqual(spec[1].shape[1], 3) # Second dim is static + def test_compile_fix_broken_ops(self) -> None: class ExportableLoop(nn.Module): def __init__(self, hidden_size, out_channels): From 6c3b0e15130f52def84ecb6fc2252304ad3a43c0 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 15:26:29 -0800 Subject: [PATCH 17/19] Now add the pseudo hack in the emitter --- exir/emit/_emitter.py | 13 +++++++++++++ exir/pass_base.py | 2 +- exir/passes/spec_prop_pass.py | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index b0ac3b15ff3..ff654670919 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1189,6 +1189,19 @@ def _emit_control_flow( """ specs = self.node.meta["spec"] + # For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND + # BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but + # stack memory behavior, so we need to be able to update the shape +1 for each iteration + # which we can't do for tensors marked static. + if target is torch.ops.higher_order.scan: + combine_fn, init, xs, additional_inputs = args + num_carry = len(init) + if isinstance(specs, (list, tuple)): + y_specs = specs[num_carry:] + for y_spec in y_specs: + if isinstance(y_spec, TensorSpec): + y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), specs, diff --git a/exir/pass_base.py b/exir/pass_base.py index 500d57bf9b7..8b365fceac9 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -572,7 +572,7 @@ def call_scan( xs_element_data.append(ph.meta["val"]) combine_fn_result = self.call_submodule( - combine_fn, (*init , *xs_element_data , *additional_inputs) + combine_fn, (*init, *xs_element_data, *additional_inputs) ) assert combine_fn_result is not None diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 0b7db0f75e9..36fbc2cc8f4 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -59,11 +59,13 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: res = ExportPass()(graph_module) assert res is not None gm = res.graph_module + def get_spec(x): if hasattr(x, "meta"): return x.meta.get("spec", None) else: return None + for module in gm.modules(): if isinstance(module, torch.fx.GraphModule): for node in module.graph.nodes: From 04264ac35fcb1dbd6f6fa4a50c0d29eec6d3b8cc Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 15:56:16 -0800 Subject: [PATCH 18/19] back fix map --- exir/emit/_emitter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index ff654670919..9e35c3291fd 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1189,7 +1189,7 @@ def _emit_control_flow( """ specs = self.node.meta["spec"] - # For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND + # For scan/map, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND # BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but # stack memory behavior, so we need to be able to update the shape +1 for each iteration # which we can't do for tensors marked static. @@ -1201,6 +1201,10 @@ def _emit_control_flow( for y_spec in y_specs: if isinstance(y_spec, TensorSpec): y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + elif target is torch.ops.higher_order.map_impl: + assert len(specs) == 1 + assert isinstance(specs[0], TensorSpec) + specs[0].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), From 90e55dd5fe8b2ddf1897b50480e478a8093ce254 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 10 Dec 2025 20:51:39 -0800 Subject: [PATCH 19/19] lint --- exir/pass_base.py | 2 +- exir/passes/spec_prop_pass.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/exir/pass_base.py b/exir/pass_base.py index 8b365fceac9..27dd9bf6b73 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -566,7 +566,7 @@ def call_scan( num_init = len(init) # The x_element placeholders are at indices [num_init : num_init + num_xs] xs_element_data = [] - for i, x_proxy in enumerate(xs): + for i in range(0, len(xs)): ph = combine_fn_placeholders[num_init + i] # Use the placeholder's val which has the correct shape xs_element_data.append(ph.meta["val"]) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 36fbc2cc8f4..9adbf65dd90 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -7,11 +7,11 @@ # pyre-strict import operator -from typing import List, Optional +from typing import Optional import torch from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature from torch.fx.node import Node