@@ -944,12 +944,275 @@ def forward(self, x,y):
944944
945945 return subemitter_binding_output_values
946946
947+ def _emit_scan (
948+ self ,
949+ args : Tuple [_Argument , ...],
950+ subemitter_binding_output_values : List [_AbstractValue ],
951+ ) -> List [_AbstractValue ]:
952+ """Emits torch.scan.
953+
954+ Converts the higher order scan op into a loop constructed from jump instructions
955+ and primitive operations. Scan differs from map in that it maintains a carry state
956+ that evolves across iterations.
957+
958+ Scan signature: scan(combine_fn, init, xs, additional_inputs)
959+ - combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
960+ and returns (next_carry, y_slice)
961+ - init: Initial carry state (list of tensors)
962+ - xs: Input tensors to scan over (list of tensors, scanned along dim 0)
963+ - additional_inputs: Additional arguments passed to combine_fn
964+
965+ Output: (final_carry, stacked_ys)
966+ - final_carry: The carry state after the last iteration
967+ - stacked_ys: All y outputs stacked along dim 0
968+
969+ Memory Layout:
970+ - carry_outputs (subemitter_binding_output_values[:num_carry]):
971+ Working carry buffers, initialized from init, updated each iteration
972+ - y_outputs (subemitter_binding_output_values[num_carry:]):
973+ Pre-allocated stacked output buffers, filled via et_copy_index
974+
975+ The combine_fn writes to its own temporary output buffers (concrete_output_ids).
976+ After each iteration:
977+ 1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
978+ 2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
979+
980+ This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
981+ """
982+ combine_fn , init , xs , additional_inputs = args
983+
984+ assert isinstance (
985+ subemitter_binding_output_values , (list , tuple )
986+ ), f"Expected list for subemitter_binding_output_values. Got { subemitter_binding_output_values } ."
987+
988+ assert isinstance (combine_fn , torch .fx .GraphModule )
989+ assert isinstance (init , (list , tuple ))
990+ assert isinstance (xs , (list , tuple ))
991+ assert isinstance (additional_inputs , (list , tuple ))
992+
993+ num_carry = len (init )
994+ num_xs = len (xs )
995+ num_additional = len (additional_inputs )
996+
997+ # Split output values into carry outputs and y outputs
998+ carry_outputs = list (subemitter_binding_output_values [:num_carry ])
999+ y_outputs = list (subemitter_binding_output_values [num_carry :])
1000+
1001+ if num_xs < 1 :
1002+ raise RuntimeError ("Scan requires at least one xs tensor to scan over." )
1003+
1004+ # === INITIALIZATION ===
1005+
1006+ # Generate iterator index EValue
1007+ iter_idx = self ._emit_evalue (EValue (Int (0 )))
1008+
1009+ # Get scan length from first xs tensor
1010+ op_index , op = self ._get_operator (
1011+ name = "aten::sym_size" ,
1012+ overload = "int" ,
1013+ )
1014+ sym_size = self ._emit_evalue (EValue (Int (0 )))
1015+ kernel = Instruction (
1016+ KernelCall (
1017+ op_index = op_index ,
1018+ args = [xs [0 ].id , self ._emit_evalue (EValue (Int (0 ))).id , sym_size .id ],
1019+ )
1020+ )
1021+ self .chain .instructions .append (kernel )
1022+
1023+ # Initialize carry_outputs from init by copying init -> carry_outputs
1024+ # This is necessary because we shouldn't mutate the original init tensors
1025+ op_index_copy , _ = self ._get_operator (
1026+ name = "aten::copy_" ,
1027+ overload = "default" ,
1028+ )
1029+ for i , (init_val , carry_out ) in enumerate (zip (init , carry_outputs )):
1030+ kernel = Instruction (
1031+ KernelCall (
1032+ op_index = op_index_copy ,
1033+ args = [carry_out .id , init_val .id ],
1034+ )
1035+ )
1036+ self .chain .instructions .append (kernel )
1037+
1038+ # === LOOP START ===
1039+
1040+ # Slice each xs tensor for the current iteration
1041+ # We use -1 as placeholder for the output tensor id, which will be filled
1042+ # after the scan_emitter runs and allocates the input placeholder EValues
1043+ op_index_select , _ = self ._get_operator (
1044+ name = "aten::select_copy" ,
1045+ overload = "int_out" ,
1046+ )
1047+ xs_slice_instructions = []
1048+ for i , x in enumerate (xs ):
1049+ kernel = Instruction (
1050+ KernelCall (
1051+ op_index = op_index_select ,
1052+ args = [
1053+ x .id ,
1054+ self ._emit_evalue (EValue (Int (0 ))).id , # dim=0
1055+ iter_idx .id ,
1056+ - 1 , # placeholder for output tensor id
1057+ - 1 , # placeholder (repeated for out variant)
1058+ ],
1059+ )
1060+ )
1061+ xs_slice_instructions .append (kernel )
1062+
1063+ # Store jump target - this is where we jump back to after each iteration
1064+ jump_to_instruction = self .instruction_start_offset + len (
1065+ self .chain .instructions
1066+ )
1067+
1068+ # Add all xs slice instructions
1069+ for kernel in xs_slice_instructions :
1070+ self .chain .instructions .append (kernel )
1071+
1072+ # === EMIT COMBINE_FN SUBMODULE ===
1073+
1074+ # combine_fn inputs: (*carry, *xs_slice, *additional_inputs)
1075+ # We bind carry inputs to carry_outputs (the working carry buffers)
1076+ # xs_slice inputs will be filled in after emitter runs (using -1 placeholder)
1077+ # additional_inputs are passed through directly
1078+ binding_input_values : List [Any ] = []
1079+ binding_input_values .extend (
1080+ carry_outputs
1081+ ) # Carry inputs bound to carry_outputs
1082+ binding_input_values .extend ([- 1 ] * num_xs ) # Placeholders for xs slices
1083+ binding_input_values .extend (additional_inputs ) # Additional inputs
1084+
1085+ # combine_fn outputs: (*next_carry, *y_slice)
1086+ # We don't bind outputs to the final destinations directly because we need
1087+ # to copy them explicitly (in-place is unsafe)
1088+ scan_emitter = _Emitter (
1089+ combine_fn ,
1090+ self .emitter_state ,
1091+ self .program_state ,
1092+ instruction_start_offset = self .instruction_start_offset
1093+ + len (self .chain .instructions ),
1094+ binding_input_values = binding_input_values ,
1095+ binding_output_values = None , # Let combine_fn use its own output buffers
1096+ )
1097+ scan_emitter .run ()
1098+
1099+ # Merge combine_fn instructions
1100+ self ._merge_chain (scan_emitter .chain )
1101+ # Remove the return instruction from combine_fn
1102+ self .chain .instructions .pop ()
1103+
1104+ # Update xs_slice instructions with the actual placeholder EValue ids
1105+ # The xs placeholders start after the carry inputs in combine_fn
1106+ for i , kernel in enumerate (xs_slice_instructions ):
1107+ xs_placeholder_id = scan_emitter .binding_input_values [num_carry + i ].id
1108+ kernel .instr_args .args [- 1 ] = xs_placeholder_id
1109+ kernel .instr_args .args [- 2 ] = xs_placeholder_id
1110+
1111+ # === COPY OUTPUTS ===
1112+
1113+ # Get combine_fn's actual output EValues
1114+ # concrete_output_ids contains: (*carry_temp, *y_temp)
1115+ concrete_outputs = scan_emitter .concrete_output_ids
1116+ carry_temp = concrete_outputs [:num_carry ]
1117+ y_temp = concrete_outputs [num_carry :]
1118+
1119+ self ._internal_assert_emitter (
1120+ len (carry_temp ) == num_carry ,
1121+ self .node ,
1122+ f"Scan combine_fn should output { num_carry } carry values, got { len (carry_temp )} " ,
1123+ )
1124+ self ._internal_assert_emitter (
1125+ len (y_temp ) == len (y_outputs ),
1126+ self .node ,
1127+ f"Scan combine_fn should output { len (y_outputs )} y values, got { len (y_temp )} " ,
1128+ )
1129+
1130+ # Copy carry_temp -> carry_outputs for next iteration
1131+ # This explicit copy is required because in-place op.out(x, out=x) is unsafe
1132+ for carry_t , carry_out in zip (carry_temp , carry_outputs ):
1133+ kernel = Instruction (
1134+ KernelCall (
1135+ op_index = op_index_copy ,
1136+ args = [carry_out .id , carry_t .id ],
1137+ )
1138+ )
1139+ self .chain .instructions .append (kernel )
1140+
1141+ # Copy y_temp to stacked y_outputs using et_copy_index
1142+ op_index_copy_index , _ = self ._get_operator (
1143+ name = "executorch_prim::et_copy_index" ,
1144+ overload = "tensor" ,
1145+ )
1146+ for y_t , y_out in zip (y_temp , y_outputs ):
1147+ kernel = Instruction (
1148+ KernelCall (
1149+ op_index = op_index_copy_index ,
1150+ args = [y_out .id , y_t .id , iter_idx .id ],
1151+ )
1152+ )
1153+ self .chain .instructions .append (kernel )
1154+
1155+ # === LOOP CONTROL ===
1156+
1157+ # Increment iter_idx
1158+ op_index_add , _ = self ._get_operator (
1159+ name = "executorch_prim::add" ,
1160+ overload = "Scalar" ,
1161+ )
1162+ kernel = Instruction (
1163+ KernelCall (
1164+ op_index = op_index_add ,
1165+ args = [iter_idx .id , self ._emit_evalue (EValue (Int (1 ))).id , iter_idx .id ],
1166+ )
1167+ )
1168+ self .chain .instructions .append (kernel )
1169+
1170+ # Check if iteration is complete
1171+ jump_bool_value = self ._emit_evalue (EValue (Bool (False )))
1172+ op_index_eq , _ = self ._get_operator (
1173+ name = "executorch_prim::eq" ,
1174+ overload = "Scalar" ,
1175+ )
1176+ kernel = Instruction (
1177+ KernelCall (
1178+ op_index = op_index_eq ,
1179+ args = [iter_idx .id , sym_size .id , jump_bool_value .id ],
1180+ )
1181+ )
1182+ self .chain .instructions .append (kernel )
1183+
1184+ # Jump back to loop start if not done
1185+ jf_beginning_loop = Instruction (
1186+ JumpFalseCall (
1187+ cond_value_index = jump_bool_value .id ,
1188+ destination_instruction = jump_to_instruction ,
1189+ )
1190+ )
1191+ self .chain .instructions .append (jf_beginning_loop )
1192+
1193+ # === CLEANUP ===
1194+
1195+ # Reset iter_idx for potential re-runs of the model
1196+ op_index_sub , _ = self ._get_operator (
1197+ name = "executorch_prim::sub" ,
1198+ overload = "Scalar" ,
1199+ )
1200+ kernel = Instruction (
1201+ KernelCall (
1202+ op_index = op_index_sub ,
1203+ args = [iter_idx .id , sym_size .id , iter_idx .id ],
1204+ )
1205+ )
1206+ self .chain .instructions .append (kernel )
1207+
1208+ return subemitter_binding_output_values
1209+
9471210 def _emit_control_flow (
9481211 self , target : _Target , args : Tuple [_Argument , ...], kwargs : Dict [str , _Argument ]
9491212 ) -> _EmitterValue :
9501213 """Wraps common logic for emitting all control flow operations.
9511214
952- See the more specific emission functions for more details on how cond or map get emitted.
1215+ See the more specific emission functions for more details on how cond, map, or scan get emitted.
9531216 """
9541217 subemitter_binding_output_values = pytree .tree_map (
9551218 lambda spec : self ._emit_spec (spec ),
@@ -960,6 +1223,8 @@ def _emit_control_flow(
9601223 return self ._emit_cond (args , subemitter_binding_output_values )
9611224 elif target is torch .ops .higher_order .map_impl :
9621225 return self ._emit_map (args , subemitter_binding_output_values )
1226+ elif target is torch .ops .higher_order .scan :
1227+ return self ._emit_scan (args , subemitter_binding_output_values )
9631228 else :
9641229 raise InternalError (
9651230 self ._emit_node_specific_error (
@@ -1511,6 +1776,7 @@ def call_function( # pyre-fixme[14]
15111776 torch .ops .higher_order .cond ,
15121777 torch .ops .higher_order .map_impl ,
15131778 torch .ops .higher_order .while_loop ,
1779+ torch .ops .higher_order .scan ,
15141780 ):
15151781 return self ._emit_control_flow (target , args , kwargs )
15161782
0 commit comments