Skip to content

Commit d9a4d5a

Browse files
committed
Removes dropout functionality from backward kernel
Eliminates dropout-related template parameters, includes, and logic throughout the backward computation kernel. Replaces dropout scaling parameters with standard softmax scaling and removes conditional dropout application during score computation. Adds support for mask and bias operations by introducing new memory copy operations and tensor partitioning for these features. Simplifies the codebase by removing complex dropout branching logic while maintaining the core attention backward pass functionality.
1 parent 3bb3774 commit d9a4d5a

File tree

1 file changed

+67
-61
lines changed

1 file changed

+67
-61
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "utils.h"
1717
#include "softmax.h"
1818
#include "mask.h"
19-
#include "dropout.h"
2019

2120
namespace FLASH_NAMESPACE {
2221

@@ -75,7 +74,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
7574

7675
////////////////////////////////////////////////////////////////////////////////////////////////////
7776

78-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
77+
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
7978
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
8079

8180
using Element = typename Kernel_traits::Element;
@@ -273,9 +272,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
273272
typename Kernel_traits::SmemLayoutdQ{}
274273
);
275274

276-
275+
// Global to Shared Memory operation
277276
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
278277
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
278+
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
279+
auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx);
280+
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
281+
auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx);
279282
using GmemTiledCopydO = std::conditional_t<
280283
Is_first,
281284
typename Kernel_traits::GmemTiledCopydO,
@@ -298,11 +301,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
298301
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
299302
Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
300303
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
301-
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
304+
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
302305
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
303-
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
306+
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
304307
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
305-
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
308+
Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N)
309+
Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask);
310+
Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N)
311+
Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias);
312+
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
306313
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
307314
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
308315
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
@@ -311,32 +318,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
311318
// printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
312319
// }
313320

321+
// Matrix Multiply Accumulate
314322
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
315323
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
316-
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
317-
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
318-
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K)
319-
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K)
324+
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA, MMA_N, MMA_K)
325+
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
326+
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA, MMA_N, MMA_K)
327+
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA, MMA_N, MMA_K)
328+
Tensor tSrMask = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
329+
Tensor tSrBias = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
320330

321331
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
322332
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
323-
Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
324-
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
325-
Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
326-
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
333+
Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
334+
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
335+
Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
336+
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
327337

328338
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
329339
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
330-
Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
331-
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
340+
Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
341+
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
332342

333-
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
334-
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
343+
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
344+
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K)
335345

336-
//
337346
// Copy Atom retiling
338-
//
339-
340347
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
341348
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
342349
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
@@ -363,6 +370,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
363370
// }
364371
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
365372

373+
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
374+
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
375+
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
376+
377+
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
378+
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
379+
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
380+
366381
auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
367382
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
368383
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
@@ -385,14 +400,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
385400
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
386401
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
387402

388-
//
389403
// PREDICATES
390-
//
391-
392-
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
393-
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
404+
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
405+
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
406+
Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M,BLK_N) -> (blk_m,blk_n)
407+
Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M,BLK_N) -> (blk_m,blk_n)
394408
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
395409
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
410+
Tensor tMaskcMask = gmem_thr_copy_Mask.partition_D(cMask);
411+
Tensor tBiascBias = gmem_thr_copy_Bias.partition_D(cBias);
396412

397413
// Allocate predicate tensors for k
398414
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
@@ -407,7 +423,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
407423
}
408424

409425
// Prologue
410-
411426
// We'll advance gdQ and gdQaccum before the 1st read/write.
412427
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
413428
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
@@ -531,7 +546,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
531546
if (Is_first) {
532547
cute::copy(tdOrdO, tdOsdO);
533548
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
534-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
549+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), 0.0f);
535550
}
536551

537552
if (Kernel_traits::Is_V_in_regs) {
@@ -542,9 +557,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
542557
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
543558
}
544559

545-
FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
546-
bidb, bidh, tidx, params.h);
547-
548560
clear(acc_dv);
549561
clear(acc_dk);
550562

@@ -608,24 +620,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
608620
AtomLayoutMS * 16);
609621
}
610622
}
623+
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal>(
624+
scores,
625+
mask,
626+
bias,
627+
params.scale_softmax,
628+
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
629+
binfo.actual_seqlen_k,
630+
m_block * kBlockM + get<0>(taccScS_row(0)),
631+
binfo.actual_seqlen_q,
632+
AtomLayoutMS * 16
633+
);
611634

612635
// if (cute::thread(32, 0)) { print(scores); }
613636
// Compute the exponential value.
614637
FLASH_NAMESPACE::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
615-
if constexpr (Is_dropout) {
616-
int warp_id = tidx / 32;
617-
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
618-
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
619-
static_assert(MMA_N_SdP % 2 == 0);
620-
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
621-
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
622-
acc_s, block_row_idx, block_col_idx, AtomLayoutMS
623-
);
624-
}
625638
// Convert scores from fp32 to fp16/bf16
626-
Tensor rP = !Is_dropout
627-
? FLASH_NAMESPACE::convert_type<Element>(acc_s)
628-
: FLASH_NAMESPACE::convert_type_relu<Element>(acc_s);
639+
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
629640
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
630641
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
631642
Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
@@ -660,7 +671,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
660671
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
661672
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
662673
auto pointwise_mult = [](float p, float dp, float d) {
663-
return p * (!Is_dropout || p >= 0 ? dp - d : d);
674+
return p * (dp - d);
664675
};
665676
#pragma unroll
666677
for (int mi = 0; mi < size<0>(dS); ++mi) {
@@ -757,7 +768,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
757768
}
758769
} else {
759770
#pragma unroll
760-
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
771+
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
761772
// Convert acc_dq from fp32 to fp16
762773
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
763774
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
@@ -781,7 +792,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
781792
if (Is_first && m_block > m_block_min) {
782793
cute::copy(tdOrdO, tdOsdO);
783794
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
784-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
795+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), 0.0f);
785796
}
786797

787798
if (Is_last) {
@@ -803,12 +814,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
803814

804815
// Epilogue
805816

806-
if (Is_dropout) {
807-
#pragma unroll
808-
for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
809-
}
810817
#pragma unroll
811-
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
818+
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax; }
812819

813820
// Convert acc_dv from fp32 to fp16
814821
Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
@@ -874,7 +881,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
874881

875882
////////////////////////////////////////////////////////////////////////////////////////////////////
876883

877-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
884+
template<typename Kernel_traits, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
878885
inline __device__ void compute_dq_dk_dv(const Params &params) {
879886

880887
// The block index for the batch.
@@ -888,20 +895,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
888895

889896
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
890897
if (n_block_max == 1) {
891-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
898+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, true>(params, bidb, bidh, 0);
892899
} else {
893900
// Iterating backward from n_block_max - 1 to 0 might save 1 register
894-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
901+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, false>(params, bidb, bidh, n_block_max - 1);
895902
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
896-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
903+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, false>(params, bidb, bidh, n_block);
897904
}
898-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
905+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, true>(params, bidb, bidh, 0);
899906
}
900907
}
901908

902909
////////////////////////////////////////////////////////////////////////////////////////////////////
903910

904-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
911+
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
905912
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
906913

907914
// The block index for the batch.
@@ -911,9 +918,8 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
911918

912919
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
913920
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
914-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
921+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
915922
}
916923
}
917924

918-
////////////////////////////////////////////////////////////////////////////////////////////////////
919-
} // namespace flash
925+
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)