@@ -116,7 +116,7 @@ void set_params_fprop(
116116
117117 // Set the different scale values.
118118 #ifdef FLASHATTENTION_DISABLE_SOFTCAP
119- TORCH_CHECK (softcap <= 0.0 , " This flash attention build does not support softcap." );
119+ TORCH_CHECK (softcap <= 0.0 , " This flash dynamic mask attention build does not support softcap." );
120120 #endif
121121 if (softcap > 0.0 ) {
122122 params.softcap = softmax_scale / softcap;
@@ -133,7 +133,7 @@ void set_params_fprop(
133133 params.is_seqlens_k_cumulative = true ;
134134
135135 #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
136- TORCH_CHECK (d == d_rounded, " This flash attention build does not support headdim not being a multiple of 32." );
136+ TORCH_CHECK (d == d_rounded, " This flash dynamic mask attention build does not support headdim not being a multiple of 32." );
137137 #endif
138138
139139 params.unpadded_lse = unpadded_lse;
@@ -231,7 +231,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
231231 FP16_SWITCH (!params.is_bf16 , [&] {
232232 HEADDIM_SWITCH (params.d , [&] {
233233 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
234- if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
234+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
235235 run_mha_fwd_<elem_type, kHeadDim , Is_causal>(params, stream);
236236 } else {
237237 run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim , Is_causal>(params, stream);
@@ -354,6 +354,8 @@ mha_fwd(
354354 TORCH_CHECK (q.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
355355 TORCH_CHECK (k.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
356356 TORCH_CHECK (v.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
357+ TORCH_CHECK (mask.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
358+ TORCH_CHECK (bias.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
357359
358360 const auto sizes = q.sizes ();
359361
@@ -375,17 +377,21 @@ mha_fwd(
375377 // H/t Daniel Haziza
376378 const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0 ;
377379 const int ngroups = num_heads / num_heads_k;
380+ at::Tensor mask_view = mask;
381+ at::Tensor bias_view = bias;
378382 if (seqlenq_ngroups_swapped) {
379383 q = q.reshape ({batch_size, num_heads_k, ngroups, head_size}).transpose (1 , 2 );
384+ mask_view = mask.expand ({batch_size, num_heads_k, ngroups, seqlen_k});
385+ bias_view = bias.expand ({batch_size, num_heads_k, ngroups, seqlen_k});
380386 seqlen_q = ngroups;
381387 num_heads = num_heads_k;
382388 }
383389
384390 CHECK_SHAPE (q, batch_size, seqlen_q, num_heads, head_size);
385391 CHECK_SHAPE (k, batch_size, seqlen_k, num_heads_k, head_size);
386392 CHECK_SHAPE (v, batch_size, seqlen_k, num_heads_k, head_size);
387- CHECK_SHAPE (mask , batch_size, num_heads_k, seqlen_q, seqlen_k);
388- CHECK_SHAPE (bias , batch_size, num_heads_k, seqlen_q, seqlen_k);
393+ CHECK_SHAPE (mask_view , batch_size, num_heads_k, seqlen_q, seqlen_k);
394+ CHECK_SHAPE (bias_view , batch_size, num_heads_k, seqlen_q, seqlen_k);
389395
390396 at::Tensor out;
391397 if (out_.has_value ()) {
@@ -425,7 +431,7 @@ mha_fwd(
425431 seqlen_q_rounded, seqlen_k_rounded,
426432 num_heads, num_heads_k,
427433 head_size, head_size_rounded,
428- q, k, v, mask, bias , out,
434+ q, k, v, mask_view, bias_view , out,
429435 /* cu_seqlens_q_d=*/ nullptr ,
430436 /* cu_seqlens_k_d=*/ nullptr ,
431437 /* seqused_k=*/ nullptr ,
0 commit comments