Skip to content

Commit 3bbb02c

Browse files
committed
ok fix the spec issues at the source
1 parent d18aa78 commit 3bbb02c

File tree

3 files changed

+117
-158
lines changed

3 files changed

+117
-158
lines changed

exir/pass_base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -554,16 +554,25 @@ def call_scan(
554554
self,
555555
combine_fn: torch.fx.GraphModule,
556556
init: List[ProxyValue],
557-
xs: List[ProxyValue],
557+
xs: List[Argument],
558558
additional_inputs: List[ProxyValue],
559559
meta: NodeMetadata,
560560
) -> ProxyValue:
561-
xs_first_slice = _unstack_pytree([arg.data for arg in xs])[0]
562-
init_data = [arg.data for arg in init]
563-
additional_data = [arg.data for arg in additional_inputs]
561+
# Get the expected x element shapes from the combine_fn's placeholders
562+
# The combine_fn expects: (carry..., x_element..., additional_inputs...)
563+
combine_fn_placeholders = [
564+
n for n in combine_fn.graph.nodes if n.op == "placeholder"
565+
]
566+
num_init = len(init)
567+
# The x_element placeholders are at indices [num_init : num_init + num_xs]
568+
xs_element_data = []
569+
for i, x_proxy in enumerate(xs):
570+
ph = combine_fn_placeholders[num_init + i]
571+
# Use the placeholder's val which has the correct shape
572+
xs_element_data.append(ph.meta["val"])
564573

565574
combine_fn_result = self.call_submodule(
566-
combine_fn, tuple(init_data + xs_first_slice + additional_data)
575+
combine_fn, (*init , *xs_element_data , *additional_inputs)
567576
)
568577
assert combine_fn_result is not None
569578

exir/passes/spec_prop_pass.py

Lines changed: 0 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
from executorch.exir.delegate import executorch_call_delegate
1414
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
15-
from executorch.exir.schema import TensorShapeDynamism
1615
from executorch.exir.tensor import TensorSpec
1716
from torch.export.exported_program import ExportGraphSignature
1817
from torch.fx.node import Node
@@ -60,18 +59,15 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
6059
res = ExportPass()(graph_module)
6160
assert res is not None
6261
gm = res.graph_module
63-
6462
def get_spec(x):
6563
if hasattr(x, "meta"):
6664
return x.meta.get("spec", None)
6765
else:
6866
return None
69-
7067
for module in gm.modules():
7168
if isinstance(module, torch.fx.GraphModule):
7269
for node in module.graph.nodes:
7370
meta_val = node.meta.get("val", None)
74-
7571
if node.op == "output":
7672
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
7773
elif node.op == "call_function" and node.target == operator.getitem:
@@ -123,152 +119,3 @@ def update_placeholder_tensor_specs(
123119
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
124120
):
125121
spec.const = True
126-
127-
# pyre-ignore
128-
def placeholder(self, name: str, arg, meta):
129-
meta["spec"] = make_spec(arg)
130-
return super().placeholder(name, arg, meta)
131-
132-
# pyre-ignore
133-
def call_operator(self, op, args, kwargs, meta):
134-
args_data, kwargs_data = pytree.tree_map_only(
135-
ProxyValue, lambda x: x.data, (args, kwargs)
136-
)
137-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
138-
return super().call_operator(op, args, kwargs, meta)
139-
140-
# pyre-ignore
141-
def call_getitem(self, value, key: int, meta):
142-
meta["spec"] = value.node.meta["spec"][key]
143-
return super().call_getitem(value, key, meta)
144-
145-
# pyre-ignore
146-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
147-
# true_fn/false_fn return tensors of the same shape, so we can pick
148-
# either one here.
149-
*_, true_out_node = true_fn.graph.nodes
150-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
151-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
152-
153-
def call_while(
154-
self,
155-
cond_fn: torch.fx.GraphModule,
156-
body_fn: torch.fx.GraphModule,
157-
carried_inputs: List[ProxyValue],
158-
additional_inputs: List[ProxyValue],
159-
meta: NodeMetadata,
160-
):
161-
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
162-
return super().call_while(
163-
cond_fn, body_fn, carried_inputs, additional_inputs, meta
164-
)
165-
166-
def call_map(
167-
self,
168-
f: torch.fx.GraphModule,
169-
mapped_args: List[ProxyValue],
170-
operands: List[ProxyValue],
171-
meta: NodeMetadata,
172-
) -> ProxyValue:
173-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
174-
*_, body_out_node = f.graph.nodes
175-
body_out_node_fake_tensor = body_out_node.meta["val"]
176-
177-
# For dynamic shapes, initialize with size 0 in the mapped dimension.
178-
# The et_copy_index op will resize as it writes to each index.
179-
# Check if the mapped dimension is symbolic (dynamic).
180-
is_dynamic = isinstance(mapped_dim_size, torch.SymInt)
181-
init_size = 0 if is_dynamic else mapped_dim_size
182-
183-
map_fake_tensor = pytree.tree_map_only(
184-
torch.Tensor,
185-
lambda x: x.new_empty(init_size, *x.shape),
186-
body_out_node_fake_tensor,
187-
)
188-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
189-
return super().call_map(f, mapped_args, operands, meta)
190-
191-
def call_scan(
192-
self,
193-
combine_fn: torch.fx.GraphModule,
194-
init: List[ProxyValue],
195-
xs: List[ProxyValue],
196-
additional_inputs: List[ProxyValue],
197-
meta: NodeMetadata,
198-
) -> ProxyValue:
199-
# Get the scan length - this may be symbolic for dynamic shapes
200-
xs_tensor = [arg.data for arg in xs][0]
201-
scan_length = xs_tensor.size(0)
202-
203-
*_, body_out_node = combine_fn.graph.nodes
204-
body_out_fake = body_out_node.meta["val"]
205-
206-
num_carry = len(init)
207-
flat_body_out, out_spec = pytree.tree_flatten(body_out_fake)
208-
209-
carry_out = flat_body_out[:num_carry]
210-
y_out = flat_body_out[num_carry:]
211-
212-
# Check if the scan dimension is symbolic (dynamic)
213-
is_dynamic = isinstance(scan_length, torch.SymInt)
214-
215-
# For the y outputs, we need to use the upper bound size to allocate memory,
216-
# but also mark the tensor spec as DYNAMIC_BOUND so it can be resized at runtime.
217-
if is_dynamic:
218-
# Get the upper bound by evaluating the symbolic int
219-
# Using hint gives us the concrete upper bound value
220-
upper_bound_size = scan_length.node.shape_env.size_hint(
221-
scan_length.node.expr
222-
)
223-
else:
224-
upper_bound_size = scan_length
225-
226-
carry_fake = carry_out
227-
y_fake = [
228-
(
229-
x.new_empty(upper_bound_size, *x.shape)
230-
if isinstance(x, torch.Tensor)
231-
else x
232-
)
233-
for x in y_out
234-
]
235-
236-
combined_fake = carry_fake + y_fake
237-
238-
# Create specs from the fake tensors
239-
specs = pytree.tree_map(make_spec, combined_fake)
240-
241-
# For dynamic shapes, mark the y_output specs as DYNAMIC_BOUND
242-
# so that et_copy_index can resize them at runtime
243-
if is_dynamic and isinstance(specs, list):
244-
for i in range(num_carry, len(specs)):
245-
if isinstance(specs[i], TensorSpec):
246-
specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
247-
248-
meta["spec"] = specs
249-
return super().call_scan(combine_fn, init, xs, additional_inputs, meta)
250-
251-
# pyre-ignore
252-
def call_delegate(self, lowered_module, args, kwargs, meta):
253-
args_data, kwargs_data = pytree.tree_map_only(
254-
ProxyValue, lambda x: x.data, (args, kwargs)
255-
)
256-
# If spec is missing, re-genenrate it with args data
257-
if "spec" not in meta:
258-
meta["spec"] = pytree.tree_map(
259-
make_spec,
260-
executorch_call_delegate(lowered_module, *args_data),
261-
)
262-
return super().call_delegate(lowered_module, args, kwargs, meta)
263-
264-
# pyre-ignore
265-
def output(self, results, meta):
266-
# pyre-ignore
267-
def get_spec(x):
268-
if isinstance(x, ProxyValue):
269-
return x.node.meta["spec"]
270-
else:
271-
return make_spec(x)
272-
273-
meta["spec"] = pytree.tree_map(get_spec, results)
274-
return super().output(results, meta)

exir/tests/test_passes.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,109 @@ def loop_body(i, acc):
745745
upper_bound = eval_upper_bound(spec[1].shape[0])
746746
self.assertEqual(upper_bound, 10)
747747

748+
def test_spec_prop_pass_scan(self) -> None:
749+
from torch._higher_order_ops.scan import scan
750+
751+
class ModelWithScan(torch.nn.Module):
752+
def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
753+
def combine_fn(carry, x):
754+
new_carry = carry + x
755+
return new_carry, new_carry.clone()
756+
757+
init = torch.zeros_like(xs[0])
758+
return scan(combine_fn, init, xs)
759+
760+
model = ModelWithScan()
761+
inputs = (torch.arange(15).float().reshape(5, 3),)
762+
763+
# Run the spec prop pass and sanity check the spec on the scan.
764+
edge_program = to_edge(
765+
export(model, inputs, strict=True),
766+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
767+
)
768+
gm = edge_program.exported_program().graph_module
769+
new_gm = SpecPropPass()(gm)
770+
self.assertIsNotNone(new_gm)
771+
772+
# Check the spec for the scan node.
773+
scan_node = next(
774+
n
775+
for n in new_gm.graph_module.graph.nodes
776+
if hasattr(n.target, "name") and n.target.name() == "scan"
777+
)
778+
self.assertIsNotNone(scan_node)
779+
780+
# Spec for the scan node should be a two-element tuple (carry, stacked_outputs)
781+
spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"]
782+
self.assertTrue(isinstance(spec, tuple))
783+
self.assertEqual(len(spec), 2)
784+
785+
# Carry should have shape [3] (same as xs[0])
786+
self.assertEqual(list(spec[0].shape), [3])
787+
# Stacked outputs should have shape [5, 3] (same as xs)
788+
self.assertEqual(list(spec[1].shape), [5, 3])
789+
790+
def test_spec_prop_pass_scan_dynamic_shape(self) -> None:
791+
from torch._higher_order_ops.scan import scan
792+
793+
class ModelWithScan(torch.nn.Module):
794+
def forward(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
795+
def combine_fn(carry, x):
796+
new_carry = carry + x
797+
return new_carry, new_carry.clone()
798+
799+
init = torch.zeros_like(xs[0])
800+
return scan(combine_fn, init, xs)
801+
802+
model = ModelWithScan()
803+
inputs = (torch.arange(15).float().reshape(5, 3),)
804+
dynamic_shapes = {"xs": {0: torch.export.Dim("seq_len", min=1, max=20)}}
805+
806+
# First verify that export preserves symbolic shapes
807+
exported = export(model, inputs, dynamic_shapes=dynamic_shapes, strict=True)
808+
scan_node_after_export = next(
809+
n
810+
for n in exported.graph.nodes
811+
if hasattr(n.target, "name") and n.target.name() == "scan"
812+
)
813+
val_after_export = scan_node_after_export.meta.get("val")
814+
self.assertIsNotNone(val_after_export)
815+
# After export, the stacked output should have a symbolic first dimension
816+
self.assertIsInstance(val_after_export[1].shape[0], torch.SymInt)
817+
818+
# Run the spec prop pass and sanity check the spec on the scan.
819+
edge_program = to_edge(
820+
exported,
821+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
822+
)
823+
gm = edge_program.exported_program().graph_module
824+
new_gm = SpecPropPass()(gm)
825+
self.assertIsNotNone(new_gm)
826+
827+
# Check the spec for the scan node.
828+
scan_node = next(
829+
n
830+
for n in new_gm.graph_module.graph.nodes
831+
if hasattr(n.target, "name") and n.target.name() == "scan"
832+
)
833+
self.assertIsNotNone(scan_node)
834+
835+
# Spec for the scan node should be a two-element tuple (carry, stacked_outputs)
836+
spec: Tuple[TensorSpec, TensorSpec] = scan_node.meta["spec"]
837+
self.assertTrue(isinstance(spec, tuple))
838+
self.assertEqual(len(spec), 2)
839+
840+
# Carry should have static shape [3]
841+
self.assertEqual(list(spec[0].shape), [3])
842+
self.assertEqual(spec[0].shape_dynamism, TensorShapeDynamism.STATIC)
843+
844+
# Stacked outputs should have dynamic first dimension with upper bound 20
845+
self.assertEqual(len(spec[1].shape), 2)
846+
upper_bound = eval_upper_bound(spec[1].shape[0])
847+
self.assertEqual(upper_bound, 20)
848+
self.assertEqual(spec[1].shape_dynamism, TensorShapeDynamism.DYNAMIC_BOUND)
849+
self.assertEqual(spec[1].shape[1], 3) # Second dim is static
850+
748851
def test_compile_fix_broken_ops(self) -> None:
749852
class ExportableLoop(nn.Module):
750853
def __init__(self, hidden_size, out_channels):

0 commit comments

Comments
 (0)