Skip to content

Commit 820fecf

Browse files
authored
Fix CUDA forward crash when seqlen_q == 1
2 parents c73d643 + 1af076b commit 820fecf

File tree

3 files changed

+82
-75
lines changed

3 files changed

+82
-75
lines changed

benchmarks/forward_equivalence.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -518,27 +518,28 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
518518
# If you encounter NAN issues when running multiple configurations, try running a single configuration
519519
test_configs = [
520520
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
521-
# (1, 1, 1, 64, 64, 32, True),
522-
# (1, 1, 1, 64, 64, 32, False),
523-
# (1, 1, 1, 128, 128, 32, True),
524-
# (1, 1, 1, 128, 128, 32, False),
525-
# (1, 1, 1, 256, 256, 32, True),
526-
# (1, 1, 1, 256, 256, 32, False),
527-
# (1, 1, 1, 512, 512, 32, True),
528-
# (1, 1, 1, 512, 512, 32, False),
529-
# (1, 1, 1, 1024, 1024, 32, True),
530-
# (1, 1, 1, 1024, 1024, 32, False),
531-
# (1, 1, 1, 2048, 2048, 32, True),
532-
# (1, 1, 1, 2048, 2048, 32, False),
521+
(1, 1, 1, 64, 64, 32, True),
522+
(1, 1, 1, 64, 64, 32, False),
523+
(1, 1, 1, 128, 128, 32, True),
524+
(1, 1, 1, 128, 128, 32, False),
525+
(1, 1, 1, 256, 256, 32, True),
526+
(1, 1, 1, 256, 256, 32, False),
527+
(1, 1, 1, 512, 512, 32, True),
528+
(1, 1, 1, 512, 512, 32, False),
529+
(1, 1, 1, 1024, 1024, 32, True),
530+
(1, 1, 1, 1024, 1024, 32, False),
531+
(1, 1, 1, 2048, 2048, 32, True),
532+
(1, 1, 1, 2048, 2048, 32, False),
533533
(1, 1, 1, 4096, 4096, 32, True),
534-
# (1, 1, 1, 4096, 4096, 32, False),
535-
# (1, 2, 1, 64, 64, 32, True),
536-
# (2, 1, 1, 128, 128, 32, True),
537-
# (2, 2, 1, 128, 128, 32, True),
538-
# (1, 2, 1, 64, 64, 128, True),
539-
# (1, 2, 1, 128, 128, 128, True),
540-
# (1, 2, 1, 256, 256, 128, True),
541-
# (1, 2, 1, 512, 512, 128, True),
534+
(1, 1, 1, 4096, 4096, 32, False),
535+
(1, 2, 1, 64, 64, 32, True),
536+
(2, 1, 1, 128, 128, 32, True),
537+
(2, 2, 1, 128, 128, 32, True),
538+
(1, 2, 1, 64, 64, 128, True),
539+
(1, 2, 1, 128, 128, 128, True),
540+
(1, 2, 1, 256, 256, 128, True),
541+
(1, 2, 1, 3, 512, 128, True),
542+
(1, 2, 1, 1, 512, 128, True),
542543
]
543544

544545
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

benchmarks/forward_performance.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
732732
(1, 2, 1, 4096, 4096, 128, 2048, True),
733733
(1, 2, 1, 8192, 8192, 128, 2048, True),
734734
(1, 2, 1, 16384, 16384, 128, 2048, True),
735-
# (1, 2, 1, 32768, 32768, 128, 2048, True),
735+
(1, 2, 1, 32768, 32768, 128, 2048, True),
736736

737737
# # Inference
738-
# (1, 2, 1, 2, 256, 128, 2048, True),
739-
# (1, 2, 1, 2, 512, 128, 2048, True),
740-
# (1, 2, 1, 2, 1024, 128, 2048, True),
741-
# (1, 2, 1, 2, 2048, 128, 2048, True),
742-
# (1, 2, 1, 2, 4096, 128, 2048, True),
743-
# (1, 2, 1, 2, 8192, 128, 2048, True),
744-
# (1, 2, 1, 2, 16384, 128, 2048, True),
745-
# (1, 2, 1, 2, 32768, 128, 2048, True),
746-
(1, 2, 1, 2, 65536, 128, 2048, True),
747-
# (1, 2, 1, 2, 131072, 128, 2048, True),
748-
# (1, 2, 1, 2, 262144, 128, 2048, True),
749-
# (1, 2, 1, 2, 524288, 128, 2048, True),
750-
751-
# # Vary batch size
752-
# (1, 2, 1, 4096, 4096, 32, 2048, True),
753-
# (2, 2, 1, 4096, 4096, 32, 2048, True),
754-
# (4, 2, 1, 4096, 4096, 32, 2048, True),
755-
# (8, 2, 1, 4096, 4096, 32, 2048, True),
756-
757-
# # Vary head count
758-
# (1, 1, 1, 4096, 4096, 32, 2048, True),
759-
# (1, 2, 1, 4096, 4096, 32, 2048, True),
760-
# (1, 4, 1, 4096, 4096, 32, 2048, True),
761-
# (1, 8, 2, 4096, 4096, 32, 2048, True),
762-
763-
# # Vary head dimension
764-
# (1, 2, 1, 4096, 4096, 32, 2048, True),
765-
# (1, 2, 1, 4096, 4096, 64, 2048, True),
766-
# (1, 2, 1, 4096, 4096, 96, 2048, True),
767-
# (1, 2, 1, 4096, 4096, 128, 2048, True),
768-
# (1, 2, 1, 4096, 4096, 192, 2048, True),
769-
# (1, 2, 1, 4096, 4096, 256, 2048, True),
770-
771-
# # Vary keep_window_size
772-
# (1, 2, 1, 32768, 32768, 128, 32, True),
773-
# (1, 2, 1, 32768, 32768, 128, 64, True),
774-
# (1, 2, 1, 32768, 32768, 128, 128, True),
775-
# (1, 2, 1, 32768, 32768, 128, 256, True),
776-
# (1, 2, 1, 32768, 32768, 128, 512, True),
777-
# (1, 2, 1, 32768, 32768, 128, 1024, True),
778-
# (1, 2, 1, 32768, 32768, 128, 2048, True),
779-
# (1, 2, 1, 32768, 32768, 128, 4096, True),
780-
# (1, 2, 1, 32768, 32768, 128, 8192, True),
781-
# (1, 2, 1, 32768, 32768, 128, 16384, True),
782-
# (1, 2, 1, 32768, 32768, 128, 32768, True),
783-
784-
# # Test non-causal
785-
# (1, 2, 1, 4096, 4096, 128, 2048, False),
738+
(1, 2, 1, 1, 256, 128, 2048, True),
739+
(1, 2, 1, 1, 512, 128, 2048, True),
740+
(1, 2, 1, 1, 1024, 128, 2048, True),
741+
(1, 2, 1, 1, 2048, 128, 2048, True),
742+
(1, 2, 1, 1, 4096, 128, 2048, True),
743+
(1, 2, 1, 1, 8192, 128, 2048, True),
744+
(1, 2, 1, 1, 16384, 128, 2048, True),
745+
(1, 2, 1, 1, 32768, 128, 2048, True),
746+
(1, 2, 1, 1, 65536, 128, 2048, True),
747+
(1, 2, 1, 1, 131072, 128, 2048, True),
748+
(1, 2, 1, 1, 262144, 128, 2048, True),
749+
(1, 2, 1, 1, 524288, 128, 2048, True),
750+
751+
# Vary batch size
752+
(1, 2, 1, 4096, 4096, 32, 2048, True),
753+
(2, 2, 1, 4096, 4096, 32, 2048, True),
754+
(4, 2, 1, 4096, 4096, 32, 2048, True),
755+
(8, 2, 1, 4096, 4096, 32, 2048, True),
756+
757+
# Vary head count
758+
(1, 1, 1, 4096, 4096, 32, 2048, True),
759+
(1, 2, 1, 4096, 4096, 32, 2048, True),
760+
(1, 4, 1, 4096, 4096, 32, 2048, True),
761+
(1, 8, 2, 4096, 4096, 32, 2048, True),
762+
763+
# Vary head dimension
764+
(1, 2, 1, 4096, 4096, 32, 2048, True),
765+
(1, 2, 1, 4096, 4096, 64, 2048, True),
766+
(1, 2, 1, 4096, 4096, 96, 2048, True),
767+
(1, 2, 1, 4096, 4096, 128, 2048, True),
768+
(1, 2, 1, 4096, 4096, 192, 2048, True),
769+
(1, 2, 1, 4096, 4096, 256, 2048, True),
770+
771+
# Vary keep_window_size
772+
(1, 2, 1, 32768, 32768, 128, 32, True),
773+
(1, 2, 1, 32768, 32768, 128, 64, True),
774+
(1, 2, 1, 32768, 32768, 128, 128, True),
775+
(1, 2, 1, 32768, 32768, 128, 256, True),
776+
(1, 2, 1, 32768, 32768, 128, 512, True),
777+
(1, 2, 1, 32768, 32768, 128, 1024, True),
778+
(1, 2, 1, 32768, 32768, 128, 2048, True),
779+
(1, 2, 1, 32768, 32768, 128, 4096, True),
780+
(1, 2, 1, 32768, 32768, 128, 8192, True),
781+
(1, 2, 1, 32768, 32768, 128, 16384, True),
782+
(1, 2, 1, 32768, 32768, 128, 32768, True),
783+
784+
# Test non-causal
785+
(1, 2, 1, 4096, 4096, 128, 2048, False),
786786
]
787787

788788
print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):")

csrc/flash_api.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void set_params_fprop(
116116

117117
// Set the different scale values.
118118
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
119-
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
119+
TORCH_CHECK(softcap <= 0.0, "This flash dynamic mask attention build does not support softcap.");
120120
#endif
121121
if (softcap > 0.0) {
122122
params.softcap = softmax_scale / softcap;
@@ -133,7 +133,7 @@ void set_params_fprop(
133133
params.is_seqlens_k_cumulative = true;
134134

135135
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
136-
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
136+
TORCH_CHECK(d == d_rounded, "This flash dynamic mask attention build does not support headdim not being a multiple of 32.");
137137
#endif
138138

139139
params.unpadded_lse = unpadded_lse;
@@ -231,7 +231,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
231231
FP16_SWITCH(!params.is_bf16, [&] {
232232
HEADDIM_SWITCH(params.d, [&] {
233233
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
234-
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
234+
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
235235
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
236236
} else {
237237
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
@@ -354,6 +354,8 @@ mha_fwd(
354354
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
355355
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
356356
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
357+
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
358+
TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
357359

358360
const auto sizes = q.sizes();
359361

@@ -375,17 +377,21 @@ mha_fwd(
375377
// H/t Daniel Haziza
376378
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
377379
const int ngroups = num_heads / num_heads_k;
380+
at::Tensor mask_view = mask;
381+
at::Tensor bias_view = bias;
378382
if (seqlenq_ngroups_swapped) {
379383
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
384+
mask_view = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k});
385+
bias_view = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k});
380386
seqlen_q = ngroups;
381387
num_heads = num_heads_k;
382388
}
383389

384390
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
385391
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
386392
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
387-
CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
388-
CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k);
393+
CHECK_SHAPE(mask_view, batch_size, num_heads_k, seqlen_q, seqlen_k);
394+
CHECK_SHAPE(bias_view, batch_size, num_heads_k, seqlen_q, seqlen_k);
389395

390396
at::Tensor out;
391397
if (out_.has_value()) {
@@ -425,7 +431,7 @@ mha_fwd(
425431
seqlen_q_rounded, seqlen_k_rounded,
426432
num_heads, num_heads_k,
427433
head_size, head_size_rounded,
428-
q, k, v, mask, bias, out,
434+
q, k, v, mask_view, bias_view, out,
429435
/*cu_seqlens_q_d=*/nullptr,
430436
/*cu_seqlens_k_d=*/nullptr,
431437
/*seqused_k=*/nullptr,

0 commit comments

Comments
 (0)