1616#include " utils.h"
1717#include " softmax.h"
1818#include " mask.h"
19- #include " dropout.h"
2019
2120namespace FLASH_NAMESPACE {
2221
@@ -75,7 +74,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
7574
7675// //////////////////////////////////////////////////////////////////////////////////////////////////
7776
78- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false , typename Params>
77+ template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false , typename Params>
7978inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
8079
8180 using Element = typename Kernel_traits::Element;
@@ -273,9 +272,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
273272 typename Kernel_traits::SmemLayoutdQ{}
274273 );
275274
276-
275+ // Global to Shared Memory operation
277276 typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
278277 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
278+ typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
279+ auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice (tidx);
280+ typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
281+ auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice (tidx);
279282 using GmemTiledCopydO = std::conditional_t <
280283 Is_first,
281284 typename Kernel_traits::GmemTiledCopydO,
@@ -298,11 +301,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
298301 Tensor tdOgdO = gmem_thr_copy_dO.partition_S (gdO);
299302 Tensor tdOsdO = gmem_thr_copy_dO.partition_D (sdO);
300303 Tensor tdOgO = gmem_thr_copy_dO.partition_S (gO );
301- Tensor tKgK = gmem_thr_copy_QKV.partition_S (gK ); // (KCPY, KCPY_N, KCPY_K)
304+ Tensor tKgK = gmem_thr_copy_QKV.partition_S (gK ); // (KCPY, KCPY_N, KCPY_K)
302305 Tensor tKsK = gmem_thr_copy_QKV.partition_D (sK );
303- Tensor tVgV = gmem_thr_copy_QKV.partition_S (gV ); // (VCPY, VCPY_N, VCPY_K)
306+ Tensor tVgV = gmem_thr_copy_QKV.partition_S (gV ); // (VCPY, VCPY_N, VCPY_K)
304307 Tensor tVsV = gmem_thr_copy_QKV.partition_D (sV );
305- Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S (sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
308+ Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S (gMask ); // (MaskCPY, MaskCPY_M, MaskCPY_N)
309+ Tensor tMasksMask = gmem_thr_copy_Mask.partition_D (sMask );
310+ Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S (gBias ); // (BiasCPY, BiasCPY_M, BiasCPY_N)
311+ Tensor tBiassBias = gmem_thr_copy_Bias.partition_D (sBias );
312+ Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S (sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
306313 Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D (gdQ);
307314 Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D (gdQaccum);
308315 // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
@@ -311,32 +318,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
311318 // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
312319 // }
313320
321+ // Matrix Multiply Accumulate
314322 typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
315323 auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice (tidx);
316- Tensor tSrQ = thr_mma_sdp.partition_fragment_A (sQ ); // (MMA,MMA_N,MMA_K)
317- Tensor tSrK = thr_mma_sdp.partition_fragment_B (sK ); // (MMA,MMA_N,MMA_K)
318- Tensor tdPrdO = thr_mma_sdp.partition_fragment_A (sdO); // (MMA,MMA_N,MMA_K)
319- Tensor tdPrV = thr_mma_sdp.partition_fragment_B (sV ); // (MMA,MMA_N,MMA_K)
324+ Tensor tSrQ = thr_mma_sdp.partition_fragment_A (sQ ); // (MMA, MMA_N, MMA_K)
325+ Tensor tSrK = thr_mma_sdp.partition_fragment_B (sK ); // (MMA, MMA_N, MMA_K)
326+ Tensor tdPrdO = thr_mma_sdp.partition_fragment_A (sdO); // (MMA, MMA_N, MMA_K)
327+ Tensor tdPrV = thr_mma_sdp.partition_fragment_B (sV ); // (MMA, MMA_N, MMA_K)
328+ Tensor tSrMask = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
329+ Tensor tSrBias = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA, MMA_M, MMA_N)
320330
321331 typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
322332 auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice (tidx);
323- Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A (sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
324- Tensor tdKrQt = thr_mma_dkv.partition_fragment_B (sQtNoSwizzle ); // (MMA, MMA_K, MMA_N)
325- Tensor tdVrPt = thr_mma_dkv.partition_fragment_A (sPtNoSwizzle ); // (MMA, MMA_N, MMA_N)
326- Tensor tdVrdO = thr_mma_dkv.partition_fragment_B (sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
333+ Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A (sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
334+ Tensor tdKrQt = thr_mma_dkv.partition_fragment_B (sQtNoSwizzle ); // (MMA, MMA_K, MMA_N)
335+ Tensor tdVrPt = thr_mma_dkv.partition_fragment_A (sPtNoSwizzle ); // (MMA, MMA_N, MMA_N)
336+ Tensor tdVrdO = thr_mma_dkv.partition_fragment_B (sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
327337
328338 typename Kernel_traits::TiledMmadQ tiled_mma_dq;
329339 auto thr_mma_dq = tiled_mma_dq.get_thread_slice (tidx);
330- Tensor tdQrdS = thr_mma_dq.partition_fragment_A (sdS); // (MMA, MMA_N, MMA_N)
331- Tensor tdQrKt = thr_mma_dq.partition_fragment_B (sKtNoSwizzle ); // (MMA, MMA_K, MMA_N)
340+ Tensor tdQrdS = thr_mma_dq.partition_fragment_A (sdS); // (MMA, MMA_N, MMA_N)
341+ Tensor tdQrKt = thr_mma_dq.partition_fragment_B (sKtNoSwizzle ); // (MMA, MMA_K, MMA_N)
332342
333- Tensor acc_dk = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // MMA, MMA_N, MMA_K
334- Tensor acc_dv = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // MMA, MMA_N, MMA_K
343+ Tensor acc_dk = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // ( MMA, MMA_N, MMA_K)
344+ Tensor acc_dv = partition_fragment_C (tiled_mma_dkv, Shape<Int<kBlockN >, Int<kHeadDim >>{}); // ( MMA, MMA_N, MMA_K)
335345
336- //
337346 // Copy Atom retiling
338- //
339-
340347 auto smem_tiled_copy_QdO = make_tiled_copy_A (typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
341348 auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice (tidx);
342349 Tensor tSsQ = smem_thr_copy_QdO.partition_S (sQ );
@@ -363,6 +370,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
363370 // }
364371 Tensor tdSsdS = smem_thr_copy_PdS.partition_D (sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
365372
373+ auto smem_tiled_copy_Mask = make_tiled_copy_C (typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
374+ auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice (tidx);
375+ Tensor tSsMask = smem_thr_copy_Mask.partition_S (sMask );
376+
377+ auto smem_tiled_copy_Bias = make_tiled_copy_C (typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
378+ auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice (tidx);
379+ Tensor tSsBias = smem_thr_copy_Bias.partition_S (sBias );
380+
366381 auto smem_tiled_copy_PdSt = make_tiled_copy_A (typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
367382 auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice (tidx);
368383 Tensor tdVsPt = smem_thr_copy_PdSt.partition_S (sPt );
@@ -385,14 +400,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
385400 auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice (tidx);
386401 Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D (sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
387402
388- //
389403 // PREDICATES
390- //
391-
392- Tensor cQ = make_identity_tensor (make_shape (size<0 >(sQ ), size<1 >(sQ ))); // (BLK_M,BLK_K ) -> (blk_m,blk_k )
393- Tensor cKV = make_identity_tensor (make_shape (size<0 >(sK ), size<1 >(sK ))); // (BLK_N,BLK_K ) -> (blk_n,blk_k )
404+ Tensor cQ = make_identity_tensor ( make_shape (size< 0 >( sQ ), size< 1 >( sQ ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
405+ Tensor cKV = make_identity_tensor ( make_shape (size< 0 >( sK ), size< 1 >( sK ))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
406+ Tensor cMask = make_identity_tensor (make_shape (size<0 >(sMask ), size<1 >(sMask ))); // (BLK_M,BLK_N ) -> (blk_m,blk_n )
407+ Tensor cBias = make_identity_tensor (make_shape (size<0 >(sBias ), size<1 >(sBias ))); // (BLK_M,BLK_N ) -> (blk_m,blk_n )
394408 Tensor tQcQ = gmem_thr_copy_QKV.partition_D (cQ);
395409 Tensor tKVcKV = gmem_thr_copy_QKV.partition_D (cKV);
410+ Tensor tMaskcMask = gmem_thr_copy_Mask.partition_D (cMask);
411+ Tensor tBiascBias = gmem_thr_copy_Bias.partition_D (cBias);
396412
397413 // Allocate predicate tensors for k
398414 Tensor tQpQ = make_tensor<bool >(make_shape (size<2 >(tQsQ)));
@@ -407,7 +423,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
407423 }
408424
409425 // Prologue
410-
411426 // We'll advance gdQ and gdQaccum before the 1st read/write.
412427 tdQgdQ.data () = tdQgdQ.data () + kBlockM * params.dq_row_stride ;
413428 tdQgdQaccum.data () = tdQgdQaccum.data () + kBlockM * params.h * params.d_rounded ;
@@ -531,7 +546,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
531546 if (Is_first) {
532547 cute::copy (tdOrdO, tdOsdO);
533548 dot_do_o<Kernel_traits::kGmemThreadsPerRow >(tdOrdO, tdOrO, gdPsum,
534- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow ), params. p_dropout );
549+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow ), 0 . 0f );
535550 }
536551
537552 if (Kernel_traits::Is_V_in_regs) {
@@ -542,9 +557,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
542557 cute::copy (smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
543558 }
544559
545- FLASH_NAMESPACE::Dropout dropout (params.rng_state [0 ], params.rng_state [1 ], params.p_dropout_in_uint8_t ,
546- bidb, bidh, tidx, params.h );
547-
548560 clear (acc_dv);
549561 clear (acc_dk);
550562
@@ -608,24 +620,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
608620 AtomLayoutMS * 16 );
609621 }
610622 }
623+ FLASH_NAMESPACE::apply_mask</* Causal_mask=*/ Is_causal>(
624+ scores,
625+ mask,
626+ bias,
627+ params.scale_softmax ,
628+ n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16 ,
629+ binfo.actual_seqlen_k ,
630+ m_block * kBlockM + get<0 >(taccScS_row (0 )),
631+ binfo.actual_seqlen_q ,
632+ AtomLayoutMS * 16
633+ );
611634
612635 // if (cute::thread(32, 0)) { print(scores); }
613636 // Compute the exponential value.
614637 FLASH_NAMESPACE::scale_apply_exp2</* scale_max=*/ false >(scores, lse, params.scale_softmax_log2 );
615- if constexpr (Is_dropout) {
616- int warp_id = tidx / 32 ;
617- int block_row_idx = m_block * (kBlockM / 16 ) + warp_id % AtomLayoutMS;
618- // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
619- static_assert (MMA_N_SdP % 2 == 0 );
620- int block_col_idx = n_block * (kBlockN / 32 ) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2 );
621- dropout.template apply_dropout </* encode_dropout_in_sign_bit=*/ true >(
622- acc_s, block_row_idx, block_col_idx, AtomLayoutMS
623- );
624- }
625638 // Convert scores from fp32 to fp16/bf16
626- Tensor rP = !Is_dropout
627- ? FLASH_NAMESPACE::convert_type<Element>(acc_s)
628- : FLASH_NAMESPACE::convert_type_relu<Element>(acc_s);
639+ Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
629640 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
630641 // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
631642 Tensor tPrP = make_tensor (rP.data (), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout ()));
@@ -660,7 +671,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
660671 // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
661672 Tensor dS = make_tensor (acc_dp.data (), scores.layout ());
662673 auto pointwise_mult = [](float p, float dp, float d) {
663- return p * (!Is_dropout || p >= 0 ? dp - d : d);
674+ return p * (dp - d);
664675 };
665676 #pragma unroll
666677 for (int mi = 0 ; mi < size<0 >(dS); ++mi) {
@@ -757,7 +768,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
757768 }
758769 } else {
759770 #pragma unroll
760- for (int i = 0 ; i < size (acc_dq); ++i) { acc_dq (i) *= params.scale_softmax_rp_dropout ; }
771+ for (int i = 0 ; i < size (acc_dq); ++i) { acc_dq (i) *= params.scale_softmax ; }
761772 // Convert acc_dq from fp32 to fp16
762773 Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
763774 Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S (rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
@@ -781,7 +792,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
781792 if (Is_first && m_block > m_block_min) {
782793 cute::copy (tdOrdO, tdOsdO);
783794 dot_do_o<Kernel_traits::kGmemThreadsPerRow >(tdOrdO, tdOrO, gdPsum,
784- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow ), params. p_dropout );
795+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow ), 0 . 0f );
785796 }
786797
787798 if (Is_last) {
@@ -803,12 +814,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
803814
804815 // Epilogue
805816
806- if (Is_dropout) {
807- #pragma unroll
808- for (int i = 0 ; i < size (acc_dv); ++i) { acc_dv (i) *= params.rp_dropout ; }
809- }
810817 #pragma unroll
811- for (int i = 0 ; i < size (acc_dk); ++i) { acc_dk (i) *= params.scale_softmax_rp_dropout ; }
818+ for (int i = 0 ; i < size (acc_dk); ++i) { acc_dk (i) *= params.scale_softmax ; }
812819
813820 // Convert acc_dv from fp32 to fp16
814821 Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
@@ -874,7 +881,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
874881
875882// //////////////////////////////////////////////////////////////////////////////////////////////////
876883
877- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
884+ template <typename Kernel_traits, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
878885inline __device__ void compute_dq_dk_dv (const Params ¶ms) {
879886
880887 // The block index for the batch.
@@ -888,20 +895,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
888895
889896 const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1 ) / Kernel_traits::kBlockN ;
890897 if (n_block_max == 1 ) {
891- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true , true >(params, bidb, bidh, 0 );
898+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , true , true >(params, bidb, bidh, 0 );
892899 } else {
893900 // Iterating backward from n_block_max - 1 to 0 might save 1 register
894- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true , false >(params, bidb, bidh, n_block_max - 1 );
901+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , true , false >(params, bidb, bidh, n_block_max - 1 );
895902 for (int n_block = n_block_max - 2 ; n_block > 0 ; n_block--) {
896- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false , false >(params, bidb, bidh, n_block);
903+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , false , false >(params, bidb, bidh, n_block);
897904 }
898- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false , true >(params, bidb, bidh, 0 );
905+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , false , true >(params, bidb, bidh, 0 );
899906 }
900907}
901908
902909// //////////////////////////////////////////////////////////////////////////////////////////////////
903910
904- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
911+ template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
905912inline __device__ void compute_dq_dk_dv_seqk_parallel (const Params ¶ms) {
906913
907914 // The block index for the batch.
@@ -911,9 +918,8 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
911918
912919 // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
913920 for (int n_block = blockIdx.x ; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1 ) / Kernel_traits::kBlockN ; n_block += gridDim.x ) {
914- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
921+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
915922 }
916923}
917924
918- // //////////////////////////////////////////////////////////////////////////////////////////////////
919- } // namespace flash
925+ } // namespace FLASH_NAMESPACE
0 commit comments