1+ # Copyright 2025 Jingze Shi and the HuggingFace Team. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ """
15+ Import utilities: Utilities related to imports and our lazy inits.
16+ """
17+
18+ import importlib .metadata
19+ import importlib .util
20+ from functools import lru_cache
21+ from typing import Union
22+
23+
24+ from transformers import is_torch_available
25+ from transformers .utils import logging
26+
27+
28+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
29+
30+
31+ # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
32+ def _is_package_available (pkg_name : str , return_version : bool = False ) -> Union [tuple [bool , str ], bool ]:
33+ # Check if the package spec exists and grab its version to avoid importing a local directory
34+ package_exists = importlib .util .find_spec (pkg_name ) is not None
35+ package_version = "N/A"
36+ if package_exists :
37+ try :
38+ # TODO: Once python 3.9 support is dropped, `importlib.metadata.packages_distributions()`
39+ # should be used here to map from package name to distribution names
40+ # e.g. PIL -> Pillow, Pillow-SIMD; quark -> amd-quark; onnxruntime -> onnxruntime-gpu.
41+ # `importlib.metadata.packages_distributions()` is not available in Python 3.9.
42+
43+ # Primary method to get the package version
44+ package_version = importlib .metadata .version (pkg_name )
45+ except importlib .metadata .PackageNotFoundError :
46+ # Fallback method: Only for "torch" and versions containing "dev"
47+ if pkg_name == "torch" :
48+ try :
49+ package = importlib .import_module (pkg_name )
50+ temp_version = getattr (package , "__version__" , "N/A" )
51+ # Check if the version contains "dev"
52+ if "dev" in temp_version :
53+ package_version = temp_version
54+ package_exists = True
55+ else :
56+ package_exists = False
57+ except ImportError :
58+ # If the package can't be imported, it's not available
59+ package_exists = False
60+ elif pkg_name == "quark" :
61+ # TODO: remove once `importlib.metadata.packages_distributions()` is supported.
62+ try :
63+ package_version = importlib .metadata .version ("amd-quark" )
64+ except Exception :
65+ package_exists = False
66+ elif pkg_name == "triton" :
67+ try :
68+ package_version = importlib .metadata .version ("pytorch-triton" )
69+ except Exception :
70+ package_exists = False
71+ else :
72+ # For packages other than "torch", don't attempt the fallback and set as not available
73+ package_exists = False
74+ logger .debug (f"Detected { pkg_name } version: { package_version } " )
75+ if return_version :
76+ return package_exists , package_version
77+ else :
78+ return package_exists
79+
80+
81+
82+ @lru_cache
83+ def is_flash_dmattn_available ():
84+ if not is_torch_available ():
85+ return False
86+
87+ if not _is_package_available ("flash_dmattn" ):
88+ return False
89+
90+ import torch
91+
92+ if not torch .cuda .is_available ():
93+ return False
94+
95+ return True
0 commit comments