Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions fla/ops/nsa/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def parallel_nsa_fwd(
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
_, T_q, HQ, _ = q.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using _ to ignore dimensions can be concise, but it's less explicit and misses an opportunity to validate tensor shapes. For better robustness and readability, it's recommended to explicitly unpack all dimensions and add assertions to ensure that the batch size and key/value dimensions of q and k are compatible.

Suggested change
_, T_q, HQ, _ = q.shape
B_q, T_q, HQ, K_q = q.shape
assert B == B_q, f"q and k must have the same batch size, but got {B_q} and {B}"
assert K == K_q, f"q and k must have the same key dimension, but got {K_q} and {K}"

G = HQ // H
BS = block_size
if check_shared_mem('hopper', q.device.index):
Expand All @@ -555,9 +555,9 @@ def parallel_nsa_fwd(
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"

grid = (T, NV, B * H)
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
grid = (T_q, NV, B * H)
o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device)

parallel_nsa_fwd_kernel[grid](
q=q,
Expand Down
Loading