Skip to content

Commit a7ee9bc

Browse files
committed
Adds flash dynamic mask attention forward pass
Implements forward function that integrates with transformers library for flash attention with dynamic masking capabilities. Handles input validation, tensor transposition, dtype casting for PEFT compatibility, and delegates to core attention implementation with proper parameter mapping. Provides warning for unsupported features like output_attentions and head_mask, directing users to eager attention mode when needed.
1 parent 7eb71dc commit a7ee9bc

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from .modeling_flash_dynamic_mask_attention_utils import _flash_dynamic_mask_attention_forward
6+
from transformers.utils import logging
7+
8+
9+
logger = logging.get_logger(__name__)
10+
11+
12+
13+
def flash_dynamic_mask_attention_forward(
14+
module: torch.nn.Module,
15+
query: torch.Tensor,
16+
key: torch.Tensor,
17+
value: torch.Tensor,
18+
attention_mask: Optional[torch.Tensor],
19+
attention_bias: Optional[torch.Tensor],
20+
scaling: Optional[float] = None,
21+
softcap: Optional[float] = None,
22+
**kwargs,
23+
) -> tuple[torch.Tensor, None]:
24+
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
25+
logger.warning_once(
26+
"`flash_dynamic_mask_attention` does not support `output_attentions=True` or `head_mask`."
27+
" Please set your attention to `eager` if you want any of these features."
28+
)
29+
30+
# This is before the transpose
31+
seq_len = query.shape[2]
32+
33+
if any(dim == 0 for dim in query.shape):
34+
raise ValueError(
35+
"Tensor query has shape with a zero dimension.\n"
36+
"FlashDynamicMaskAttention does not support inputs with dim=0.\n"
37+
"Please check your input shapes or use SDPA instead."
38+
)
39+
# FDMA uses non-transposed inputs
40+
query = query.transpose(1, 2)
41+
key = key.transpose(1, 2)
42+
value = value.transpose(1, 2)
43+
44+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
45+
# therefore the input hidden states gets silently casted in float32. Hence, we need
46+
# cast them back in the correct dtype just to be sure everything works as expected.
47+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
48+
# in fp32. (usually our RMSNorm modules handle it correctly)
49+
target_dtype = None
50+
if query.dtype == torch.float32:
51+
if torch.is_autocast_enabled():
52+
target_dtype = torch.get_autocast_gpu_dtype()
53+
# Handle the case where the model is quantized
54+
elif hasattr(module.config, "_pre_quantization_dtype"):
55+
target_dtype = module.config._pre_quantization_dtype
56+
else:
57+
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
58+
59+
# FDMA always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
60+
kwargs.pop("is_causal", None)
61+
62+
attn_output = _flash_dynamic_mask_attention_forward(
63+
query,
64+
key,
65+
value,
66+
attention_mask,
67+
attention_bias,
68+
query_length=seq_len,
69+
is_causal=module.is_causal,
70+
softmax_scale=scaling,
71+
softcap=softcap,
72+
target_dtype=target_dtype,
73+
attn_implementation=module.config._attn_implementation,
74+
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
75+
**kwargs,
76+
)
77+
78+
return attn_output, None

0 commit comments

Comments
 (0)