Skip to content

Commit a5b1b49

Browse files
committed
Removes dropout support from flash attention kernels
Simplifies the kernel interface by eliminating the Is_dropout template parameter and associated conditional logic throughout the forward pass implementations. Reduces template instantiation complexity and removes branching logic that was previously used to handle dropout variations for different head dimensions. Streamlines kernel dispatch by removing DROPOUT_SWITCH macros and consolidating execution paths that were previously split based on dropout configuration.
1 parent c685a49 commit a5b1b49

File tree

1 file changed

+57
-87
lines changed

1 file changed

+57
-87
lines changed

csrc/src/flash_fwd_launch_template.h

Lines changed: 57 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ namespace FLASH_NAMESPACE {
3030
template<typename Kernel_traits, __VA_ARGS__> \
3131
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
3232

33-
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
33+
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
3434
#if defined(ARCH_SUPPORTS_FLASH)
35-
FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
35+
FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
3636
#else
3737
FLASH_UNSUPPORTED_ARCH
3838
#endif
@@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L
5151
FLASH_NAMESPACE::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
5252
}
5353

54-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
54+
template<typename Kernel_traits, bool Is_causal>
5555
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
5656
const size_t smem_size = Kernel_traits::kSmemSize;
5757
// printf("smem_size = %d (includes mask memory)\n", int(smem_size));
@@ -72,9 +72,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
7272
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
7373
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
7474
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
75-
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
75+
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && !Is_softcap>;
7676
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
77-
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
77+
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst));
7878
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
7979
if (smem_size >= 48 * 1024) {
8080
C10_CUDA_CHECK(cudaFuncSetAttribute(
@@ -162,107 +162,78 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
162162
template<typename T, bool Is_causal>
163163
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
164164
constexpr static int Headdim = 32;
165-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
166-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
167-
});
165+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
168166
}
169167

170168
template<typename T, bool Is_causal>
171169
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
172170
constexpr static int Headdim = 64;
173-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
174-
if constexpr(!Is_dropout) {
175-
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
176-
// Using block size (64 x 128) is 27% slower for seqlen=2k
177-
// Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling
178-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
179-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
180-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
181-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
182-
} else {
183-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
184-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
185-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
186-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
187-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
188-
}
189-
});
171+
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
172+
// Using block size (64 x 128) is 27% slower for seqlen=2k
173+
// Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling
174+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
175+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
176+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
177+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
190178
}
191179

192180
template<typename T, bool Is_causal>
193181
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
194182
constexpr static int Headdim = 96;
195183
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
196184
bool is_sm8x = cc_major == 8 && cc_minor > 0;
197-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
198-
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
199-
if (is_sm8x) {
200-
if constexpr(!Is_causal) {
201-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
202-
} else {
203-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
204-
}
185+
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
186+
if (is_sm8x) {
187+
if constexpr(!Is_causal) {
188+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
205189
} else {
206-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
190+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
207191
}
208-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
209-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
210-
// These two are always slower
211-
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
212-
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
213-
});
192+
} else {
193+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
194+
}
195+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
196+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
197+
// These two are always slower
198+
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
199+
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
214200
}
215201

216202
template<typename T, bool Is_causal>
217203
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
218204
constexpr static int Headdim = 128;
219205
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
220206
bool is_sm8x = cc_major == 8 && cc_minor > 0;
221-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
222-
if constexpr(!Is_dropout) {
223-
// For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
224-
// Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
225-
if (is_sm8x) {
226-
if constexpr(!Is_causal) {
227-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
228-
} else {
229-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
230-
}
231-
} else {
232-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
233-
}
234-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
235-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
236-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
237-
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
238-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
239-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
240-
// 1st ones are good for H100, A100
241-
// 2nd one is good for A6000 bc we get slightly better occupancy
207+
// For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
208+
// Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
209+
if (is_sm8x) {
210+
if constexpr(!Is_causal) {
211+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
242212
} else {
243-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
244-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
245-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
246-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
213+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
247214
}
248-
});
215+
} else {
216+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
217+
}
218+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
219+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
220+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_causal>(params, stream);
221+
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
222+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_causal>(params, stream);
223+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
224+
// 1st ones are good for H100, A100
225+
// 2nd one is good for A6000 bc we get slightly better occupancy
249226
}
250227

251228
template<typename T, bool Is_causal>
252229
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
253230
constexpr static int Headdim = 192;
254-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
255-
if constexpr(!Is_dropout) {
256-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
257-
} else {
258-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
259-
}
260-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
261-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
262-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
263-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
264-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
265-
});
231+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32, 32, 4, false, false, T>, Is_causal>(params, stream);
232+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
233+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
234+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
235+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
236+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
266237
}
267238

268239
template<typename T, bool Is_causal>
@@ -279,15 +250,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
279250
C10_CUDA_CHECK(status_);
280251
}
281252
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
282-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
283-
// For A100, we want to run with 64 x 64 (112KB smem).
284-
// For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM.
285-
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
286-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
287-
} else {
288-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
289-
}
290-
});
253+
254+
// For A100, we want to run with 64 x 64 (112KB smem).
255+
// For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM.
256+
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
257+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 8, false, false, T>, Is_causal>(params, stream);
258+
} else {
259+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
260+
}
291261
}
292262

293263
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)