@@ -45,95 +45,6 @@ Thus, a more effective approach is sparse attention: interacting each query with
4545- Further performance improvements for skipping memory access and computation
4646
4747
48- ## Performance
49-
50- We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
51-
52- ![ FSA Performance Overview] ( assets/performance_overview.png )
53-
54- ---
55-
56- ### Forward Pass Performance
57-
58- The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
59-
60- | Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup |
61- | --------| -------| --------| ----------| -----------| -----------| ---------|
62- | Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
63- | Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
64- | Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x |
65- | Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x |
66- | Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x |
67- | Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x |
68- | Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x |
69- | Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x |
70- | Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x |
71- | Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x |
72- | Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x |
73- | Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x |
74- | Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x |
75- | Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x |
76- | Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x |
77- | Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x |
78- | Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x |
79- | Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x |
80- | Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x |
81- | Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x |
82- | Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x |
83- | Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x |
84- | Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x |
85- | Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x |
86- | Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x |
87- | Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x |
88- | Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x |
89- | Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x |
90- | Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x |
91- | Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x |
92- | Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x |
93- | Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x |
94- | Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x |
95- | Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x |
96- | Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x |
97- | Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x |
98- | Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x |
99- | Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x |
100- | Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x |
101- | Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x |
102- | Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x |
103- | Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x |
104- | Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x |
105-
106- ---
107-
108- ### Backward Pass Performance
109-
110- The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
111-
112- | Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup |
113- | -------| -------| --------| ----------| ---------------| ---------------| ---------|
114- | Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
115- | Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
116- | Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x |
117- | Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x |
118- | Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x |
119- | Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x |
120- | Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x |
121- | Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x |
122- | Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x |
123- | Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x |
124- | Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x |
125- | Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x |
126- | Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x |
127- | Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x |
128- | Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x |
129- | Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x |
130- | Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x |
131- | Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x |
132- | Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x |
133-
134- ---
135-
136-
13748## Installation
13849
13950### Requirements
@@ -150,14 +61,14 @@ The following table shows the backward pass performance comparison between FSA a
15061You can install FSA via pre-compiled wheels:
15162
15263``` bash
153- pip install flash_sparse_attn --no-build-isolation
64+ pip install flash-sparse-attn --no-build-isolation
15465```
15566
15667Alternatively, you can compile and install from source:
15768
15869``` bash
159- git clone https://github.com/SmallDoges/flash_sparse_attn .git
160- cd flash_sparse_attn
70+ git clone https://github.com/SmallDoges/flash-sparse-attn .git
71+ cd flash-sparse-attn
16172pip install . --no-build-isolation
16273```
16374
@@ -245,6 +156,95 @@ print(f"Bias gradient shape: {attn_bias.grad.shape}")
245156```
246157
247158
159+ ## Performance
160+
161+ We present the expected speedup of FSA over standard PyTorch SDPA under mask and bias conditions.
162+
163+ ![ FSA Performance Overview] ( assets/performance_overview.png )
164+
165+ ---
166+
167+ ### Forward Pass Performance
168+
169+ The following table shows the forward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
170+
171+ | Mode | Q len | K len | Window W | SDPA (ms) | FSA (ms) | Speedup |
172+ | --------| -------| --------| ----------| -----------| -----------| ---------|
173+ | Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
174+ | Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
175+ | Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x |
176+ | Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x |
177+ | Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x |
178+ | Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x |
179+ | Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x |
180+ | Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x |
181+ | Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x |
182+ | Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x |
183+ | Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x |
184+ | Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x |
185+ | Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x |
186+ | Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x |
187+ | Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x |
188+ | Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x |
189+ | Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x |
190+ | Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x |
191+ | Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x |
192+ | Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x |
193+ | Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x |
194+ | Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x |
195+ | Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x |
196+ | Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x |
197+ | Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x |
198+ | Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x |
199+ | Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x |
200+ | Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x |
201+ | Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x |
202+ | Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x |
203+ | Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x |
204+ | Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x |
205+ | Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x |
206+ | Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x |
207+ | Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x |
208+ | Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x |
209+ | Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x |
210+ | Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x |
211+ | Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x |
212+ | Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x |
213+ | Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x |
214+ | Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x |
215+ | Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x |
216+
217+ ---
218+
219+ ### Backward Pass Performance
220+
221+ The following table shows the backward pass performance comparison between FSA and standard PyTorch SDPA on an NVIDIA A100-SXM4-80GB. Results are averaged over 3 runs after 2 warmup runs.
222+
223+ | Mode | Q len | K len | Window W | SDPA-BWD (ms) | FSA-BWD (ms) | Speedup |
224+ | -------| -------| --------| ----------| ---------------| ---------------| ---------|
225+ | Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
226+ | Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
227+ | Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x |
228+ | Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x |
229+ | Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x |
230+ | Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x |
231+ | Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x |
232+ | Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x |
233+ | Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x |
234+ | Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x |
235+ | Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x |
236+ | Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x |
237+ | Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x |
238+ | Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x |
239+ | Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x |
240+ | Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x |
241+ | Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x |
242+ | Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x |
243+ | Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x |
244+
245+ ---
246+
247+
248248## Benchmarking
249249
250250FSA provides comprehensive benchmarking tools to evaluate performance across different configurations:
0 commit comments