Skip to content

Commit 9fa7885

Browse files
committed
Removes dropout functionality from flash attention kernel
Eliminates dropout-related template parameters, includes, and implementation code throughout the attention computation functions. Simplifies the kernel interface by removing Is_dropout template parameter and associated dropout logic including RNG state management, dropout application during attention computation, and dropout-specific normalization paths. Streamlines the codebase by removing dependencies on ATen CUDA utilities and dropout/rotary header files that are no longer needed.
1 parent 1cf7fd4 commit 9fa7885

File tree

1 file changed

+7
-49
lines changed

1 file changed

+7
-49
lines changed

csrc/src/flash_fwd_kernel.h

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#pragma once
66

77
#include "namespace_config.h"
8-
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
98

109
#include <cute/tensor.hpp>
1110

@@ -18,8 +17,6 @@
1817
#include "utils.h"
1918
#include "softmax.h"
2019
#include "mask.h"
21-
#include "dropout.h"
22-
#include "rotary.h"
2320

2421
namespace FLASH_NAMESPACE {
2522

@@ -47,7 +44,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params &params, const int bid
4744
return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
4845
}
4946

50-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
47+
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
5148
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
5249

5350
using Element = typename Kernel_traits::Element;
@@ -65,17 +62,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
6562
constexpr int kHeadDim = Kernel_traits::kHeadDim; // head_dim
6663
constexpr int kNWarps = Kernel_traits::kNWarps;
6764

68-
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
69-
FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
70-
bidb, bidh, tidx, params.h);
71-
72-
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
73-
// exit early and no one saves the rng states.
74-
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
75-
params.rng_state[0] = std::get<0>(seed_offset);
76-
params.rng_state[1] = std::get<1>(seed_offset);
77-
}
78-
7965
// Check if there are any queries to process in the block
8066
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
8167
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
@@ -477,20 +463,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
477463

478464
// Convert acc_s from fp32 to fp16/bf16
479465
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
480-
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
481-
int block_col_idx = n_block * (kBlockN / 32);
482466
if (Return_softmax) {
483-
Tensor rP_drop = make_fragment_like(rP);
484-
cute::copy(rP, rP_drop);
485-
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
486-
rP_drop, block_row_idx, block_col_idx, kNWarps
487-
);
488-
cute::copy(rP_drop, tSgS);
467+
cute::copy(rP, tSgS);
489468
tSgS.data() = tSgS.data() + (-kBlockN);
490469
}
491-
if (Is_dropout) {
492-
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
493-
}
494470

495471
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
496472
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
@@ -574,20 +550,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
574550

575551
// Convert acc_s from fp32 to fp16/bf16
576552
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
577-
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
578-
int block_col_idx = n_block * (kBlockN / 32);
579553
if (Return_softmax) {
580-
Tensor rP_drop = make_fragment_like(rP);
581-
cute::copy(rP, rP_drop);
582-
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
583-
rP_drop, block_row_idx, block_col_idx, kNWarps
584-
);
585-
cute::copy(rP_drop, tSgS);
554+
cute::copy(rP, tSgS);
586555
tSgS.data() = tSgS.data() + (-kBlockN);
587556
}
588-
if (Is_dropout) {
589-
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
590-
}
591557

592558
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
593559
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
@@ -603,7 +569,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
603569

604570
// Epilogue
605571

606-
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
572+
Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax);
607573

608574
// Convert acc_o from fp32 to fp16/bf16
609575
Tensor rO = FLASH_NAMESPACE::convert_type<Element>(acc_o);
@@ -1198,7 +1164,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11981164

11991165
// Epilogue
12001166

1201-
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
1167+
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);
12021168
// if (cute::thread0()) { print(lse); }
12031169

12041170
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
@@ -1276,23 +1242,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
12761242

12771243
////////////////////////////////////////////////////////////////////////////////////////////////////
12781244

1279-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
1245+
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
12801246
inline __device__ void compute_attn(const Params &params) {
12811247
const int m_block = blockIdx.x;
12821248
// The block index for the batch.
12831249
const int bidb = blockIdx.y;
12841250
// The block index for the head.
12851251
const int bidh = blockIdx.z;
12861252

1287-
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
1288-
// them to have the same number of threads or have to traverse the attention matrix
1289-
// in the same order.
1290-
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
1291-
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
1292-
// the attention matrix. This way, as long as we have the batch, head, and the location of
1293-
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
1294-
1295-
FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
1253+
FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
12961254
}
12971255

12981256
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)