Skip to content

Commit 6c3b0e1

Browse files
committed
Now add the pseudo hack in the emitter
1 parent 3bbb02c commit 6c3b0e1

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

exir/emit/_emitter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,19 @@ def _emit_control_flow(
11891189
"""
11901190
specs = self.node.meta["spec"]
11911191

1192+
# For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
1193+
# BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
1194+
# stack memory behavior, so we need to be able to update the shape +1 for each iteration
1195+
# which we can't do for tensors marked static.
1196+
if target is torch.ops.higher_order.scan:
1197+
combine_fn, init, xs, additional_inputs = args
1198+
num_carry = len(init)
1199+
if isinstance(specs, (list, tuple)):
1200+
y_specs = specs[num_carry:]
1201+
for y_spec in y_specs:
1202+
if isinstance(y_spec, TensorSpec):
1203+
y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
1204+
11921205
subemitter_binding_output_values = pytree.tree_map(
11931206
lambda spec: self._emit_spec(spec),
11941207
specs,

exir/pass_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def call_scan(
572572
xs_element_data.append(ph.meta["val"])
573573

574574
combine_fn_result = self.call_submodule(
575-
combine_fn, (*init , *xs_element_data , *additional_inputs)
575+
combine_fn, (*init, *xs_element_data, *additional_inputs)
576576
)
577577
assert combine_fn_result is not None
578578

exir/passes/spec_prop_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
5959
res = ExportPass()(graph_module)
6060
assert res is not None
6161
gm = res.graph_module
62+
6263
def get_spec(x):
6364
if hasattr(x, "meta"):
6465
return x.meta.get("spec", None)
6566
else:
6667
return None
68+
6769
for module in gm.modules():
6870
if isinstance(module, torch.fx.GraphModule):
6971
for node in module.graph.nodes:

0 commit comments

Comments
 (0)