3434# ninja build does not work unless include_dirs are abs path
3535this_dir = os .path .dirname (os .path .abspath (__file__ ))
3636
37- PACKAGE_NAME = "flash_dmattn "
37+ PACKAGE_NAME = "flash_sparse_attn "
3838
3939BASE_WHEEL_URL = (
40- "https://github.com/SmallDoges/flash-dmattn /releases/download/{tag_name}/{wheel_name}"
40+ "https://github.com/SmallDoges/flash-sparse-attention /releases/download/{tag_name}/{wheel_name}"
4141)
4242
4343# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
4444# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
4545# Also useful when user only wants Triton/Flex backends without CUDA compilation
46- FORCE_BUILD = os .getenv ("FLASH_DMATTN_FORCE_BUILD " , "FALSE" ) == "TRUE"
47- SKIP_CUDA_BUILD = os .getenv ("FLASH_DMATTN_SKIP_CUDA_BUILD " , "FALSE" ) == "TRUE"
46+ FORCE_BUILD = os .getenv ("FLASH_SPARSE_ATTENTION_FORCE_BUILD " , "FALSE" ) == "TRUE"
47+ SKIP_CUDA_BUILD = os .getenv ("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD " , "FALSE" ) == "TRUE"
4848# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
49- FORCE_CXX11_ABI = os .getenv ("FLASH_DMATTN_FORCE_CXX11_ABI " , "FALSE" ) == "TRUE"
49+ FORCE_CXX11_ABI = os .getenv ("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI " , "FALSE" ) == "TRUE"
5050
5151# Auto-detect if user wants only Triton/Flex backends based on pip install command
5252# This helps avoid unnecessary CUDA compilation when user only wants Python backends
@@ -69,7 +69,7 @@ def should_skip_cuda_build():
6969
7070 if has_triton_or_flex and not has_all_or_dev :
7171 print ("Detected Triton/Flex-only installation. Skipping CUDA compilation." )
72- print ("Set FLASH_DMATTN_FORCE_BUILD =TRUE to force CUDA compilation." )
72+ print ("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD =TRUE to force CUDA compilation." )
7373 return True
7474
7575 return False
@@ -79,7 +79,7 @@ def should_skip_cuda_build():
7979
8080@functools .lru_cache (maxsize = None )
8181def cuda_archs ():
82- return os .getenv ("FLASH_DMATTN_CUDA_ARCHS " , "80;90;100" ).split (";" )
82+ return os .getenv ("FLASH_SPARSE_ATTENTION_CUDA_ARCHS " , "80;90;100" ).split (";" )
8383
8484
8585def detect_preferred_sm_arch () -> Optional [str ]:
@@ -154,14 +154,14 @@ def append_nvcc_threads(nvcc_extra_args):
154154 TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
155155 TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
156156
157- check_if_cuda_home_none ("flash_dmattn " )
157+ check_if_cuda_home_none ("flash_sparse_attn " )
158158 # Check, if CUDA11 is installed for compute capability 8.0
159159 cc_flag = []
160160 if CUDA_HOME is not None :
161161 _ , bare_metal_version = get_cuda_bare_metal_version (CUDA_HOME )
162162 if bare_metal_version < Version ("11.7" ):
163163 raise RuntimeError (
164- "Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. "
164+ "Flash Sparse Attention is only supported on CUDA 11.7 and above. "
165165 "Note: make sure nvcc has a supported version by running nvcc -V."
166166 )
167167
@@ -218,31 +218,31 @@ def append_nvcc_threads(nvcc_extra_args):
218218
219219 ext_modules .append (
220220 CUDAExtension (
221- name = "flash_dmattn_cuda " ,
221+ name = "flash_sparse_attn_cuda " ,
222222 sources = (
223223 [
224- "csrc/flash_dmattn /flash_api.cpp" ,
224+ "csrc/flash_sparse_attn /flash_api.cpp" ,
225225 ]
226- + sorted (glob .glob ("csrc/flash_dmattn /src/instantiations/flash_*.cu" ))
226+ + sorted (glob .glob ("csrc/flash_sparse_attn /src/instantiations/flash_*.cu" ))
227227 ),
228228 extra_compile_args = {
229229 "cxx" : compiler_c17_flag ,
230230 "nvcc" : append_nvcc_threads (nvcc_flags + cc_flag ),
231231 },
232232 include_dirs = [
233- Path (this_dir ) / "csrc" / "flash_dmattn " ,
234- Path (this_dir ) / "csrc" / "flash_dmattn " / "src" ,
233+ Path (this_dir ) / "csrc" / "flash_sparse_attn " ,
234+ Path (this_dir ) / "csrc" / "flash_sparse_attn " / "src" ,
235235 Path (this_dir ) / "csrc" / "cutlass" / "include" ,
236236 ],
237237 )
238238 )
239239
240240
241241def get_package_version ():
242- with open (Path (this_dir ) / "flash_dmattn " / "__init__.py" , "r" ) as f :
242+ with open (Path (this_dir ) / "flash_sparse_attn " / "__init__.py" , "r" ) as f :
243243 version_match = re .search (r"^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
244244 public_version = ast .literal_eval (version_match .group (1 ))
245- local_version = os .environ .get ("FLASH_DMATTN_LOCAL_VERSION " )
245+ local_version = os .environ .get ("FLASH_SPARSE_ATTENTION_LOCAL_VERSION " )
246246 if local_version :
247247 return f"{ public_version } +{ local_version } "
248248 else :
0 commit comments