Skip to content

Commit 3bb3774

Browse files
committed
Removes alibi and local window support from backward kernels
Simplifies flash attention backward kernel templates by removing Has_alibi and Is_local template parameters. This reduces the number of kernel instantiations and compilation complexity while removing support for position-dependent attention biases and local windowed attention in the backward pass. Eliminates the static assertion that prevented using causal and local attention together, and removes the nested switch statements for alibi and local configurations.
1 parent 858cabd commit 3bb3774

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

csrc/src/flash_bwd_launch_template.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,17 @@ namespace FLASH_NAMESPACE {
3131
template<typename Kernel_traits, __VA_ARGS__> \
3232
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
3333

34-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
34+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
3535
#if defined(ARCH_SUPPORTS_FLASH)
36-
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
36+
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
3737
#else
3838
FLASH_UNSUPPORTED_ARCH
3939
#endif
4040
}
4141

42-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
42+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
4343
#if defined(ARCH_SUPPORTS_FLASH)
44-
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
45-
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
44+
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
4645
#else
4746
FLASH_UNSUPPORTED_ARCH
4847
#endif
@@ -96,22 +95,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
9695
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
9796
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
9897
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
99-
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
100-
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
101-
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
102-
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103-
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104-
// If Is_local, set Is_causal to false
105-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;
106-
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107-
if (smem_size_dq_dk_dv >= 48 * 1024) {
108-
C10_CUDA_CHECK(cudaFuncSetAttribute(
109-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
110-
}
111-
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
112-
C10_CUDA_KERNEL_LAUNCH_CHECK();
113-
});
114-
});
98+
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
99+
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
100+
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
101+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
102+
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103+
if (smem_size_dq_dk_dv >= 48 * 1024) {
104+
C10_CUDA_CHECK(cudaFuncSetAttribute(
105+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
106+
}
107+
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
108+
C10_CUDA_KERNEL_LAUNCH_CHECK();
115109
});
116110
});
117111
});

0 commit comments

Comments
 (0)