@@ -250,8 +250,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
250250 Tensor tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA, MMA_M, MMA_K)
251251 Tensor tSrK = thr_mma.partition_fragment_B (sK ); // (MMA, MMA_N, MMA_K)
252252 Tensor tOrVt = thr_mma.partition_fragment_B (sVtNoSwizzle ); // (MMA, MMA_K, MMA_N)
253- Tensor tSrMask = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
254- Tensor tSrBias = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
253+ // Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
254+ // Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
255255 Tensor tSgS = thr_mma.partition_C (gP );
256256 Tensor acc_o = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (MMA, MMA_M, MMA_K)
257257
@@ -405,6 +405,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
405405 __syncthreads ();
406406
407407 // Copy Mask and Bias from smem to registers
408+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
409+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
408410 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
409411 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
410412 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -515,6 +517,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
515517 __syncthreads ();
516518
517519 // Copy Mask and Bias from smem to registers
520+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
521+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
518522 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
519523 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
520524 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -873,8 +877,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
873877 Tensor tSrQ = thr_mma.partition_fragment_A (sQ ); // (MMA, MMA_M, MMA_K)
874878 Tensor tSrK = thr_mma.partition_fragment_B (sK ); // (MMA, MMA_N, MMA_K)
875879 Tensor tOrVt = thr_mma.partition_fragment_B (sVtNoSwizzle ); // (MMA, MMA_K, MMA_N)
876- Tensor tSrMask = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
877- Tensor tSrBias = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
880+ // Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
881+ // Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
878882 Tensor acc_o = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (MMA, MMA_M, MMA_K)
879883
880884 // Copy Atom retiling
@@ -989,6 +993,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
989993 __syncthreads ();
990994
991995 // Copy Mask and Bias from smem to registers
996+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
997+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
992998 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
993999 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
9941000 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
@@ -1119,6 +1125,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11191125 __syncthreads ();
11201126
11211127 // Copy Mask and Bias from smem to registers
1128+ Tensor tSrMask = make_tensor<Element>(shape (acc_s));
1129+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
11221130 Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D (tSrMask);
11231131 cute::copy (smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
11241132 Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D (tSrBias);
0 commit comments