Skip to content

Conversation

@shiyuan680
Copy link
Contributor

@shiyuan680 shiyuan680 commented Nov 8, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link

github-actions bot commented Nov 8, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds Triton kernel implementations for chunk_gated_delta_rule operations, seemingly for use with Huawei Ascend NPUs within the vLLM framework. The implementation is a substantial port from the fla library, introducing several new files for forward and backward passes. While the effort to optimize these operations is commendable, the current implementation has several critical issues that will prevent it from running correctly. My review has identified missing Python imports, undefined variables causing NameErrors, inconsistent and likely incorrect use of chunk_size, missing parameters in a kernel launch, and a leftover debugging statement. These issues must be addressed to ensure the correctness and performance of the new operations.

Comment on lines +3 to +5
import warnings

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type hint Optional is used in this file (e.g., on line 188), but it is not imported from the typing module. This will cause a NameError at runtime.

Suggested change
import warnings
import torch
import warnings
from typing import Optional
import torch

output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is an inconsistency in the chunk_size used. Here, chunk_local_cumsum is called with chunk_size=64, but all subsequent chunked operations in both the forward and backward passes use chunk_size=16. Since chunk_local_cumsum performs a cumulative sum within chunks, this discrepancy will likely lead to incorrect calculations in later stages that expect data to be processed in chunks of 16. To ensure correctness, the chunk size should be consistent across all related operations.

Suggested change
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = chunk_local_cumsum(g, chunk_size=16, cu_seqlens=cu_seqlens)

g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function exp is used within this Triton kernel, but it is not a standard tl function and has not been imported. This will result in a NameError. In other files within this PR, a custom exp function is imported from fla.ops.utils.op. The same import is needed here. Please add from fla.ops.utils.op import exp to the imports at the top of the file.

for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_DW'],
**autotune_cache_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable autotune_cache_kwargs is used in the @triton.autotune decorator, but it is not defined anywhere in this file. This will cause a NameError during module loading. You should define it at the top of the file, for example, by adapting the definition from the fla library.

Comment on lines +477 to +514
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None = None,
g_gamma: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5

o = torch.empty_like(v)
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
o=o,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
return o
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The chunk_fwd_kernel_o kernel is not autotuned and requires BK and BV to be passed as constexpr arguments. However, these are missing from the kernel launch call, which will lead to a runtime error. You should define BK and BV and pass them to the kernel, similar to how it's done in other wrapper functions in this file.

def chunk_fwd_o(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    h: torch.Tensor,
    g: torch.Tensor | None = None,
    g_gamma: torch.Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
) -> torch.Tensor:
    B, T, H, K, V = *q.shape, v.shape[-1]
    BT = chunk_size
    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
    if scale is None:
        scale = k.shape[-1] ** -0.5

    if check_shared_mem('hopper', k.device.index):
        CONST_TILING = 128
    elif check_shared_mem:
        CONST_TILING = 64
    else:
        CONST_TILING = 32
    BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
    BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

    o = torch.empty_like(v)
    def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
    chunk_fwd_kernel_o[grid](
        q=q,
        k=k,
        v=v,
        h=h,
        g=g,
        g_gamma=g_gamma,
        o=o,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        scale=scale,
        T=T,
        H=H,
        K=K,
        V=V,
        BT=BT,
        BK=BK,
        BV=BV
    )
    return o

p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))

tl.debug_barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A tl.debug_barrier() is present here. This is typically used for debugging and should be removed from production code as it forces synchronization and can negatively impact performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant