Skip to content

Commit bdfeffc

Browse files
authored
Merge pull request #109 from SmallDoges/Support-backward
Add backward pass support for FlashDynamicMaskAttention
2 parents 820fecf + 974451e commit bdfeffc

39 files changed

+2877
-509
lines changed

benchmarks/backward_equivalence.py

Lines changed: 1246 additions & 0 deletions
Large diffs are not rendered by default.

benchmarks/forward_equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def dynamic_mask_attention_cuda(
249249
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
250250
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
251251

252-
# Call the new flash_dmattn_func interface
252+
# Call the flash_dmattn_func interface
253253
attn_outputs = flash_dmattn_func(
254254
query_states, # [batch, query_len, num_heads, head_dim]
255255
key_states, # [batch, key_len, num_kv_heads, head_dim]

csrc/flash_api.cpp

Lines changed: 434 additions & 2 deletions
Large diffs are not rendered by default.

csrc/src/flash.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
153153
void *__restrict__ dq_accum_ptr;
154154
void *__restrict__ dk_accum_ptr;
155155
void *__restrict__ dv_accum_ptr;
156-
void *__restrict__ dbias_accum_ptr;
157156

158157
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
159158
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__

csrc/src/flash_bwd_kernel.h

Lines changed: 310 additions & 181 deletions
Large diffs are not rendered by default.

csrc/src/flash_bwd_launch_template.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
137137
if (status_ != cudaSuccess) {
138138
C10_CUDA_CHECK(status_);
139139
}
140-
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
140+
// 2 * (...) - Double buffering factor
141+
// (3 * kBlockM + 2 * kBlockN) * Headdim - Vector tiles in shared memory
142+
// - 3 * kBlockM * Headdim: Q tile, dQ tile, dOut tile
143+
// - 2 * kBlockN * Headdim: K tile, V tile
144+
// 4 * kBlockM * kBlockN - Matrix tiles in shared memory
145+
// - 2 * kBlockM * kBlockN: S tile, P tile
146+
// - 2 * kBlockM * kBlockN: Mask tile, Bias tile
147+
if (max_smem_per_block >= 2 * ((3 * 64 + 2 * 128) * Headdim + 4 * 64 * 128)) { // 94 KB
141148
// We can afford more registers to keep V in registers
142-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
149+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
143150
} else { // 96 KB
144-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
151+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
145152
}
146153
}
147154

csrc/src/flash_fwd_kernel.h

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ using namespace cute;
2525
////////////////////////////////////////////////////////////////////////////////////////////////////
2626

2727
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
28-
__forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
28+
__forceinline__ __device__ auto get_lse_tile(
29+
const Params &params,
30+
const int bidb,
31+
const int bidh,
32+
const int m_block,
33+
const BlockInfo</*Varlen=*/!Is_even_MN> &binfo
34+
) {
2935
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
3036
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
3137
// Otherwise, it's written as (h, b, seqlen_q).
@@ -244,8 +250,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
244250
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
245251
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
246252
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
247-
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
248-
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
253+
// Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
254+
// Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
249255
Tensor tSgS = thr_mma.partition_C(gP);
250256
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)
251257

@@ -268,7 +274,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
268274
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
269275
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
270276

277+
271278
// PREDICATES
279+
272280
// // Allocate predicate tensors for m and n
273281
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
274282
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
@@ -294,9 +302,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
294302
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k)
295303
Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n)
296304
Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n)
305+
297306
// Allocate predicate tensors for k
298307
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
299308
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
309+
300310
// Set predicates for k bounds
301311
if (!Is_even_K) {
302312
#pragma unroll
@@ -309,7 +319,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
309319
}
310320
}
311321

322+
312323
// Prologue
324+
313325
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
314326
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
315327
gmem_tiled_copy_QKV,
@@ -393,6 +405,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
393405
__syncthreads();
394406

395407
// Copy Mask and Bias from smem to registers
408+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
409+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
396410
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
397411
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
398412
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -419,9 +433,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
419433
// Use sparse general matrix multiplication
420434
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
421435
acc_s,
422-
tSrQ,
423-
tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
424-
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
436+
tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
437+
tiled_mma,
438+
smem_tiled_copy_Q, smem_tiled_copy_K,
425439
smem_thr_copy_Q, smem_thr_copy_K
426440
);
427441
// if (cute::thread0()) { print(acc_s); }
@@ -483,7 +497,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
483497
FLASH_NAMESPACE::sparse_gemm_rs(
484498
acc_o,
485499
tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
486-
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
500+
tiled_mma,
501+
smem_tiled_copy_V, smem_thr_copy_V
487502
);
488503
// if (cute::thread0()) { print(scores); }
489504

@@ -502,6 +517,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
502517
__syncthreads();
503518

504519
// Copy Mask and Bias from smem to registers
520+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
521+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
505522
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
506523
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
507524
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -514,11 +531,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
514531
);
515532
cute::cp_async_fence();
516533

534+
// Use sparse general matrix multiplication
517535
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
518536
acc_s,
519-
tSrQ,
520-
tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
521-
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
537+
tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
538+
tiled_mma,
539+
smem_tiled_copy_Q, smem_tiled_copy_K,
522540
smem_thr_copy_Q, smem_thr_copy_K
523541
);
524542
if constexpr (Is_softcap){
@@ -574,10 +592,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
574592
FLASH_NAMESPACE::sparse_gemm_rs(
575593
acc_o,
576594
tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
577-
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
595+
tiled_mma,
596+
smem_tiled_copy_V, smem_thr_copy_V
578597
);
579598
}
580599

600+
581601
// Epilogue
582602

583603
Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax);
@@ -857,8 +877,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
857877
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
858878
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
859879
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
860-
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
861-
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
880+
// Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
881+
// Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
862882
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)
863883

864884
// Copy Atom retiling
@@ -878,7 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
878898
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
879899
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
880900

901+
881902
// PREDICATES
903+
882904
// Construct identity layout for sQ and sK
883905
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M, BLK_K) -> (blk_m, blk_k)
884906
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N, BLK_K) -> (blk_n, blk_k)
@@ -904,7 +926,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
904926
}
905927
}
906928

929+
907930
// Prologue
931+
908932
// Read Q from gmem to smem
909933
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
910934
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
@@ -969,6 +993,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
969993
__syncthreads();
970994

971995
// Copy Mask and Bias from smem to registers
996+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
997+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
972998
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
973999
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
9741000
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -1004,9 +1030,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10041030
// Use sparse general matrix multiplication
10051031
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
10061032
acc_s,
1007-
tSrQ,
1008-
tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1009-
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
1033+
tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1034+
tiled_mma,
1035+
smem_tiled_copy_Q, smem_tiled_copy_K,
10101036
smem_thr_copy_Q, smem_thr_copy_K
10111037
);
10121038
// if (cute::thread0()) { print(acc_s); }
@@ -1080,7 +1106,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10801106
FLASH_NAMESPACE::sparse_gemm_rs(
10811107
acc_o,
10821108
tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
1083-
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
1109+
tiled_mma,
1110+
smem_tiled_copy_V, smem_thr_copy_V
10841111
);
10851112

10861113
// This check is at the end of the loop since we always have at least 1 iteration
@@ -1098,6 +1125,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10981125
__syncthreads();
10991126

11001127
// Copy Mask and Bias from smem to registers
1128+
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
1129+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
11011130
Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask);
11021131
cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view);
11031132
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
@@ -1120,11 +1149,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11201149
);
11211150
cute::cp_async_fence();
11221151

1152+
// Use sparse general matrix multiplication
11231153
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
11241154
acc_s,
1125-
tSrQ,
1126-
tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1127-
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
1155+
tSrQ, tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication
1156+
tiled_mma,
1157+
smem_tiled_copy_Q, smem_tiled_copy_K,
11281158
smem_thr_copy_Q, smem_thr_copy_K
11291159
);
11301160
if constexpr (Is_softcap){
@@ -1190,10 +1220,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11901220
FLASH_NAMESPACE::sparse_gemm_rs(
11911221
acc_o,
11921222
tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication
1193-
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
1223+
tiled_mma,
1224+
smem_tiled_copy_V, smem_thr_copy_V
11941225
);
11951226
}
11961227

1228+
11971229
// Epilogue
11981230

11991231
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);

csrc/src/generate_kernels.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
1616

1717
def get_fwd_template() -> str:
18-
return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"
18+
return NAMESPACE_INCLUDE + """
19+
#include "flash_fwd_launch_template.h"
1920
2021
namespace FLASH_NAMESPACE {{
2122
@@ -24,19 +25,23 @@ def get_fwd_template() -> str:
2425
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
2526
}}
2627
27-
}} // namespace FLASH_NAMESPACE"""
28+
}} // namespace FLASH_NAMESPACE
29+
""".strip()
2830

2931
def get_fwd_split_template() -> str:
30-
return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"
32+
return NAMESPACE_INCLUDE + """
33+
#include "flash_fwd_launch_template.h"
3134
3235
namespace FLASH_NAMESPACE {{
3336
3437
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
3538
36-
}} // namespace FLASH_NAMESPACE"""
39+
}} // namespace FLASH_NAMESPACE
40+
""".strip()
3741

3842
def get_bwd_template() -> str:
39-
return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h"
43+
return NAMESPACE_INCLUDE + """
44+
#include "flash_bwd_launch_template.h"
4045
4146
namespace FLASH_NAMESPACE {{
4247
@@ -45,7 +50,8 @@ def get_bwd_template() -> str:
4550
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
4651
}}
4752
48-
}} // namespace FLASH_NAMESPACE"""
53+
}} // namespace FLASH_NAMESPACE
54+
""".strip()
4955

5056
@dataclass
5157
class Kernel:
@@ -59,7 +65,7 @@ class Kernel:
5965
def template(self) -> str:
6066
template_funcs = {
6167
"fwd": get_fwd_template,
62-
# "bwd": get_bwd_template,
68+
"bwd": get_bwd_template,
6369
"fwd_split": get_fwd_split_template
6470
}
6571
template_func = template_funcs[self.direction]
@@ -74,15 +80,16 @@ def filename(self) -> str:
7480
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
7581

7682
def get_all_kernels() -> Generator[Kernel, None, None]:
77-
# for direction in ["fwd", "fwd_split", "bwd"]:
78-
for direction in ["fwd", "fwd_split"]:
83+
for direction in ["fwd", "fwd_split", "bwd"]:
7984
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
8085
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
8186

8287
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
83-
prelude = """// Copyright (c) 2025, Jingze Shi and Tri Dao.
88+
prelude = """
89+
// Copyright (c) 2025, Jingze Shi and Tri Dao.
8490
// Splitting the different head dimensions to different files to speed up compilation.
85-
// This file is auto-generated. See "generate_kernels.py"\n"""
91+
// This file is auto-generated. See "generate_kernels.py"\n
92+
""".strip()
8693
content = prelude + kernel.template
8794
(autogen_dir / kernel.filename).write_text(content)
8895

csrc/src/hardware_info.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/******************************************************************************
2-
* Copyright (c) 2024, Tri Dao.
2+
* Copyright (c) 2025, Jingze Shi and Tri Dao.
33
******************************************************************************/
44

55
#pragma once
@@ -10,14 +10,17 @@
1010
#include "cuda_runtime.h"
1111
#endif
1212

13-
#define CHECK_CUDA(call) \
14-
do { \
15-
cudaError_t status_ = call; \
16-
if (status_ != cudaSuccess) { \
17-
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
18-
cudaGetErrorString(status_)); \
19-
exit(1); \
20-
} \
13+
#define CHECK_CUDA(call) \
14+
do { \
15+
cudaError_t status_ = call; \
16+
if (status_ != cudaSuccess) { \
17+
fprintf( \
18+
stderr, \
19+
"CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
20+
cudaGetErrorString(status_) \
21+
); \
22+
exit(1); \
23+
} \
2124
} while (0)
2225

2326

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) 2025, Jingze Shi and Tri Dao.
2+
// Splitting the different head dimensions to different files to speed up compilation.
3+
// This file is auto-generated. See "generate_kernels.py"
4+
#include "namespace_config.h"
5+
#include "flash_bwd_launch_template.h"
6+
7+
namespace FLASH_NAMESPACE {
8+
9+
template<>
10+
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11+
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
12+
}
13+
14+
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)