@@ -97,49 +97,6 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
9797 _unlift (
9898 body_gm , inp_pos_to_buffer_name_for_submod , in_spec , None , state_dict
9999 )
100- if node .op == "call_function" and node .target .__name__ == "scan" :
101- # scan signature: scan(combine_fn, init, xs, additional_inputs)
102- # - combine_fn: GraphModule for the scan body
103- # - init: list of initial carry tensors
104- # - xs: list of input tensors to scan over
105- # - additional_inputs: tuple of additional arguments (may contain lifted params/buffers)
106- combine_fn , init , xs , additional_inputs = node .args
107- combine_gm = getattr (gm , combine_fn .name )
108- inp_pos_to_buffer_name_for_submod = {}
109- real_additional_inputs = []
110-
111- # additional_inputs may contain lifted parameters/buffers that need to be
112- # registered in the combine_fn submodule
113- for ix , operand in enumerate (additional_inputs ):
114- if (
115- hasattr (operand , "target" )
116- and operand .target in inp_pos_to_param_buffer_name .values ()
117- ):
118- # This is a lifted param/buffer, register it in the submodule
119- # The index needs to account for init and xs inputs to combine_fn
120- # combine_fn inputs: (*init, *xs_slice, *additional_inputs)
121- num_init = len (init ) if isinstance (init , (list , tuple )) else 1
122- num_xs = len (xs ) if isinstance (xs , (list , tuple )) else 1
123- adjusted_ix = num_init + num_xs + ix
124- inp_pos_to_buffer_name_for_submod [adjusted_ix ] = operand .target
125- combine_gm .register_buffer (
126- operand .target , state_dict [operand .target ]
127- )
128- else :
129- real_additional_inputs .append (operand )
130-
131- # Update node args with the filtered additional_inputs
132- node .args = (combine_fn , init , xs , tuple (real_additional_inputs ))
133-
134- _ , in_spec = pytree .tree_flatten ((init , xs , tuple (real_additional_inputs )))
135-
136- _unlift (
137- combine_gm ,
138- inp_pos_to_buffer_name_for_submod ,
139- in_spec ,
140- None ,
141- state_dict ,
142- )
143100 gm .graph .lint ()
144101 gm .graph .eliminate_dead_code ()
145102 gm .recompile ()
0 commit comments