@@ -944,22 +944,279 @@ def forward(self, x,y):
944944
945945 return subemitter_binding_output_values
946946
947+ def _emit_scan (
948+ self ,
949+ args : Tuple [_Argument , ...],
950+ subemitter_binding_output_values : List [_AbstractValue ],
951+ ) -> List [_AbstractValue ]:
952+ """Emits torch.scan.
953+
954+ Converts the higher order scan op into a loop constructed from jump instructions
955+ and primitive operations. Scan differs from map in that it maintains a carry state
956+ that evolves across iterations.
957+
958+ Scan signature: scan(combine_fn, init, xs, additional_inputs)
959+ - combine_fn: GraphModule that takes (carry, x_slice, *additional_inputs)
960+ and returns (next_carry, y_slice)
961+ - init: Initial carry state (list of tensors)
962+ - xs: Input tensors to scan over (list of tensors, scanned along dim 0)
963+ - additional_inputs: Additional arguments passed to combine_fn
964+
965+ Output: (final_carry, stacked_ys)
966+ - final_carry: The carry state after the last iteration
967+ - stacked_ys: All y outputs stacked along dim 0
968+
969+ Memory Layout:
970+ - carry_outputs (subemitter_binding_output_values[:num_carry]):
971+ Working carry buffers, initialized from init, updated each iteration
972+ - y_outputs (subemitter_binding_output_values[num_carry:]):
973+ Pre-allocated stacked output buffers, filled via et_copy_index
974+
975+ The combine_fn writes to its own temporary output buffers (concrete_output_ids).
976+ After each iteration:
977+ 1. Copy combine_fn's carry output -> carry_outputs (for next iteration)
978+ 2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
979+
980+ This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
981+ """
982+ combine_fn , init , xs , additional_inputs = args
983+
984+ assert isinstance (subemitter_binding_output_values , (list , tuple )), (
985+ f"Expected list for subemitter_binding_output_values. "
986+ f"Got { type (subemitter_binding_output_values ).__name__ } : "
987+ f"{ subemitter_binding_output_values } ."
988+ )
989+
990+ assert isinstance (combine_fn , torch .fx .GraphModule )
991+ assert isinstance (init , (list , tuple ))
992+ assert isinstance (xs , (list , tuple ))
993+ assert isinstance (additional_inputs , (list , tuple ))
994+
995+ num_carry = len (init )
996+ num_xs = len (xs )
997+
998+ carry_outputs = list (subemitter_binding_output_values [:num_carry ])
999+ y_outputs = list (subemitter_binding_output_values [num_carry :])
1000+
1001+ if num_xs < 1 :
1002+ raise RuntimeError (
1003+ f"Scan requires at least one xs tensor to scan over but got { num_xs } "
1004+ )
1005+
1006+ iter_idx = self ._emit_evalue (EValue (Int (0 )))
1007+
1008+ op_index , op = self ._get_operator (
1009+ name = "aten::sym_size" ,
1010+ overload = "int" ,
1011+ )
1012+ sym_size = self ._emit_evalue (EValue (Int (0 )))
1013+ kernel = Instruction (
1014+ KernelCall (
1015+ op_index = op_index ,
1016+ args = [xs [0 ].id , self ._emit_evalue (EValue (Int (0 ))).id , sym_size .id ],
1017+ )
1018+ )
1019+ self .chain .instructions .append (kernel )
1020+
1021+ # Initialize carry_outputs from init
1022+ op_index_copy , _ = self ._get_operator (name = "aten::copy_" )
1023+ for init_val , carry_out in zip (init , carry_outputs ):
1024+ kernel = Instruction (
1025+ KernelCall (
1026+ op_index = op_index_copy ,
1027+ args = [
1028+ carry_out .id ,
1029+ init_val .id ,
1030+ self ._emit_evalue (EValue (Bool (False ))).id ,
1031+ carry_out .id ,
1032+ ],
1033+ )
1034+ )
1035+ self .chain .instructions .append (kernel )
1036+
1037+ # Slice each xs tensor for the current iteration
1038+ op_index_select , _ = self ._get_operator (
1039+ name = "aten::select_copy" ,
1040+ overload = "int_out" ,
1041+ )
1042+ xs_slice_instructions = []
1043+ for x in xs :
1044+ kernel = Instruction (
1045+ KernelCall (
1046+ op_index = op_index_select ,
1047+ args = [
1048+ x .id ,
1049+ self ._emit_evalue (EValue (Int (0 ))).id ,
1050+ iter_idx .id ,
1051+ - 1 ,
1052+ - 1 ,
1053+ ],
1054+ )
1055+ )
1056+ xs_slice_instructions .append (kernel )
1057+
1058+ jump_to_instruction = self .instruction_start_offset + len (
1059+ self .chain .instructions
1060+ )
1061+
1062+ for kernel in xs_slice_instructions :
1063+ self .chain .instructions .append (kernel )
1064+
1065+ # Emit combine_fn submodule
1066+ binding_input_values : List [Any ] = []
1067+ binding_input_values .extend (carry_outputs )
1068+ binding_input_values .extend ([- 1 ] * num_xs )
1069+ binding_input_values .extend (additional_inputs )
1070+
1071+ scan_emitter = _Emitter (
1072+ combine_fn ,
1073+ self .emitter_state ,
1074+ self .program_state ,
1075+ instruction_start_offset = self .instruction_start_offset
1076+ + len (self .chain .instructions ),
1077+ binding_input_values = binding_input_values ,
1078+ binding_output_values = None ,
1079+ )
1080+ scan_emitter .run ()
1081+
1082+ self ._merge_chain (scan_emitter .chain )
1083+
1084+ for i , kernel in enumerate (xs_slice_instructions ):
1085+ xs_placeholder_id = scan_emitter .binding_input_values [num_carry + i ].id
1086+ kernel .instr_args .args [- 1 ] = xs_placeholder_id
1087+ kernel .instr_args .args [- 2 ] = xs_placeholder_id
1088+
1089+ concrete_outputs = scan_emitter .concrete_output_ids
1090+ carry_temp = concrete_outputs [:num_carry ]
1091+ y_temp = concrete_outputs [num_carry :]
1092+
1093+ self ._internal_assert_emitter (
1094+ len (carry_temp ) == num_carry ,
1095+ self .node ,
1096+ f"Scan combine_fn should output { num_carry } carry values, got { len (carry_temp )} " ,
1097+ )
1098+ self ._internal_assert_emitter (
1099+ len (y_temp ) == len (y_outputs ),
1100+ self .node ,
1101+ f"Scan combine_fn should output { len (y_outputs )} y values, got { len (y_temp )} " ,
1102+ )
1103+
1104+ # Copy carry_temp -> carry_outputs for next iteration
1105+ for carry_t , carry_out in zip (carry_temp , carry_outputs ):
1106+ kernel = Instruction (
1107+ KernelCall (
1108+ op_index = op_index_copy ,
1109+ args = [
1110+ carry_out .id ,
1111+ carry_t .id ,
1112+ self ._emit_evalue (EValue (Bool (False ))).id ,
1113+ carry_out .id ,
1114+ ],
1115+ )
1116+ )
1117+ self .chain .instructions .append (kernel )
1118+
1119+ # Copy y_temp to stacked y_outputs
1120+ op_index_copy_index , _ = self ._get_operator (
1121+ name = "executorch_prim::et_copy_index" ,
1122+ overload = "tensor" ,
1123+ )
1124+ for y_t , y_out in zip (y_temp , y_outputs ):
1125+ kernel = Instruction (
1126+ KernelCall (
1127+ op_index = op_index_copy_index ,
1128+ args = [y_out .id , y_t .id , iter_idx .id ],
1129+ )
1130+ )
1131+ self .chain .instructions .append (kernel )
1132+
1133+ # Increment iter_idx
1134+ op_index_add , _ = self ._get_operator (
1135+ name = "executorch_prim::add" ,
1136+ overload = "Scalar" ,
1137+ )
1138+ kernel = Instruction (
1139+ KernelCall (
1140+ op_index = op_index_add ,
1141+ args = [iter_idx .id , self ._emit_evalue (EValue (Int (1 ))).id , iter_idx .id ],
1142+ )
1143+ )
1144+ self .chain .instructions .append (kernel )
1145+
1146+ # Check if iteration is complete
1147+ jump_bool_value = self ._emit_evalue (EValue (Bool (False )))
1148+ op_index_eq , _ = self ._get_operator (
1149+ name = "executorch_prim::eq" ,
1150+ overload = "Scalar" ,
1151+ )
1152+ kernel = Instruction (
1153+ KernelCall (
1154+ op_index = op_index_eq ,
1155+ args = [iter_idx .id , sym_size .id , jump_bool_value .id ],
1156+ )
1157+ )
1158+ self .chain .instructions .append (kernel )
1159+
1160+ jf_beginning_loop = Instruction (
1161+ JumpFalseCall (
1162+ cond_value_index = jump_bool_value .id ,
1163+ destination_instruction = jump_to_instruction ,
1164+ )
1165+ )
1166+ self .chain .instructions .append (jf_beginning_loop )
1167+
1168+ # Reset iter_idx for potential re-runs
1169+ op_index_sub , _ = self ._get_operator (
1170+ name = "executorch_prim::sub" ,
1171+ overload = "Scalar" ,
1172+ )
1173+ kernel = Instruction (
1174+ KernelCall (
1175+ op_index = op_index_sub ,
1176+ args = [iter_idx .id , sym_size .id , iter_idx .id ],
1177+ )
1178+ )
1179+ self .chain .instructions .append (kernel )
1180+
1181+ return subemitter_binding_output_values
1182+
9471183 def _emit_control_flow (
9481184 self , target : _Target , args : Tuple [_Argument , ...], kwargs : Dict [str , _Argument ]
9491185 ) -> _EmitterValue :
9501186 """Wraps common logic for emitting all control flow operations.
9511187
952- See the more specific emission functions for more details on how cond or map get emitted.
1188+ See the more specific emission functions for more details on how cond, map, or scan get emitted.
9531189 """
1190+ specs = self .node .meta ["spec" ]
1191+
1192+ # For scan/map, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
1193+ # BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
1194+ # stack memory behavior, so we need to be able to update the shape +1 for each iteration
1195+ # which we can't do for tensors marked static.
1196+ if target is torch .ops .higher_order .scan :
1197+ combine_fn , init , xs , additional_inputs = args
1198+ num_carry = len (init )
1199+ if isinstance (specs , (list , tuple )):
1200+ y_specs = specs [num_carry :]
1201+ for y_spec in y_specs :
1202+ if isinstance (y_spec , TensorSpec ):
1203+ y_spec .shape_dynamism = TensorShapeDynamism .DYNAMIC_BOUND
1204+ elif target is torch .ops .higher_order .map_impl :
1205+ assert len (specs ) == 1
1206+ assert isinstance (specs [0 ], TensorSpec )
1207+ specs [0 ].shape_dynamism = TensorShapeDynamism .DYNAMIC_BOUND
1208+
9541209 subemitter_binding_output_values = pytree .tree_map (
9551210 lambda spec : self ._emit_spec (spec ),
956- self . node . meta [ "spec" ] ,
1211+ specs ,
9571212 )
9581213
9591214 if target is torch .ops .higher_order .cond :
9601215 return self ._emit_cond (args , subemitter_binding_output_values )
9611216 elif target is torch .ops .higher_order .map_impl :
9621217 return self ._emit_map (args , subemitter_binding_output_values )
1218+ elif target is torch .ops .higher_order .scan :
1219+ return self ._emit_scan (args , subemitter_binding_output_values )
9631220 else :
9641221 raise InternalError (
9651222 self ._emit_node_specific_error (
@@ -1190,7 +1447,7 @@ def _emit_delegate(
11901447
11911448 return delegate_ret
11921449
1193- def _get_operator (self , name : str , overload : str ) -> Tuple [int , Operator ]:
1450+ def _get_operator (self , name : str , overload : str = "" ) -> Tuple [int , Operator ]:
11941451 """Given a fully qualified name, lookups the operator in the ExecuTorch Program, or adds it
11951452 if it is not already present"""
11961453 key = (name , overload )
@@ -1511,6 +1768,7 @@ def call_function( # pyre-fixme[14]
15111768 torch .ops .higher_order .cond ,
15121769 torch .ops .higher_order .map_impl ,
15131770 torch .ops .higher_order .while_loop ,
1771+ torch .ops .higher_order .scan ,
15141772 ):
15151773 return self ._emit_control_flow (target , args , kwargs )
15161774
0 commit comments