Skip to content

Commit 1af076b

Browse files
committed
Adds support for dynamic mask attention tensors
Updates error messages to reflect "flash dynamic mask attention" branding. Adds contiguity checks for mask and bias tensors to ensure proper memory layout. Handles tensor reshaping for grouped query attention scenarios by expanding mask and bias tensors to match the reshaped query dimensions, ensuring consistent tensor shapes throughout the attention computation.
1 parent 5d3bbc6 commit 1af076b

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

csrc/flash_api.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void set_params_fprop(
116116

117117
// Set the different scale values.
118118
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
119-
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
119+
TORCH_CHECK(softcap <= 0.0, "This flash dynamic mask attention build does not support softcap.");
120120
#endif
121121
if (softcap > 0.0) {
122122
params.softcap = softmax_scale / softcap;
@@ -133,7 +133,7 @@ void set_params_fprop(
133133
params.is_seqlens_k_cumulative = true;
134134

135135
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
136-
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
136+
TORCH_CHECK(d == d_rounded, "This flash dynamic mask attention build does not support headdim not being a multiple of 32.");
137137
#endif
138138

139139
params.unpadded_lse = unpadded_lse;
@@ -231,7 +231,7 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
231231
FP16_SWITCH(!params.is_bf16, [&] {
232232
HEADDIM_SWITCH(params.d, [&] {
233233
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
234-
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
234+
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
235235
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
236236
} else {
237237
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
@@ -354,6 +354,8 @@ mha_fwd(
354354
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
355355
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
356356
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
357+
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
358+
TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
357359

358360
const auto sizes = q.sizes();
359361

@@ -375,17 +377,21 @@ mha_fwd(
375377
// H/t Daniel Haziza
376378
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
377379
const int ngroups = num_heads / num_heads_k;
380+
at::Tensor mask_view = mask;
381+
at::Tensor bias_view = bias;
378382
if (seqlenq_ngroups_swapped) {
379383
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
384+
mask_view = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k});
385+
bias_view = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k});
380386
seqlen_q = ngroups;
381387
num_heads = num_heads_k;
382388
}
383389

384390
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
385391
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
386392
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
387-
CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
388-
CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k);
393+
CHECK_SHAPE(mask_view, batch_size, num_heads_k, seqlen_q, seqlen_k);
394+
CHECK_SHAPE(bias_view, batch_size, num_heads_k, seqlen_q, seqlen_k);
389395

390396
at::Tensor out;
391397
if (out_.has_value()) {
@@ -425,7 +431,7 @@ mha_fwd(
425431
seqlen_q_rounded, seqlen_k_rounded,
426432
num_heads, num_heads_k,
427433
head_size, head_size_rounded,
428-
q, k, v, mask, bias, out,
434+
q, k, v, mask_view, bias_view, out,
429435
/*cu_seqlens_q_d=*/nullptr,
430436
/*cu_seqlens_k_d=*/nullptr,
431437
/*seqused_k=*/nullptr,

0 commit comments

Comments
 (0)