Skip to content

Commit 4c3627a

Browse files
committed
Adds support for mask and bias tensors in backward kernel
Extends tensor partitioning to include mask and bias identity tensors alongside existing query and key-value tensors. Enables proper handling of attention masks and bias terms during backward pass computation by creating corresponding partitioned tensors with appropriate layouts.
1 parent 0679a60 commit 4c3627a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
409409

410410
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M, BLK_K) -> (blk_m, blk_k)
411411
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N, BLK_K) -> (blk_n, blk_k)
412-
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
413-
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
412+
Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M, BLK_N) -> (blk_m, blk_n)
413+
Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M, BLK_N) -> (blk_m, blk_n)
414+
415+
// Repeat the partitioning with identity layouts
416+
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k)
417+
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k)
418+
Tensor tMaskcMask = gmem_thr_copy_Mask.partition_D(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n)
419+
Tensor tBiascBias = gmem_thr_copy_Bias.partition_D(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n)
414420

415421
// Allocate predicate tensors for k
416422
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));

0 commit comments

Comments
 (0)