-
Notifications
You must be signed in to change notification settings - Fork 321
[KDA] Fuse kda gate and cumsum #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis PR introduces a new Triton-based chunk-cumulative-sum kernel ( Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related issues
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the KDA (Kernelized Dynamic Attention) implementation by fusing the KDA gate and chunk-local cumulative sum operations into a single, highly efficient Triton kernel. This fusion reduces overhead associated with multiple kernel launches and memory transfers, leading to improved performance for KDA-based models, especially in scenarios where gating is applied within the kernel. The changes are thoroughly tested to maintain numerical accuracy and functionality across various configurations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused kernel for the KDA gate and chunked cumulative sum, which is a great optimization. The forward pass implementation and test updates look solid. However, I've identified a critical issue in the backward pass where gradient propagation through the chunk_local_cumsum operation appears to be missing. This could lead to incorrect gradients when use_gate_in_kernel=True. I've also included a couple of minor suggestions to improve code clarity and test hygiene.
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion to check if chunk_size is a power of two is a bit obscure and might be hard to understand for future maintainers. A more conventional and clearer way to perform this check is by using bitwise operations.
| assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" | |
| assert (chunk_size > 0) and (chunk_size & (chunk_size - 1) == 0), "chunk_size must be a power of 2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tests/ops/test_kda.py (3)
431-442: Remove duplicate test parameters.The test parameter list contains duplicate entries that will run the same test twice unnecessarily:
(1, 2, 2, 12)- duplicated(1, 32, 2, 16)- duplicated(2, 64, 4, 32)- duplicated(4, 128, 8, 64)- duplicated(4, 128, 8, 128)- duplicated[ pytest.param(*test, id="B{}-T{}-H{}-D{}".format(*test)) for test in [ (1, 2, 2, 12), (1, 32, 2, 16), (2, 64, 4, 32), (4, 128, 8, 64), (4, 128, 8, 128), - (1, 2, 2, 12), - (1, 32, 2, 16), - (2, 64, 4, 32), - (4, 128, 8, 64), - (4, 128, 8, 128), ] ],
456-457: Redundant condition:dt_biasis always non-None.The condition
if dt_bias is not Noneon line 456 is always true sincedt_biasis unconditionally created on line 454. Consider either removing the condition or makingdt_biasconditionally created (similar totest_gate) to test both code paths.- if dt_bias is not None: - dt_bias = dt_bias.to(device).requires_grad_(True) + dt_bias = dt_bias.to(device).requires_grad_(True)Alternatively, parameterize
HAS_BIASto test both paths.
467-475: Consider adding backward gradient test for completeness.The test validates forward equivalence but doesn't verify backward gradient flow. While the backward path is tested indirectly through
test_chunkandtest_chunk_varlen, a direct gradient test here would provide more targeted coverage of the fused kernel's interaction with autograd.fla/ops/kda/gate.py (1)
377-383: Address type hint and unused parameter.Per static analysis hints:
- Line 377:
scale: float = Noneshould bescale: float | None = Noneper PEP 484- Line 383:
**kwargsis unused but may be intentional for API compatibility withchunk_local_cumsumdef kda_gate_chunk_cumsum( g: torch.Tensor, A_log: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, + scale: float | None = None, dt_bias: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, chunk_indices: torch.LongTensor | None = None, - **kwargs, + **kwargs, # For API compatibility with chunk_local_cumsum ) -> torch.Tensor:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/ops/kda/chunk.py(2 hunks)fla/ops/kda/gate.py(2 hunks)tests/ops/test_kda.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/chunk.py (2)
fla/ops/kda/gate.py (2)
kda_gate_chunk_cumsum(372-417)kda_gate_fwd(165-191)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(429-469)
fla/ops/kda/gate.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)fla/utils.py (2)
check_shared_mem(447-453)input_guard(137-168)
🪛 Ruff (0.14.6)
fla/ops/kda/gate.py
320-320: Unused function argument: B
(ARG001)
377-377: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
383-383: Unused function argument: kwargs
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (11)
fla/ops/kda/chunk.py (3)
11-11: LGTM!The import correctly adds
kda_gate_chunk_cumsumalongside the existingkda_gate_bwdandkda_gate_fwdimports.
212-225: LGTM!The forward path correctly uses the fused
kda_gate_chunk_cumsumkernel whenuse_gate_in_kernel=True, combining gate computation and chunk cumsum into a single kernel launch. The fallback tochunk_local_cumsumwhen gating is disabled maintains backward compatibility.
265-271: LGTM!The backward path correctly recomputes
gusing the separatekda_gate_fwdfollowed bychunk_local_cumsum. This is appropriate because the gradient computation requires access to intermediate values that the fused forward kernel doesn't preserve.tests/ops/test_kda.py (3)
10-12: LGTM!The imports correctly add
kda_gate_chunk_cumsumfrom the gate module andchunk_local_cumsumfrom utils to support the new test coverage.
212-221: LGTM!The reference computation correctly casts tensors to
float32for the naive gate computation whenuse_gate_in_kernelis enabled, and the backward pass properly handles dtype conversion for loss computation.
290-298: Potential issue with mask application whenuse_gate_in_kernel=True.When
use_gate_in_kernel=True,gis not processed throughlogsigmoid(lines 295-296), but lines 297-298 still apply a mask that sets values to-1000. This mask application may not be the intended behavior for the raw gate input path, as the gate kernel applies its own transformation.Please verify that applying the mask (lines 297-298) to the raw gate values when
use_gate_in_kernel=Trueis the intended behavior, as this differs from how the mask is applied in thetest_chunkfunction.fla/ops/kda/gate.py (5)
8-12: LGTM!The new imports correctly bring in
prepare_chunk_indicesfor variable-length sequence handling andcheck_shared_memfor dynamic BS_LIST configuration.
296-310: LGTM!The heuristics and autotune configuration are well-designed. Using
do_not_specialize=['T']allows the kernel to handle varying sequence lengths without triggering recompilation.
320-320: Static analysis false positive:Bis used.The static analysis flags
Bas unused, but it's used in the grid calculation(triton.cdiv(meta['S'], meta['BS']), NT, B * H)at line 400. TheBparameter is correctly passed and used for determining the 3rd grid dimension.
359-367: LGTM!The cumsum implementation correctly handles both forward and reverse directions, matching the semantics of
chunk_local_cumsum. The optional scaling is applied after the cumsum, which is consistent with the existing API.
340-345: Based on my thorough verification, I can now provide the final assessment:Pointer arithmetic for both HEAD_FIRST layouts is mathematically correct and consistent across the codebase.
The mathematical verification confirms:
- HEAD_FIRST=False with [B,T,H,S] layout: The formula
s + (bos * H + i_h) * Swith stride(H*S, 1)correctly computes offsets- HEAD_FIRST=True with [B,H,T,S] layout: The formula
s + (bos * H + i_h*T)*Swith stride(S, 1)correctly computes offsetsThe same pattern is consistently applied across multiple files (cumsum.py, comba/utils.py) and the tests pass with valid outputs.
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.