Skip to content

Commit bfd7f36

Browse files
committed
clean up
1 parent d563028 commit bfd7f36

File tree

5 files changed

+13
-123
lines changed

5 files changed

+13
-123
lines changed

exir/emit/_emitter.py

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -993,19 +993,14 @@ def _emit_scan(
993993
num_carry = len(init)
994994
num_xs = len(xs)
995995

996-
# Split output values into carry outputs and y outputs
997996
carry_outputs = list(subemitter_binding_output_values[:num_carry])
998997
y_outputs = list(subemitter_binding_output_values[num_carry:])
999998

1000999
if num_xs < 1:
10011000
raise RuntimeError("Scan requires at least one xs tensor to scan over.")
10021001

1003-
# === INITIALIZATION ===
1004-
1005-
# Generate iterator index EValue
10061002
iter_idx = self._emit_evalue(EValue(Int(0)))
10071003

1008-
# Get scan length from first xs tensor
10091004
op_index, op = self._get_operator(
10101005
name="aten::sym_size",
10111006
overload="int",
@@ -1019,9 +1014,7 @@ def _emit_scan(
10191014
)
10201015
self.chain.instructions.append(kernel)
10211016

1022-
# Initialize carry_outputs from init by copying init -> carry_outputs
1023-
# This is necessary because we shouldn't mutate the original init tensors
1024-
# Use aten::copy_.default which copies src to self in-place
1017+
# Initialize carry_outputs from init
10251018
op_index_copy, _ = self._get_operator(name="aten::copy_")
10261019
for init_val, carry_out in zip(init, carry_outputs):
10271020
kernel = Instruction(
@@ -1037,11 +1030,7 @@ def _emit_scan(
10371030
)
10381031
self.chain.instructions.append(kernel)
10391032

1040-
# === LOOP START ===
1041-
10421033
# Slice each xs tensor for the current iteration
1043-
# We use -1 as placeholder for the output tensor id, which will be filled
1044-
# after the scan_emitter runs and allocates the input placeholder EValues
10451034
op_index_select, _ = self._get_operator(
10461035
name="aten::select_copy",
10471036
overload="int_out",
@@ -1053,69 +1042,46 @@ def _emit_scan(
10531042
op_index=op_index_select,
10541043
args=[
10551044
x.id,
1056-
self._emit_evalue(EValue(Int(0))).id, # dim=0
1045+
self._emit_evalue(EValue(Int(0))).id,
10571046
iter_idx.id,
1058-
-1, # placeholder for output tensor id
1059-
-1, # placeholder (repeated for out variant)
1047+
-1,
1048+
-1,
10601049
],
10611050
)
10621051
)
10631052
xs_slice_instructions.append(kernel)
10641053

1065-
# Store jump target - this is where we jump back to after each iteration
10661054
jump_to_instruction = self.instruction_start_offset + len(
10671055
self.chain.instructions
10681056
)
10691057

1070-
# Add all xs slice instructions
10711058
for kernel in xs_slice_instructions:
10721059
self.chain.instructions.append(kernel)
10731060

1074-
# === EMIT COMBINE_FN SUBMODULE ===
1075-
1076-
# combine_fn inputs: (*carry, *xs_slice, *additional_inputs)
1077-
# We bind carry inputs to carry_outputs (the working carry buffers)
1078-
# xs_slice inputs will be filled in after emitter runs (using -1 placeholder)
1079-
# additional_inputs are passed through directly
1061+
# Emit combine_fn submodule
10801062
binding_input_values: List[Any] = []
1081-
binding_input_values.extend(
1082-
carry_outputs
1083-
) # Carry inputs bound to carry_outputs
1084-
binding_input_values.extend([-1] * num_xs) # Placeholders for xs slices
1085-
binding_input_values.extend(additional_inputs) # Additional inputs
1086-
1087-
# combine_fn outputs: (*next_carry, *y_slice)
1088-
# Pass binding_output_values=None so the combine_fn writes directly to its
1089-
# own output buffers (concrete_output_ids). We then copy from these directly
1090-
# to the final carry/y buffers, avoiding unnecessary temp buffers and MOVEs.
1063+
binding_input_values.extend(carry_outputs)
1064+
binding_input_values.extend([-1] * num_xs)
1065+
binding_input_values.extend(additional_inputs)
1066+
10911067
scan_emitter = _Emitter(
10921068
combine_fn,
10931069
self.emitter_state,
10941070
self.program_state,
10951071
instruction_start_offset=self.instruction_start_offset
10961072
+ len(self.chain.instructions),
10971073
binding_input_values=binding_input_values,
1098-
binding_output_values=None, # Use concrete outputs directly
1074+
binding_output_values=None,
10991075
)
11001076
scan_emitter.run()
11011077

1102-
# Merge combine_fn instructions
11031078
self._merge_chain(scan_emitter.chain)
1104-
# NOTE: When binding_output_values=None, no return/move instruction is added
1105-
# by the output() method, so we don't need to pop anything.
11061079

1107-
# Update xs_slice instructions with the actual placeholder EValue ids
1108-
# The xs placeholders start after the carry inputs in combine_fn
11091080
for i, kernel in enumerate(xs_slice_instructions):
11101081
xs_placeholder_id = scan_emitter.binding_input_values[num_carry + i].id
11111082
kernel.instr_args.args[-1] = xs_placeholder_id
11121083
kernel.instr_args.args[-2] = xs_placeholder_id
11131084

1114-
# === COPY OUTPUTS ===
1115-
1116-
# Get combine_fn's actual output EValues
1117-
# concrete_output_ids contains the actual EValues that the combine_fn
1118-
# graph operations write to: (*carry_temp, *y_temp)
11191085
concrete_outputs = scan_emitter.concrete_output_ids
11201086
carry_temp = concrete_outputs[:num_carry]
11211087
y_temp = concrete_outputs[num_carry:]
@@ -1132,8 +1098,6 @@ def _emit_scan(
11321098
)
11331099

11341100
# Copy carry_temp -> carry_outputs for next iteration
1135-
# This explicit copy is required because in-place op.out(x, out=x) is unsafe
1136-
# aten::copy_ signature: (self, src, non_blocking, out) -> self
11371101
for carry_t, carry_out in zip(carry_temp, carry_outputs):
11381102
kernel = Instruction(
11391103
KernelCall(
@@ -1148,7 +1112,7 @@ def _emit_scan(
11481112
)
11491113
self.chain.instructions.append(kernel)
11501114

1151-
# Copy y_temp to stacked y_outputs using et_copy_index
1115+
# Copy y_temp to stacked y_outputs
11521116
op_index_copy_index, _ = self._get_operator(
11531117
name="executorch_prim::et_copy_index",
11541118
overload="tensor",
@@ -1162,8 +1126,6 @@ def _emit_scan(
11621126
)
11631127
self.chain.instructions.append(kernel)
11641128

1165-
# === LOOP CONTROL ===
1166-
11671129
# Increment iter_idx
11681130
op_index_add, _ = self._get_operator(
11691131
name="executorch_prim::add",
@@ -1191,7 +1153,6 @@ def _emit_scan(
11911153
)
11921154
self.chain.instructions.append(kernel)
11931155

1194-
# Jump back to loop start if not done
11951156
jf_beginning_loop = Instruction(
11961157
JumpFalseCall(
11971158
cond_value_index=jump_bool_value.id,
@@ -1200,9 +1161,7 @@ def _emit_scan(
12001161
)
12011162
self.chain.instructions.append(jf_beginning_loop)
12021163

1203-
# === CLEANUP ===
1204-
1205-
# Reset iter_idx for potential re-runs of the model
1164+
# Reset iter_idx for potential re-runs
12061165
op_index_sub, _ = self._get_operator(
12071166
name="executorch_prim::sub",
12081167
overload="Scalar",

exir/memory_planning.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,11 +1024,6 @@ def get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
10241024

10251025

10261026
def get_scan_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
1027-
"""Get all scan nodes in the graph module.
1028-
1029-
Scan nodes have the signature: scan(combine_fn, init, xs, additional_inputs)
1030-
where combine_fn is a submodule at args[0].
1031-
"""
10321027
for nd in graph_module.graph.nodes:
10331028
if nd.target is torch.ops.higher_order.scan:
10341029
yield nd
@@ -1172,12 +1167,6 @@ def _handle(
11721167
for map_node in get_map_nodes(graph_module):
11731168
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)
11741169

1175-
# Handle scan nodes
1176-
# Scan signature: scan(combine_fn, init, xs, additional_inputs)
1177-
# combine_fn is at args[0]
1178-
# Like map, scan needs alloc_graph_input=True because the runtime slices
1179-
# xs tensors during each iteration, requiring allocated input buffers.
1180-
# Additionally, scan has carry state that flows between iterations.
11811170
for scan_node in get_scan_nodes(graph_module):
11821171
_handle(cast(torch.fx.Node, scan_node.args[0]), alloc_graph_input=True)
11831172

exir/pass_base.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -558,29 +558,10 @@ def call_scan(
558558
additional_inputs: List[ProxyValue],
559559
meta: NodeMetadata,
560560
) -> ProxyValue:
561-
"""
562-
Process a scan higher-order operation.
563-
564-
Scan applies combine_fn iteratively, carrying state across iterations:
565-
combine_fn(carry, x_slice) -> (next_carry, y_slice)
566-
567-
Args:
568-
combine_fn: GraphModule implementing the scan body
569-
init: Initial carry state values
570-
xs: Input tensors to scan over (along dim 0)
571-
additional_inputs: Additional arguments passed to combine_fn
572-
meta: Node metadata
573-
574-
Returns:
575-
ProxyValue containing (final_carry, stacked_outputs)
576-
"""
577-
# Get the first slice of xs to determine input shapes for combine_fn
578-
# combine_fn inputs: (*init, *xs_slice, *additional_inputs)
579561
xs_first_slice = _unstack_pytree([arg.data for arg in xs])[0]
580562
init_data = [arg.data for arg in init]
581563
additional_data = [arg.data for arg in additional_inputs]
582564

583-
# Call submodule with representative inputs
584565
combine_fn_result = self.call_submodule(
585566
combine_fn, tuple(init_data + xs_first_slice + additional_data)
586567
)

exir/passes/spec_prop_pass.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -150,61 +150,23 @@ def call_scan(
150150
additional_inputs: List[ProxyValue],
151151
meta: NodeMetadata,
152152
) -> ProxyValue:
153-
"""
154-
Propagate specs for scan higher-order operation.
155-
156-
Scan returns (final_carry, stacked_outputs) where:
157-
- final_carry: Same shape as init (NOT stacked, just the final carry state)
158-
- stacked_outputs: Outputs stacked along dim 0 with scan_length
159-
160-
The combine_fn signature is:
161-
combine_fn(*init, *xs_slice, *additional_inputs) -> (*next_carry, *y_slice)
162-
163-
So the combine_fn outputs are split into:
164-
- First len(init) outputs: carry values (same shape as init)
165-
- Remaining outputs: y values (to be stacked)
166-
167-
Memory Layout Note:
168-
The specs created here are for the FINAL outputs of the scan operation:
169-
- carry specs: Working carry buffers that persist across iterations.
170-
These are SEPARATE from combine_fn's output buffers. The emitter
171-
must copy from combine_fn's temporary carry output to these buffers
172-
after each iteration (in-place op.out(x, out=x) is unsafe).
173-
- y specs: Pre-allocated stacked buffers filled via et_copy_index.
174-
175-
The combine_fn's internal temporary buffers are allocated separately
176-
via memory planning with alloc_graph_input=True, alloc_graph_output=True.
177-
"""
178-
# Get scan length from first xs tensor
179153
scan_length = [arg.data for arg in xs][0].size(0)
180154

181-
# Get the output node from combine_fn
182155
*_, body_out_node = combine_fn.graph.nodes
183156
body_out_fake = body_out_node.meta["val"]
184157

185-
# The combine_fn outputs are: (*next_carry, *y_slice)
186-
# Split them based on the number of init values
187158
num_carry = len(init)
188-
189-
# Flatten the outputs to handle them uniformly
190159
flat_body_out, out_spec = pytree.tree_flatten(body_out_fake)
191160

192-
# Split into carry outputs and y outputs
193161
carry_out = flat_body_out[:num_carry]
194162
y_out = flat_body_out[num_carry:]
195163

196-
# Create specs:
197-
# - Carry: same shape as combine_fn output (NOT stacked)
198-
# These are working buffers that get updated each iteration
199-
# - Y: stacked along dim 0 with scan_length
200-
carry_fake = carry_out # Carry keeps same shape
201-
164+
carry_fake = carry_out
202165
y_fake = [
203166
x.new_empty(scan_length, *x.shape) if isinstance(x, torch.Tensor) else x
204167
for x in y_out
205168
]
206169

207-
# Combine carry and stacked y outputs
208170
combined_fake = carry_fake + y_fake
209171

210172
meta["spec"] = pytree.tree_map(make_spec, combined_fake)

exir/tests/control_flow_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def forward(
158158
self, xs: torch.Tensor, scale: torch.Tensor
159159
) -> tuple[torch.Tensor, torch.Tensor]:
160160
def combine_fn(carry, x):
161-
# Scale is captured from outer scope
162161
new_carry = carry + x * scale
163162
return new_carry, new_carry.clone()
164163

0 commit comments

Comments
 (0)