Skip to content

Conversation

@JacobSzwejbka
Copy link
Contributor

Add support for higher order ops scan. Its inefficient today because we are manually deep copying from output to input for every carry. We could do better by shallow swapping the pointers but Ill do that in a follow up if needed.

Test plan: Unit tests and internal verification against harder patterns

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16028

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 90e55dd with merge base 9eaea4a (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 1, 2025
@github-actions
Copy link

github-actions bot commented Dec 1, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync
Copy link

meta-codesync bot commented Dec 1, 2025

@JacobSzwejbka has imported this pull request. If you are a Meta employee, you can view this in D88107948.

@JacobSzwejbka JacobSzwejbka changed the title [WIP] Scan support Scan support Dec 9, 2025
op_table = program.execution_plan[0].operators
instructions = program.execution_plan[0].chains[0].instructions

# Collect all operator names in the program
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly all the ops seem like implementation details and should not be tested

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using it as a sort of a proxy that the general pattern was emitted. If you want we can just test the end 2 end behavior though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have a strong opinion, but you might have to maintain this test if there's a change to the exported graph in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ops we are querying over are the ones /not/ in the original model definition but instead created by the emitter to maintain the semantics of scan

Comment on lines +978 to +980
2. et_copy_index(y_outputs, combine_fn's y output, iter_idx)
This explicit copy approach is used because in-place op.out(x, out=x) is unsafe.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that this might be fine. We basically emit scan at the very end of the lowering process and I'm not convinced we still require the graph to be functional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the problem isnt being functional its that aten (and ET ops) are not guaranteed to work when in and out alias the same memory.

You could very easily write before read over sections of the tensor.

meta,
)

def call_scan(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angelayi can you check that Im not doing anything stupid here

# Use the placeholder's val which has the correct shape
xs_element_data.append(ph.meta["val"])

combine_fn_result = self.call_submodule(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly copied torch.cond here with running call_submodul. Is this just so the subgraph also gets a chance to be run over by spec prop before callign the original? It just seems weird Im calling scan on this subgraph instead of the original one passed in as an arg

@JacobSzwejbka JacobSzwejbka merged commit fae5d1b into main Dec 11, 2025
164 of 166 checks passed
@JacobSzwejbka JacobSzwejbka deleted the scan_support branch December 11, 2025 17:41
for i in range(0, len(xs)):
ph = combine_fn_placeholders[num_init + i]
# Use the placeholder's val which has the correct shape
xs_element_data.append(ph.meta["val"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this part is a little sus where you look at the subgraph's placeholder nodes. I think the xs_element_data should just be something like, xs[0]?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants