Skip to content

Commit b42be8b

Browse files
committed
Adds mask and bias handling to backward kernel block loop
Extends the memory advancement logic to include mask and bias tensors alongside the existing query tensor handling. This ensures all relevant tensors are properly synchronized when processing multiple blocks in the backward pass, maintaining consistency across attention computations. The change mirrors the query tensor advancement pattern by updating pointers and copying data for both mask and bias tensors using the same block-based iteration approach.
1 parent 8c5bcff commit b42be8b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,13 +879,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
879879
}
880880
if (!Double_buffer && m_block > m_block_min) {
881881
__syncthreads();
882-
// Advance gQ
882+
// Advance gQ, gMask, gBias
883883
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
884+
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
885+
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
884886
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(
885887
gmem_tiled_copy_QKV,
886888
tQgQ, tQsQ,
887889
tQcQ, tQpQ
888890
);
891+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
892+
gmem_tiled_copy_Mask,
893+
tMaskgMask, tMasksMask,
894+
tMaskcMask,
895+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
896+
);
897+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
898+
gmem_tiled_copy_Bias,
899+
tBiasgBias, tBiassBias,
900+
tBiascBias,
901+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
902+
);
889903
FLASH_NAMESPACE::cp_async_fence();
890904
}
891905

0 commit comments

Comments
 (0)