@@ -47,7 +47,6 @@ void set_params_fprop(
4747 void *seqused_k,
4848 void *p_d,
4949 void *softmax_lse_d,
50- float p_dropout,
5150 float softmax_scale,
5251 bool is_causal,
5352 const float softcap,
@@ -134,20 +133,6 @@ void set_params_fprop(
134133 params.scale_softmax_log2 = softmax_scale * M_LOG2E;
135134 }
136135
137- // Set this to probability of keeping an element to simplify things.
138- params.p_dropout = 1 .f - p_dropout;
139- // Convert p from float to int so we don't have to convert the random uint to float to compare.
140- // [Minor] We want to round down since when we do the comparison we use <= instead of <
141- // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
142- // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
143- params.p_dropout_in_uint8_t = uint8_t (std::floor (params.p_dropout * 255.0 ));
144- params.rp_dropout = 1 .f / params.p_dropout ;
145- params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax ;
146- TORCH_CHECK (p_dropout < 1 .f );
147- #ifdef FLASHATTENTION_DISABLE_DROPOUT
148- TORCH_CHECK (p_dropout == 0 .0f , " This flash attention build does not support dropout." );
149- #endif
150-
151136 params.is_causal = is_causal;
152137 params.is_seqlens_k_cumulative = true ;
153138
@@ -223,7 +208,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
223208 const int max_seqlen_k,
224209 const int max_seqlen_q,
225210 const int head_size_rounded,
226- const float p_dropout,
227211 const int num_splits,
228212 const int num_sm,
229213 struct c10 ::TensorOptions opts
@@ -239,19 +223,17 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
239223 at::Tensor softmax_lse_accum;
240224 at::Tensor out_accum;
241225
242- if (p_dropout == 0 .0f ) { // SplitKV is not implemented for dropout
243- if (num_splits < 1 ) {
244- // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
245- params.num_splits = num_splits_heuristic (batch_size * num_heads * num_m_blocks, num_sm * 2 , num_n_blocks, 128 );
246- }
247- if (params.num_splits > 1 ) {
248- softmax_lse_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q}, opts.dtype (at::kFloat ));
249- out_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype (at::kFloat ));
250- params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr ();
251- params.oaccum_ptr = out_accum.data_ptr ();
252- }
253- TORCH_CHECK (params.num_splits <= 128 , " num_splits > 128 not supported" );
226+ if (num_splits < 1 ) {
227+ // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
228+ params.num_splits = num_splits_heuristic (batch_size * num_heads * num_m_blocks, num_sm * 2 , num_n_blocks, 128 );
229+ }
230+ if (params.num_splits > 1 ) {
231+ softmax_lse_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q}, opts.dtype (at::kFloat ));
232+ out_accum = torch::empty ({params.num_splits , batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype (at::kFloat ));
233+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr ();
234+ params.oaccum_ptr = out_accum.data_ptr ();
254235 }
236+ TORCH_CHECK (params.num_splits <= 128 , " num_splits > 128 not supported" );
255237
256238 // Temporarily disable Split-KV, because some bugs are still being fixed.
257239 // See: https://github.com/SmallDoges/flash-dmattn/issues/47
@@ -272,12 +254,10 @@ mha_fwd(
272254 const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
273255 const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
274256 std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
275- const float p_dropout,
276257 const float softmax_scale,
277258 bool is_causal,
278259 const float softcap,
279- const bool return_softmax,
280- std::optional<at::Generator> gen_
260+ const bool return_softmax
281261) {
282262
283263 // Otherwise the kernel will be launched from cuda:0 device
@@ -313,14 +293,12 @@ mha_fwd(
313293 TORCH_CHECK (head_size % 8 == 0 , " query, key, value, and out_ must have a head_size that is a multiple of 8" );
314294 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
315295
316- if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
317-
318296 // causal=true is the same as causal=false in this case
319297 if (seqlen_q == 1 ) { is_causal = false ; }
320298
321299 // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
322300 // H/t Daniel Haziza
323- const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0 . f && head_size % 8 == 0 ;
301+ const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0 ;
324302 const int ngroups = num_heads / num_heads_k;
325303 if (seqlenq_ngroups_swapped) {
326304 q = q.reshape ({batch_size, num_heads_k, ngroups, head_size}).transpose (1 , 2 );
@@ -357,12 +335,10 @@ mha_fwd(
357335
358336 auto softmax_lse = torch::empty ({batch_size, num_heads, seqlen_q}, opts.dtype (at::kFloat ));
359337 at::Tensor p;
360- // Only return softmax if there's dropout to reduce compilation time
338+
361339 if (return_softmax) {
362- TORCH_CHECK (p_dropout > 0 .0f , " return_softmax is only supported when p_dropout > 0.0" );
363340 p = torch::empty ({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
364- }
365- else {
341+ } else {
366342 p = torch::empty ({ 0 }, opts);
367343 }
368344
@@ -380,7 +356,6 @@ mha_fwd(
380356 /* seqused_k=*/ nullptr ,
381357 return_softmax ? p.data_ptr () : nullptr ,
382358 softmax_lse.data_ptr (),
383- p_dropout,
384359 softmax_scale,
385360 is_causal,
386361 softcap
@@ -390,26 +365,9 @@ mha_fwd(
390365 at::Tensor softmax_lse_accum, out_accum;
391366 std::tie (softmax_lse_accum, out_accum) = set_params_splitkv (
392367 params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
393- head_size_rounded, p_dropout, /* num_splits*/ 0 , get_num_sm (get_current_device ()), opts
368+ head_size_rounded, /* num_splits*/ 0 , get_num_sm (get_current_device ()), opts
394369 );
395370
396- // number of times random will be generated per thread, to offset philox counter in thc random
397- // state
398- // We use a custom RNG that increases the offset by batch_size * nheads * 32.
399- int64_t counter_offset = params.b * params.h * 32 ;
400- auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
401- auto rng_state = torch::empty ({2 }, options.dtype (torch::kInt64 ));
402- // Forward kernel will populate memory with the seed and offset.
403- params.rng_state = reinterpret_cast <uint64_t *>(rng_state.data_ptr ());
404-
405- if (p_dropout > 0.0 ) {
406- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
407- gen_, at::cuda::detail::getDefaultCUDAGenerator ());
408- // See Note [Acquire lock when using random generators]
409- std::lock_guard<std::mutex> lock (gen->mutex_ );
410- params.philox_args = gen->philox_cuda_state (counter_offset);
411- }
412-
413371 if (seqlen_k > 0 ) {
414372 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
415373 run_mha_fwd (params, stream);
@@ -424,7 +382,7 @@ mha_fwd(
424382 q = q.transpose (1 , 2 ).reshape ({batch_size, 1 , num_heads_k * seqlen_q, head_size});
425383 softmax_lse = softmax_lse.reshape ({batch_size, num_heads_k * seqlen_q, 1 });
426384 }
427- return {out, softmax_lse, p, rng_state };
385+ return {out, softmax_lse, p};
428386}
429387
430388std::vector<at::Tensor>
@@ -442,13 +400,11 @@ mha_varlen_fwd(
442400 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
443401 int max_seqlen_q,
444402 const int max_seqlen_k,
445- const float p_dropout,
446403 const float softmax_scale,
447404 const bool zero_tensors,
448405 bool is_causal,
449406 const float softcap,
450- const bool return_softmax,
451- std::optional<at::Generator> gen_
407+ const bool return_softmax
452408) {
453409 // Otherwise the kernel will be launched from cuda:0 device
454410 at::cuda::CUDAGuard device_guard{q.device ()};
@@ -494,8 +450,6 @@ mha_varlen_fwd(
494450 const int head_size = sizes[2 ];
495451 const int num_heads_k = paged_KV ? k.size (2 ) : k.size (1 );
496452
497- if (softcap > 0 .f ) { TORCH_CHECK (p_dropout == 0 .f , " Softcapping does not support dropout for now" ); }
498-
499453 const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size (1 );
500454 const int num_blocks = !paged_KV ? 0 : k.size (0 );
501455 const int page_block_size = !paged_KV ? 1 : k.size (1 );
@@ -507,7 +461,7 @@ mha_varlen_fwd(
507461
508462 // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
509463 // H/t Daniel Haziza
510- const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0 . f && head_size % 8 == 0 ;
464+ const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0 ;
511465 const int ngroups = num_heads / num_heads_k;
512466 if (seqlenq_ngroups_swapped) {
513467 q = q.reshape ({batch_size, num_heads_k, ngroups, head_size}).transpose (1 , 2 ).reshape ({batch_size * ngroups, num_heads_k, head_size});
@@ -568,12 +522,10 @@ mha_varlen_fwd(
568522 auto opts = q.options ();
569523 auto softmax_lse = torch::empty ({num_heads, total_q}, opts.dtype (at::kFloat ));
570524 at::Tensor p;
571- // Only return softmax if there's dropout to reduce compilation time
525+
572526 if (return_softmax) {
573- TORCH_CHECK (p_dropout > 0 .0f , " return_softmax is only supported when p_dropout > 0.0" );
574527 p = torch::empty ({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
575- }
576- else {
528+ } else {
577529 p = torch::empty ({ 0 }, opts);
578530 }
579531
@@ -597,7 +549,6 @@ mha_varlen_fwd(
597549 seqused_k.has_value () ? seqused_k.value ().data_ptr () : nullptr ,
598550 return_softmax ? p.data_ptr () : nullptr ,
599551 softmax_lse.data_ptr (),
600- p_dropout,
601552 softmax_scale,
602553 is_causal,
603554 softcap,
@@ -621,7 +572,7 @@ mha_varlen_fwd(
621572 set_params_splitkv (
622573 params, batch_size, num_heads, head_size,
623574 max_seqlen_k, max_seqlen_q, head_size_rounded,
624- p_dropout, /* num_splits*/ 0 , get_num_sm (get_current_device ()), opts
575+ /* num_splits*/ 0 , get_num_sm (get_current_device ()), opts
625576 );
626577 }
627578
@@ -635,23 +586,6 @@ mha_varlen_fwd(
635586 params.leftpad_k = static_cast <int *>(leftpad_k.data_ptr ());
636587 }
637588
638- // number of times random will be generated per thread, to offset philox counter in thc random
639- // state
640- // We use a custom RNG that increases the offset by batch_size * nheads * 32.
641- int64_t counter_offset = params.b * params.h * 32 ;
642- auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
643- auto rng_state = torch::empty ({2 }, options.dtype (torch::kInt64 ));
644- // Forward kernel will populate memory with the seed and offset.
645- params.rng_state = reinterpret_cast <uint64_t *>(rng_state.data_ptr ());
646-
647- if (p_dropout > 0.0 ) {
648- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
649- gen_, at::cuda::detail::getDefaultCUDAGenerator ());
650- // See Note [Acquire lock when using random generators]
651- std::lock_guard<std::mutex> lock (gen->mutex_ );
652- params.philox_args = gen->philox_cuda_state (counter_offset);
653- }
654-
655589 if (max_seqlen_k > 0 ) {
656590 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
657591 run_mha_fwd (params, stream, paged_KV);
@@ -669,7 +603,7 @@ mha_varlen_fwd(
669603 softmax_lse = softmax_lse.reshape ({num_heads * max_seqlen_q, batch_size});
670604 }
671605
672- return {out, softmax_lse, p, rng_state };
606+ return {out, softmax_lse, p};
673607}
674608
675609} // namespace FLASH_NAMESPACE
0 commit comments