|
12 | 12 | import torch |
13 | 13 | from executorch.exir.delegate import executorch_call_delegate |
14 | 14 | from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue |
15 | | -from executorch.exir.schema import TensorShapeDynamism |
16 | 15 | from executorch.exir.tensor import TensorSpec |
17 | 16 | from torch.export.exported_program import ExportGraphSignature |
18 | 17 | from torch.fx.node import Node |
@@ -60,18 +59,15 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: |
60 | 59 | res = ExportPass()(graph_module) |
61 | 60 | assert res is not None |
62 | 61 | gm = res.graph_module |
63 | | - |
64 | 62 | def get_spec(x): |
65 | 63 | if hasattr(x, "meta"): |
66 | 64 | return x.meta.get("spec", None) |
67 | 65 | else: |
68 | 66 | return None |
69 | | - |
70 | 67 | for module in gm.modules(): |
71 | 68 | if isinstance(module, torch.fx.GraphModule): |
72 | 69 | for node in module.graph.nodes: |
73 | 70 | meta_val = node.meta.get("val", None) |
74 | | - |
75 | 71 | if node.op == "output": |
76 | 72 | node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) |
77 | 73 | elif node.op == "call_function" and node.target == operator.getitem: |
@@ -123,152 +119,3 @@ def update_placeholder_tensor_specs( |
123 | 119 | in exported_program.graph_signature.inputs_to_lifted_tensor_constants |
124 | 120 | ): |
125 | 121 | spec.const = True |
126 | | - |
127 | | - # pyre-ignore |
128 | | - def placeholder(self, name: str, arg, meta): |
129 | | - meta["spec"] = make_spec(arg) |
130 | | - return super().placeholder(name, arg, meta) |
131 | | - |
132 | | - # pyre-ignore |
133 | | - def call_operator(self, op, args, kwargs, meta): |
134 | | - args_data, kwargs_data = pytree.tree_map_only( |
135 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
136 | | - ) |
137 | | - meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) |
138 | | - return super().call_operator(op, args, kwargs, meta) |
139 | | - |
140 | | - # pyre-ignore |
141 | | - def call_getitem(self, value, key: int, meta): |
142 | | - meta["spec"] = value.node.meta["spec"][key] |
143 | | - return super().call_getitem(value, key, meta) |
144 | | - |
145 | | - # pyre-ignore |
146 | | - def call_cond(self, pred, true_fn, false_fn, inputs, meta): |
147 | | - # true_fn/false_fn return tensors of the same shape, so we can pick |
148 | | - # either one here. |
149 | | - *_, true_out_node = true_fn.graph.nodes |
150 | | - meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) |
151 | | - return super().call_cond(pred, true_fn, false_fn, inputs, meta) |
152 | | - |
153 | | - def call_while( |
154 | | - self, |
155 | | - cond_fn: torch.fx.GraphModule, |
156 | | - body_fn: torch.fx.GraphModule, |
157 | | - carried_inputs: List[ProxyValue], |
158 | | - additional_inputs: List[ProxyValue], |
159 | | - meta: NodeMetadata, |
160 | | - ): |
161 | | - meta["spec"] = pytree.tree_map(make_spec, carried_inputs) |
162 | | - return super().call_while( |
163 | | - cond_fn, body_fn, carried_inputs, additional_inputs, meta |
164 | | - ) |
165 | | - |
166 | | - def call_map( |
167 | | - self, |
168 | | - f: torch.fx.GraphModule, |
169 | | - mapped_args: List[ProxyValue], |
170 | | - operands: List[ProxyValue], |
171 | | - meta: NodeMetadata, |
172 | | - ) -> ProxyValue: |
173 | | - mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) |
174 | | - *_, body_out_node = f.graph.nodes |
175 | | - body_out_node_fake_tensor = body_out_node.meta["val"] |
176 | | - |
177 | | - # For dynamic shapes, initialize with size 0 in the mapped dimension. |
178 | | - # The et_copy_index op will resize as it writes to each index. |
179 | | - # Check if the mapped dimension is symbolic (dynamic). |
180 | | - is_dynamic = isinstance(mapped_dim_size, torch.SymInt) |
181 | | - init_size = 0 if is_dynamic else mapped_dim_size |
182 | | - |
183 | | - map_fake_tensor = pytree.tree_map_only( |
184 | | - torch.Tensor, |
185 | | - lambda x: x.new_empty(init_size, *x.shape), |
186 | | - body_out_node_fake_tensor, |
187 | | - ) |
188 | | - meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) |
189 | | - return super().call_map(f, mapped_args, operands, meta) |
190 | | - |
191 | | - def call_scan( |
192 | | - self, |
193 | | - combine_fn: torch.fx.GraphModule, |
194 | | - init: List[ProxyValue], |
195 | | - xs: List[ProxyValue], |
196 | | - additional_inputs: List[ProxyValue], |
197 | | - meta: NodeMetadata, |
198 | | - ) -> ProxyValue: |
199 | | - # Get the scan length - this may be symbolic for dynamic shapes |
200 | | - xs_tensor = [arg.data for arg in xs][0] |
201 | | - scan_length = xs_tensor.size(0) |
202 | | - |
203 | | - *_, body_out_node = combine_fn.graph.nodes |
204 | | - body_out_fake = body_out_node.meta["val"] |
205 | | - |
206 | | - num_carry = len(init) |
207 | | - flat_body_out, out_spec = pytree.tree_flatten(body_out_fake) |
208 | | - |
209 | | - carry_out = flat_body_out[:num_carry] |
210 | | - y_out = flat_body_out[num_carry:] |
211 | | - |
212 | | - # Check if the scan dimension is symbolic (dynamic) |
213 | | - is_dynamic = isinstance(scan_length, torch.SymInt) |
214 | | - |
215 | | - # For the y outputs, we need to use the upper bound size to allocate memory, |
216 | | - # but also mark the tensor spec as DYNAMIC_BOUND so it can be resized at runtime. |
217 | | - if is_dynamic: |
218 | | - # Get the upper bound by evaluating the symbolic int |
219 | | - # Using hint gives us the concrete upper bound value |
220 | | - upper_bound_size = scan_length.node.shape_env.size_hint( |
221 | | - scan_length.node.expr |
222 | | - ) |
223 | | - else: |
224 | | - upper_bound_size = scan_length |
225 | | - |
226 | | - carry_fake = carry_out |
227 | | - y_fake = [ |
228 | | - ( |
229 | | - x.new_empty(upper_bound_size, *x.shape) |
230 | | - if isinstance(x, torch.Tensor) |
231 | | - else x |
232 | | - ) |
233 | | - for x in y_out |
234 | | - ] |
235 | | - |
236 | | - combined_fake = carry_fake + y_fake |
237 | | - |
238 | | - # Create specs from the fake tensors |
239 | | - specs = pytree.tree_map(make_spec, combined_fake) |
240 | | - |
241 | | - # For dynamic shapes, mark the y_output specs as DYNAMIC_BOUND |
242 | | - # so that et_copy_index can resize them at runtime |
243 | | - if is_dynamic and isinstance(specs, list): |
244 | | - for i in range(num_carry, len(specs)): |
245 | | - if isinstance(specs[i], TensorSpec): |
246 | | - specs[i].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND |
247 | | - |
248 | | - meta["spec"] = specs |
249 | | - return super().call_scan(combine_fn, init, xs, additional_inputs, meta) |
250 | | - |
251 | | - # pyre-ignore |
252 | | - def call_delegate(self, lowered_module, args, kwargs, meta): |
253 | | - args_data, kwargs_data = pytree.tree_map_only( |
254 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
255 | | - ) |
256 | | - # If spec is missing, re-genenrate it with args data |
257 | | - if "spec" not in meta: |
258 | | - meta["spec"] = pytree.tree_map( |
259 | | - make_spec, |
260 | | - executorch_call_delegate(lowered_module, *args_data), |
261 | | - ) |
262 | | - return super().call_delegate(lowered_module, args, kwargs, meta) |
263 | | - |
264 | | - # pyre-ignore |
265 | | - def output(self, results, meta): |
266 | | - # pyre-ignore |
267 | | - def get_spec(x): |
268 | | - if isinstance(x, ProxyValue): |
269 | | - return x.node.meta["spec"] |
270 | | - else: |
271 | | - return make_spec(x) |
272 | | - |
273 | | - meta["spec"] = pytree.tree_map(get_spec, results) |
274 | | - return super().output(results, meta) |
0 commit comments