diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 5c24f08c732..9e35c3291fd 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -944,22 +944,279 @@ def forward(self, x,y): return subemitter_binding_output_values + def _emit_scan( + self, + args: Tuple[_Argument, ...], + subemitter_binding_output_values: List[_AbstractValue], + ) -> List[_AbstractValue]: + """Emits torch.scan. + + Converts the higher order scan op into a loop constructed from jump instructions + and primitive operations. Scan differs from map in that it maintains a carry state + that evolves across iterations. + + Scan signature: scan(combine_fn, init, xs, additional_inputs) + - combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs) + and returns (next_carry, y_slice) + - init: Initial carry state (list of tensors) + - xs: Input tensors to scan over (list of tensors, scanned along dim 0) + - additional_inputs: Additional arguments passed to combine_fn + + Output: (final_carry, stacked_ys) + - final_carry: The carry state after the last iteration + - stacked_ys: All y outputs stacked along dim 0 + + Memory Layout: + - carry_outputs (subemitter_binding_output_values[:num_carry]): + Working carry buffers, initialized from init, updated each iteration + - y_outputs (subemitter_binding_output_values[num_carry:]): + Pre-allocated stacked output buffers, filled via et_copy_index + + The combine_fn writes to its own temporary output buffers (concrete_output_ids). + After each iteration: + 1. Copy combine_fn's carry output -> carry_outputs (for next iteration) + 2. et_copy_index(y_outputs, combine_fn's y output, iter_idx) + + This explicit copy approach is used because in-place op.out(x, out=x) is unsafe. + """ + combine_fn, init, xs, additional_inputs = args + + assert isinstance(subemitter_binding_output_values, (list, tuple)), ( + f"Expected list for subemitter_binding_output_values. " + f"Got {type(subemitter_binding_output_values).__name__}: " + f"{subemitter_binding_output_values}." + ) + + assert isinstance(combine_fn, torch.fx.GraphModule) + assert isinstance(init, (list, tuple)) + assert isinstance(xs, (list, tuple)) + assert isinstance(additional_inputs, (list, tuple)) + + num_carry = len(init) + num_xs = len(xs) + + carry_outputs = list(subemitter_binding_output_values[:num_carry]) + y_outputs = list(subemitter_binding_output_values[num_carry:]) + + if num_xs < 1: + raise RuntimeError( + f"Scan requires at least one xs tensor to scan over but got {num_xs}" + ) + + iter_idx = self._emit_evalue(EValue(Int(0))) + + op_index, op = self._get_operator( + name="aten::sym_size", + overload="int", + ) + sym_size = self._emit_evalue(EValue(Int(0))) + kernel = Instruction( + KernelCall( + op_index=op_index, + args=[xs[0].id, self._emit_evalue(EValue(Int(0))).id, sym_size.id], + ) + ) + self.chain.instructions.append(kernel) + + # Initialize carry_outputs from init + op_index_copy, _ = self._get_operator(name="aten::copy_") + for init_val, carry_out in zip(init, carry_outputs): + kernel = Instruction( + KernelCall( + op_index=op_index_copy, + args=[ + carry_out.id, + init_val.id, + self._emit_evalue(EValue(Bool(False))).id, + carry_out.id, + ], + ) + ) + self.chain.instructions.append(kernel) + + # Slice each xs tensor for the current iteration + op_index_select, _ = self._get_operator( + name="aten::select_copy", + overload="int_out", + ) + xs_slice_instructions = [] + for x in xs: + kernel = Instruction( + KernelCall( + op_index=op_index_select, + args=[ + x.id, + self._emit_evalue(EValue(Int(0))).id, + iter_idx.id, + -1, + -1, + ], + ) + ) + xs_slice_instructions.append(kernel) + + jump_to_instruction = self.instruction_start_offset + len( + self.chain.instructions + ) + + for kernel in xs_slice_instructions: + self.chain.instructions.append(kernel) + + # Emit combine_fn submodule + binding_input_values: List[Any] = [] + binding_input_values.extend(carry_outputs) + binding_input_values.extend([-1] * num_xs) + binding_input_values.extend(additional_inputs) + + scan_emitter = _Emitter( + combine_fn, + self.emitter_state, + self.program_state, + instruction_start_offset=self.instruction_start_offset + + len(self.chain.instructions), + binding_input_values=binding_input_values, + binding_output_values=None, + ) + scan_emitter.run() + + self._merge_chain(scan_emitter.chain) + + for i, kernel in enumerate(xs_slice_instructions): + xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id + kernel.instr_args.args[-1] = xs_placeholder_id + kernel.instr_args.args[-2] = xs_placeholder_id + + concrete_outputs = scan_emitter.concrete_output_ids + carry_temp = concrete_outputs[:num_carry] + y_temp = concrete_outputs[num_carry:] + + self._internal_assert_emitter( + len(carry_temp) == num_carry, + self.node, + f"Scan combine_fn should output {num_carry} carry values, got {len(carry_temp)}", + ) + self._internal_assert_emitter( + len(y_temp) == len(y_outputs), + self.node, + f"Scan combine_fn should output {len(y_outputs)} y values, got {len(y_temp)}", + ) + + # Copy carry_temp -> carry_outputs for next iteration + for carry_t, carry_out in zip(carry_temp, carry_outputs): + kernel = Instruction( + KernelCall( + op_index=op_index_copy, + args=[ + carry_out.id, + carry_t.id, + self._emit_evalue(EValue(Bool(False))).id, + carry_out.id, + ], + ) + ) + self.chain.instructions.append(kernel) + + # Copy y_temp to stacked y_outputs + op_index_copy_index, _ = self._get_operator( + name="executorch_prim::et_copy_index", + overload="tensor", + ) + for y_t, y_out in zip(y_temp, y_outputs): + kernel = Instruction( + KernelCall( + op_index=op_index_copy_index, + args=[y_out.id, y_t.id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + # Increment iter_idx + op_index_add, _ = self._get_operator( + name="executorch_prim::add", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_add, + args=[iter_idx.id, self._emit_evalue(EValue(Int(1))).id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + # Check if iteration is complete + jump_bool_value = self._emit_evalue(EValue(Bool(False))) + op_index_eq, _ = self._get_operator( + name="executorch_prim::eq", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_eq, + args=[iter_idx.id, sym_size.id, jump_bool_value.id], + ) + ) + self.chain.instructions.append(kernel) + + jf_beginning_loop = Instruction( + JumpFalseCall( + cond_value_index=jump_bool_value.id, + destination_instruction=jump_to_instruction, + ) + ) + self.chain.instructions.append(jf_beginning_loop) + + # Reset iter_idx for potential re-runs + op_index_sub, _ = self._get_operator( + name="executorch_prim::sub", + overload="Scalar", + ) + kernel = Instruction( + KernelCall( + op_index=op_index_sub, + args=[iter_idx.id, sym_size.id, iter_idx.id], + ) + ) + self.chain.instructions.append(kernel) + + return subemitter_binding_output_values + def _emit_control_flow( self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _EmitterValue: """Wraps common logic for emitting all control flow operations. - See the more specific emission functions for more details on how cond or map get emitted. + See the more specific emission functions for more details on how cond, map, or scan get emitted. """ + specs = self.node.meta["spec"] + + # For scan/map, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND + # BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but + # stack memory behavior, so we need to be able to update the shape +1 for each iteration + # which we can't do for tensors marked static. + if target is torch.ops.higher_order.scan: + combine_fn, init, xs, additional_inputs = args + num_carry = len(init) + if isinstance(specs, (list, tuple)): + y_specs = specs[num_carry:] + for y_spec in y_specs: + if isinstance(y_spec, TensorSpec): + y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + elif target is torch.ops.higher_order.map_impl: + assert len(specs) == 1 + assert isinstance(specs[0], TensorSpec) + specs[0].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND + subemitter_binding_output_values = pytree.tree_map( lambda spec: self._emit_spec(spec), - self.node.meta["spec"], + specs, ) if target is torch.ops.higher_order.cond: return self._emit_cond(args, subemitter_binding_output_values) elif target is torch.ops.higher_order.map_impl: return self._emit_map(args, subemitter_binding_output_values) + elif target is torch.ops.higher_order.scan: + return self._emit_scan(args, subemitter_binding_output_values) else: raise InternalError( self._emit_node_specific_error( @@ -1190,7 +1447,7 @@ def _emit_delegate( return delegate_ret - def _get_operator(self, name: str, overload: str) -> Tuple[int, Operator]: + def _get_operator(self, name: str, overload: str = "") -> Tuple[int, Operator]: """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it if it is not already present""" key = (name, overload) @@ -1511,6 +1768,7 @@ def call_function( # pyre-fixme[14] torch.ops.higher_order.cond, torch.ops.higher_order.map_impl, torch.ops.higher_order.while_loop, + torch.ops.higher_order.scan, ): return self._emit_control_flow(target, args, kwargs) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 165bc2951f7..3ed432c1872 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -62,10 +62,10 @@ _load_for_executorch_from_buffer, ) from executorch.runtime import Runtime - -from functorch.experimental import control_flow from torch import nn +from torch._higher_order_ops import cond as torch_cond, map as torch_map + from torch.export import Dim, export from torch.export.experimental import _export_forward_backward @@ -674,7 +674,7 @@ def true_fn(y: torch.Tensor) -> torch.Tensor: def false_fn(y: torch.Tensor) -> torch.Tensor: return torch.mm(y, y) - ret = control_flow.cond(pred, true_fn, false_fn, [x]) + ret = torch_cond(pred, true_fn, false_fn, [x]) return ret module = to_edge( @@ -708,7 +708,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -781,7 +781,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -798,7 +798,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - return control_flow.map(map_fn, x, y) + return torch_map(map_fn, x, y) f = Foo() @@ -812,6 +812,323 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: outputs = loaded_model(inputs)[0] torch.allclose(outputs, f(*inputs)) + def test_emit_scan_basic(self) -> None: + """Test basic scan emission: verifies instruction structure for cumulative sum.""" + from torch._higher_order_ops.scan import scan + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + f = ScanCumSum() + # Use contiguous tensor to avoid stride=0 issue + inputs = (torch.arange(15).float().reshape(5, 3),) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + program = module.to_executorch().executorch_program + + op_table = program.execution_plan[0].operators + instructions = program.execution_plan[0].chains[0].instructions + + # Collect all operator names in the program + op_names = [op.name for op in op_table] + + # Verify the key operators are present for scan: + # 1. sym_size - to get scan length + self.assertIn( + "aten::sym_size", op_names, "Should have sym_size for scan length" + ) + + # 2. copy_ - for carry initialization and carry updates + self.assertIn("aten::copy_", op_names, "Should have copy_ for carry handling") + + # 3. select_copy - to slice xs + self.assertIn( + "aten::select_copy", op_names, "Should have select_copy for xs slicing" + ) + + # 4. et_copy_index - to accumulate y outputs + self.assertIn( + "executorch_prim::et_copy_index", + op_names, + "Should have et_copy_index for y accumulation", + ) + + # 5. Loop control: add, eq for iteration control + self.assertIn( + "executorch_prim::add", op_names, "Should have add for iter increment" + ) + self.assertIn( + "executorch_prim::eq", op_names, "Should have eq for completion check" + ) + + # 6. sub - to reset iter_idx for re-runs + self.assertIn( + "executorch_prim::sub", op_names, "Should have sub to reset iterator" + ) + + # 7. Should have JumpFalseCall for loop back + jump_false_found = False + for instr in instructions: + if isinstance(instr.instr_args, JumpFalseCall): + jump_false_found = True + break + self.assertTrue(jump_false_found, "Should have JumpFalseCall for loop control") + + # 8. Verify we have the body operations (add from combine_fn) + self.assertIn("aten::add", op_names, "Should have add from combine_fn body") + + def test_run_emit_scan_cumsum(self) -> None: + """Test scan execution correctness: cumulative sum.""" + from torch._higher_order_ops.scan import scan + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + f = ScanCumSum() + # Use contiguous tensor to avoid stride=0 issue + inputs = (torch.arange(15).float().reshape(5, 3),) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et = module.to_executorch() + et.dump_executorch_program(False) + buffer = et.buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + # Run through executorch + et_outputs = loaded_model(inputs) + + # Run through eager PyTorch + eager_outputs = f(*inputs) + + # Compare final carry + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0]), + f"Final carry mismatch: {et_outputs[0]} vs {eager_outputs[0]}", + ) + + # Compare stacked outputs + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1]), + f"Stacked outputs mismatch: {et_outputs[1]} vs {eager_outputs[1]}", + ) + + def test_emit_scan_add_mul(self) -> None: + """Test scan with add operation in combine_fn.""" + from torch._higher_order_ops.scan import scan + + class ScanAddMul(torch.nn.Module): + def forward( + self, xs: torch.Tensor, y: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + # Multiply carry by 2 and add x + new_carry = carry * 2 + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + f = ScanAddMul() + inputs = (torch.ones(4, 3), torch.ones(3)) + + module = to_edge( + export(f, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + program = module.to_executorch().executorch_program + + # Verify we have the expected operators + op_names = [op.name for op in program.execution_plan[0].operators] + + # Should have mul and add operations from the combine_fn body + self.assertIn("aten::mul", op_names) + self.assertIn("aten::add", op_names) + + # Should have scan control flow operators + self.assertIn("aten::sym_size", op_names) + self.assertIn("aten::select_copy", op_names) + self.assertIn("executorch_prim::et_copy_index", op_names) + + def test_emit_scan_gru(self) -> None: + """Test scan with a simple GRU-like computation.""" + from torch._higher_order_ops.scan import scan + + class SimpleGRU(torch.nn.Module): + """Simple single-layer unidirectional GRU using scan.""" + + def __init__(self, input_size: int, hidden_size: int): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + # GRU gates: reset, update, new + self.weight_ih = torch.nn.Parameter( + torch.randn(3 * hidden_size, input_size), requires_grad=False + ) + self.weight_hh = torch.nn.Parameter( + torch.randn(3 * hidden_size, hidden_size), requires_grad=False + ) + self.bias_ih = torch.nn.Parameter( + torch.randn(3 * hidden_size), requires_grad=False + ) + self.bias_hh = torch.nn.Parameter( + torch.randn(3 * hidden_size), requires_grad=False + ) + + def forward( + self, x: torch.Tensor, h0: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: Input tensor of shape [seq_len, batch, input_size] + h0: Initial hidden state of shape [batch, hidden_size] + Returns: + output: Output tensor of shape [seq_len, batch, hidden_size] + h_n: Final hidden state of shape [batch, hidden_size] + """ + weight_ih = self.weight_ih + weight_hh = self.weight_hh + bias_ih = self.bias_ih + bias_hh = self.bias_hh + + def gru_cell( + h: torch.Tensor, x_t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute gates + gates_ih = torch.nn.functional.linear(x_t, weight_ih, bias_ih) + gates_hh = torch.nn.functional.linear(h, weight_hh, bias_hh) + + # Split into reset, update, new gates + r_ih, z_ih, n_ih = gates_ih.chunk(3, dim=-1) + r_hh, z_hh, n_hh = gates_hh.chunk(3, dim=-1) + + r = torch.sigmoid(r_ih + r_hh) + z = torch.sigmoid(z_ih + z_hh) + n = torch.tanh(n_ih + r * n_hh) + + h_new = (1 - z) * n + z * h + return h_new, h_new.clone() + + final_h, outputs = scan(gru_cell, h0, x) + return outputs, final_h + + # Create model and inputs + input_size = 4 + hidden_size = 8 + seq_len = 5 + batch_size = 2 + + model = SimpleGRU(input_size, hidden_size) + x = torch.randn(seq_len, batch_size, input_size) + h0 = torch.randn(batch_size, hidden_size) + inputs = (x, h0) + + # Run through eager PyTorch + eager_outputs = model(*inputs) + + # Export and convert to edge + module = to_edge( + export(model, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et = module.to_executorch() + program = et.executorch_program + + # Verify the program has expected operators + op_names = [op.name for op in program.execution_plan[0].operators] + + # Should have scan control flow operators + self.assertIn("aten::sym_size", op_names) + self.assertIn("aten::select_copy", op_names) + self.assertIn("executorch_prim::et_copy_index", op_names) + + # Verify we can load the program + buffer = et.buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + # Run through executorch + et_outputs = loaded_model(inputs) + + # Compare outputs (with tolerance for floating point) + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0], atol=1e-5), + f"Output mismatch: {et_outputs[0]} vs {eager_outputs[0]}", + ) + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1], atol=1e-5), + f"Final hidden state mismatch: {et_outputs[1]} vs {eager_outputs[1]}", + ) + + def test_run_emit_scan_dynamic_shape(self) -> None: + """Test scan execution with dynamic sequence length.""" + from torch._higher_order_ops.scan import scan + + class ScanCumSum(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + f = ScanCumSum() + + upper_bound_inputs = (torch.arange(24).float().reshape(8, 3),) + + seq_dim = Dim("seq", max=8) + dynamic_shapes = {"xs": {0: seq_dim}} + + module = to_edge( + export(f, upper_bound_inputs, dynamic_shapes=dynamic_shapes, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et = module.to_executorch( + config=exir.ExecutorchBackendConfig( + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + et.dump_executorch_program(True) + buffer = et.buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + + test_cases = [ + (torch.arange(15).float().reshape(5, 3),), # seq_len=5 + (torch.arange(9).float().reshape(3, 3),), # seq_len=3 + (torch.arange(21).float().reshape(7, 3),), # seq_len=7 + ] + + for test_inputs in test_cases: + et_outputs = loaded_model(test_inputs) + eager_outputs = f(*test_inputs) + + self.assertTrue( + torch.allclose(et_outputs[0], eager_outputs[0]), + f"Final carry mismatch for shape {test_inputs[0].shape}: {et_outputs[0]} vs {eager_outputs[0]}", + ) + self.assertTrue( + torch.allclose(et_outputs[1], eager_outputs[1]), + f"Stacked outputs mismatch for shape {test_inputs[0].shape}: {et_outputs[1]} vs {eager_outputs[1]}", + ) + def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: @@ -1600,7 +1917,6 @@ def forward(self, x): model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True)) model = model.to_executorch() - model.dump_executorch_program(True) self.assertTrue( model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None diff --git a/exir/graph_module.py b/exir/graph_module.py index 2adf62ab0b8..cc85a596bfb 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -78,14 +78,18 @@ def get_control_flow_submodules( ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: """ Returns a list of submodules used for control flow operations - (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + (torch.ops.higher_order.cond/map/scan) that are in the given toplevel graph (does not look into submodules). Specifically, the returned value is a list containing tuples of (name of the submodule that's stored in the graph module, the submodule itself, and the fx node that uses this submodule). """ return _get_control_flow_submodules( graph_module, - {torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]}, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.map_impl: [0], + torch.ops.higher_order.scan: [0], # combine_fn is at arg index 0 + }, ) @@ -108,6 +112,27 @@ def get_cond_while_submodules( ) +def get_scan_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for scan operations + (torch.ops.higher_order.scan) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + + For scan, the combine_fn submodule is at argument index 0. + The scan operator signature is: scan(combine_fn, init, xs, additional_inputs) + """ + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.scan: [0], + }, + ) + + def bfs_trace_with_node_process( gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] ) -> None: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 0394ed9c529..ad6b01e4c9d 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -1023,6 +1023,12 @@ def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: yield nd +def get_scan_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: + for nd in graph_module.graph.nodes: + if nd.target is torch.ops.higher_order.scan: + yield nd + + def get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: return_specs = set() nodes = graph_module.graph.nodes @@ -1125,7 +1131,7 @@ def _apply_algo_to_submodules( alignment: int, graph_signature: Optional[ExportGraphSignature] = None, ) -> list[int]: - """Apply algo to map/cond/while nodes in the graph module. + """Apply algo to map/cond/while/scan nodes in the graph module. This method will popuate graph_module.meta["non_const_buffer_sizes"] for all submodules and return a bufsizes list that is the maximum size of all @@ -1161,6 +1167,9 @@ def _handle( for map_node in get_map_nodes(graph_module): _handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True) + for scan_node in get_scan_nodes(graph_module): + _handle(cast(torch.fx.Node, scan_node.args[0]), alloc_graph_input=True) + # TODO: We can handle delegates the same way as map/cond/while. # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates. diff --git a/exir/pass_base.py b/exir/pass_base.py index 497970fae34..27dd9bf6b73 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -348,6 +348,11 @@ def call_function( elif target == torch.ops.higher_order.map_impl: f, mapped_args, operands = args # type: ignore[assignment] return self.callback.call_map(f, mapped_args, operands, meta) + elif target == torch.ops.higher_order.scan: + combine_fn, init, xs, additional_inputs = args # type: ignore[assignment] + return self.callback.call_scan( + combine_fn, init, xs, additional_inputs, meta + ) # For other unregistered HigherOrderOps, just interpret them blindly elif isinstance(target, torch._ops.HigherOrderOperator): return self.callback._fx( @@ -545,6 +550,40 @@ def call_map( meta, ) + def call_scan( + self, + combine_fn: torch.fx.GraphModule, + init: List[ProxyValue], + xs: List[Argument], + additional_inputs: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + # Get the expected x element shapes from the combine_fn's placeholders + # The combine_fn expects: (carry..., x_element..., additional_inputs...) + combine_fn_placeholders = [ + n for n in combine_fn.graph.nodes if n.op == "placeholder" + ] + num_init = len(init) + # The x_element placeholders are at indices [num_init : num_init + num_xs] + xs_element_data = [] + for i in range(0, len(xs)): + ph = combine_fn_placeholders[num_init + i] + # Use the placeholder's val which has the correct shape + xs_element_data.append(ph.meta["val"]) + + combine_fn_result = self.call_submodule( + combine_fn, (*init, *xs_element_data, *additional_inputs) + ) + assert combine_fn_result is not None + + return self._fx( + "call_function", + torch.ops.higher_order.scan, + (combine_fn_result.graph_module, init, xs, additional_inputs), + {}, + meta, + ) + def call_getitem( self, value: ProxyValue, key: int, meta: NodeMetadata ) -> ProxyValue: diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 5c6eb63db46..cdc1c22a9d5 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -344,6 +344,11 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue + elif target == torch.ops.higher_order.scan: + # scan(combine_fn, init, xs, additional_inputs) + # combine_fn is at args[0] + self.call(get_submodule(node.args[0])) + continue elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 637cc0013f0..9adbf65dd90 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -70,7 +70,6 @@ def get_spec(x): if isinstance(module, torch.fx.GraphModule): for node in module.graph.nodes: meta_val = node.meta.get("val", None) - if node.op == "output": node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) elif node.op == "call_function" and node.target == operator.getitem: diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index d398b81ee8f..ff116cc3f88 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -745,6 +745,109 @@ def loop_body(i, acc): upper_bound = eval_upper_bound(spec[1].shape[0]) self.assertEqual(upper_bound, 10) + def test_spec_prop_pass_scan(self) -> None: + from torch._higher_order_ops.scan import scan + + class ModelWithScan(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + model = ModelWithScan() + inputs = (torch.arange(15).float().reshape(5, 3),) + + # Run the spec prop pass and sanity check the spec on the scan. + edge_program = to_edge( + export(model, inputs, strict=True), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the scan node. + scan_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + self.assertIsNotNone(scan_node) + + # Spec for the scan node should be a two-element tuple (carry, stacked_outputs) + spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 2) + + # Carry should have shape [3] (same as xs[0]) + self.assertEqual(list(spec[0].shape), [3]) + # Stacked outputs should have shape [5, 3] (same as xs) + self.assertEqual(list(spec[1].shape), [5, 3]) + + def test_spec_prop_pass_scan_dynamic_shape(self) -> None: + from torch._higher_order_ops.scan import scan + + class ModelWithScan(torch.nn.Module): + def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def combine_fn(carry, x): + new_carry = carry + x + return new_carry, new_carry.clone() + + init = torch.zeros_like(xs[0]) + return scan(combine_fn, init, xs) + + model = ModelWithScan() + inputs = (torch.arange(15).float().reshape(5, 3),) + dynamic_shapes = {"xs": {0: torch.export.Dim("seq_len", min=1, max=20)}} + + # First verify that export preserves symbolic shapes + exported = export(model, inputs, dynamic_shapes=dynamic_shapes, strict=True) + scan_node_after_export = next( + n + for n in exported.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + val_after_export = scan_node_after_export.meta.get("val") + self.assertIsNotNone(val_after_export) + # After export, the stacked output should have a symbolic first dimension + self.assertIsInstance(val_after_export[1].shape[0], torch.SymInt) + + # Run the spec prop pass and sanity check the spec on the scan. + edge_program = to_edge( + exported, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the scan node. + scan_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "scan" + ) + self.assertIsNotNone(scan_node) + + # Spec for the scan node should be a two-element tuple (carry, stacked_outputs) + spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 2) + + # Carry should have static shape [3] + self.assertEqual(list(spec[0].shape), [3]) + self.assertEqual(spec[0].shape_dynamism, TensorShapeDynamism.STATIC) + + # Stacked outputs should have dynamic first dimension with upper bound 20 + self.assertEqual(len(spec[1].shape), 2) + upper_bound = eval_upper_bound(spec[1].shape[0]) + self.assertEqual(upper_bound, 20) + self.assertEqual(spec[1].shape_dynamism, TensorShapeDynamism.DYNAMIC_BOUND) + self.assertEqual(spec[1].shape[1], 3) # Second dim is static + def test_compile_fix_broken_ops(self) -> None: class ExportableLoop(nn.Module): def __init__(self, hidden_size, out_channels): diff --git a/kernels/prim_ops/et_copy_index.cpp b/kernels/prim_ops/et_copy_index.cpp index 2ef076ad1a0..2d46069fcde 100644 --- a/kernels/prim_ops/et_copy_index.cpp +++ b/kernels/prim_ops/et_copy_index.cpp @@ -114,27 +114,16 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { expected_output_size[i + 1] = copy_from.sizes()[i]; } - if (copy_to.sizes()[0] < expected_output_size[0]) { - // Resize `copy_to` to the expected output size. - const void* data_ptr = copy_to.const_data_ptr(); - Error err = - resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); - ET_CHECK(err == Error::Ok); - ET_KERNEL_CHECK_MSG( - context, - data_ptr == copy_to.const_data_ptr(), - InvalidState, - /* void */, - "Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors"); - } - - // After potential resize, verify that index is within bounds. + // Resize `copy_to` to the expected output size. This grows the tensor + // as we write to each index (0→1, 1→2, 2→3, etc.). + Error err = + resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); ET_KERNEL_CHECK_MSG( context, - index < copy_to.sizes()[0], - InvalidArgument, + err == Error::Ok, + InvalidState, /* void */, - "Index out of bounds"); + "Failed to resize copy_to tensor"); auto copy_to_ptr = copy_to.const_data_ptr(); auto copy_from_ptr = copy_from.const_data_ptr(); diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index b46733045fb..ff0d633fd94 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -281,42 +281,6 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) { getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack)); } -TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) { - int64_t index = 1; - testing::TensorFactory tf; - - EValue values[3]; - EValue* stack[3]; - - // Test with static shape tensors. - const std::vector buf = {1, 2, 3, 4}; - auto copy_to = tf.make({2, 2}, buf); - auto to_copy = tf.make({2}, {5, 6}); - - values[0] = EValue(copy_to); - values[1] = EValue(to_copy); - values[2] = EValue(index); - - stack[0] = &values[0]; - stack[1] = &values[1]; - stack[2] = &values[2]; - - // Copy and replace at index 1. - getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack); - EXPECT_EQ(copy_to.sizes()[0], 2); - EXPECT_EQ(copy_to.sizes()[1], 2); - EXPECT_TENSOR_EQ(copy_to, tf.make({2, 2}, {1, 2, 5, 6})); - -#ifndef USE_ATEN_LIB - // Copy and replace at index 2. This should trigger an EXPECT - // in lean mode. - index = 2; - values[2] = EValue(index); - ET_EXPECT_DEATH( - getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), ""); -#endif -} - TEST_F(RegisterPrimOpsTest, TestBooleanOps) { EValue values[3]; double a = 3;