-
Notifications
You must be signed in to change notification settings - Fork 544
【Draft】support triton chunk_gated_delta_rule ops #4070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds Triton kernel implementations for chunk_gated_delta_rule operations, seemingly for use with Huawei Ascend NPUs within the vLLM framework. The implementation is a substantial port from the fla library, introducing several new files for forward and backward passes. While the effort to optimize these operations is commendable, the current implementation has several critical issues that will prevent it from running correctly. My review has identified missing Python imports, undefined variables causing NameErrors, inconsistent and likely incorrect use of chunk_size, missing parameters in a kernel launch, and a leftover debugging statement. These issues must be addressed to ensure the correctness and performance of the new operations.
| import warnings | ||
|
|
||
| import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| output_final_state: bool, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| ): | ||
| g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an inconsistency in the chunk_size used. Here, chunk_local_cumsum is called with chunk_size=64, but all subsequent chunked operations in both the forward and backward passes use chunk_size=16. Since chunk_local_cumsum performs a cumulative sum within chunks, this discrepancy will likely lead to incorrect calculations in later stages that expect data to be processed in chunks of 16. To ensure correctness, the chunk size should be consistent across all related operations.
| g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) | |
| g = chunk_local_cumsum(g, chunk_size=16, cu_seqlens=cu_seqlens) |
| g += bos * H + i_h | ||
| p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||
| b_g = tl.load(p_g, boundary_check=(0,)) | ||
| b_o = b_o * exp(b_g)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function exp is used within this Triton kernel, but it is not a standard tl function and has not been imported. This will result in a NameError. In other files within this PR, a custom exp function is imported from fla.ops.utils.op. The same import is needed here. Please add from fla.ops.utils.op import exp to the imports at the top of the file.
| for num_stages in [2, 3, 4] | ||
| ], | ||
| key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_DW'], | ||
| **autotune_cache_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def chunk_fwd_o( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| h: torch.Tensor, | ||
| g: torch.Tensor | None = None, | ||
| g_gamma: torch.Tensor | None = None, | ||
| scale: float | None = None, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| chunk_size: int = 64, | ||
| ) -> torch.Tensor: | ||
| B, T, H, K, V = *q.shape, v.shape[-1] | ||
| BT = chunk_size | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| if scale is None: | ||
| scale = k.shape[-1] ** -0.5 | ||
|
|
||
| o = torch.empty_like(v) | ||
| def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) | ||
| chunk_fwd_kernel_o[grid]( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| h=h, | ||
| g=g, | ||
| g_gamma=g_gamma, | ||
| o=o, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| scale=scale, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT, | ||
| ) | ||
| return o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chunk_fwd_kernel_o kernel is not autotuned and requires BK and BV to be passed as constexpr arguments. However, these are missing from the kernel launch call, which will lead to a runtime error. You should define BK and BV and pass them to the kernel, similar to how it's done in other wrapper functions in this file.
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None = None,
g_gamma: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
o = torch.empty_like(v)
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
o=o,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV
)
return o| p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
| tl.debug_barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?