Skip to content

Commit 6e68b61

Browse files
committed
Adds import utilities for flash attention integration
Introduces package availability checking with version detection and fallback mechanisms for special cases like torch dev versions. Includes dedicated function to verify flash attention availability by checking torch, CUDA, and package dependencies.
1 parent ba77faa commit 6e68b61

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Jingze Shi and the HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Import utilities: Utilities related to imports and our lazy inits.
16+
"""
17+
18+
import importlib.metadata
19+
import importlib.util
20+
from functools import lru_cache
21+
from typing import Union
22+
23+
24+
from transformers import is_torch_available
25+
from transformers.utils import logging
26+
27+
28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
31+
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
32+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[tuple[bool, str], bool]:
33+
# Check if the package spec exists and grab its version to avoid importing a local directory
34+
package_exists = importlib.util.find_spec(pkg_name) is not None
35+
package_version = "N/A"
36+
if package_exists:
37+
try:
38+
# TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()`
39+
# should be used here to map from package name to distribution names
40+
# e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu.
41+
# `importlib.metadata.packages_distributions()` is not available in Python 3.9.
42+
43+
# Primary method to get the package version
44+
package_version = importlib.metadata.version(pkg_name)
45+
except importlib.metadata.PackageNotFoundError:
46+
# Fallback method: Only for "torch" and versions containing "dev"
47+
if pkg_name == "torch":
48+
try:
49+
package = importlib.import_module(pkg_name)
50+
temp_version = getattr(package, "__version__", "N/A")
51+
# Check if the version contains "dev"
52+
if "dev" in temp_version:
53+
package_version = temp_version
54+
package_exists = True
55+
else:
56+
package_exists = False
57+
except ImportError:
58+
# If the package can't be imported, it's not available
59+
package_exists = False
60+
elif pkg_name == "quark":
61+
# TODO: remove once `importlib.metadata.packages_distributions()` is supported.
62+
try:
63+
package_version = importlib.metadata.version("amd-quark")
64+
except Exception:
65+
package_exists = False
66+
elif pkg_name == "triton":
67+
try:
68+
package_version = importlib.metadata.version("pytorch-triton")
69+
except Exception:
70+
package_exists = False
71+
else:
72+
# For packages other than "torch", don't attempt the fallback and set as not available
73+
package_exists = False
74+
logger.debug(f"Detected {pkg_name} version: {package_version}")
75+
if return_version:
76+
return package_exists, package_version
77+
else:
78+
return package_exists
79+
80+
81+
82+
@lru_cache
83+
def is_flash_dmattn_available():
84+
if not is_torch_available():
85+
return False
86+
87+
if not _is_package_available("flash_dmattn"):
88+
return False
89+
90+
import torch
91+
92+
if not torch.cuda.is_available():
93+
return False
94+
95+
return True

0 commit comments

Comments
 (0)