Skip to content

Commit 9900589

Browse files
committed
Adds commented tensor declarations for mask and bias
Includes placeholder tensor declarations for future mask and bias support in the backward kernel computation. These commented lines prepare the codebase for upcoming attention mask and bias functionality.
1 parent 7b90c9c commit 9900589

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
322322
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
323323
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA, MMA_N, MMA_K)
324324
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA, MMA_N, MMA_K)
325+
// Tensor tSrMask = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
326+
// Tensor tSrBias = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
325327

326328
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
327329
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);

0 commit comments

Comments
 (0)