Skip to content

Commit 8c5bcff

Browse files
committed
Adds mask and bias copying to backward kernel loop
Extends the existing query tensor copying logic to also handle mask and bias tensors during backward pass computation. Updates pointer advancement to include mask and bias row strides, ensuring proper memory alignment across iterations. Adds bounds checking for out-of-bounds elements to prevent memory access violations when copying mask and bias data.
1 parent 2a9b919 commit 8c5bcff

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
753753
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
754754
tQsQ.data() = tQsQ.data() + sQ_offset;
755755
tSsQ.data() = tSsQ.data() + sQ_offset;
756-
// Advance gQ
756+
// Advance gQ, gMask, gBias
757757
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
758+
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
759+
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
758760
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(
759761
gmem_tiled_copy_QKV,
760762
tQgQ, tQsQ,
761763
tQcQ, tQpQ
762764
);
765+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
766+
gmem_tiled_copy_Mask,
767+
tMaskgMask, tMasksMask,
768+
tMaskcMask,
769+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
770+
);
771+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
772+
gmem_tiled_copy_Bias,
773+
tBiasgBias, tBiassBias,
774+
tBiascBias,
775+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
776+
);
763777
FLASH_NAMESPACE::cp_async_fence();
764778
}
765779

0 commit comments

Comments
 (0)