Skip to content

Commit a092dfd

Browse files
committed
Adds mask and bias copying in backward kernel
Ensures mask and bias tensors are properly copied during the backward pass computation by adding copy operations for both mask and bias data structures with out-of-bounds clearing enabled.
1 parent 4c3627a commit a092dfd

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
573573
tKVcKV, tKVpKV,
574574
binfo.actual_seqlen_k - n_block * kBlockN
575575
);
576+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
577+
gmem_tiled_copy_Mask,
578+
tMaskgMask, tMasksMask,
579+
tMaskcMask,
580+
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
581+
);
582+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
583+
gmem_tiled_copy_Bias,
584+
tBiasgBias, tBiassBias,
585+
tBiascBias,
586+
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
587+
);
576588
if (!Kernel_traits::Is_V_in_regs) {
577589
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
578590
gmem_tiled_copy_QKV,

0 commit comments

Comments
 (0)