diff --git a/sparse_attention_hub/sparse_attention/efficient_attention/bias_sparse_attention_backend.py b/sparse_attention_hub/sparse_attention/efficient_attention/bias_sparse_attention_backend.py new file mode 100644 index 00000000..6016b3c5 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/efficient_attention/bias_sparse_attention_backend.py @@ -0,0 +1,295 @@ +import math + +import torch +import triton +import triton.language as tl + +from sparse_attention_backend import ( + sparse_decode_stage2 as _sparse_decode_stage2, # stage-2 kernel is weight-agnostic +) + +# ----------------------------------------------------------------------------- +# Stage-1 kernel – incorporate per-token weight (bias) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _fwd_kernel_bias_sparse_decode_stage1( + Q, + K, + V, + sm_scale, + Sparse_List, # [B, H, S] + Sparse_Len, # [B, H] + Weight_List, # [B, H, S] + Mid_O, + Mid_O_LogExpSum, + # strides (all element-wise) + stride_sparse_b, + stride_sparse_h, + stride_qbs, + stride_qh, + stride_qd, + stride_kbb, + stride_kh, + stride_ks, + stride_vbb, + stride_vh, + stride_vs, + stride_weight_b, + stride_weight_h, + stride_weight_s, + stride_splen_b, + stride_splen_h, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_block_id = tl.program_id(2) + + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Sequence length of this (b,h) + cur_seq_len_ptr = Sparse_Len + cur_batch * stride_splen_b + cur_head * stride_splen_h + cur_seq_len = tl.load(cur_seq_len_ptr) + + block_start = seq_block_id * BLOCK_SEQ + block_end = tl.minimum(cur_seq_len, block_start + BLOCK_SEQ) + + # Base pointers + sparse_ptr_base = Sparse_List + cur_batch * stride_sparse_b + cur_head * stride_sparse_h + weight_ptr_base = Weight_List + cur_batch * stride_weight_b + cur_head * stride_weight_h + + # Load query + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_l = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + block_n_size = ( + tl.where(block_end - block_start <= 0, 0, block_end - block_start + BLOCK_N - 1) + // BLOCK_N + ) + + offs_n = block_start + tl.arange(0, BLOCK_N) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + token_idx = tl.load( + sparse_ptr_base + offs_n_new, + mask=offs_n_new < cur_seq_len, + other=0, + ) + weight_val = tl.load( + weight_ptr_base + token_idx * stride_weight_s, + mask=offs_n_new < cur_seq_len, + other=0.0, + ).to(tl.float32) + weight_val = tl.where(weight_val > 0.0, weight_val, 1e-30) + + base_k_ptr = cur_batch * stride_kbb + cur_kv_head * stride_kh + off_k = base_k_ptr + token_idx[:, None] * stride_ks + offs_d[None, :] + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_seq_len, other=0.0) + v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_seq_len, other=0.0) + + att_val = tl.sum(q[None, :] * k, axis=1) * sm_scale # [BLOCK_N] + att_val = att_val + tl.log(weight_val) + att_val = tl.where(offs_n_new < cur_seq_len, att_val, float("-inf")) + + cur_max = tl.max(att_val, axis=0) + new_max = tl.maximum(cur_max, max_l) + + exp_l = tl.exp(att_val - new_max) + scale = tl.exp(max_l - new_max) + + acc *= scale + acc += tl.sum(exp_l[:, None] * v, axis=0) + + sum_exp = sum_exp * scale + tl.sum(exp_l, axis=0) + max_l = new_max + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + seq_block_id * stride_mid_os + + offs_d + ) + off_log = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_block_id + tl.store(Mid_O + off_mid, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_log, max_l + tl.log(sum_exp)) + + +# ----------------------------------------------------------------------------- +# Python wrappers +# ----------------------------------------------------------------------------- + + +@torch.no_grad() +def bias_sparse_decode_stage1( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, + weight_list: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, +): + BLOCK_N = 16 + BLOCK_SEQ = block_seq + + D = q.shape[-1] + assert D in {16, 32, 64, 128} + assert k.shape[-1] == D + + sm_scale = 1.0 / math.sqrt(D) + + B, H = q.shape[0], q.shape[1] + grid = (B, H, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + + gqa_group_size = H // k.shape[1] + + _fwd_kernel_bias_sparse_decode_stage1[grid]( + q, + k, + v, + sm_scale, + sparse_list, + sparse_len, + weight_list, + mid_out, + mid_out_logsumexp, + sparse_list.stride(0), + sparse_list.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + weight_list.stride(0), + weight_list.stride(1), + weight_list.stride(2), + sparse_len.stride(0), + sparse_len.stride(1), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=D, + BLOCK_N=BLOCK_N, + num_warps=1, + num_stages=2, + ) + + +@torch.no_grad() +def bias_sparse_attention_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, + weight_list: torch.Tensor, + block_seq: int = 256, +): + """Triton-accelerated biased sparse attention forward.""" + + assert all(t.is_cuda for t in [query, key, value, sparse_list, weight_list]) + + B, H, D = query.shape + max_len = int(sparse_len.max().item()) + + blk_num = (max_len + block_seq - 1) // block_seq + mid_out = torch.empty((B, H, blk_num, D), dtype=torch.float32, device=query.device) + mid_log = torch.empty((B, H, blk_num), dtype=torch.float32, device=query.device) + out = torch.empty((B, H, D), dtype=query.dtype, device=query.device) + + bias_sparse_decode_stage1( + query, + key, + value, + sparse_list, + sparse_len, + weight_list, + max_len, + mid_out, + mid_log, + block_seq, + ) + + # stage-2 (weight-independent) + _sparse_decode_stage2(mid_out, mid_log, sparse_len, out, block_seq) + return out + + +# ----------------------------------------------------------------------------- +# Quick verification +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + from ref_bias_sparse_attention_backend import ref_bias_sparse_attention_fwd + + torch.manual_seed(0) + + B, H, D, S = 32, 32, 128, 4096 + gqa = 4 + Kv = H // gqa + + dtype = torch.float16 + + q = torch.randn(B, H, D, device="cuda", dtype=dtype) + k = torch.randn(B, Kv, S, D, device="cuda", dtype=dtype) + v = torch.randn(B, Kv, S, D, device="cuda", dtype=dtype) + + sparse_list = torch.empty((B, H, S), dtype=torch.int32, device="cuda") + sparse_len = torch.empty((B, H), dtype=torch.int32, device="cuda") + for b in range(B): + for h in range(H): + perm = torch.randperm(S, device="cuda", dtype=torch.int32) + sparse_list[b, h] = perm + sparse_len[b, h] = torch.randint(1, S + 1, (1,), device="cuda", dtype=torch.int32) + + weight_list = torch.rand((B, H, S), device="cuda", dtype=dtype) * 2.0 + 0.5 # positive weights + # weight_list = torch.ones((B, H, S), device="cuda", dtype=dtype) + + print(f"{weight_list[0, :5, :5]=}") + + out_ref = ref_bias_sparse_attention_fwd(q, k, v, sparse_list, sparse_len, weight_list) + out_triton = bias_sparse_attention_fwd(q, k, v, sparse_list, sparse_len, weight_list) + + print(f"{out_ref[0, :5, :5]=}") + print(f"{out_triton[0, :5, :5]=}") + + max_err = (out_ref - out_triton).abs().max().item() + mean_err = (out_ref - out_triton).abs().mean().item() + print(f"[SPARSE ATTENTION TEST] max|ref - triton| = {max_err:.6e}") + print(f"[SPARSE ATTENTION TEST] mean|ref - triton| = {mean_err:.6e}") + assert mean_err < 1e-4, "Triton sparse attention does not match reference!" + print("[SPARSE ATTENTION TEST] Passed!") \ No newline at end of file diff --git a/sparse_attention_hub/sparse_attention/efficient_attention/ref_bias_sparse_attention_backend.py b/sparse_attention_hub/sparse_attention/efficient_attention/ref_bias_sparse_attention_backend.py new file mode 100644 index 00000000..9da637d7 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/efficient_attention/ref_bias_sparse_attention_backend.py @@ -0,0 +1,140 @@ +import math +from typing import Tuple + +import torch + +from ref_sparse_attention_backend import ref_sparse_attention_fwd + + +def _get_gqa_group_size(H: int, Kv: int) -> int: + assert H % Kv == 0, "H must be divisible by Kv (H // gqa)" + return H // Kv + + +# @torch.no_grad() +# def ref_sparse_attention_fwd( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# sparse_list: torch.Tensor, +# sparse_len: torch.Tensor, +# ) -> torch.Tensor: +# """Reference sparse attention (no bias) – same as earlier helper. + +# Args are identical to the previous spec. +# """ +# assert query.ndim == 3 and key.ndim == 4 and value.ndim == 4 + +# B, H, D = query.shape +# _, Kv, S, _ = key.shape + +# gqa_group_size = _get_gqa_group_size(H, Kv) +# sm_scale = 1.0 / math.sqrt(D) + +# out = torch.empty_like(query) + +# for b in range(B): +# for h in range(H): +# kv_h = h // gqa_group_size +# L = int(sparse_len[b, h].item()) +# if L == 0: +# out[b, h].zero_() +# continue + +# idx = sparse_list[b, h, :L].to(dtype=torch.long, device=query.device) +# k_vec = key[b, kv_h].index_select(0, idx).to(torch.float32) +# v_vec = value[b, kv_h].index_select(0, idx).to(torch.float32) +# q_vec = query[b, h].to(torch.float32) + +# att_logits = (k_vec * q_vec).sum(dim=-1) * sm_scale # [L] +# att_weights = torch.softmax(att_logits, dim=-1) +# out[b, h] = (att_weights.unsqueeze(-1) * v_vec).sum(dim=0).to(query.dtype) + +# return out + + +@torch.no_grad() +def ref_bias_sparse_attention_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, + weight_list: torch.Tensor, +) -> torch.Tensor: + """Reference implementation of *biased* sparse attention. + + The weight_list supplies a per-(b,h,token) positive weight w. The attention + weights become w * exp(q·k) / Σ w * exp(q·k). + """ + + assert query.ndim == 3 and key.ndim == 4 and value.ndim == 4 + assert weight_list.shape == sparse_list.shape, "weight_list must be [B,H,S]" + + B, H, D = query.shape + _, Kv, S, _ = key.shape + + gqa_group_size = _get_gqa_group_size(H, Kv) + sm_scale = 1.0 / math.sqrt(D) + + out = torch.empty_like(query) + + for b in range(B): + for h in range(H): + kv_h = h // gqa_group_size + L = int(sparse_len[b, h].item()) + if L == 0: + out[b, h].zero_() + continue + + idx = sparse_list[b, h, :L].to(dtype=torch.long, device=query.device) + k_vec = key[b, kv_h].index_select(0, idx) + v_vec = value[b, kv_h].index_select(0, idx) + w_vec = weight_list[b, h].index_select(0, idx).to(torch.float32) # [L] + + # Ensure positivity to avoid log(-) + w_vec = torch.clamp_min(w_vec, 1e-30) + + q_vec = query[b, h].to(torch.float32) + att_logits = (k_vec * q_vec).sum(dim=-1).to(torch.float32) * sm_scale # [L] + + # Incorporate weight as additive bias in log-space + logits_with_bias = att_logits + torch.log(w_vec) + att_weights = torch.softmax(logits_with_bias, dim=-1).to(query.dtype) + out_vec = (att_weights.unsqueeze(-1) * v_vec).sum(dim=0) + out[b, h] = out_vec.to(query.dtype) + + return out + + +if __name__ == "__main__": + # Simple correctness check: when weight == 1, biased == un-biased + torch.manual_seed(0) + + B, H, D, S = 32, 32, 128, 4096 + gqa_group_size = 2 + Kv = H // gqa_group_size + + q = torch.randn(B, H, D, device="cuda", dtype=torch.float16) + k = torch.randn(B, Kv, S, D, device="cuda", dtype=torch.float16) + v = torch.randn(B, Kv, S, D, device="cuda", dtype=torch.float16) + + # Sparse pattern + sparse_list = torch.empty((B, H, S), dtype=torch.long, device="cuda") + sparse_len = torch.empty((B, H), dtype=torch.long, device="cuda") + for b in range(B): + for h in range(H): + perm = torch.randperm(S, device="cuda") + sparse_list[b, h] = perm + sparse_len[b, h] = torch.randint(1, S + 1, (1,), device="cuda") + + # All-ones weight + weight_list = torch.ones((B, H, S), device="cuda", dtype=torch.float16) + + out_ref = ref_sparse_attention_fwd(q, k, v, sparse_list, sparse_len) + out_bias = ref_bias_sparse_attention_fwd(q, k, v, sparse_list, sparse_len, weight_list) + + max_err = (out_ref - out_bias).abs().max().item() + print(f"[BIAS REF TEST] max|no-bias - bias(1)| = {max_err:.6e}") + assert max_err < 2e-3, "Biased sparse attention (w=1) should equal un-biased!" + print("[BIAS REF TEST] Passed.") \ No newline at end of file diff --git a/sparse_attention_hub/sparse_attention/efficient_attention/ref_sparse_attention_backend.py b/sparse_attention_hub/sparse_attention/efficient_attention/ref_sparse_attention_backend.py new file mode 100644 index 00000000..5c448fa9 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/efficient_attention/ref_sparse_attention_backend.py @@ -0,0 +1,135 @@ +import torch +import math + + +def ref_sparse_attention_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, +): + """Reference implementation of sparse attention flash-decoding. + + Args: + query: Tensor of shape [B, H, D] + key: Tensor of shape [B, H // gqa, S, D] + value: Tensor of shape [B, H // gqa, S, D] + sparse_list: Tensor of shape [B, H, S] that stores the token indices to + attend to. Only the first ``sparse_len[b, h]`` entries of + the last dimension are valid. + sparse_len: Tensor of shape [B, H] giving the valid length in + ``sparse_list`` for every (b, h). + + Returns: + Tensor of shape [B, H, D] – the attention output for each query head. + + This is a *slow* but very clear reference used for correctness checks. It + supports grouped-query attention (GQA) where several query heads share the + same key / value head. Setting ``gqa = 1`` reduces to standard multi-head + attention (MHA). + """ + + assert query.ndim == 3, "query must be [B, H, D]" + assert key.ndim == value.ndim == 4, "key/value must be [B, Kv, S, D]" + + B, H, D = query.shape + _, Kv, S, _ = key.shape + device = query.device + dtype = query.dtype + + # Infer group size from the shapes. gqa == number of Q heads per KV head. + gqa_group_size = H // Kv + assert gqa_group_size * Kv == H, "H must be divisible by Kv (H//gqa)" + + sm_scale = 1.0 / math.sqrt(D) + + # Output tensor + out = torch.empty_like(query) + + # Iterate over batch and heads – this is a slow reference so clarity beats speed. + for b in range(B): + for h in range(H): + kv_h = h // gqa_group_size # which KV head this Q head should use + + # Number of tokens that this (b, h) attends to + L = int(sparse_len[b, h].item()) + if L == 0: + # Edge-case: no tokens attended -> return zeros (like softmax over empty set) + out[b, h].zero_() + continue + + # The token indices we actually attend to (shape [L]) + idx = sparse_list[b, h, :L].to(dtype=torch.long, device=device) + + # Gather the key/value vectors we need (shape [L, D]) + k_vec = key[b, kv_h].index_select(0, idx) # [L, D] + v_vec = value[b, kv_h].index_select(0, idx) # [L, D] + + # Attention logits – [L] + q_vec = query[b, h] # [D] + attn_logits = (k_vec * q_vec).sum(dim=-1).to(torch.float32) * sm_scale + + attn_weights = torch.softmax(attn_logits, dim=-1).to(query.dtype) # [L] + out[b, h] = torch.sum(attn_weights.unsqueeze(-1) * v_vec, dim=0) + + return out + + +def ref_dense_attention_fwd(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + """Vectorised dense attention (reference). + + We replicate key / value along the head dimension so each query head has its + own slice, then compute attention in batch using two Einsums – this is + clearer and avoids Python-side loops. + """ + + assert query.ndim == 3 and key.ndim == 4 and value.ndim == 4 + + B, H, D = query.shape + _, Kv, S, _ = key.shape + + gqa_group_size = H // Kv # heads per KV group + sm_scale = 1.0 / math.sqrt(D) + + # Repeat key/value so we have one slice per query head: [B, H, S, D] + key_rep = key.repeat_interleave(gqa_group_size, dim=1) + value_rep = value.repeat_interleave(gqa_group_size, dim=1) + + # Compute attention logits: [B, H, S] + attn_logits = torch.einsum("bhd,bhsd->bhs", query, key_rep).to(torch.float32) * sm_scale + attn_weights = torch.softmax(attn_logits, dim=-1).to(query.dtype) + + # Output: [B, H, D] + out = torch.einsum("bhs,bhsd->bhd", attn_weights, value_rep) + return out + + +if __name__ == "__main__": + # Simple self-test: when every token is attended, sparse == dense. + torch.manual_seed(0) + torch_dtype = torch.float16 + + B, H, D, S = 32, 32, 128, 4096 + gqa_group_size = 4 # change as you like – 1 corresponds to MHA + Kv = H // gqa_group_size + + query = torch.randn(B, H, D, device="cuda", dtype=torch_dtype) + key = torch.randn(B, Kv, S, D, device="cuda", dtype=torch_dtype) + value = torch.randn(B, Kv, S, D, device="cuda", dtype=torch_dtype) + + # Build full sparse_list / sparse_len that cover ALL tokens + sparse_list = torch.arange(S, device="cuda").view(1, 1, S).repeat(B, H, 1) + sparse_len = torch.full((B, H), S, dtype=torch.long, device="cuda") + + out_sparse = ref_sparse_attention_fwd(query, key, value, sparse_list, sparse_len) + out_dense = ref_dense_attention_fwd(query, key, value) + + max_abs_err = (out_sparse - out_dense).abs().max().item() + mean_abs_err = (out_sparse - out_dense).abs().mean().item() + print(f"[TEST] mean|sparse - dense| = {(out_sparse - out_dense).abs().mean().item():.6e}") + print(f"[TEST] max|sparse - dense| = {max_abs_err:.6e}") + # Assert the two results are (almost) identical – tolerance 1e-4 in fp32. + assert mean_abs_err < 1e-4, "Sparse and dense results differ!" + + print("[TEST] Passed – sparse attention matches dense attention when all tokens are attended.") \ No newline at end of file diff --git a/sparse_attention_hub/sparse_attention/efficient_attention/sparse_attention_backend.py b/sparse_attention_hub/sparse_attention/efficient_attention/sparse_attention_backend.py new file mode 100644 index 00000000..6086dd21 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/efficient_attention/sparse_attention_backend.py @@ -0,0 +1,393 @@ +import math +from typing import Tuple + +import torch +import triton +import triton.language as tl + +# ------------------------------- +# Kernel: Stage-1 – compute per-block partial results & log-sum-exp +# ------------------------------- + +@triton.jit +def _fwd_kernel_sparse_decode_stage1( + Q, # [B, H, D] + K, # [B, Kv, S, D] + V, # [B, Kv, S, D] + sm_scale, # scalar + Sparse_List, # [B, H, S] + Sparse_Len, # [B, H] – seq length per (b, h) + Mid_O, # [B, H, seq_block_num, D] + Mid_O_LogExpSum, # [B, H, seq_block_num] + # strides – note that all strides are in *elements* + stride_sparse_b, + stride_sparse_h, + stride_qbs, + stride_qh, + stride_qd, + stride_kbb, # K.stride(0) – batch + stride_kh, # K.stride(1) – kv head + stride_ks, # K.stride(2) – seq + stride_vbb, + stride_vh, + stride_vs, + stride_splen_b, + stride_splen_h, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Each program instance processes (b, h, seq_block). + + Within the sequence block (<= BLOCK_SEQ tokens) we iterate in tiles of + BLOCK_N tokens to compute numerically-stable softmax partials akin to + Flash-Attention / Flash-Decode stage-1. + """ + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_kv_head = cur_head // gqa_group_size # shared key/value head + + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Sequence length for (b, h) + cur_seq_len_ptr = Sparse_Len + cur_batch * stride_splen_b + cur_head * stride_splen_h + cur_seq_len = tl.load(cur_seq_len_ptr) + + # Start / end position (in sparse_list) of this sequence block + cur_block_start = seq_start_block * BLOCK_SEQ + cur_block_end = tl.minimum(cur_seq_len, cur_block_start + BLOCK_SEQ) + + # Pointers base for sparse_list of this head + sparse_ptr_base = Sparse_List + cur_batch * stride_sparse_b + cur_head * stride_sparse_h + + # Load query vector (shape [D]) – no sequence dim for decode (one query) + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q) # [D] + + # Prepare accumulators + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + # Number of micro-blocks within this sequence block + block_n_size = ( + tl.where(cur_block_end - cur_block_start <= 0, 0, + cur_block_end - cur_block_start + BLOCK_N - 1) // BLOCK_N + ) + + offs_n = cur_block_start + tl.arange(0, BLOCK_N) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n # absolute positions inside sparse_list + + # Load token indices for these positions – mask out-of-range + token_idx = tl.load( + sparse_ptr_base + offs_n_new, + mask=offs_n_new < cur_seq_len, + other=0, + ) # [BLOCK_N] + + # Build pointer to K/V: token_idx is [n] so broadcast with d + base_ptr = cur_batch * stride_kbb + cur_kv_head * stride_kh + off_k = base_ptr + token_idx[:, None] * stride_ks + offs_d[None, :] + # Note: stride_kbs == K.stride(2) because K is [B, Kv, S, D] and we want S dimension + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_seq_len, other=0.0) + v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_seq_len, other=0.0) + + # Attention scores + att_value = tl.sum(q[None, :] * k, 1) # [BLOCK_N] + att_value *= sm_scale + att_value = tl.where(offs_n_new < cur_seq_len, att_value, float("-inf")) + + # Numerically-stable softmax merge + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + # Decide whether to store (skip if sequence length 0) + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d + ) + off_mid_o_logexpsum = ( + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + ) + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + + +# ------------------------------- +# Kernel: Stage-2 – reduce across sequence blocks +# identical logic to flash-decode stage-2 +# ------------------------------- + +@triton.jit +def _fwd_kernel_sparse_decode_stage2( + Sparse_Len, # [B, H] + Mid_O, # [B, H, seq_block_num, D] + Mid_O_LogExpSum, # [B, H, seq_block_num] + O, # [B, H, D] + stride_splen_b, + stride_splen_h, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + """Second stage reduction over sequence blocks (identical to Flash-Decode).""" + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Sequence length for (b,h) + cur_seq_len_ptr = Sparse_Len + cur_batch * stride_splen_b + cur_head * stride_splen_h + cur_seq_len = tl.load(cur_seq_len_ptr) + + # Number of blocks covering this sequence + block_n_size = ( + tl.where(cur_seq_len <= 0, 0, cur_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + ) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + # Precompute starting offsets into Mid tensors + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + # Write output + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d + tl.store(O + off_o, acc / sum_exp) + + +# ------------------------------- +# Python helper functions +# ------------------------------- + + +@torch.no_grad() +def sparse_decode_stage1( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + + batch, head_num = q.shape[0], q.shape[1] + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + + gqa_group_size = head_num // k.shape[1] + + _fwd_kernel_sparse_decode_stage1[grid]( + q, + k, + v, + sm_scale, + sparse_list, + sparse_len, + mid_out, + mid_out_logsumexp, + sparse_list.stride(0), + sparse_list.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), # stride over B (kbb) + k.stride(1), + k.stride(2), # stride over S (ks) + v.stride(0), # stride over B (vbb) + v.stride(1), + v.stride(2), # stride over S (vs) + sparse_len.stride(0), + sparse_len.stride(1), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK_N, + num_warps=1, + num_stages=2, + ) + + +@torch.no_grad() +def sparse_decode_stage2( + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + sparse_len: torch.Tensor, + O: torch.Tensor, + block_seq: int, +): + Lk = mid_out.shape[-1] + assert Lk in {16, 32, 64, 128} + + batch, head_num = mid_out.shape[0], mid_out.shape[1] + grid = (batch, head_num) + + _fwd_kernel_sparse_decode_stage2[grid]( + sparse_len, + mid_out, + mid_out_logsumexp, + O, + sparse_len.stride(0), + sparse_len.stride(1), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=Lk, + num_warps=4, + num_stages=2, + ) + + +def sparse_attention_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sparse_list: torch.Tensor, + sparse_len: torch.Tensor, + block_seq: int = 256, +) -> torch.Tensor: + """Triton-accelerated sparse attention (flash-decode style). + + Args follow the same convention as the reference implementation. + Returns: Tensor [B, H, D]. + """ + + assert query.is_cuda and key.is_cuda and value.is_cuda and sparse_list.is_cuda + + B, H, D = query.shape + max_len_in_batch = int(sparse_len.max().item()) + + # Allocate intermediate + output + block_seq_num = (max_len_in_batch + block_seq - 1) // block_seq + mid_o = torch.empty((B, H, block_seq_num, D), dtype=torch.float32, device=query.device) + mid_o_log = torch.empty((B, H, block_seq_num), dtype=torch.float32, device=query.device) + out = torch.empty((B, H, D), dtype=query.dtype, device=query.device) + + sparse_decode_stage1(query, key, value, sparse_list, sparse_len, max_len_in_batch, mid_o, mid_o_log, block_seq) + sparse_decode_stage2(mid_o, mid_o_log, sparse_len, out, block_seq) + + return out + + +# ------------------------------- +# Quick correctness test vs reference implementation +# ------------------------------- + + +if __name__ == "__main__": + from ref_sparse_attention_backend import ref_sparse_attention_fwd + + torch.manual_seed(0) + + B, H, D, S = 32, 32, 128, 4096 + gqa_group_size = 4 + Kv = H // gqa_group_size + + dtype = torch.float16 + + q = torch.randn(B, H, D, device="cuda", dtype=dtype) + k = torch.randn(B, Kv, S, D, device="cuda", dtype=dtype) + v = torch.randn(B, Kv, S, D, device="cuda", dtype=dtype) + + # Build random sparse pattern + sparse_list = torch.randint(0, S, (B, H, S), device="cuda", dtype=torch.int32) + sparse_len = torch.randint(1, S + 1, (B, H), device="cuda", dtype=torch.int32) + + print(sparse_list[:5, :5, :10]) + print(sparse_len[:5, :5]) + + # Ensure first part of list are unique indices < S (for fairness) – we'll do simple + for b in range(B): + for h in range(H): + perm = torch.randperm(S, device="cuda") + sparse_list[b, h] = perm + sparse_len[b, h] = torch.randint(1, S + 1, (1,), device="cuda", dtype=torch.int32) + + out_ref = ref_sparse_attention_fwd(q, k, v, sparse_list, sparse_len) + out_triton = sparse_attention_fwd(q, k, v, sparse_list, sparse_len) + + print(f"{out_ref[0, :5, :5]=}") + print(f"{out_triton[0, :5, :5]=}") + + max_err = (out_ref - out_triton).abs().max().item() + mean_err = (out_ref - out_triton).abs().mean().item() + print(f"[SPARSE ATTENTION TEST] max|ref - triton| = {max_err:.6e}") + print(f"[SPARSE ATTENTION TEST] mean|ref - triton| = {mean_err:.6e}") + assert mean_err < 1e-4, "Triton sparse attention does not match reference!" + print("[SPARSE ATTENTION TEST] Passed!")