@@ -25,7 +25,13 @@ using namespace cute;
2525// //////////////////////////////////////////////////////////////////////////////////////////////////
2626
2727template <typename ElementAccum, typename Params, int kBlockM , bool Is_even_MN>
28- __forceinline__ __device__ auto get_lse_tile (const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo</* Varlen=*/ !Is_even_MN> &binfo) {
28+ __forceinline__ __device__ auto get_lse_tile (
29+ const Params ¶ms,
30+ const int bidb,
31+ const int bidh,
32+ const int m_block,
33+ const BlockInfo</* Varlen=*/ !Is_even_MN> &binfo
34+ ) {
2935 // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
3036 // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
3137 // Otherwise, it's written as (h, b, seqlen_q).
@@ -244,8 +250,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
244250 Tensor tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA, MMA_M, MMA_K)
245251 Tensor tSrK = thr_mma.partition_fragment_B (sK ); // (MMA, MMA_N, MMA_K)
246252 Tensor tOrVt = thr_mma.partition_fragment_B (sVtNoSwizzle ); // (MMA, MMA_K, MMA_N)
247- Tensor tSrMask = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
248- Tensor tSrBias = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
253+ // Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
254+ // Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
249255 Tensor tSgS = thr_mma.partition_C (gP );
250256 Tensor acc_o = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (MMA, MMA_M, MMA_K)
251257
@@ -268,7 +274,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
268274 auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice (tidx);
269275 Tensor tSsBias = smem_thr_copy_Bias.partition_S (sBias );
270276
277+
271278 // PREDICATES
279+
272280 // // Allocate predicate tensors for m and n
273281 // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
274282 // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
@@ -294,9 +302,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
294302 Tensor tKVcKV = gmem_thr_copy_QKV.partition_S (cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k)
295303 Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S (cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n)
296304 Tensor tBiascBias = gmem_thr_copy_Bias.partition_S (cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n)
305+
297306 // Allocate predicate tensors for k
298307 Tensor tQpQ = make_tensor<bool >(make_shape (size<2 >(tQsQ)));
299308 Tensor tKVpKV = make_tensor<bool >(make_shape (size<2 >(tKsK)));
309+
300310 // Set predicates for k bounds
301311 if (!Is_even_K) {
302312 #pragma unroll
@@ -309,7 +319,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
309319 }
310320 }
311321
322+
312323 // Prologue
324+
313325 // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
314326 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
315327 gmem_tiled_copy_QKV,
@@ -393,6 +405,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
393405 __syncthreads ();
394406
395407 // Copy Mask and Bias from smem to registers
408+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
409+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
396410 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
397411 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
398412 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -419,9 +433,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
419433 // Use sparse general matrix multiplication
420434 FLASH_NAMESPACE::sparse_gemm</* A_in_regs=*/ Kernel_traits::Is_Q_in_regs>(
421435 acc_s,
422- tSrQ,
423- tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
424- tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
436+ tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
437+ tiled_mma,
438+ smem_tiled_copy_Q, smem_tiled_copy_K,
425439 smem_thr_copy_Q, smem_thr_copy_K
426440 );
427441 // if (cute::thread0()) { print(acc_s); }
@@ -483,7 +497,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
483497 FLASH_NAMESPACE::sparse_gemm_rs (
484498 acc_o,
485499 tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
486- tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
500+ tiled_mma,
501+ smem_tiled_copy_V, smem_thr_copy_V
487502 );
488503 // if (cute::thread0()) { print(scores); }
489504
@@ -502,6 +517,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
502517 __syncthreads ();
503518
504519 // Copy Mask and Bias from smem to registers
520+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
521+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
505522 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
506523 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
507524 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -514,11 +531,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
514531 );
515532 cute::cp_async_fence ();
516533
534+ // Use sparse general matrix multiplication
517535 FLASH_NAMESPACE::sparse_gemm</* A_in_regs=*/ Kernel_traits::Is_Q_in_regs>(
518536 acc_s,
519- tSrQ,
520- tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
521- tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
537+ tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
538+ tiled_mma,
539+ smem_tiled_copy_Q, smem_tiled_copy_K,
522540 smem_thr_copy_Q, smem_thr_copy_K
523541 );
524542 if constexpr (Is_softcap){
@@ -574,10 +592,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
574592 FLASH_NAMESPACE::sparse_gemm_rs (
575593 acc_o,
576594 tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
577- tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
595+ tiled_mma,
596+ smem_tiled_copy_V, smem_thr_copy_V
578597 );
579598 }
580599
600+
581601 // Epilogue
582602
583603 Tensor lse = softmax.template normalize_softmax_lse (acc_o, params.scale_softmax );
@@ -857,8 +877,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
857877 Tensor tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA, MMA_M, MMA_K)
858878 Tensor tSrK = thr_mma.partition_fragment_B (sK ); // (MMA, MMA_N, MMA_K)
859879 Tensor tOrVt = thr_mma.partition_fragment_B (sVtNoSwizzle ); // (MMA, MMA_K, MMA_N)
860- Tensor tSrMask = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
861- Tensor tSrBias = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
880+ // Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
881+ // Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
862882 Tensor acc_o = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (MMA, MMA_M, MMA_K)
863883
864884 // Copy Atom retiling
@@ -878,7 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
878898 auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice (tidx);
879899 Tensor tSsBias = smem_thr_copy_Bias.partition_S (sBias );
880900
901+
881902 // PREDICATES
903+
882904 // Construct identity layout for sQ and sK
883905 Tensor cQ = make_identity_tensor (make_shape (size<0 >(sQ ), size<1 >(sQ ))); // (BLK_M, BLK_K) -> (blk_m, blk_k)
884906 Tensor cKV = make_identity_tensor (make_shape (size<0 >(sK ), size<1 >(sK ))); // (BLK_N, BLK_K) -> (blk_n, blk_k)
@@ -904,7 +926,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
904926 }
905927 }
906928
929+
907930 // Prologue
931+
908932 // Read Q from gmem to smem
909933 // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
910934 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
@@ -969,6 +993,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
969993 __syncthreads ();
970994
971995 // Copy Mask and Bias from smem to registers
996+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
997+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
972998 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
973999 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
9741000 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -1004,9 +1030,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
10041030 // Use sparse general matrix multiplication
10051031 FLASH_NAMESPACE::sparse_gemm</* A_in_regs=*/ Kernel_traits::Is_Q_in_regs>(
10061032 acc_s,
1007- tSrQ,
1008- tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1009- tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
1033+ tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1034+ tiled_mma,
1035+ smem_tiled_copy_Q, smem_tiled_copy_K,
10101036 smem_thr_copy_Q, smem_thr_copy_K
10111037 );
10121038 // if (cute::thread0()) { print(acc_s); }
@@ -1080,7 +1106,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
10801106 FLASH_NAMESPACE::sparse_gemm_rs (
10811107 acc_o,
10821108 tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
1083- tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
1109+ tiled_mma,
1110+ smem_tiled_copy_V, smem_thr_copy_V
10841111 );
10851112
10861113 // This check is at the end of the loop since we always have at least 1 iteration
@@ -1098,6 +1125,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
10981125 __syncthreads ();
10991126
11001127 // Copy Mask and Bias from smem to registers
1128+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
1129+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
11011130 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
11021131 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
11031132 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -1120,11 +1149,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11201149 );
11211150 cute::cp_async_fence ();
11221151
1152+ // Use sparse general matrix multiplication
11231153 FLASH_NAMESPACE::sparse_gemm</* A_in_regs=*/ Kernel_traits::Is_Q_in_regs>(
11241154 acc_s,
1125- tSrQ,
1126- tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1127- tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
1155+ tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1156+ tiled_mma,
1157+ smem_tiled_copy_Q, smem_tiled_copy_K,
11281158 smem_thr_copy_Q, smem_thr_copy_K
11291159 );
11301160 if constexpr (Is_softcap){
@@ -1190,10 +1220,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11901220 FLASH_NAMESPACE::sparse_gemm_rs (
11911221 acc_o,
11921222 tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
1193- tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
1223+ tiled_mma,
1224+ smem_tiled_copy_V, smem_thr_copy_V
11941225 );
11951226 }
11961227
1228+
11971229 // Epilogue
11981230
11991231 Tensor lse = softmax.template normalize_softmax_lse <Split>(acc_o, params.scale_softmax );
0 commit comments