Skip to content

Commit 6900b29

Browse files
committed
Adds backend auto-selection API
Exposes backend availability flags to let callers probe supported runtimes without import errors. Provides auto-selection helper to fall back to the first available backend for attention execution.
1 parent 938aec9 commit 6900b29

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

flash_sparse_attn/__init__.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2025, Jingze Shi.
2+
3+
from typing import Optional
4+
5+
__version__ = "1.2.3"
6+
7+
8+
# Import CUDA functions when available
9+
try:
10+
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func, flash_sparse_attn_varlen_func
11+
CUDA_AVAILABLE = True
12+
except ImportError:
13+
CUDA_AVAILABLE = False
14+
flash_sparse_attn_func, flash_sparse_attn_varlen_func = None, None
15+
16+
# Import Triton functions when available
17+
try:
18+
from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func
19+
TRITON_AVAILABLE = True
20+
except ImportError:
21+
TRITON_AVAILABLE = False
22+
triton_sparse_attn_func = None
23+
24+
# Import Flex functions when available
25+
try:
26+
from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func
27+
FLEX_AVAILABLE = True
28+
except ImportError:
29+
FLEX_AVAILABLE = False
30+
flex_sparse_attn_func = None
31+
32+
33+
def get_available_backends():
34+
"""Return a list of available backends."""
35+
backends = []
36+
if CUDA_AVAILABLE:
37+
backends.append("cuda")
38+
if TRITON_AVAILABLE:
39+
backends.append("triton")
40+
if FLEX_AVAILABLE:
41+
backends.append("flex")
42+
return backends
43+
44+
45+
def flash_sparse_attn_func_auto(backend: Optional[str] = None, **kwargs):
46+
"""
47+
Flash Dynamic Mask Attention function with automatic backend selection.
48+
49+
Args:
50+
backend (str, optional): Backend to use ('cuda', 'triton', 'flex').
51+
If None, will use the first available backend in order: cuda, triton, flex.
52+
**kwargs: Arguments to pass to the attention function.
53+
54+
Returns:
55+
The attention function for the specified or auto-selected backend.
56+
"""
57+
if backend is None:
58+
# Auto-select backend
59+
if CUDA_AVAILABLE:
60+
backend = "cuda"
61+
elif TRITON_AVAILABLE:
62+
backend = "triton"
63+
elif FLEX_AVAILABLE:
64+
backend = "flex"
65+
else:
66+
raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.")
67+
68+
if backend == "cuda":
69+
if not CUDA_AVAILABLE:
70+
raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.")
71+
return flash_sparse_attn_func
72+
73+
elif backend == "triton":
74+
if not TRITON_AVAILABLE:
75+
raise RuntimeError("Triton backend is not available. Please install triton: pip install triton")
76+
return triton_sparse_attn_func
77+
78+
elif backend == "flex":
79+
if not FLEX_AVAILABLE:
80+
raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers")
81+
return flex_sparse_attn_func
82+
83+
else:
84+
raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}")
85+
86+
87+
__all__ = [
88+
"CUDA_AVAILABLE",
89+
"TRITON_AVAILABLE",
90+
"FLEX_AVAILABLE",
91+
"flash_sparse_attn_func",
92+
"flash_sparse_attn_varlen_func",
93+
"triton_sparse_attn_func",
94+
"flex_sparse_attn_func",
95+
"get_available_backends",
96+
"flash_sparse_attn_func_auto",
97+
]

0 commit comments

Comments
 (0)