-
Notifications
You must be signed in to change notification settings - Fork 318
[L2Norm] Avoid recompilation for variable-length inputs #669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughChanged two Triton kernel signatures in Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @retonym, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the efficiency of L2 normalization by allowing its Triton kernels to handle variable-length inputs without requiring recompilation. This change is particularly beneficial for training methodologies like Supervised Fine-Tuning (SFT) where input sequence lengths can vary. Additionally, it cleans up the codebase by removing an unused kernel argument. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively addresses the performance issue of kernel recompilation for variable-length inputs in L2Norm. By changing the sequence length T from a tl.constexpr to a regular argument in the Triton kernels, the compiled kernels can be reused across different sequence lengths, which is a significant improvement for scenarios like supervised fine-tuning. The removal of the unused NB argument and its associated logic is a good code cleanup. The changes are well-implemented and correctly isolated to the relevant functions.
|
LGTM. But I prefer to use Varlen and Autotune for better performance, and we found a 50% performance gap for different tune configurations |
gogongxt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BD argument also requires the input element size, so it may still trigger recompilation.
|
Maybe we can remove the recompile but keep autotune |
Hi @gogongxt |
Hi @zhiyuan1i , restore |
In infra scenario, the input element size will vary. MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))Therefor the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fla/modules/l2norm.py (1)
120-132: Bwd kernel change is consistent; consider dropping unusedepsfrom the kernelThe same treatment of
Tinl2norm_bwd_kernel(runtime arg +do_not_specialize=['T']) is consistent with the fwd kernel and should likewise avoid recompilation for changingT. Usage ofTis limited totl.make_block_ptrshapes/offsets, so removingtl.constexpris safe.Static analysis correctly points out that
epsinl2norm_bwd_kernelis unused; the backward math only depends onrstd, which already encodes the forwardeps. If you want to clean this up, you could removeepsfrom the Triton kernel signature and its call site while keeping the Python-levelepsparameter:-@triton.jit(do_not_specialize=['T']) -def l2norm_bwd_kernel( - y, - rstd, - dy, - dx, - eps, - T, +@triton.jit(do_not_specialize=['T']) +def l2norm_bwd_kernel( + y, + rstd, + dy, + dx, + T, D: tl.constexpr, BD: tl.constexpr, NB: tl.constexpr, BT: tl.constexpr, ): @@ if D <= 512: NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_bwd_kernel[grid]( y=y, rstd=rstd, dy=dy, dx=dx, - eps=eps, T=T, D=D, BD=BD, NB=NB, )If you prefer to keep
epsfor API symmetry, adding a short comment explaining that it is intentionally unused would also address the lint.Please verify that there are no other call sites to
l2norm_bwd_kernelbefore applying the signature change, and that your tests still pass after droppingepsfrom the kernel.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/modules/l2norm.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
🪛 Ruff (0.14.6)
fla/modules/l2norm.py
126-126: Unused function argument: eps
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (1)
fla/modules/l2norm.py (1)
86-97: T as runtime +do_not_specializein fwd kernel matches the varlen goalMaking
Ta runtime argument and adding@triton.jit(do_not_specialize=['T'])aligns with the intent to avoid recompilation when only the sequence length changes. Withinl2norm_fwd_kernel,Tis only used intl.make_block_ptrshapes and offsets, which do not requiretl.constexpr, and the host-side launch still passesTconsistently.This looks functionally correct and should reduce JIT cache pressure for variable-length batches while keeping autotune keyed on
D/NBas before.Please confirm on your target Triton version that
do_not_specialize=['T']indeed prevents additional specializations for differentTvalues (e.g., by checking the number of compiled variants or JIT logs).
|
Thank you |
The
Targument was previously marked astl.constexpr, forcing recompilation whenever the sequence length changed. This PR fixes this issue to improve performance in variable-length training (e.g., SFT).It also removes the
NBargument, which was unused in the kernel.Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.