Skip to content

Commit e3379c3

Browse files
authored
Standardize parameter naming and improve API consistency in attention functions
2 parents 6403ad1 + 39ddb40 commit e3379c3

File tree

7 files changed

+92
-86
lines changed

7 files changed

+92
-86
lines changed

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ pip install .
7272

7373
```python
7474
import torch
75-
from flash_dmattn import flash_dmattn_func
75+
from flash_dmattn import flash_dmattn_func_auto
7676
import math
7777

7878
# Setup
79-
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
79+
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
8080
device = torch.device('cuda')
8181
dtype = torch.bfloat16
8282

@@ -103,18 +103,21 @@ if seq_len > keep_window_size:
103103
attention_mask.zero_()
104104
attention_mask.scatter(-1, topk_indices, 1.0)
105105

106+
# Select backend
107+
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
108+
106109
# Run Flash Dynamic Mask Attention
107110
output = flash_dmattn_func(
108111
q=query,
109112
k=key,
110113
v=value,
111114
attn_mask=attention_mask,
112115
attn_bias=attention_bias,
113-
softmax_scale=1.0/math.sqrt(head_dim),
114-
is_causal=True
116+
is_causal=True,
117+
scale=1.0/math.sqrt(head_dim),
115118
)
116119

117-
print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]
120+
print(f"Output shape: {output.shape}") # [2, 4096, 16, 128]
118121
```
119122

120123

README_zh.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ pip install .
7272

7373
```python
7474
import torch
75-
from flash_dmattn import flash_dmattn_func
75+
from flash_dmattn import flash_dmattn_func_auto
7676
import math
7777

7878
# 设置
79-
batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128
79+
batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128
8080
device = torch.device('cuda')
8181
dtype = torch.bfloat16
8282

@@ -103,18 +103,21 @@ if seq_len > keep_window_size:
103103
attention_mask.zero_()
104104
attention_mask.scatter(-1, topk_indices, 1.0)
105105

106+
# 选择后端
107+
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
108+
106109
# 运行 Flash 动态掩码注意力
107110
output = flash_dmattn_func(
108-
q=query,
109-
k=key,
110-
v=value,
111+
query=query,
112+
key=key,
113+
value=value,
111114
attn_mask=attention_mask,
112115
attn_bias=attention_bias,
113-
softmax_scale=1.0/math.sqrt(head_dim),
114-
is_causal=True
116+
is_causal=True,
117+
scale=1.0/math.sqrt(head_dim),
115118
)
116119

117-
print(f"输出形状: {output.shape}") # [2, 4096, 12, 128]
120+
print(f"输出形状: {output.shape}") # [2, 4096, 16, 128]
118121
```
119122

120123

benchmarks/benchmark_forward_equivalence.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ def dynamic_mask_attention_cuda(
257257
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
258258
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
259259
dropout_p=0.0,
260-
softmax_scale=scaling,
261260
is_causal=is_causal,
261+
scale=scaling,
262262
softcap=0.0,
263263
deterministic=True,
264264
return_attn_probs=return_softmax
@@ -331,10 +331,10 @@ def dynamic_mask_attention_triton(
331331
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
332332
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
333333
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
334-
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
335-
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
336-
is_causal, # causal masking
337-
scaling # scaling factor
334+
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
335+
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
336+
is_causal=is_causal, # causal masking
337+
scale=scaling # scaling factor
338338
)
339339

340340
return attn_outputs # [batch, query_len, num_heads, head_dim]
@@ -396,14 +396,14 @@ def dynamic_mask_attention_flex(
396396
# But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format
397397

398398
# Call the Flex Attention implementation
399-
attn_outputs, _ = flex_dmattn_func(
400-
query_states, # q: [batch, num_heads, query_len, head_dim]
401-
key_states, # k: [batch, num_heads, key_len, head_dim]
402-
value_states, # v: [batch, num_heads, key_len, head_dim]
403-
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
404-
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
405-
is_causal=is_causal, # is_causal: whether to apply causal masking
406-
scaling=scaling # scaling factor
399+
attn_outputs = flex_dmattn_func(
400+
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
401+
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
402+
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
403+
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
404+
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
405+
is_causal=is_causal, # is_causal: whether to apply causal masking
406+
scale=scaling # scaling factor
407407
)
408408

409409
return attn_outputs # [batch, query_len, num_heads, head_dim]

benchmarks/benchmark_forward_performance.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def dynamic_mask_attention_cuda(
266266
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
267267
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
268268
dropout_p=0.0,
269-
softmax_scale=scaling,
270269
is_causal=is_causal,
270+
scale=scaling,
271271
softcap=0.0,
272272
deterministic=True,
273273
return_attn_probs=return_softmax
@@ -350,10 +350,10 @@ def dynamic_mask_attention_triton(
350350
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
351351
key_states, # k: [batch, seqlen_k, num_heads, head_dim]
352352
value_states, # v: [batch, seqlen_k, num_heads, head_dim]
353-
attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
354-
attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
355-
is_causal, # causal masking
356-
scaling # scaling factor
353+
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
354+
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
355+
is_causal=is_causal, # causal masking
356+
scale=scaling # scaling factor
357357
)
358358

359359
torch.cuda.synchronize()
@@ -425,14 +425,14 @@ def dynamic_mask_attention_flex(
425425
start_time = time.time()
426426

427427
# Call the Flex Attention implementation
428-
attn_outputs, _ = flex_dmattn_func(
429-
query_states, # q: [batch, num_heads, query_len, head_dim]
430-
key_states, # k: [batch, num_heads, key_len, head_dim]
431-
value_states, # v: [batch, num_heads, key_len, head_dim]
432-
attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len]
433-
attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len]
434-
is_causal=is_causal, # is_causal: Whether to apply causal masking
435-
scaling=scaling # scaling factor
428+
attn_outputs = flex_dmattn_func(
429+
query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim]
430+
key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim]
431+
value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim]
432+
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
433+
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
434+
is_causal=is_causal, # is_causal: whether to apply causal masking
435+
scale=scaling # scaling factor
436436
)
437437

438438
torch.cuda.synchronize()

flash_dmattn/flash_dmattn_flex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def flex_attention_forward(
1010
value: torch.Tensor,
1111
attn_mask: torch.Tensor,
1212
attn_bias: torch.Tensor,
13-
scale: Optional[float] = None,
1413
is_causal: bool = True,
14+
scale: Optional[float] = None,
1515
**kwargs,
1616
) -> Tuple[torch.Tensor, torch.Tensor]:
1717
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]

0 commit comments

Comments
 (0)