Skip to content

Commit 858cabd

Browse files
committed
Adds FlashAttention backward pass launch template
Introduces comprehensive kernel launch infrastructure for FlashAttention backward computation with support for multiple head dimensions (32, 64, 96, 128, 192, 256). Implements architecture-aware kernel selection that automatically chooses optimal configurations based on available shared memory and device capabilities. Provides specialized kernel variants for different execution modes including sequence-parallel processing, causal masking, and deterministic computation. Centralizes kernel parameter handling with macro-based definitions to reduce code duplication and improve maintainability.
1 parent 5fdb0b9 commit 858cabd

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
/******************************************************************************
2+
* Copyright (c) 2025, Jingze Shi and Tri Dao.
3+
******************************************************************************/
4+
5+
#pragma once
6+
7+
#include "namespace_config.h"
8+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
9+
10+
#include "static_switch.h"
11+
#include "hardware_info.h"
12+
#include "flash.h"
13+
#include "flash_bwd_preprocess_kernel.h"
14+
#include "flash_bwd_kernel.h"
15+
16+
namespace FLASH_NAMESPACE {
17+
18+
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
19+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
20+
#define ARCH_SUPPORTS_FLASH
21+
#define KERNEL_PARAM_MODIFIER __grid_constant__
22+
#else
23+
#define KERNEL_PARAM_MODIFIER
24+
#endif
25+
26+
// Define a macro for unsupported architecture handling to centralize the error message
27+
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
28+
29+
// Use a macro to clean up kernel definitions
30+
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
31+
template<typename Kernel_traits, __VA_ARGS__> \
32+
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
33+
34+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
35+
#if defined(ARCH_SUPPORTS_FLASH)
36+
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
37+
#else
38+
FLASH_UNSUPPORTED_ARCH
39+
#endif
40+
}
41+
42+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
43+
#if defined(ARCH_SUPPORTS_FLASH)
44+
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
45+
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
46+
#else
47+
FLASH_UNSUPPORTED_ARCH
48+
#endif
49+
}
50+
51+
52+
template<bool Clear_dQaccum=true, typename Kernel_traits>
53+
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
54+
FLASH_NAMESPACE::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
55+
}
56+
57+
template<typename Kernel_traits>
58+
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
59+
FLASH_NAMESPACE::clear_dKVaccum<Kernel_traits>(params);
60+
}
61+
62+
template<typename Kernel_traits>
63+
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
64+
FLASH_NAMESPACE::convert_dQ<Kernel_traits>(params, nsplits);
65+
}
66+
67+
template<typename Kernel_traits>
68+
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
69+
FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
70+
}
71+
72+
template<typename Kernel_traits, bool Is_causal>
73+
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
74+
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
75+
dim3 grid_m(num_m_block, params.b, params.h);
76+
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
77+
int gridDimx = num_n_block;
78+
if (params.deterministic) {
79+
int num_sm = get_num_sm(get_current_device());
80+
gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h);
81+
}
82+
dim3 grid_n(gridDimx, params.b, params.h);
83+
84+
if (!params.deterministic) {
85+
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
86+
} else {
87+
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
88+
}
89+
C10_CUDA_KERNEL_LAUNCH_CHECK();
90+
91+
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
92+
// a multiple of kBlockN, we'll need to apply mask in the loop.
93+
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
94+
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
95+
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
96+
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
97+
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
98+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
99+
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
100+
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
101+
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
102+
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103+
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104+
// If Is_local, set Is_causal to false
105+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;
106+
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107+
if (smem_size_dq_dk_dv >= 48 * 1024) {
108+
C10_CUDA_CHECK(cudaFuncSetAttribute(
109+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
110+
}
111+
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
112+
C10_CUDA_KERNEL_LAUNCH_CHECK();
113+
});
114+
});
115+
});
116+
});
117+
});
118+
119+
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
120+
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
121+
C10_CUDA_CHECK(cudaFuncSetAttribute(
122+
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
123+
}
124+
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
125+
C10_CUDA_KERNEL_LAUNCH_CHECK();
126+
}
127+
128+
template<typename Kernel_traits, bool Is_causal>
129+
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
130+
#ifndef FLASHATTENTION_DISABLE_BACKWARD
131+
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream);
132+
#endif
133+
}
134+
135+
template<typename T, bool Is_causal>
136+
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
137+
constexpr static int Headdim = 32;
138+
int device;
139+
cudaGetDevice(&device);
140+
int max_smem_per_block;
141+
cudaError status_ = cudaDeviceGetAttribute(
142+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
143+
if (status_ != cudaSuccess) {
144+
C10_CUDA_CHECK(status_);
145+
}
146+
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
147+
// We can afford more registers to keep V in registers
148+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
149+
} else { // 96 KB
150+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
151+
}
152+
}
153+
154+
template<typename T, bool Is_causal>
155+
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
156+
constexpr static int Headdim = 64;
157+
int device;
158+
cudaGetDevice(&device);
159+
int max_smem_per_block;
160+
cudaError status_ = cudaDeviceGetAttribute(
161+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
162+
if (status_ != cudaSuccess) {
163+
C10_CUDA_CHECK(status_);
164+
}
165+
// printf("max_smem_per_block = %d\n", max_smem_per_block);
166+
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
167+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
168+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
169+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
170+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
171+
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
172+
if (max_smem_per_block >= 144 * 1024) {
173+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
174+
// This has a lot of register spilling
175+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
176+
} else {
177+
// if (params.h == params.h_k) {
178+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
179+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
180+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
181+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
182+
// } else {
183+
// }
184+
}
185+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
186+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
187+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
188+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
189+
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
190+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
191+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
192+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
193+
194+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
195+
}
196+
197+
template<typename T, bool Is_causal>
198+
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
199+
constexpr static int Headdim = 96;
200+
int device;
201+
cudaGetDevice(&device);
202+
int max_smem_per_block;
203+
cudaError status_ = cudaDeviceGetAttribute(
204+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
205+
if (status_ != cudaSuccess) {
206+
C10_CUDA_CHECK(status_);
207+
}
208+
// printf("max_smem_per_block = %d\n", max_smem_per_block);
209+
if (max_smem_per_block >= 116 * 1024) {
210+
// 92KB
211+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
212+
} else {
213+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
214+
}
215+
}
216+
217+
template<typename T, bool Is_causal>
218+
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
219+
constexpr static int Headdim = 128;
220+
int device;
221+
cudaGetDevice(&device);
222+
int max_smem_per_block;
223+
cudaError status_ = cudaDeviceGetAttribute(
224+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
225+
if (status_ != cudaSuccess) {
226+
C10_CUDA_CHECK(status_);
227+
}
228+
// printf("max_smem_per_block = %d\n", max_smem_per_block);
229+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
230+
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
231+
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
232+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
233+
if (max_smem_per_block >= 144 * 1024) {
234+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
235+
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>>(params, stream);
236+
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>>(params, stream);
237+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
238+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
239+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>>(params, stream);
240+
} else {
241+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
242+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
243+
}
244+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
245+
246+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
247+
}
248+
249+
template<typename T, bool Is_causal>
250+
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
251+
constexpr static int Headdim = 192;
252+
int device;
253+
cudaGetDevice(&device);
254+
int max_smem_per_block;
255+
cudaError status_ = cudaDeviceGetAttribute(
256+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
257+
if (status_ != cudaSuccess) {
258+
C10_CUDA_CHECK(status_);
259+
}
260+
if (max_smem_per_block >= 136 * 1024) {
261+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
262+
} else {
263+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
264+
}
265+
}
266+
267+
template<typename T, bool Is_causal>
268+
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
269+
constexpr static int Headdim = 256;
270+
int device;
271+
cudaGetDevice(&device);
272+
int max_smem_per_block;
273+
cudaError status_ = cudaDeviceGetAttribute(
274+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
275+
if (status_ != cudaSuccess) {
276+
C10_CUDA_CHECK(status_);
277+
}
278+
if (max_smem_per_block >= 176 * 1024) { // H100
279+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
280+
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
281+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
282+
} else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering.
283+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
284+
}
285+
}
286+
287+
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)