Skip to content

Commit 951abdc

Browse files
committed
Adds comprehensive documentation with logo and Chinese translation
Enhances project presentation with professional branding and multilingual support. Introduces centered logo banner and performance visualization charts to improve visual appeal and communicate project capabilities more effectively. Provides complete Chinese translation of documentation to expand accessibility for Chinese-speaking developers and researchers. Updates acknowledgments section with emoji for friendlier tone while maintaining professional content structure.
1 parent ac9d2aa commit 951abdc

File tree

4 files changed

+309
-2
lines changed

4 files changed

+309
-2
lines changed

README.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1-
# Flash Dynamic Mask Attention
1+
<div align="center">
2+
<img src="./assets/logo.png" alt="SmallDoges" width="100%">
3+
</div>
4+
5+
<div align="center">
6+
7+
8+
**English** | [简体中文](./README_zh.md)
9+
10+
</div>
11+
12+
**Trainable Dynamic Mask Sparse Attention**
13+
14+
> Jingze Shi, Yifan Wu, Bingheng Wu, Yiran Peng, Liangdong Wang, Guang Liu, Yuyu Luo
15+
16+
> Paper: https://huggingface.co/papers/2508.02124
217
318
![Flash-DMA Banner](assets/flash_dmattn_banner.png)
419

520
Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
621

22+
723
## Key Features
824

925
- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$.
@@ -12,6 +28,14 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A
1228
- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds `keep_window_size`.
1329
- **Advanced Integration**: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies.
1430

31+
32+
## Performance
33+
34+
We present expected speedup of Flash-DMA over standard PyTorch SDPA.
35+
36+
![Speedup](assets/speedup.png)
37+
38+
1539
## Installation
1640

1741
### Prerequisites
@@ -43,6 +67,7 @@ git submodule update --init --recursive
4367
pip install .
4468
```
4569

70+
4671
## Quick Start
4772

4873
```python
@@ -254,4 +279,4 @@ This project builds upon and integrates several excellent works:
254279
- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - Memory-efficient attention computation
255280
- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - High-performance matrix operations library
256281

257-
We thank the open-source community for their contributions to efficient transformer implementations.
282+
We thank the open-source community for their contributions to efficient transformer implementations. 🤗

README_zh.md

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
<div align="center">
2+
<img src="./assets/logo.png" alt="SmallDoges" width="100%">
3+
</div>
4+
5+
<div align="center">
6+
7+
8+
[English](./README.md) | **简体中文**
9+
10+
</div>
11+
12+
**可训练的动态掩码稀疏注意力**
13+
14+
> Jingze Shi, Yifan Wu, Bingheng Wu, Yiran Peng, Liangdong Wang, Guang Liu, Yuyu Luo
15+
16+
> 论文: https://huggingface.co/papers/2508.02124
17+
18+
![Flash-DMA Banner](assets/flash_dmattn_banner.png)
19+
20+
Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存效率与动态掩码注意力的稀疏计算能力相结合,用于在 Transformer 模型中处理超长序列。
21+
22+
23+
## 主要特性
24+
25+
- **稀疏注意力计算**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$。
26+
- **内存效率**: 保持 Flash Attention 的 $O(N)$ 内存复杂度,无需实例化完整的注意力矩阵。
27+
- **CUDA 加速**: 在 CUDA 内核层面深度集成,采用自定义稀疏 GEMM 运算以获得最佳性能。
28+
- **长序列支持**: 当序列长度超过 `keep_window_size` 时,通过动态掩码高效处理 128K+ 标记的序列。
29+
- **高级集成**: 从 Python 前端到 CUDA 后端的完整集成,具有优化的内存布局和稀疏计算策略。
30+
31+
32+
## 性能
33+
34+
我们展示了 Flash-DMA 相对于标准 PyTorch SDPA 的预期加速效果。
35+
36+
![Speedup](assets/speedup.png)
37+
38+
39+
## 安装
40+
41+
### 先决条件
42+
43+
- **Python**: 3.8 或更高版本
44+
- **PyTorch**: 2.0.0 或更高版本
45+
- **CUDA**: 11.8 或更高版本
46+
- **NVIDIA GPU**: 计算能力 8.0 或更高
47+
- **C++ 编译器**: GCC 7+
48+
49+
### CUDA 环境设置
50+
51+
确保您的 CUDA 环境已正确配置:
52+
53+
```bash
54+
# 检查 CUDA 安装
55+
nvcc --version
56+
57+
# 如需要,设置 CUDA_HOME
58+
export CUDA_HOME=/usr/local/cuda
59+
```
60+
61+
### 从源码安装
62+
63+
```bash
64+
git clone https://github.com/SmallDoges/flash-dmattn.git
65+
cd flash-dmattn
66+
git submodule update --init --recursive
67+
pip install .
68+
```
69+
70+
71+
## 快速开始
72+
73+
```python
74+
import torch
75+
from flash_dmattn import flash_dmattn_func
76+
import math
77+
78+
# 设置
79+
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
80+
device = torch.device('cuda')
81+
dtype = torch.bfloat16
82+
83+
# 输入张量
84+
query = torch.randn(batch_size, seq_len, num_heads, head_dim,
85+
device=device, dtype=dtype)
86+
key = torch.randn(batch_size, seq_len, num_heads, head_dim,
87+
device=device, dtype=dtype)
88+
value = torch.randn(batch_size, seq_len, num_heads, head_dim,
89+
device=device, dtype=dtype)
90+
91+
# 为稀疏注意力创建掩码和偏置
92+
attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len,
93+
device=device, dtype=dtype)
94+
attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
95+
device=device, dtype=dtype)
96+
97+
# 应用动态掩码(为长序列保留 top-k)
98+
keep_window_size = 2048
99+
if seq_len > keep_window_size:
100+
# 为每个查询选择 top-k 最重要的键
101+
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
102+
largest=True, sorted=False).indices
103+
attention_mask.zero_()
104+
attention_mask.scatter(-1, topk_indices, 1.0)
105+
106+
# 运行 Flash 动态掩码注意力
107+
output = flash_dmattn_func(
108+
q=query,
109+
k=key,
110+
v=value,
111+
attn_mask=attention_mask,
112+
attn_bias=attention_bias,
113+
softmax_scale=1.0/math.sqrt(head_dim),
114+
is_causal=True
115+
)
116+
117+
print(f"输出形状: {output.shape}") # [2, 4096, 12, 128]
118+
```
119+
120+
121+
## 工作原理
122+
123+
Flash-DMA 结合了两种互补的技术:
124+
125+
- **动态掩码注意力**: 计算键的相关性分数,并仅选择最重要的键进行注意力计算
126+
- **Flash Attention**: 分块处理注意力以减少内存使用和 HBM 访问
127+
128+
### 集成方法
129+
130+
集成发生在 CUDA 内核层面,具有几个关键组件:
131+
132+
- **ZOH 状态**: 预计算的键选择重要性分数
133+
- **活跃掩码**: 指示每个查询应考虑哪些键的二进制掩码
134+
- **稀疏矩阵乘法**: 高效稀疏注意力计算的自定义 CUDA 内核
135+
- **分块处理**: 保持 Flash Attention 的分块方法以提高内存效率
136+
137+
这创建了一种混合注意力机制,为长序列实现了内存和计算效率。
138+
139+
140+
## 文档
141+
142+
📚 **完整文档可在 [docs](docs/) 目录中找到:**
143+
144+
- **[API 参考](docs/api_reference.md)** - 完整的函数文档和使用示例
145+
- **[集成指南](docs/integration.md)** - Flash Attention 集成的详细技术文档
146+
147+
148+
## 从源码构建
149+
150+
### 开发环境设置
151+
152+
```bash
153+
# 克隆包含子模块
154+
git clone --recursive https://github.com/SmallDoges/flash-dmattn.git
155+
cd flash-dmattn
156+
157+
# 在开发模式下构建
158+
pip install -e .
159+
160+
# 运行测试以验证安装
161+
python -c "import flash_dma_cuda; print('✅ Flash DMA CUDA 扩展导入成功')"
162+
```
163+
164+
### 构建要求
165+
166+
- CUDA Toolkit 11.8+
167+
- CUTLASS 库
168+
- 支持 CUDA 的 PyTorch
169+
170+
### 支持的架构
171+
172+
- **SM 8.0**
173+
- **SM 9.0**
174+
- **SM 10.0**
175+
- **SM 12.0**
176+
177+
**注意**: Flash 动态掩码注意力需要 CUDA 计算能力 8.0+ 才能获得最佳性能。不支持更早的架构。
178+
179+
## 基准测试
180+
181+
Flash-DMA 提供全面的基准测试工具,用于评估不同配置下的性能:
182+
183+
### 前向传播等效性
184+
```bash
185+
python benchmarks/benchmark_forward_equivalence.py
186+
```
187+
验证 Python 参考实现与 CUDA 实现之间的数值一致性。
188+
189+
### 性能基准测试
190+
```bash
191+
python benchmarks/benchmark_forward_performance.py
192+
```
193+
在各种序列长度和批大小下比较 Flash-DMA 与标准 Flash Attention。
194+
195+
### 梯度计算
196+
```bash
197+
python benchmarks/benchmark_grad.py
198+
```
199+
测试反向传播实现和梯度等效性。
200+
201+
### 多查询联想回忆
202+
```bash
203+
python benchmarks/benchmark_mqar.py
204+
```
205+
评估长程推理任务的性能。
206+
207+
208+
## 故障排除
209+
210+
### 常见问题
211+
212+
**编译错误**
213+
```bash
214+
# 确保 CUDA_HOME 设置正确
215+
echo $CUDA_HOME # Linux/Mac
216+
echo $env:CUDA_HOME # Windows PowerShell
217+
218+
# 检查 CUDA 工具包版本
219+
nvcc --version
220+
221+
# 验证 PyTorch CUDA 支持
222+
python -c "import torch; print(f'CUDA 可用: {torch.cuda.is_available()}')"
223+
```
224+
225+
**导入错误**
226+
```python
227+
# 测试基本导入
228+
try:
229+
from flash_dmattn import flash_dmattn_func, get_available_backends
230+
print("✅ Flash 动态掩码注意力导入成功")
231+
print(f"可用后端: {get_available_backends()}")
232+
except ImportError as e:
233+
print(f"❌ 导入失败: {e}")
234+
print("请确保包已正确安装,使用: pip install -e .")
235+
```
236+
237+
**性能问题**
238+
```python
239+
# 监控 GPU 内存使用
240+
from flash_dmattn import flash_dmattn_func
241+
242+
def print_memory_stats():
243+
if torch.cuda.is_available():
244+
print(f"GPU 内存: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
245+
246+
print_memory_stats()
247+
output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True)
248+
print_memory_stats()
249+
250+
# 如需要,清除缓存
251+
torch.cuda.empty_cache()
252+
```
253+
254+
## 许可证
255+
256+
本项目采用 BSD 3-Clause 许可证。详情请参见 [LICENSE](LICENSE)
257+
258+
## 引用
259+
260+
如果您在研究中使用 Flash-DMA,请引用:
261+
262+
```bibtex
263+
@misc{shi2025trainabledynamicmasksparse,
264+
title={Trainable Dynamic Mask Sparse Attention},
265+
author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
266+
year={2025},
267+
eprint={2508.02124},
268+
archivePrefix={arXiv},
269+
primaryClass={cs.AI},
270+
url={https://arxiv.org/abs/2508.02124},
271+
}
272+
```
273+
274+
## 致谢
275+
276+
本项目基于并集成了几个优秀的工作:
277+
278+
- **[OpenSeek](https://github.com/FlagAI-Open/OpenSeek)** - 内核开发支持
279+
- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - 内存高效的注意力计算
280+
- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - 高性能矩阵运算库
281+
282+
我们感谢开源社区对高效 Transformer 实现的贡献。🤗

assets/logo.png

308 KB
Loading

assets/speedup.png

260 KB
Loading

0 commit comments

Comments
 (0)