Skip to content

Commit 77e4e61

Browse files
committed
Clarifies install docs and performance layout
Updates package and repo naming so installation commands match the published distribution. Repositions performance benchmarks after usage guidance for both languages and aligns tensor examples to current API expectations.
1 parent 152c73a commit 77e4e61

File tree

2 files changed

+185
-185
lines changed

2 files changed

+185
-185
lines changed

README.md

Lines changed: 92 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -45,95 +45,6 @@ Thus, a more effective approach is sparse attention: interacting each query with
4545
- Further performance improvements for skipping memory access and computation
4646

4747

48-
## Performance
49-
50-
We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
51-
52-
![FSA Performance Overview](assets/performance_overview.png)
53-
54-
---
55-
56-
### Forward Pass Performance
57-
58-
The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
59-
60-
| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup |
61-
|--------|-------|--------|----------|-----------|-----------|---------|
62-
| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
63-
| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
64-
| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x |
65-
| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x |
66-
| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x |
67-
| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x |
68-
| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x |
69-
| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x |
70-
| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x |
71-
| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x |
72-
| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x |
73-
| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x |
74-
| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x |
75-
| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x |
76-
| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x |
77-
| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x |
78-
| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x |
79-
| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x |
80-
| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x |
81-
| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x |
82-
| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x |
83-
| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x |
84-
| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x |
85-
| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x |
86-
| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x |
87-
| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x |
88-
| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x |
89-
| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x |
90-
| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x |
91-
| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x |
92-
| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x |
93-
| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x |
94-
| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x |
95-
| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x |
96-
| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x |
97-
| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x |
98-
| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x |
99-
| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x |
100-
| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x |
101-
| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x |
102-
| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x |
103-
| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x |
104-
| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x |
105-
106-
---
107-
108-
### Backward Pass Performance
109-
110-
The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
111-
112-
| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup |
113-
|-------|-------|--------|----------|---------------|---------------|---------|
114-
| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
115-
| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
116-
| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x |
117-
| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x |
118-
| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x |
119-
| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x |
120-
| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x |
121-
| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x |
122-
| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x |
123-
| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x |
124-
| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x |
125-
| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x |
126-
| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x |
127-
| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x |
128-
| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x |
129-
| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x |
130-
| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x |
131-
| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x |
132-
| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x |
133-
134-
---
135-
136-
13748
## Installation
13849

13950
### Requirements
@@ -150,14 +61,14 @@ The following table shows the backward pass performance comparison between FSA a
15061
You can install FSA via pre-compiled wheels:
15162

15263
```bash
153-
pip install flash_sparse_attn --no-build-isolation
64+
pip install flash-sparse-attn --no-build-isolation
15465
```
15566

15667
Alternatively, you can compile and install from source:
15768

15869
```bash
159-
git clone https://github.com/SmallDoges/flash_sparse_attn.git
160-
cd flash_sparse_attn
70+
git clone https://github.com/SmallDoges/flash-sparse-attn.git
71+
cd flash-sparse-attn
16172
pip install . --no-build-isolation
16273
```
16374

@@ -245,6 +156,95 @@ print(f"Bias gradient shape: {attn_bias.grad.shape}")
245156
```
246157

247158

159+
## Performance
160+
161+
We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
162+
163+
![FSA Performance Overview](assets/performance_overview.png)
164+
165+
---
166+
167+
### Forward Pass Performance
168+
169+
The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
170+
171+
| Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup |
172+
|--------|-------|--------|----------|-----------|-----------|---------|
173+
| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
174+
| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
175+
| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x |
176+
| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x |
177+
| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x |
178+
| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x |
179+
| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x |
180+
| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x |
181+
| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x |
182+
| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x |
183+
| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x |
184+
| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x |
185+
| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x |
186+
| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x |
187+
| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x |
188+
| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x |
189+
| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x |
190+
| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x |
191+
| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x |
192+
| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x |
193+
| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x |
194+
| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x |
195+
| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x |
196+
| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x |
197+
| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x |
198+
| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x |
199+
| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x |
200+
| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x |
201+
| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x |
202+
| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x |
203+
| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x |
204+
| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x |
205+
| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x |
206+
| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x |
207+
| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x |
208+
| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x |
209+
| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x |
210+
| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x |
211+
| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x |
212+
| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x |
213+
| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x |
214+
| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x |
215+
| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x |
216+
217+
---
218+
219+
### Backward Pass Performance
220+
221+
The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
222+
223+
| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup |
224+
|-------|-------|--------|----------|---------------|---------------|---------|
225+
| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
226+
| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
227+
| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x |
228+
| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x |
229+
| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x |
230+
| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x |
231+
| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x |
232+
| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x |
233+
| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x |
234+
| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x |
235+
| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x |
236+
| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x |
237+
| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x |
238+
| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x |
239+
| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x |
240+
| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x |
241+
| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x |
242+
| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x |
243+
| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x |
244+
245+
---
246+
247+
248248
## Benchmarking
249249

250250
FSA provides comprehensive benchmarking tools to evaluate performance across different configurations:

0 commit comments

Comments
 (0)