Skip to content

Conversation

@zhiyuan1i
Copy link
Collaborator

@zhiyuan1i zhiyuan1i commented Nov 29, 2025

Summary by CodeRabbit

Release Notes

  • New Features

    • Added chunked cumulative sum functionality for KDA gate operations with improved support for variable-length sequences and optional bias/scaling.
  • Improvements

    • Enhanced KDA gate forward and backward paths with optimized chunk processing and fallback logic.
    • Improved numerical consistency in gradient computation for gate-enabled operations.
  • Tests

    • Expanded test coverage for kernel-based gating pathways and added validation for chunk-cumsum equivalence.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 29, 2025

Walkthrough

This PR introduces a new Triton-based chunk-cumulative-sum kernel (kda_gate_chunk_cumsum) for KDA gating in gate.py, updates chunk.py to use this new kernel instead of the separate kda_gate_fwd and chunk_local_cumsum sequence, and expands test coverage in test_kda.py to validate gate-in-kernel pathways and cumsum equivalence.

Changes

Cohort / File(s) Summary
KDA Gate Implementation
fla/ops/kda/gate.py
Adds new Triton kernel kda_gate_chunk_cumsum_vector_kernel with autotune support for variable-length sequences, optional bias, scaling, and HEAD_FIRST layout. Introduces host wrapper kda_gate_chunk_cumsum that prepares grid parameters, validates chunk sizes as powers of two, and invokes the vector kernel.
KDA Chunk Forward Path
fla/ops/kda/chunk.py
Replaces gate forward path with kda_gate_chunk_cumsum in both gating-enabled and standard paths; propagates chunk_size parameter (set to 64) and passes cu_seqlens and chunk_indices. Standardizes chunk_size handling by introducing a local variable; maintains kda_gate_fwd in backward path.
Test Coverage Expansion
tests/ops/test_kda.py
Adds imports for kda_gate_chunk_cumsum and chunk_local_cumsum; introduces new test test_gate_cumsum validating equivalence between fused gate with chunked cumsum and chunked KDA with local cumsum. Extends existing tests with explicit float32 dtype handling for gate tensors, gate-in-kernel parameter propagation, and conditional backward assertions.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Triton kernel implementation (kda_gate_chunk_cumsum_vector_kernel): Dense kernel logic with autotune configurations, pointer arithmetic, and memory layout handling requires careful verification
  • Parameter propagation changes: Chunk size and indices threading through multiple code paths in chunk.py needs validation for correctness
  • Test refactoring scope: Comprehensive changes across multiple test functions with dtype casting, parameter renaming (use_tmause_gate_in_kernel), and new backward assertions increase review surface
  • Cross-file dependencies: Changes span gate implementation, chunking logic, and tests with interdependencies that need cohesive validation

Possibly related issues

Possibly related PRs

  • [KDA] Support fused forget gate #662: Modifies KDA gating and chunked KDA implementation; updates gate kernels and chunk callsites directly related to this PR's gate path changes.
  • Add KDA #621: Builds on KDA gate and chunk infrastructure; extends same files (fla/ops/kda/chunk.py and gate functions) introduced in the foundational KDA implementation.
  • [KDA]: Fuse beta.float().sigmoid() in fused_kda_gate #642: Updates KDA gate function signatures and return values in fla/ops/kda/gate.py; interacts with the same gating functions being refactored in this PR.

Poem

🐰 A kernel bloomed in Triton's garden fair,
Cumsum chunks gathered with utmost care,
Gate and gather now dance as one,
Fused and ready—let kernels run!
✨ Optimization hops through the test queue,
Flash-linear-attention shines anew! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: introducing fused KDA gate and cumsum functionality by combining gate computations with chunked cumulative sum operations across multiple files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fused_g_cumsum

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the KDA (Kernelized Dynamic Attention) implementation by fusing the KDA gate and chunk-local cumulative sum operations into a single, highly efficient Triton kernel. This fusion reduces overhead associated with multiple kernel launches and memory transfers, leading to improved performance for KDA-based models, especially in scenarios where gating is applied within the kernel. The changes are thoroughly tested to maintain numerical accuracy and functionality across various configurations.

Highlights

  • Fused KDA Gate and Cumulative Sum: Introduced a new Triton kernel kda_gate_chunk_cumsum that combines the KDA gating mechanism with a chunk-local cumulative sum operation, improving efficiency.
  • Integration into Chunk KDA: The chunk.py module now utilizes this fused operation when the use_gate_in_kernel flag is enabled, replacing a two-step process with a single optimized kernel call.
  • Comprehensive Testing: New and updated test cases ensure the correctness and numerical stability of the fused operation, including variable-length sequences and different gate configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhiyuan1i zhiyuan1i linked an issue Nov 29, 2025 that may be closed by this pull request
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused kernel for the KDA gate and chunked cumulative sum, which is a great optimization. The forward pass implementation and test updates look solid. However, I've identified a critical issue in the backward pass where gradient propagation through the chunk_local_cumsum operation appears to be missing. This could lead to incorrect gradients when use_gate_in_kernel=True. I've also included a couple of minor suggestions to improve code clarity and test hygiene.

if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion to check if chunk_size is a power of two is a bit obscure and might be hard to understand for future maintainers. A more conventional and clearer way to perform this check is by using bitwise operations.

Suggested change
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
assert (chunk_size > 0) and (chunk_size & (chunk_size - 1) == 0), "chunk_size must be a power of 2"

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
tests/ops/test_kda.py (3)

431-442: Remove duplicate test parameters.

The test parameter list contains duplicate entries that will run the same test twice unnecessarily:

  • (1, 2, 2, 12) - duplicated
  • (1, 32, 2, 16) - duplicated
  • (2, 64, 4, 32) - duplicated
  • (4, 128, 8, 64) - duplicated
  • (4, 128, 8, 128) - duplicated
     [
         pytest.param(*test, id="B{}-T{}-H{}-D{}".format(*test))
         for test in [
             (1, 2, 2, 12),
             (1, 32, 2, 16),
             (2, 64, 4, 32),
             (4, 128, 8, 64),
             (4, 128, 8, 128),
-            (1, 2, 2, 12),
-            (1, 32, 2, 16),
-            (2, 64, 4, 32),
-            (4, 128, 8, 64),
-            (4, 128, 8, 128),
         ]
     ],

456-457: Redundant condition: dt_bias is always non-None.

The condition if dt_bias is not None on line 456 is always true since dt_bias is unconditionally created on line 454. Consider either removing the condition or making dt_bias conditionally created (similar to test_gate) to test both code paths.

-    if dt_bias is not None:
-        dt_bias = dt_bias.to(device).requires_grad_(True)
+    dt_bias = dt_bias.to(device).requires_grad_(True)

Alternatively, parameterize HAS_BIAS to test both paths.


467-475: Consider adding backward gradient test for completeness.

The test validates forward equivalence but doesn't verify backward gradient flow. While the backward path is tested indirectly through test_chunk and test_chunk_varlen, a direct gradient test here would provide more targeted coverage of the fused kernel's interaction with autograd.

fla/ops/kda/gate.py (1)

377-383: Address type hint and unused parameter.

Per static analysis hints:

  1. Line 377: scale: float = None should be scale: float | None = None per PEP 484
  2. Line 383: **kwargs is unused but may be intentional for API compatibility with chunk_local_cumsum
 def kda_gate_chunk_cumsum(
     g: torch.Tensor,
     A_log: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
+    scale: float | None = None,
     dt_bias: torch.Tensor | None = None,
     cu_seqlens: torch.Tensor | None = None,
     head_first: bool = False,
     output_dtype: torch.dtype | None = torch.float,
     chunk_indices: torch.LongTensor | None = None,
-    **kwargs,
+    **kwargs,  # For API compatibility with chunk_local_cumsum
 ) -> torch.Tensor:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 78d7a93 and 7bc5a30.

📒 Files selected for processing (3)
  • fla/ops/kda/chunk.py (2 hunks)
  • fla/ops/kda/gate.py (2 hunks)
  • tests/ops/test_kda.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/chunk.py (2)
fla/ops/kda/gate.py (2)
  • kda_gate_chunk_cumsum (372-417)
  • kda_gate_fwd (165-191)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (429-469)
fla/ops/kda/gate.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
fla/utils.py (2)
  • check_shared_mem (447-453)
  • input_guard (137-168)
🪛 Ruff (0.14.6)
fla/ops/kda/gate.py

320-320: Unused function argument: B

(ARG001)


377-377: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


383-383: Unused function argument: kwargs

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (11)
fla/ops/kda/chunk.py (3)

11-11: LGTM!

The import correctly adds kda_gate_chunk_cumsum alongside the existing kda_gate_bwd and kda_gate_fwd imports.


212-225: LGTM!

The forward path correctly uses the fused kda_gate_chunk_cumsum kernel when use_gate_in_kernel=True, combining gate computation and chunk cumsum into a single kernel launch. The fallback to chunk_local_cumsum when gating is disabled maintains backward compatibility.


265-271: LGTM!

The backward path correctly recomputes g using the separate kda_gate_fwd followed by chunk_local_cumsum. This is appropriate because the gradient computation requires access to intermediate values that the fused forward kernel doesn't preserve.

tests/ops/test_kda.py (3)

10-12: LGTM!

The imports correctly add kda_gate_chunk_cumsum from the gate module and chunk_local_cumsum from utils to support the new test coverage.


212-221: LGTM!

The reference computation correctly casts tensors to float32 for the naive gate computation when use_gate_in_kernel is enabled, and the backward pass properly handles dtype conversion for loss computation.


290-298: Potential issue with mask application when use_gate_in_kernel=True.

When use_gate_in_kernel=True, g is not processed through logsigmoid (lines 295-296), but lines 297-298 still apply a mask that sets values to -1000. This mask application may not be the intended behavior for the raw gate input path, as the gate kernel applies its own transformation.

Please verify that applying the mask (lines 297-298) to the raw gate values when use_gate_in_kernel=True is the intended behavior, as this differs from how the mask is applied in the test_chunk function.

fla/ops/kda/gate.py (5)

8-12: LGTM!

The new imports correctly bring in prepare_chunk_indices for variable-length sequence handling and check_shared_mem for dynamic BS_LIST configuration.


296-310: LGTM!

The heuristics and autotune configuration are well-designed. Using do_not_specialize=['T'] allows the kernel to handle varying sequence lengths without triggering recompilation.


320-320: Static analysis false positive: B is used.

The static analysis flags B as unused, but it's used in the grid calculation (triton.cdiv(meta['S'], meta['BS']), NT, B * H) at line 400. The B parameter is correctly passed and used for determining the 3rd grid dimension.


359-367: LGTM!

The cumsum implementation correctly handles both forward and reverse directions, matching the semantics of chunk_local_cumsum. The optional scaling is applied after the cumsum, which is consistent with the existing API.


340-345: Based on my thorough verification, I can now provide the final assessment:

Pointer arithmetic for both HEAD_FIRST layouts is mathematically correct and consistent across the codebase.

The mathematical verification confirms:

  • HEAD_FIRST=False with [B,T,H,S] layout: The formula s + (bos * H + i_h) * S with stride (H*S, 1) correctly computes offsets
  • HEAD_FIRST=True with [B,H,T,S] layout: The formula s + (bos * H + i_h*T)*S with stride (S, 1) correctly computes offsets

The same pattern is consistently applied across multiple files (cumsum.py, comba/utils.py) and the tests pass with valid outputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Fuse kda_gate_fwd and chunk_local_cumsum

2 participants