@@ -31,18 +31,17 @@ namespace FLASH_NAMESPACE {
3131template <typename Kernel_traits, __VA_ARGS__> \
3232__global__ void kernelName (KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
3333
34- DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
34+ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
3535 #if defined(ARCH_SUPPORTS_FLASH)
36- FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
36+ FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
3737 #else
3838 FLASH_UNSUPPORTED_ARCH
3939 #endif
4040}
4141
42- DEFINE_FLASH_BACKWARD_KERNEL (flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
42+ DEFINE_FLASH_BACKWARD_KERNEL (flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
4343 #if defined(ARCH_SUPPORTS_FLASH)
44- static_assert (!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
45- FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
44+ FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
4645 #else
4746 FLASH_UNSUPPORTED_ARCH
4847 #endif
@@ -96,22 +95,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
9695 // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
9796 BOOL_SWITCH (is_even_MN, IsEvenMNConst, [&] {
9897 EVENK_SWITCH (is_even_K, IsEvenKConst, [&] {
99- LOCAL_SWITCH ((params.window_size_left >= 0 || params.window_size_right >= 0 ) && !params.is_causal , Is_local, [&] {
100- ALIBI_SWITCH (params.alibi_slopes_ptr != nullptr , Has_alibi, [&] {
101- SOFTCAP_SWITCH (params.softcap > 0.0 , Is_softcap, [&] {
102- // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103- // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104- // If Is_local, set Is_causal to false
105- auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128 , IsEvenKConst && !Has_alibi, Is_softcap>;
106- // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107- if (smem_size_dq_dk_dv >= 48 * 1024 ) {
108- C10_CUDA_CHECK (cudaFuncSetAttribute (
109- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
110- }
111- kernel<<<grid_n, Kernel_traits::kNThreads , smem_size_dq_dk_dv, stream>>>(params);
112- C10_CUDA_KERNEL_LAUNCH_CHECK ();
113- });
114- });
98+ SOFTCAP_SWITCH (params.softcap > 0.0 , Is_softcap, [&] {
99+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
100+ // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
101+ auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Is_softcap>;
102+ // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103+ if (smem_size_dq_dk_dv >= 48 * 1024 ) {
104+ C10_CUDA_CHECK (cudaFuncSetAttribute (
105+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
106+ }
107+ kernel<<<grid_n, Kernel_traits::kNThreads , smem_size_dq_dk_dv, stream>>>(params);
108+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
115109 });
116110 });
117111 });
0 commit comments