Skip to content

Commit a56ed00

Browse files
committed
Refactors bias gradient computation and memory access pattern
Moves bias gradient writing to occur immediately after computing the gradient values, improving memory locality and reducing synchronization overhead. Consolidates mask and bias loading operations to occur together in the main loop iteration, eliminating redundant memory access patterns and improving cache efficiency. Adds proper gradient bias tensor partitioning to support the new computation flow.
1 parent b42be8b commit a56ed00

File tree

1 file changed

+33
-31
lines changed

1 file changed

+33
-31
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
304304
Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask);
305305
Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N)
306306
Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias);
307-
307+
Tensor tdBiasgdBias = gmem_thr_copy_Bias.partition_D(gdBias);
308308

309309
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
310310
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
@@ -753,33 +753,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
753753
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
754754
tQsQ.data() = tQsQ.data() + sQ_offset;
755755
tSsQ.data() = tSsQ.data() + sQ_offset;
756-
// Advance gQ, gMask, gBias
756+
// Advance gQ
757757
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
758-
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
759-
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
760758
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(
761759
gmem_tiled_copy_QKV,
762760
tQgQ, tQsQ,
763761
tQcQ, tQpQ
764762
);
765-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
766-
gmem_tiled_copy_Mask,
767-
tMaskgMask, tMasksMask,
768-
tMaskcMask,
769-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
770-
);
771-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
772-
gmem_tiled_copy_Bias,
773-
tBiasgBias, tBiassBias,
774-
tBiascBias,
775-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
776-
);
777763
FLASH_NAMESPACE::cp_async_fence();
778764
}
779765

780766
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
781767
// Convert dS from fp32 to fp16
782768
Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
769+
770+
// Write tdSrdS to gdBias
771+
Tensor tdBiasrdS = smem_thr_copy_Bias.retile_S(tdSrdS);
772+
cute::copy(smem_tiled_copy_Bias, tdBiasrdS, tSsBias);
773+
__syncthreads();
774+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
775+
gmem_tiled_copy_Bias,
776+
tBiassBias, tdBiasgdBias,
777+
tBiascBias,
778+
binfo.actual_seqlen_q - m_block * kBlockM,
779+
binfo.actual_seqlen_k - n_block * kBlockN
780+
);
781+
783782
// if (cute::thread0()) { print(tPrP); }
784783
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
785784
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
@@ -838,9 +837,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
838837

839838
if (m_block > m_block_min) {
840839
gLSE.data() = gLSE.data() + (-int(kBlockM));
840+
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
841+
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
842+
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
841843
#pragma unroll
842844
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
843845
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
846+
// Advance gMask, gBias
847+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
848+
gmem_tiled_copy_Mask,
849+
tMaskgMask, tMasksMask,
850+
tMaskcMask,
851+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
852+
);
853+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
854+
gmem_tiled_copy_Bias,
855+
tBiasgBias, tBiassBias,
856+
tBiascBias,
857+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
858+
);
859+
FLASH_NAMESPACE::cp_async_fence();
844860
}
845861

846862
if (!Is_last) {
@@ -879,27 +895,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
879895
}
880896
if (!Double_buffer && m_block > m_block_min) {
881897
__syncthreads();
882-
// Advance gQ, gMask, gBias
898+
// Advance gQ
883899
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
884-
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
885-
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
886900
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(
887901
gmem_tiled_copy_QKV,
888902
tQgQ, tQsQ,
889903
tQcQ, tQpQ
890904
);
891-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
892-
gmem_tiled_copy_Mask,
893-
tMaskgMask, tMasksMask,
894-
tMaskcMask,
895-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
896-
);
897-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
898-
gmem_tiled_copy_Bias,
899-
tBiasgBias, tBiassBias,
900-
tBiascBias,
901-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
902-
);
903905
FLASH_NAMESPACE::cp_async_fence();
904906
}
905907

0 commit comments

Comments
 (0)