@@ -304,7 +304,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
304304 Tensor tMasksMask = gmem_thr_copy_Mask.partition_D (sMask );
305305 Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S (gBias ); // (BiasCPY, BiasCPY_M, BiasCPY_N)
306306 Tensor tBiassBias = gmem_thr_copy_Bias.partition_D (sBias );
307-
307+ Tensor tdBiasgdBias = gmem_thr_copy_Bias. partition_D (gdBias);
308308
309309 Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S (sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
310310 Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D (gdQ);
@@ -753,33 +753,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
753753 const int sQ_offset = m_block % 2 == 0 ? size (sQ ) : -size (sQ );
754754 tQsQ.data () = tQsQ.data () + sQ_offset ;
755755 tSsQ.data () = tSsQ.data () + sQ_offset ;
756- // Advance gQ, gMask, gBias
756+ // Advance gQ
757757 tQgQ.data () = tQgQ.data () + (-int (kBlockM * params.q_row_stride ));
758- tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
759- tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
760758 FLASH_NAMESPACE::copy</* Is_even_MN=*/ true , Is_even_K>(
761759 gmem_tiled_copy_QKV,
762760 tQgQ, tQsQ,
763761 tQcQ, tQpQ
764762 );
765- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
766- gmem_tiled_copy_Mask,
767- tMaskgMask, tMasksMask,
768- tMaskcMask,
769- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
770- );
771- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
772- gmem_tiled_copy_Bias,
773- tBiasgBias, tBiassBias,
774- tBiascBias,
775- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
776- );
777763 FLASH_NAMESPACE::cp_async_fence ();
778764 }
779765
780766 Tensor dS_reshaped = make_tensor (dS.data (), acc_dp.layout ());
781767 // Convert dS from fp32 to fp16
782768 Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
769+
770+ // Write tdSrdS to gdBias
771+ Tensor tdBiasrdS = smem_thr_copy_Bias.retile_S (tdSrdS);
772+ cute::copy (smem_tiled_copy_Bias, tdBiasrdS, tSsBias);
773+ __syncthreads ();
774+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ false >(
775+ gmem_tiled_copy_Bias,
776+ tBiassBias, tdBiasgdBias,
777+ tBiascBias,
778+ binfo.actual_seqlen_q - m_block * kBlockM ,
779+ binfo.actual_seqlen_k - n_block * kBlockN
780+ );
781+
783782 // if (cute::thread0()) { print(tPrP); }
784783 Tensor tdSadS = smem_thr_copy_PdS.retile_S (tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
785784 cute::copy (smem_tiled_copy_PdS, tdSadS, tdSsdS);
@@ -838,9 +837,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
838837
839838 if (m_block > m_block_min) {
840839 gLSE .data () = gLSE .data () + (-int (kBlockM ));
840+ tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
841+ tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
842+ tdBiasgdBias.data () = tdBiasgdBias.data () + (-int (kBlockM * params.dbias_row_stride ));
841843 #pragma unroll
842844 for (int mi = 0 ; mi < size (lse); ++mi) { lse (mi) = gLSE (get<0 >(taccScS_row (mi))); }
843845 gdPsum.data () = gdPsum.data () + (-int (kBlockM ));
846+ // Advance gMask, gBias
847+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
848+ gmem_tiled_copy_Mask,
849+ tMaskgMask, tMasksMask,
850+ tMaskcMask,
851+ binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
852+ );
853+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
854+ gmem_tiled_copy_Bias,
855+ tBiasgBias, tBiassBias,
856+ tBiascBias,
857+ binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
858+ );
859+ FLASH_NAMESPACE::cp_async_fence ();
844860 }
845861
846862 if (!Is_last) {
@@ -879,27 +895,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
879895 }
880896 if (!Double_buffer && m_block > m_block_min) {
881897 __syncthreads ();
882- // Advance gQ, gMask, gBias
898+ // Advance gQ
883899 tQgQ.data () = tQgQ.data () + (-int (kBlockM * params.q_row_stride ));
884- tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
885- tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
886900 FLASH_NAMESPACE::copy</* Is_even_MN=*/ true , Is_even_K>(
887901 gmem_tiled_copy_QKV,
888902 tQgQ, tQsQ,
889903 tQcQ, tQpQ
890904 );
891- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
892- gmem_tiled_copy_Mask,
893- tMaskgMask, tMasksMask,
894- tMaskcMask,
895- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
896- );
897- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
898- gmem_tiled_copy_Bias,
899- tBiasgBias, tBiassBias,
900- tBiascBias,
901- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
902- );
903905 FLASH_NAMESPACE::cp_async_fence ();
904906 }
905907
0 commit comments