55#pragma once
66
77#include " namespace_config.h"
8- #include < ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
98
109#include < cute/tensor.hpp>
1110
1817#include " utils.h"
1918#include " softmax.h"
2019#include " mask.h"
21- #include " dropout.h"
22- #include " rotary.h"
2320
2421namespace FLASH_NAMESPACE {
2522
@@ -47,7 +44,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid
4744 return local_tile (mLSE_slice , Shape<Int<kBlockM >>{}, make_coord (m_block));
4845}
4946
50- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
47+ template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
5148inline __device__ void compute_attn_1rowblock (const Params ¶ms, const int bidb, const int bidh, const int m_block) {
5249
5350 using Element = typename Kernel_traits::Element;
@@ -65,17 +62,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
6562 constexpr int kHeadDim = Kernel_traits::kHeadDim ; // head_dim
6663 constexpr int kNWarps = Kernel_traits::kNWarps ;
6764
68- auto seed_offset = at::cuda::philox::unpack (params.philox_args );
69- FLASH_NAMESPACE::Dropout dropout (std::get<0 >(seed_offset), std::get<1 >(seed_offset), params.p_dropout_in_uint8_t ,
70- bidb, bidh, tidx, params.h );
71-
72- // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
73- // exit early and no one saves the rng states.
74- if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0 ) {
75- params.rng_state [0 ] = std::get<0 >(seed_offset);
76- params.rng_state [1 ] = std::get<1 >(seed_offset);
77- }
78-
7965 // Check if there are any queries to process in the block
8066 const BlockInfo</* Varlen=*/ !Is_even_MN> binfo (params, bidb);
8167 if (m_block * kBlockM >= binfo.actual_seqlen_q ) return ;
@@ -477,20 +463,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
477463
478464 // Convert acc_s from fp32 to fp16/bf16
479465 Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
480- int block_row_idx = m_block * (kBlockM / 16 ) + tidx / 32 ;
481- int block_col_idx = n_block * (kBlockN / 32 );
482466 if (Return_softmax) {
483- Tensor rP_drop = make_fragment_like (rP);
484- cute::copy (rP, rP_drop);
485- dropout.template apply_dropout </* encode_dropout_in_sign_bit=*/ true >(
486- rP_drop, block_row_idx, block_col_idx, kNWarps
487- );
488- cute::copy (rP_drop, tSgS);
467+ cute::copy (rP, tSgS);
489468 tSgS.data () = tSgS.data () + (-kBlockN );
490469 }
491- if (Is_dropout) {
492- dropout.apply_dropout (rP, block_row_idx, block_col_idx, kNWarps );
493- }
494470
495471 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
496472 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
@@ -574,20 +550,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
574550
575551 // Convert acc_s from fp32 to fp16/bf16
576552 Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
577- int block_row_idx = m_block * (kBlockM / 16 ) + tidx / 32 ;
578- int block_col_idx = n_block * (kBlockN / 32 );
579553 if (Return_softmax) {
580- Tensor rP_drop = make_fragment_like (rP);
581- cute::copy (rP, rP_drop);
582- dropout.template apply_dropout </* encode_dropout_in_sign_bit=*/ true >(
583- rP_drop, block_row_idx, block_col_idx, kNWarps
584- );
585- cute::copy (rP_drop, tSgS);
554+ cute::copy (rP, tSgS);
586555 tSgS.data () = tSgS.data () + (-kBlockN );
587556 }
588- if (Is_dropout) {
589- dropout.apply_dropout (rP, block_row_idx, block_col_idx, kNWarps );
590- }
591557
592558 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
593559 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
@@ -603,7 +569,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
603569
604570 // Epilogue
605571
606- Tensor lse = softmax.template normalize_softmax_lse <Is_dropout> (acc_o, params.scale_softmax , params. rp_dropout );
572+ Tensor lse = softmax.template normalize_softmax_lse (acc_o, params.scale_softmax );
607573
608574 // Convert acc_o from fp32 to fp16/bf16
609575 Tensor rO = FLASH_NAMESPACE::convert_type<Element>(acc_o);
@@ -1198,7 +1164,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11981164
11991165 // Epilogue
12001166
1201- Tensor lse = softmax.template normalize_softmax_lse </* Is_dropout= */ false , Split>(acc_o, params.scale_softmax );
1167+ Tensor lse = softmax.template normalize_softmax_lse <Split>(acc_o, params.scale_softmax );
12021168 // if (cute::thread0()) { print(lse); }
12031169
12041170 Tensor sOaccum = make_tensor (make_smem_ptr (reinterpret_cast <ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
@@ -1276,23 +1242,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
12761242
12771243// //////////////////////////////////////////////////////////////////////////////////////////////////
12781244
1279- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
1245+ template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
12801246inline __device__ void compute_attn (const Params ¶ms) {
12811247 const int m_block = blockIdx.x ;
12821248 // The block index for the batch.
12831249 const int bidb = blockIdx.y ;
12841250 // The block index for the head.
12851251 const int bidh = blockIdx.z ;
12861252
1287- // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
1288- // them to have the same number of threads or have to traverse the attention matrix
1289- // in the same order.
1290- // In the Philox RNG, we use the offset to store the batch, head, and the lane id
1291- // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
1292- // the attention matrix. This way, as long as we have the batch, head, and the location of
1293- // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
1294-
1295- FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
1253+ FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
12961254}
12971255
12981256// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments