Skip to content

Commit 9f5d48d

Browse files
committed
Renames package to flash_sparse_attn
Aligns packaging metadata and build hooks with the flash_sparse_attn name so prebuilt wheels, env vars, and CUDA builds resolve correctly.
1 parent 612b85c commit 9f5d48d

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

setup.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@
3434
# ninja build does not work unless include_dirs are abs path
3535
this_dir = os.path.dirname(os.path.abspath(__file__))
3636

37-
PACKAGE_NAME = "flash_dmattn"
37+
PACKAGE_NAME = "flash_sparse_attn"
3838

3939
BASE_WHEEL_URL = (
40-
"https://github.com/SmallDoges/flash-dmattn/releases/download/{tag_name}/{wheel_name}"
40+
"https://github.com/SmallDoges/flash-sparse-attention/releases/download/{tag_name}/{wheel_name}"
4141
)
4242

4343
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
4444
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
4545
# Also useful when user only wants Triton/Flex backends without CUDA compilation
46-
FORCE_BUILD = os.getenv("FLASH_DMATTN_FORCE_BUILD", "FALSE") == "TRUE"
47-
SKIP_CUDA_BUILD = os.getenv("FLASH_DMATTN_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
46+
FORCE_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
47+
SKIP_CUDA_BUILD = os.getenv("FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
4848
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
49-
FORCE_CXX11_ABI = os.getenv("FLASH_DMATTN_FORCE_CXX11_ABI", "FALSE") == "TRUE"
49+
FORCE_CXX11_ABI = os.getenv("FLASH_SPARSE_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
5050

5151
# Auto-detect if user wants only Triton/Flex backends based on pip install command
5252
# This helps avoid unnecessary CUDA compilation when user only wants Python backends
@@ -69,7 +69,7 @@ def should_skip_cuda_build():
6969

7070
if has_triton_or_flex and not has_all_or_dev:
7171
print("Detected Triton/Flex-only installation. Skipping CUDA compilation.")
72-
print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.")
72+
print("Set FLASH_SPARSE_ATTENTION_FORCE_BUILD=TRUE to force CUDA compilation.")
7373
return True
7474

7575
return False
@@ -79,7 +79,7 @@ def should_skip_cuda_build():
7979

8080
@functools.lru_cache(maxsize=None)
8181
def cuda_archs():
82-
return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;90;100").split(";")
82+
return os.getenv("FLASH_SPARSE_ATTENTION_CUDA_ARCHS", "80;90;100").split(";")
8383

8484

8585
def detect_preferred_sm_arch() -> Optional[str]:
@@ -154,14 +154,14 @@ def append_nvcc_threads(nvcc_extra_args):
154154
TORCH_MAJOR = int(torch.__version__.split(".")[0])
155155
TORCH_MINOR = int(torch.__version__.split(".")[1])
156156

157-
check_if_cuda_home_none("flash_dmattn")
157+
check_if_cuda_home_none("flash_sparse_attn")
158158
# Check, if CUDA11 is installed for compute capability 8.0
159159
cc_flag = []
160160
if CUDA_HOME is not None:
161161
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
162162
if bare_metal_version < Version("11.7"):
163163
raise RuntimeError(
164-
"Flash Dynamic Mask Attention is only supported on CUDA 11.7 and above. "
164+
"Flash Sparse Attention is only supported on CUDA 11.7 and above. "
165165
"Note: make sure nvcc has a supported version by running nvcc -V."
166166
)
167167

@@ -218,31 +218,31 @@ def append_nvcc_threads(nvcc_extra_args):
218218

219219
ext_modules.append(
220220
CUDAExtension(
221-
name="flash_dmattn_cuda",
221+
name="flash_sparse_attn_cuda",
222222
sources=(
223223
[
224-
"csrc/flash_dmattn/flash_api.cpp",
224+
"csrc/flash_sparse_attn/flash_api.cpp",
225225
]
226-
+ sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu"))
226+
+ sorted(glob.glob("csrc/flash_sparse_attn/src/instantiations/flash_*.cu"))
227227
),
228228
extra_compile_args={
229229
"cxx": compiler_c17_flag,
230230
"nvcc": append_nvcc_threads(nvcc_flags + cc_flag),
231231
},
232232
include_dirs=[
233-
Path(this_dir) / "csrc" / "flash_dmattn",
234-
Path(this_dir) / "csrc" / "flash_dmattn" / "src",
233+
Path(this_dir) / "csrc" / "flash_sparse_attn",
234+
Path(this_dir) / "csrc" / "flash_sparse_attn" / "src",
235235
Path(this_dir) / "csrc" / "cutlass" / "include",
236236
],
237237
)
238238
)
239239

240240

241241
def get_package_version():
242-
with open(Path(this_dir) / "flash_dmattn" / "__init__.py", "r") as f:
242+
with open(Path(this_dir) / "flash_sparse_attn" / "__init__.py", "r") as f:
243243
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
244244
public_version = ast.literal_eval(version_match.group(1))
245-
local_version = os.environ.get("FLASH_DMATTN_LOCAL_VERSION")
245+
local_version = os.environ.get("FLASH_SPARSE_ATTENTION_LOCAL_VERSION")
246246
if local_version:
247247
return f"{public_version}+{local_version}"
248248
else:

0 commit comments

Comments
 (0)