Skip to content

Conversation

@retonym
Copy link
Contributor

@retonym retonym commented Nov 28, 2025

The T argument was previously marked as tl.constexpr, forcing recompilation whenever the sequence length changed. This PR fixes this issue to improve performance in variable-length training (e.g., SFT).
It also removes the NB argument, which was unused in the kernel.

Summary by CodeRabbit

  • Refactor
    • Adjusted layer-normalization kernel parameter handling so a previously compile-time parameter is now treated at runtime.
    • Prevented compile-time specialization on that parameter to simplify kernel invocation paths.
    • No changes to normalization results, behavior, or performance; purely internal maintenance improvements.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 28, 2025

Walkthrough

Changed two Triton kernel signatures in fla/modules/l2norm.py: parameter T is no longer a compile-time tl.constexpr but a runtime argument, and do_not_specialize=['T'] was added to both @triton.jit decorators. Call sites continue passing T; no other control-flow or NB/autotune changes present.

Changes

Cohort / File(s) Summary
L2Norm kernel signature change
fla/modules/l2norm.py
Removed tl.constexpr from parameter T in l2norm_fwd_kernel and l2norm_bwd_kernel (now runtime parameters) and added do_not_specialize=['T'] to both @triton.jit decorators. No other parameter or NB/autotune logic changes.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Check fla/modules/l2norm.py for only the described signature/decorator edits.
  • Verify kernel invocation sites compile and run with T as a runtime parameter.

Possibly related PRs

Poem

🐰 I hopped through kernels, small and neat,
Unbound a T from compile-time seat.
Now runtime breezes let it roam,
Same math, new freedom, nearer home. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main change: treating T as a runtime parameter to avoid kernel recompilation for variable-length sequence inputs.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @retonym, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the efficiency of L2 normalization by allowing its Triton kernels to handle variable-length inputs without requiring recompilation. This change is particularly beneficial for training methodologies like Supervised Fine-Tuning (SFT) where input sequence lengths can vary. Additionally, it cleans up the codebase by removing an unused kernel argument.

Highlights

  • Variable-Length Input Support: The 'T' argument in L2 norm kernels is no longer 'tl.constexpr', preventing recompilation for different sequence lengths and improving performance for variable-length inputs.
  • Unused Argument Removal: The 'NB' argument, which was not utilized in the kernel, has been removed from the kernel definitions and their respective calls, simplifying the code.
  • Autotune Key Update: The 'key' parameter in the 'autotune' decorators for 'l2norm_bwd_kernel1' and 'l2norm_fwd_kernel' has been adjusted to reflect the removal of the 'NB' argument.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the performance issue of kernel recompilation for variable-length inputs in L2Norm. By changing the sequence length T from a tl.constexpr to a regular argument in the Triton kernels, the compiled kernels can be reused across different sequence lengths, which is a significant improvement for scenarios like supervised fine-tuning. The removal of the unused NB argument and its associated logic is a good code cleanup. The changes are well-implemented and correctly isolated to the relevant functions.

@zhiyuan1i
Copy link
Collaborator

zhiyuan1i commented Nov 28, 2025

LGTM. But I prefer to use Varlen and Autotune for better performance, and we found a 50% performance gap for different tune configurations

Copy link

@gogongxt gogongxt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BD argument also requires the input element size, so it may still trigger recompilation.

@zhiyuan1i
Copy link
Collaborator

Maybe we can remove the recompile but keep autotune

@retonym
Copy link
Contributor Author

retonym commented Dec 1, 2025

The BD argument also requires the input element size, so it may still trigger recompilation.

Hi @gogongxt
As I understand it, there are very few scenarios where a model's head dimension varies within the same training session. We may not worry about that.

@retonym
Copy link
Contributor Author

retonym commented Dec 1, 2025

Maybe we can remove the recompile but keep autotune

Hi @zhiyuan1i , restore NB for autotuning & recompilation, only remove T tl.constexpr.

@gogongxt
Copy link

gogongxt commented Dec 1, 2025

The BD argument also requires the input element size, so it may still trigger recompilation.

Hi @gogongxt As I understand it, there are very few scenarios where a model's head dimension varies within the same training session. We may not worry about that.

In infra scenario, the input element size will vary.

    MAX_FUSED_SIZE = 65536 // x.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))

Therefor the BD will change dynamically.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
fla/modules/l2norm.py (1)

120-132: Bwd kernel change is consistent; consider dropping unused eps from the kernel

The same treatment of T in l2norm_bwd_kernel (runtime arg + do_not_specialize=['T']) is consistent with the fwd kernel and should likewise avoid recompilation for changing T. Usage of T is limited to tl.make_block_ptr shapes/offsets, so removing tl.constexpr is safe.

Static analysis correctly points out that eps in l2norm_bwd_kernel is unused; the backward math only depends on rstd, which already encodes the forward eps. If you want to clean this up, you could remove eps from the Triton kernel signature and its call site while keeping the Python-level eps parameter:

-@triton.jit(do_not_specialize=['T'])
-def l2norm_bwd_kernel(
-    y,
-    rstd,
-    dy,
-    dx,
-    eps,
-    T,
+@triton.jit(do_not_specialize=['T'])
+def l2norm_bwd_kernel(
+    y,
+    rstd,
+    dy,
+    dx,
+    T,
     D: tl.constexpr,
     BD: tl.constexpr,
     NB: tl.constexpr,
     BT: tl.constexpr,
 ):
@@
     if D <= 512:
         NB = triton.cdiv(T, 2048)
         def grid(meta): return (triton.cdiv(T, meta['BT']), )
         l2norm_bwd_kernel[grid](
             y=y,
             rstd=rstd,
             dy=dy,
             dx=dx,
-            eps=eps,
             T=T,
             D=D,
             BD=BD,
             NB=NB,
         )

If you prefer to keep eps for API symmetry, adding a short comment explaining that it is intentionally unused would also address the lint.

Please verify that there are no other call sites to l2norm_bwd_kernel before applying the signature change, and that your tests still pass after dropping eps from the kernel.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1165df2 and 0b94179.

📒 Files selected for processing (1)
  • fla/modules/l2norm.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
🪛 Ruff (0.14.6)
fla/modules/l2norm.py

126-126: Unused function argument: eps

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (1)
fla/modules/l2norm.py (1)

86-97: T as runtime + do_not_specialize in fwd kernel matches the varlen goal

Making T a runtime argument and adding @triton.jit(do_not_specialize=['T']) aligns with the intent to avoid recompilation when only the sequence length changes. Within l2norm_fwd_kernel, T is only used in tl.make_block_ptr shapes and offsets, which do not require tl.constexpr, and the host-side launch still passes T consistently.

This looks functionally correct and should reduce JIT cache pressure for variable-length batches while keeping autotune keyed on D/NB as before.

Please confirm on your target Triton version that do_not_specialize=['T'] indeed prevents additional specializations for different T values (e.g., by checking the number of compiled variants or JIT logs).

@yzhangcs yzhangcs merged commit f10e9b1 into fla-org:main Dec 1, 2025
4 checks passed
@yzhangcs
Copy link
Member

yzhangcs commented Dec 1, 2025

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants