Skip to content

Commit f567c1f

Browse files
authored
Merge pull request #118 from SmallDoges/Fix-varlen-bug
Integrate Flash Dynamic Mask Attention (FDMA) Into Transformers-Style Attention Flow
2 parents ee43866 + a7ee9bc commit f567c1f

File tree

7 files changed

+474
-930
lines changed

7 files changed

+474
-930
lines changed

demo_varlen_fix.py

Lines changed: 0 additions & 221 deletions
This file was deleted.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from .modeling_flash_dynamic_mask_attention_utils import _flash_dynamic_mask_attention_forward
6+
from transformers.utils import logging
7+
8+
9+
logger = logging.get_logger(__name__)
10+
11+
12+
13+
def flash_dynamic_mask_attention_forward(
14+
module: torch.nn.Module,
15+
query: torch.Tensor,
16+
key: torch.Tensor,
17+
value: torch.Tensor,
18+
attention_mask: Optional[torch.Tensor],
19+
attention_bias: Optional[torch.Tensor],
20+
scaling: Optional[float] = None,
21+
softcap: Optional[float] = None,
22+
**kwargs,
23+
) -> tuple[torch.Tensor, None]:
24+
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
25+
logger.warning_once(
26+
"`flash_dynamic_mask_attention` does not support `output_attentions=True` or `head_mask`."
27+
" Please set your attention to `eager` if you want any of these features."
28+
)
29+
30+
# This is before the transpose
31+
seq_len = query.shape[2]
32+
33+
if any(dim == 0 for dim in query.shape):
34+
raise ValueError(
35+
"Tensor query has shape with a zero dimension.\n"
36+
"FlashDynamicMaskAttention does not support inputs with dim=0.\n"
37+
"Please check your input shapes or use SDPA instead."
38+
)
39+
# FDMA uses non-transposed inputs
40+
query = query.transpose(1, 2)
41+
key = key.transpose(1, 2)
42+
value = value.transpose(1, 2)
43+
44+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
45+
# therefore the input hidden states gets silently casted in float32. Hence, we need
46+
# cast them back in the correct dtype just to be sure everything works as expected.
47+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
48+
# in fp32. (usually our RMSNorm modules handle it correctly)
49+
target_dtype = None
50+
if query.dtype == torch.float32:
51+
if torch.is_autocast_enabled():
52+
target_dtype = torch.get_autocast_gpu_dtype()
53+
# Handle the case where the model is quantized
54+
elif hasattr(module.config, "_pre_quantization_dtype"):
55+
target_dtype = module.config._pre_quantization_dtype
56+
else:
57+
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
58+
59+
# FDMA always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
60+
kwargs.pop("is_causal", None)
61+
62+
attn_output = _flash_dynamic_mask_attention_forward(
63+
query,
64+
key,
65+
value,
66+
attention_mask,
67+
attention_bias,
68+
query_length=seq_len,
69+
is_causal=module.is_causal,
70+
softmax_scale=scaling,
71+
softcap=softcap,
72+
target_dtype=target_dtype,
73+
attn_implementation=module.config._attn_implementation,
74+
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
75+
**kwargs,
76+
)
77+
78+
return attn_output, None
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Jingze Shi and the HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Import utilities: Utilities related to imports and our lazy inits.
16+
"""
17+
18+
import importlib.metadata
19+
import importlib.util
20+
from functools import lru_cache
21+
from typing import Union
22+
23+
24+
from transformers import is_torch_available
25+
from transformers.utils import logging
26+
27+
28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
31+
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
32+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[tuple[bool, str], bool]:
33+
# Check if the package spec exists and grab its version to avoid importing a local directory
34+
package_exists = importlib.util.find_spec(pkg_name) is not None
35+
package_version = "N/A"
36+
if package_exists:
37+
try:
38+
# TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()`
39+
# should be used here to map from package name to distribution names
40+
# e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu.
41+
# `importlib.metadata.packages_distributions()` is not available in Python 3.9.
42+
43+
# Primary method to get the package version
44+
package_version = importlib.metadata.version(pkg_name)
45+
except importlib.metadata.PackageNotFoundError:
46+
# Fallback method: Only for "torch" and versions containing "dev"
47+
if pkg_name == "torch":
48+
try:
49+
package = importlib.import_module(pkg_name)
50+
temp_version = getattr(package, "__version__", "N/A")
51+
# Check if the version contains "dev"
52+
if "dev" in temp_version:
53+
package_version = temp_version
54+
package_exists = True
55+
else:
56+
package_exists = False
57+
except ImportError:
58+
# If the package can't be imported, it's not available
59+
package_exists = False
60+
elif pkg_name == "quark":
61+
# TODO: remove once `importlib.metadata.packages_distributions()` is supported.
62+
try:
63+
package_version = importlib.metadata.version("amd-quark")
64+
except Exception:
65+
package_exists = False
66+
elif pkg_name == "triton":
67+
try:
68+
package_version = importlib.metadata.version("pytorch-triton")
69+
except Exception:
70+
package_exists = False
71+
else:
72+
# For packages other than "torch", don't attempt the fallback and set as not available
73+
package_exists = False
74+
logger.debug(f"Detected {pkg_name} version: {package_version}")
75+
if return_version:
76+
return package_exists, package_version
77+
else:
78+
return package_exists
79+
80+
81+
82+
@lru_cache
83+
def is_flash_dmattn_available():
84+
if not is_torch_available():
85+
return False
86+
87+
if not _is_package_available("flash_dmattn"):
88+
return False
89+
90+
import torch
91+
92+
if not torch.cuda.is_available():
93+
return False
94+
95+
return True

0 commit comments

Comments
 (0)