Skip to content

Commit 04264ac

Browse files
committed
back fix map
1 parent 6c3b0e1 commit 04264ac

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

exir/emit/_emitter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ def _emit_control_flow(
11891189
"""
11901190
specs = self.node.meta["spec"]
11911191

1192-
# For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
1192+
# For scan/map, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
11931193
# BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
11941194
# stack memory behavior, so we need to be able to update the shape +1 for each iteration
11951195
# which we can't do for tensors marked static.
@@ -1201,6 +1201,10 @@ def _emit_control_flow(
12011201
for y_spec in y_specs:
12021202
if isinstance(y_spec, TensorSpec):
12031203
y_spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
1204+
elif target is torch.ops.higher_order.map_impl:
1205+
assert len(specs) == 1
1206+
assert isinstance(specs[0], TensorSpec)
1207+
specs[0].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
12041208

12051209
subemitter_binding_output_values = pytree.tree_map(
12061210
lambda spec: self._emit_spec(spec),

0 commit comments

Comments
 (0)