File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -1189,7 +1189,7 @@ def _emit_control_flow(
11891189 """
11901190 specs = self .node .meta ["spec" ]
11911191
1192- # For scan, set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
1192+ # For scan/map , set the shape_dynamism for the stacked outputs (y_outputs) to DYNAMIC_BOUND
11931193 # BEFORE emitting the specs. This is because et_copy_index has cat shape semantics but
11941194 # stack memory behavior, so we need to be able to update the shape +1 for each iteration
11951195 # which we can't do for tensors marked static.
@@ -1201,6 +1201,10 @@ def _emit_control_flow(
12011201 for y_spec in y_specs :
12021202 if isinstance (y_spec , TensorSpec ):
12031203 y_spec .shape_dynamism = TensorShapeDynamism .DYNAMIC_BOUND
1204+ elif target is torch .ops .higher_order .map_impl :
1205+ assert len (specs ) == 1
1206+ assert isinstance (specs [0 ], TensorSpec )
1207+ specs [0 ].shape_dynamism = TensorShapeDynamism .DYNAMIC_BOUND
12041208
12051209 subemitter_binding_output_values = pytree .tree_map (
12061210 lambda spec : self ._emit_spec (spec ),
You can’t perform that action at this time.
0 commit comments