Skip to content

Commit 6d6ab5d

Browse files
committed
Moves tensor declarations to fix scope issues
Moves mask and bias tensor declarations from global scope to local scope within loops where they are used. This fixes potential compilation or runtime issues by ensuring tensors are properly scoped and initialized with the correct dimensions based on the accumulator tensor shape at the point of use.
1 parent 44e4d9f commit 6d6ab5d

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

csrc/src/flash_fwd_kernel.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
250250
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
251251
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
252252
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
253-
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
254-
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
253+
// Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
254+
// Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
255255
Tensor tSgS = thr_mma.partition_C(gP);
256256
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)
257257

@@ -405,6 +405,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
405405
__syncthreads();
406406

407407
// Copy Mask and Bias from smem to registers
408+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
409+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
408410
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
409411
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
410412
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -515,6 +517,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
515517
__syncthreads();
516518

517519
// Copy Mask and Bias from smem to registers
520+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
521+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
518522
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
519523
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
520524
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -873,8 +877,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
873877
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
874878
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
875879
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
876-
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
877-
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
880+
// Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
881+
// Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
878882
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)
879883

880884
// Copy Atom retiling
@@ -989,6 +993,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
989993
__syncthreads();
990994

991995
// Copy Mask and Bias from smem to registers
996+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
997+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
992998
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
993999
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
9941000
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -1119,6 +1125,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11191125
__syncthreads();
11201126

11211127
// Copy Mask and Bias from smem to registers
1128+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
1129+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
11221130
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
11231131
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
11241132
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);

0 commit comments

Comments
 (0)