@@ -993,19 +993,14 @@ def _emit_scan(
993993 num_carry = len (init )
994994 num_xs = len (xs )
995995
996- # Split output values into carry outputs and y outputs
997996 carry_outputs = list (subemitter_binding_output_values [:num_carry ])
998997 y_outputs = list (subemitter_binding_output_values [num_carry :])
999998
1000999 if num_xs < 1 :
10011000 raise RuntimeError ("Scan requires at least one xs tensor to scan over." )
10021001
1003- # === INITIALIZATION ===
1004-
1005- # Generate iterator index EValue
10061002 iter_idx = self ._emit_evalue (EValue (Int (0 )))
10071003
1008- # Get scan length from first xs tensor
10091004 op_index , op = self ._get_operator (
10101005 name = "aten::sym_size" ,
10111006 overload = "int" ,
@@ -1019,9 +1014,7 @@ def _emit_scan(
10191014 )
10201015 self .chain .instructions .append (kernel )
10211016
1022- # Initialize carry_outputs from init by copying init -> carry_outputs
1023- # This is necessary because we shouldn't mutate the original init tensors
1024- # Use aten::copy_.default which copies src to self in-place
1017+ # Initialize carry_outputs from init
10251018 op_index_copy , _ = self ._get_operator (name = "aten::copy_" )
10261019 for init_val , carry_out in zip (init , carry_outputs ):
10271020 kernel = Instruction (
@@ -1037,11 +1030,7 @@ def _emit_scan(
10371030 )
10381031 self .chain .instructions .append (kernel )
10391032
1040- # === LOOP START ===
1041-
10421033 # Slice each xs tensor for the current iteration
1043- # We use -1 as placeholder for the output tensor id, which will be filled
1044- # after the scan_emitter runs and allocates the input placeholder EValues
10451034 op_index_select , _ = self ._get_operator (
10461035 name = "aten::select_copy" ,
10471036 overload = "int_out" ,
@@ -1053,69 +1042,46 @@ def _emit_scan(
10531042 op_index = op_index_select ,
10541043 args = [
10551044 x .id ,
1056- self ._emit_evalue (EValue (Int (0 ))).id , # dim=0
1045+ self ._emit_evalue (EValue (Int (0 ))).id ,
10571046 iter_idx .id ,
1058- - 1 , # placeholder for output tensor id
1059- - 1 , # placeholder (repeated for out variant)
1047+ - 1 ,
1048+ - 1 ,
10601049 ],
10611050 )
10621051 )
10631052 xs_slice_instructions .append (kernel )
10641053
1065- # Store jump target - this is where we jump back to after each iteration
10661054 jump_to_instruction = self .instruction_start_offset + len (
10671055 self .chain .instructions
10681056 )
10691057
1070- # Add all xs slice instructions
10711058 for kernel in xs_slice_instructions :
10721059 self .chain .instructions .append (kernel )
10731060
1074- # === EMIT COMBINE_FN SUBMODULE ===
1075-
1076- # combine_fn inputs: (*carry, *xs_slice, *additional_inputs)
1077- # We bind carry inputs to carry_outputs (the working carry buffers)
1078- # xs_slice inputs will be filled in after emitter runs (using -1 placeholder)
1079- # additional_inputs are passed through directly
1061+ # Emit combine_fn submodule
10801062 binding_input_values : List [Any ] = []
1081- binding_input_values .extend (
1082- carry_outputs
1083- ) # Carry inputs bound to carry_outputs
1084- binding_input_values .extend ([- 1 ] * num_xs ) # Placeholders for xs slices
1085- binding_input_values .extend (additional_inputs ) # Additional inputs
1086-
1087- # combine_fn outputs: (*next_carry, *y_slice)
1088- # Pass binding_output_values=None so the combine_fn writes directly to its
1089- # own output buffers (concrete_output_ids). We then copy from these directly
1090- # to the final carry/y buffers, avoiding unnecessary temp buffers and MOVEs.
1063+ binding_input_values .extend (carry_outputs )
1064+ binding_input_values .extend ([- 1 ] * num_xs )
1065+ binding_input_values .extend (additional_inputs )
1066+
10911067 scan_emitter = _Emitter (
10921068 combine_fn ,
10931069 self .emitter_state ,
10941070 self .program_state ,
10951071 instruction_start_offset = self .instruction_start_offset
10961072 + len (self .chain .instructions ),
10971073 binding_input_values = binding_input_values ,
1098- binding_output_values = None , # Use concrete outputs directly
1074+ binding_output_values = None ,
10991075 )
11001076 scan_emitter .run ()
11011077
1102- # Merge combine_fn instructions
11031078 self ._merge_chain (scan_emitter .chain )
1104- # NOTE: When binding_output_values=None, no return/move instruction is added
1105- # by the output() method, so we don't need to pop anything.
11061079
1107- # Update xs_slice instructions with the actual placeholder EValue ids
1108- # The xs placeholders start after the carry inputs in combine_fn
11091080 for i , kernel in enumerate (xs_slice_instructions ):
11101081 xs_placeholder_id = scan_emitter .binding_input_values [num_carry + i ].id
11111082 kernel .instr_args .args [- 1 ] = xs_placeholder_id
11121083 kernel .instr_args .args [- 2 ] = xs_placeholder_id
11131084
1114- # === COPY OUTPUTS ===
1115-
1116- # Get combine_fn's actual output EValues
1117- # concrete_output_ids contains the actual EValues that the combine_fn
1118- # graph operations write to: (*carry_temp, *y_temp)
11191085 concrete_outputs = scan_emitter .concrete_output_ids
11201086 carry_temp = concrete_outputs [:num_carry ]
11211087 y_temp = concrete_outputs [num_carry :]
@@ -1132,8 +1098,6 @@ def _emit_scan(
11321098 )
11331099
11341100 # Copy carry_temp -> carry_outputs for next iteration
1135- # This explicit copy is required because in-place op.out(x, out=x) is unsafe
1136- # aten::copy_ signature: (self, src, non_blocking, out) -> self
11371101 for carry_t , carry_out in zip (carry_temp , carry_outputs ):
11381102 kernel = Instruction (
11391103 KernelCall (
@@ -1148,7 +1112,7 @@ def _emit_scan(
11481112 )
11491113 self .chain .instructions .append (kernel )
11501114
1151- # Copy y_temp to stacked y_outputs using et_copy_index
1115+ # Copy y_temp to stacked y_outputs
11521116 op_index_copy_index , _ = self ._get_operator (
11531117 name = "executorch_prim::et_copy_index" ,
11541118 overload = "tensor" ,
@@ -1162,8 +1126,6 @@ def _emit_scan(
11621126 )
11631127 self .chain .instructions .append (kernel )
11641128
1165- # === LOOP CONTROL ===
1166-
11671129 # Increment iter_idx
11681130 op_index_add , _ = self ._get_operator (
11691131 name = "executorch_prim::add" ,
@@ -1191,7 +1153,6 @@ def _emit_scan(
11911153 )
11921154 self .chain .instructions .append (kernel )
11931155
1194- # Jump back to loop start if not done
11951156 jf_beginning_loop = Instruction (
11961157 JumpFalseCall (
11971158 cond_value_index = jump_bool_value .id ,
@@ -1200,9 +1161,7 @@ def _emit_scan(
12001161 )
12011162 self .chain .instructions .append (jf_beginning_loop )
12021163
1203- # === CLEANUP ===
1204-
1205- # Reset iter_idx for potential re-runs of the model
1164+ # Reset iter_idx for potential re-runs
12061165 op_index_sub , _ = self ._get_operator (
12071166 name = "executorch_prim::sub" ,
12081167 overload = "Scalar" ,
0 commit comments