Skip to content

Commit 9b618dd

Browse files
committed
ditch useless unlift change
1 parent bfd7f36 commit 9b618dd

File tree

1 file changed

+0
-43
lines changed

1 file changed

+0
-43
lines changed

exir/capture/_unlift.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -97,49 +97,6 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
9797
_unlift(
9898
body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
9999
)
100-
if node.op == "call_function" and node.target.__name__ == "scan":
101-
# scan signature: scan(combine_fn, init, xs, additional_inputs)
102-
# - combine_fn: GraphModule for the scan body
103-
# - init: list of initial carry tensors
104-
# - xs: list of input tensors to scan over
105-
# - additional_inputs: tuple of additional arguments (may contain lifted params/buffers)
106-
combine_fn, init, xs, additional_inputs = node.args
107-
combine_gm = getattr(gm, combine_fn.name)
108-
inp_pos_to_buffer_name_for_submod = {}
109-
real_additional_inputs = []
110-
111-
# additional_inputs may contain lifted parameters/buffers that need to be
112-
# registered in the combine_fn submodule
113-
for ix, operand in enumerate(additional_inputs):
114-
if (
115-
hasattr(operand, "target")
116-
and operand.target in inp_pos_to_param_buffer_name.values()
117-
):
118-
# This is a lifted param/buffer, register it in the submodule
119-
# The index needs to account for init and xs inputs to combine_fn
120-
# combine_fn inputs: (*init, *xs_slice, *additional_inputs)
121-
num_init = len(init) if isinstance(init, (list, tuple)) else 1
122-
num_xs = len(xs) if isinstance(xs, (list, tuple)) else 1
123-
adjusted_ix = num_init + num_xs + ix
124-
inp_pos_to_buffer_name_for_submod[adjusted_ix] = operand.target
125-
combine_gm.register_buffer(
126-
operand.target, state_dict[operand.target]
127-
)
128-
else:
129-
real_additional_inputs.append(operand)
130-
131-
# Update node args with the filtered additional_inputs
132-
node.args = (combine_fn, init, xs, tuple(real_additional_inputs))
133-
134-
_, in_spec = pytree.tree_flatten((init, xs, tuple(real_additional_inputs)))
135-
136-
_unlift(
137-
combine_gm,
138-
inp_pos_to_buffer_name_for_submod,
139-
in_spec,
140-
None,
141-
state_dict,
142-
)
143100
gm.graph.lint()
144101
gm.graph.eliminate_dead_code()
145102
gm.recompile()

0 commit comments

Comments
 (0)