@@ -30,9 +30,9 @@ namespace FLASH_NAMESPACE {
3030template <typename Kernel_traits, __VA_ARGS__> \
3131__global__ void kernelName (KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
3232
33- DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
33+ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
3434 #if defined(ARCH_SUPPORTS_FLASH)
35- FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
35+ FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
3636 #else
3737 FLASH_UNSUPPORTED_ARCH
3838 #endif
@@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L
5151 FLASH_NAMESPACE::combine_attn_seqk_parallel<Kernel_traits, kBlockM , Log_max_splits, Is_even_K>(params);
5252}
5353
54- template <typename Kernel_traits, bool Is_dropout, bool Is_causal>
54+ template <typename Kernel_traits, bool Is_causal>
5555void run_flash_fwd (Flash_fwd_params ¶ms, cudaStream_t stream) {
5656 const size_t smem_size = Kernel_traits::kSmemSize ;
5757 // printf("smem_size = %d (includes mask memory)\n", int(smem_size));
@@ -72,9 +72,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
7272 // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
7373 // If return_softmax, set IsEvenMNConst to false to reduce number of templates
7474 // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
75- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
75+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && !Is_softcap>;
7676 // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
77- // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout ));
77+ // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst));
7878 // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
7979 if (smem_size >= 48 * 1024 ) {
8080 C10_CUDA_CHECK (cudaFuncSetAttribute (
@@ -162,107 +162,78 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
162162template <typename T, bool Is_causal>
163163void run_mha_fwd_hdim32 (Flash_fwd_params ¶ms, cudaStream_t stream) {
164164 constexpr static int Headdim = 32 ;
165- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
166- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
167- });
165+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
168166}
169167
170168template <typename T, bool Is_causal>
171169void run_mha_fwd_hdim64 (Flash_fwd_params ¶ms, cudaStream_t stream) {
172170 constexpr static int Headdim = 64 ;
173- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
174- if constexpr (!Is_dropout) {
175- // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
176- // Using block size (64 x 128) is 27% slower for seqlen=2k
177- // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling
178- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
179- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
180- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
181- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
182- } else {
183- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
184- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
185- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
186- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
187- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
188- }
189- });
171+ // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
172+ // Using block size (64 x 128) is 27% slower for seqlen=2k
173+ // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling
174+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
175+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
176+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
177+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
190178}
191179
192180template <typename T, bool Is_causal>
193181void run_mha_fwd_hdim96 (Flash_fwd_params ¶ms, cudaStream_t stream) {
194182 constexpr static int Headdim = 96 ;
195183 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
196184 bool is_sm8x = cc_major == 8 && cc_minor > 0 ;
197- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
198- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
199- if (is_sm8x) {
200- if constexpr (!Is_causal) {
201- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
202- } else {
203- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
204- }
185+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
186+ if (is_sm8x) {
187+ if constexpr (!Is_causal) {
188+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
205189 } else {
206- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128 , 64 , 4 , false , false , T>, Is_dropout , Is_causal>(params, stream);
190+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
207191 }
208- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
209- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
210- // These two are always slower
211- // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
212- // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
213- });
192+ } else {
193+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
194+ }
195+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
196+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
197+ // These two are always slower
198+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
199+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
214200}
215201
216202template <typename T, bool Is_causal>
217203void run_mha_fwd_hdim128 (Flash_fwd_params ¶ms, cudaStream_t stream) {
218204 constexpr static int Headdim = 128 ;
219205 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
220206 bool is_sm8x = cc_major == 8 && cc_minor > 0 ;
221- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
222- if constexpr (!Is_dropout) {
223- // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
224- // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
225- if (is_sm8x) {
226- if constexpr (!Is_causal) {
227- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
228- } else {
229- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
230- }
231- } else {
232- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
233- }
234- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
235- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
236- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
237- // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
238- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
239- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
240- // 1st ones are good for H100, A100
241- // 2nd one is good for A6000 bc we get slightly better occupancy
207+ // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
208+ // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
209+ if (is_sm8x) {
210+ if constexpr (!Is_causal) {
211+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_causal>(params, stream);
242212 } else {
243- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
244- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
245- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
246- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
213+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_causal>(params, stream);
247214 }
248- });
215+ } else {
216+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
217+ }
218+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
219+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
220+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_causal>(params, stream);
221+ // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
222+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_causal>(params, stream);
223+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
224+ // 1st ones are good for H100, A100
225+ // 2nd one is good for A6000 bc we get slightly better occupancy
249226}
250227
251228template <typename T, bool Is_causal>
252229void run_mha_fwd_hdim192 (Flash_fwd_params ¶ms, cudaStream_t stream) {
253230 constexpr static int Headdim = 192 ;
254- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
255- if constexpr (!Is_dropout) {
256- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32 , 32 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
257- } else {
258- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32 , 32 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
259- }
260- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
261- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
262- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
263- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
264- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
265- });
231+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32 , 32 , 4 , false , false , T>, Is_causal>(params, stream);
232+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
233+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
234+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
235+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
236+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
266237}
267238
268239template <typename T, bool Is_causal>
@@ -279,15 +250,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
279250 C10_CUDA_CHECK (status_);
280251 }
281252 // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
282- DROPOUT_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
283- // For A100, we want to run with 64 x 64 (112KB smem).
284- // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM.
285- if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64 ) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64 )) {
286- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 8 , false , false , T>, Is_dropout, Is_causal>(params, stream);
287- } else {
288- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
289- }
290- });
253+
254+ // For A100, we want to run with 64 x 64 (112KB smem).
255+ // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM.
256+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64 ) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64 )) {
257+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 8 , false , false , T>, Is_causal>(params, stream);
258+ } else {
259+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_causal>(params, stream);
260+ }
291261}
292262
293263} // namespace FLASH_NAMESPACE
0 commit comments