Skip to content

Commit e5eb029

Browse files
committed
Updates API to use auto backend selection function
Replaces direct function import with auto backend selection approach for better flexibility. Changes parameter names from abbreviated forms to full descriptive names for improved clarity. Updates num_heads from 12 to 16 in examples to reflect more common model configurations. Renames softmax_scale parameter to scale for consistency with standard naming conventions.
1 parent 6f036c1 commit e5eb029

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ pip install .
7272

7373
```python
7474
import torch
75-
from flash_dmattn import flash_dmattn_func
75+
from flash_dmattn import flash_dmattn_func_auto
7676
import math
7777

7878
# Setup
79-
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
79+
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
8080
device = torch.device('cuda')
8181
dtype = torch.bfloat16
8282

@@ -103,18 +103,21 @@ if seq_len > keep_window_size:
103103
attention_mask.zero_()
104104
attention_mask.scatter(-1, topk_indices, 1.0)
105105

106+
# Select backend
107+
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
108+
106109
# Run Flash Dynamic Mask Attention
107110
output = flash_dmattn_func(
108111
q=query,
109112
k=key,
110113
v=value,
111114
attn_mask=attention_mask,
112115
attn_bias=attention_bias,
113-
softmax_scale=1.0/math.sqrt(head_dim),
114-
is_causal=True
116+
is_causal=True,
117+
scale=1.0/math.sqrt(head_dim),
115118
)
116119

117-
print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]
120+
print(f"Output shape: {output.shape}") # [2, 4096, 16, 128]
118121
```
119122

120123

README_zh.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ pip install .
7272

7373
```python
7474
import torch
75-
from flash_dmattn import flash_dmattn_func
75+
from flash_dmattn import flash_dmattn_func_auto
7676
import math
7777

7878
# 设置
79-
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
79+
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
8080
device = torch.device('cuda')
8181
dtype = torch.bfloat16
8282

@@ -103,18 +103,21 @@ if seq_len > keep_window_size:
103103
attention_mask.zero_()
104104
attention_mask.scatter(-1, topk_indices, 1.0)
105105

106+
# 选择后端
107+
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
108+
106109
# 运行 Flash 动态掩码注意力
107110
output = flash_dmattn_func(
108-
q=query,
109-
k=key,
110-
v=value,
111+
query=query,
112+
key=key,
113+
value=value,
111114
attn_mask=attention_mask,
112115
attn_bias=attention_bias,
113-
softmax_scale=1.0/math.sqrt(head_dim),
114-
is_causal=True
116+
is_causal=True,
117+
scale=1.0/math.sqrt(head_dim),
115118
)
116119

117-
print(f"输出形状: {output.shape}") # [2, 4096, 12, 128]
120+
print(f"输出形状: {output.shape}") # [2, 4096, 16, 128]
118121
```
119122

120123

0 commit comments

Comments
 (0)