Skip to content

Commit cb7997b

Browse files
committed
Removes dropout support from flash attention API
Eliminates dropout functionality across forward pass implementations to simplify the codebase and reduce compilation overhead. Removes dropout parameter handling, probability calculations, random number generation setup, and dropout-related conditional logic from both regular and variable-length attention functions. Simplifies split-KV logic by removing dropout conditional checks and enables certain optimizations that were previously gated by dropout requirements. Updates return signatures to exclude RNG state tensors that are no longer needed without dropout functionality.
1 parent d7075c9 commit cb7997b

File tree

1 file changed

+22
-88
lines changed

1 file changed

+22
-88
lines changed

csrc/flash_api.cpp

Lines changed: 22 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ void set_params_fprop(
4747
void *seqused_k,
4848
void *p_d,
4949
void *softmax_lse_d,
50-
float p_dropout,
5150
float softmax_scale,
5251
bool is_causal,
5352
const float softcap,
@@ -134,20 +133,6 @@ void set_params_fprop(
134133
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
135134
}
136135

137-
// Set this to probability of keeping an element to simplify things.
138-
params.p_dropout = 1.f - p_dropout;
139-
// Convert p from float to int so we don't have to convert the random uint to float to compare.
140-
// [Minor] We want to round down since when we do the comparison we use <= instead of <
141-
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
142-
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
143-
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
144-
params.rp_dropout = 1.f / params.p_dropout;
145-
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
146-
TORCH_CHECK(p_dropout < 1.f);
147-
#ifdef FLASHATTENTION_DISABLE_DROPOUT
148-
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
149-
#endif
150-
151136
params.is_causal = is_causal;
152137
params.is_seqlens_k_cumulative = true;
153138

@@ -223,7 +208,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
223208
const int max_seqlen_k,
224209
const int max_seqlen_q,
225210
const int head_size_rounded,
226-
const float p_dropout,
227211
const int num_splits,
228212
const int num_sm,
229213
struct c10::TensorOptions opts
@@ -239,19 +223,17 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
239223
at::Tensor softmax_lse_accum;
240224
at::Tensor out_accum;
241225

242-
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
243-
if (num_splits < 1) {
244-
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
245-
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
246-
}
247-
if (params.num_splits > 1) {
248-
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
249-
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
250-
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
251-
params.oaccum_ptr = out_accum.data_ptr();
252-
}
253-
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
226+
if (num_splits < 1) {
227+
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
228+
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
229+
}
230+
if (params.num_splits > 1) {
231+
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
232+
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
233+
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
234+
params.oaccum_ptr = out_accum.data_ptr();
254235
}
236+
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
255237

256238
// Temporarily disable Split-KV, because some bugs are still being fixed.
257239
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
@@ -272,12 +254,10 @@ mha_fwd(
272254
const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
273255
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
274256
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
275-
const float p_dropout,
276257
const float softmax_scale,
277258
bool is_causal,
278259
const float softcap,
279-
const bool return_softmax,
280-
std::optional<at::Generator> gen_
260+
const bool return_softmax
281261
) {
282262

283263
// Otherwise the kernel will be launched from cuda:0 device
@@ -313,14 +293,12 @@ mha_fwd(
313293
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
314294
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
315295

316-
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
317-
318296
// causal=true is the same as causal=false in this case
319297
if (seqlen_q == 1) { is_causal = false; }
320298

321299
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
322300
// H/t Daniel Haziza
323-
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
301+
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
324302
const int ngroups = num_heads / num_heads_k;
325303
if (seqlenq_ngroups_swapped) {
326304
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
@@ -357,12 +335,10 @@ mha_fwd(
357335

358336
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
359337
at::Tensor p;
360-
// Only return softmax if there's dropout to reduce compilation time
338+
361339
if (return_softmax) {
362-
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
363340
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
364-
}
365-
else {
341+
} else {
366342
p = torch::empty({ 0 }, opts);
367343
}
368344

@@ -380,7 +356,6 @@ mha_fwd(
380356
/*seqused_k=*/nullptr,
381357
return_softmax ? p.data_ptr() : nullptr,
382358
softmax_lse.data_ptr(),
383-
p_dropout,
384359
softmax_scale,
385360
is_causal,
386361
softcap
@@ -390,26 +365,9 @@ mha_fwd(
390365
at::Tensor softmax_lse_accum, out_accum;
391366
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
392367
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
393-
head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
368+
head_size_rounded, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
394369
);
395370

396-
// number of times random will be generated per thread, to offset philox counter in thc random
397-
// state
398-
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
399-
int64_t counter_offset = params.b * params.h * 32;
400-
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
401-
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
402-
// Forward kernel will populate memory with the seed and offset.
403-
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
404-
405-
if (p_dropout > 0.0) {
406-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
407-
gen_, at::cuda::detail::getDefaultCUDAGenerator());
408-
// See Note [Acquire lock when using random generators]
409-
std::lock_guard<std::mutex> lock(gen->mutex_);
410-
params.philox_args = gen->philox_cuda_state(counter_offset);
411-
}
412-
413371
if (seqlen_k > 0) {
414372
auto stream = at::cuda::getCurrentCUDAStream().stream();
415373
run_mha_fwd(params, stream);
@@ -424,7 +382,7 @@ mha_fwd(
424382
q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
425383
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
426384
}
427-
return {out, softmax_lse, p, rng_state};
385+
return {out, softmax_lse, p};
428386
}
429387

430388
std::vector<at::Tensor>
@@ -442,13 +400,11 @@ mha_varlen_fwd(
442400
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
443401
int max_seqlen_q,
444402
const int max_seqlen_k,
445-
const float p_dropout,
446403
const float softmax_scale,
447404
const bool zero_tensors,
448405
bool is_causal,
449406
const float softcap,
450-
const bool return_softmax,
451-
std::optional<at::Generator> gen_
407+
const bool return_softmax
452408
) {
453409
// Otherwise the kernel will be launched from cuda:0 device
454410
at::cuda::CUDAGuard device_guard{q.device()};
@@ -494,8 +450,6 @@ mha_varlen_fwd(
494450
const int head_size = sizes[2];
495451
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
496452

497-
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
498-
499453
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
500454
const int num_blocks = !paged_KV ? 0 : k.size(0);
501455
const int page_block_size = !paged_KV ? 1 : k.size(1);
@@ -507,7 +461,7 @@ mha_varlen_fwd(
507461

508462
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
509463
// H/t Daniel Haziza
510-
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
464+
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
511465
const int ngroups = num_heads / num_heads_k;
512466
if (seqlenq_ngroups_swapped) {
513467
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
@@ -568,12 +522,10 @@ mha_varlen_fwd(
568522
auto opts = q.options();
569523
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
570524
at::Tensor p;
571-
// Only return softmax if there's dropout to reduce compilation time
525+
572526
if (return_softmax) {
573-
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
574527
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
575-
}
576-
else {
528+
} else {
577529
p = torch::empty({ 0 }, opts);
578530
}
579531

@@ -597,7 +549,6 @@ mha_varlen_fwd(
597549
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
598550
return_softmax ? p.data_ptr() : nullptr,
599551
softmax_lse.data_ptr(),
600-
p_dropout,
601552
softmax_scale,
602553
is_causal,
603554
softcap,
@@ -621,7 +572,7 @@ mha_varlen_fwd(
621572
set_params_splitkv(
622573
params, batch_size, num_heads, head_size,
623574
max_seqlen_k, max_seqlen_q, head_size_rounded,
624-
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
575+
/*num_splits*/ 0, get_num_sm(get_current_device()), opts
625576
);
626577
}
627578

@@ -635,23 +586,6 @@ mha_varlen_fwd(
635586
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
636587
}
637588

638-
// number of times random will be generated per thread, to offset philox counter in thc random
639-
// state
640-
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
641-
int64_t counter_offset = params.b * params.h * 32;
642-
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
643-
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
644-
// Forward kernel will populate memory with the seed and offset.
645-
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
646-
647-
if (p_dropout > 0.0) {
648-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
649-
gen_, at::cuda::detail::getDefaultCUDAGenerator());
650-
// See Note [Acquire lock when using random generators]
651-
std::lock_guard<std::mutex> lock(gen->mutex_);
652-
params.philox_args = gen->philox_cuda_state(counter_offset);
653-
}
654-
655589
if (max_seqlen_k > 0) {
656590
auto stream = at::cuda::getCurrentCUDAStream().stream();
657591
run_mha_fwd(params, stream, paged_KV);
@@ -669,7 +603,7 @@ mha_varlen_fwd(
669603
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
670604
}
671605

672-
return {out, softmax_lse, p, rng_state};
606+
return {out, softmax_lse, p};
673607
}
674608

675609
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)