Skip to content

Commit 6f036c1

Browse files
committed
Standardizes parameter naming and ordering across attention functions
Renames `softmax_scale` to `scale` and `q/k/v` to `query/key/value` for consistency across all flash attention function variants. Reorders parameters to place `is_causal` before `scale` in function signatures, improving API consistency and alignment with common attention interface patterns. Updates all function calls, documentation strings, and parameter passing to reflect the standardized naming convention.
1 parent 43adc40 commit 6f036c1

File tree

3 files changed

+47
-47
lines changed

3 files changed

+47
-47
lines changed

flash_dmattn/flash_dmattn_flex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def flex_attention_forward(
1010
value: torch.Tensor,
1111
attn_mask: torch.Tensor,
1212
attn_bias: torch.Tensor,
13-
scale: Optional[float] = None,
1413
is_causal: bool = True,
14+
scale: Optional[float] = None,
1515
**kwargs,
1616
) -> Tuple[torch.Tensor, torch.Tensor]:
1717
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]

flash_dmattn/flash_dmattn_interface.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,8 +1151,8 @@ def flash_dmattn_qkvpacked_func(
11511151
attn_mask: Optional[torch.Tensor] = None,
11521152
attn_bias: Optional[torch.Tensor] = None,
11531153
dropout_p: Optional[float] = None,
1154-
softmax_scale: Optional[float] = None,
11551154
is_causal: Optional[bool] = None,
1155+
scale: Optional[float] = None,
11561156
softcap: Optional[float] = None,
11571157
deterministic: Optional[bool] = None,
11581158
return_attn_probs: Optional[bool] = None,
@@ -1174,9 +1174,9 @@ def flash_dmattn_qkvpacked_func(
11741174
attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores.
11751175
If None, no bias is applied.
11761176
dropout_p: float. Dropout probability.
1177-
softmax_scale: float. The scaling of QK^T before applying softmax.
1178-
Default to 1 / sqrt(headdim).
11791177
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1178+
scale: float. The scaling of QK^T before applying softmax.
1179+
Default to 1 / sqrt(headdim).
11801180
softcap: float. Anything > 0 activates softcapping attention.
11811181
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
11821182
which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1197,7 +1197,7 @@ def flash_dmattn_qkvpacked_func(
11971197
attn_mask,
11981198
attn_bias,
11991199
dropout_p,
1200-
softmax_scale,
1200+
scale,
12011201
is_causal,
12021202
softcap,
12031203
deterministic,
@@ -1212,7 +1212,7 @@ def flash_dmattn_kvpacked_func(
12121212
attn_mask: Optional[torch.Tensor] = None,
12131213
attn_bias: Optional[torch.Tensor] = None,
12141214
dropout_p: Optional[float] = None,
1215-
softmax_scale: Optional[float] = None,
1215+
scale: Optional[float] = None,
12161216
is_causal: Optional[bool] = None,
12171217
softcap: Optional[float] = None,
12181218
deterministic: Optional[bool] = None,
@@ -1247,9 +1247,9 @@ def flash_dmattn_kvpacked_func(
12471247
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
12481248
If None, no bias is applied.
12491249
dropout_p: float. Dropout probability.
1250-
softmax_scale: float. The scaling of QK^T before applying softmax.
1251-
Default to 1 / sqrt(headdim).
12521250
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1251+
scale: float. The scaling of QK^T before applying softmax.
1252+
Default to 1 / sqrt(headdim).
12531253
softcap: float. Anything > 0 activates softcapping attention.
12541254
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
12551255
which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1271,7 +1271,7 @@ def flash_dmattn_kvpacked_func(
12711271
attn_mask,
12721272
attn_bias,
12731273
dropout_p,
1274-
softmax_scale,
1274+
scale,
12751275
is_causal,
12761276
softcap,
12771277
deterministic,
@@ -1281,13 +1281,13 @@ def flash_dmattn_kvpacked_func(
12811281

12821282

12831283
def flash_dmattn_func(
1284-
q: torch.Tensor,
1285-
k: torch.Tensor,
1286-
v: torch.Tensor,
1284+
query: torch.Tensor,
1285+
key: torch.Tensor,
1286+
value: torch.Tensor,
12871287
attn_mask: Optional[torch.Tensor] = None,
12881288
attn_bias: Optional[torch.Tensor] = None,
12891289
dropout_p: Optional[float] = None,
1290-
softmax_scale: Optional[float] = None,
1290+
scale: Optional[float] = None,
12911291
is_causal: Optional[bool] = None,
12921292
softcap: Optional[float] = None,
12931293
deterministic: Optional[bool] = None,
@@ -1312,17 +1312,17 @@ def flash_dmattn_func(
13121312
If the row of the mask is all zero, the output will be zero.
13131313
13141314
Arguments:
1315-
q: (batch_size, seqlen, nheads, headdim)
1316-
k: (batch_size, seqlen, nheads_k, headdim)
1317-
v: (batch_size, seqlen, nheads_k, headdim)
1315+
query: (batch_size, seqlen, nheads, headdim)
1316+
key: (batch_size, seqlen, nheads_k, headdim)
1317+
value: (batch_size, seqlen, nheads_k, headdim)
13181318
attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
13191319
If None, no mask is applied.
13201320
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
13211321
If None, no bias is applied.
13221322
dropout_p: float. Dropout probability.
1323-
softmax_scale: float. The scaling of QK^T before applying softmax.
1324-
Default to 1 / sqrt(headdim).
13251323
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1324+
scale: float. The scaling of QK^T before applying softmax.
1325+
Default to 1 / sqrt(headdim).
13261326
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
13271327
which is slightly slower and uses more memory. The forward pass is always deterministic.
13281328
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
@@ -1338,13 +1338,13 @@ def flash_dmattn_func(
13381338
pattern (negative means that location was dropped, nonnegative means it was kept).
13391339
"""
13401340
return FlashDMAttnFunc.apply(
1341-
q,
1342-
k,
1343-
v,
1341+
query,
1342+
key,
1343+
value,
13441344
attn_mask,
13451345
attn_bias,
13461346
dropout_p,
1347-
softmax_scale,
1347+
scale,
13481348
is_causal,
13491349
softcap,
13501350
deterministic,
@@ -1360,7 +1360,7 @@ def flash_dmattn_varlen_qkvpacked_func(
13601360
cu_seqlens: torch.Tensor = None,
13611361
max_seqlen: int = None,
13621362
dropout_p: Optional[float] = None,
1363-
softmax_scale: Optional[float] = None,
1363+
scale: Optional[float] = None,
13641364
is_causal: Optional[bool] = None,
13651365
softcap: Optional[float] = None,
13661366
deterministic: Optional[bool] = None,
@@ -1383,9 +1383,9 @@ def flash_dmattn_varlen_qkvpacked_func(
13831383
of the sequences in the batch, used to index into qkv.
13841384
max_seqlen: int. Maximum sequence length in the batch.
13851385
dropout_p: float. Dropout probability.
1386-
softmax_scale: float. The scaling of QK^T before applying softmax.
1387-
Default to 1 / sqrt(headdim).
13881386
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1387+
scale: float. The scaling of QK^T before applying softmax.
1388+
Default to 1 / sqrt(headdim).
13891389
softcap: float. Anything > 0 activates softcapping attention.
13901390
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
13911391
which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1408,7 +1408,7 @@ def flash_dmattn_varlen_qkvpacked_func(
14081408
cu_seqlens,
14091409
max_seqlen,
14101410
dropout_p,
1411-
softmax_scale,
1411+
scale,
14121412
is_causal,
14131413
softcap,
14141414
deterministic,
@@ -1427,7 +1427,7 @@ def flash_dmattn_varlen_kvpacked_func(
14271427
max_seqlen_q: int = None,
14281428
max_seqlen_k: int = None,
14291429
dropout_p: Optional[float] = None,
1430-
softmax_scale: Optional[float] = None,
1430+
scale: Optional[float] = None,
14311431
is_causal: Optional[bool] = None,
14321432
softcap: Optional[float] = None,
14331433
deterministic: Optional[bool] = None,
@@ -1468,9 +1468,9 @@ def flash_dmattn_varlen_kvpacked_func(
14681468
max_seqlen_q: int. Maximum query sequence length in the batch.
14691469
max_seqlen_k: int. Maximum key sequence length in the batch.
14701470
dropout_p: float. Dropout probability.
1471-
softmax_scale: float. The scaling of QK^T before applying softmax.
1472-
Default to 1 / sqrt(headdim).
14731471
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1472+
scale: float. The scaling of QK^T before applying softmax.
1473+
Default to 1 / sqrt(headdim).
14741474
softcap: float. Anything > 0 activates softcapping attention.
14751475
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
14761476
which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1496,7 +1496,7 @@ def flash_dmattn_varlen_kvpacked_func(
14961496
max_seqlen_q,
14971497
max_seqlen_k,
14981498
dropout_p,
1499-
softmax_scale,
1499+
scale,
15001500
is_causal,
15011501
softcap,
15021502
deterministic,
@@ -1506,17 +1506,17 @@ def flash_dmattn_varlen_kvpacked_func(
15061506

15071507

15081508
def flash_dmattn_varlen_func(
1509-
q: torch.Tensor,
1510-
k: torch.Tensor,
1511-
v: torch.Tensor,
1509+
query: torch.Tensor,
1510+
key: torch.Tensor,
1511+
value: torch.Tensor,
15121512
attn_mask: Optional[torch.Tensor] = None,
15131513
attn_bias: Optional[torch.Tensor] = None,
15141514
cu_seqlens_q: torch.Tensor = None,
15151515
cu_seqlens_k: torch.Tensor = None,
15161516
max_seqlen_q: int = None,
15171517
max_seqlen_k: int = None,
15181518
dropout_p: Optional[float] = None,
1519-
softmax_scale: Optional[float] = None,
1519+
scale: Optional[float] = None,
15201520
is_causal: Optional[bool] = None,
15211521
softcap: Optional[float] = None,
15221522
deterministic: Optional[bool] = None,
@@ -1542,9 +1542,9 @@ def flash_dmattn_varlen_func(
15421542
If the row of the mask is all zero, the output will be zero.
15431543
15441544
Arguments:
1545-
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1546-
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1547-
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1545+
query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1546+
key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1547+
value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
15481548
attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
15491549
If None, no mask is applied.
15501550
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
@@ -1556,9 +1556,9 @@ def flash_dmattn_varlen_func(
15561556
max_seqlen_q: int. Maximum query sequence length in the batch.
15571557
max_seqlen_k: int. Maximum key sequence length in the batch.
15581558
dropout_p: float. Dropout probability.
1559-
softmax_scale: float. The scaling of QK^T before applying softmax.
1560-
Default to 1 / sqrt(headdim).
15611559
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1560+
scale: float. The scaling of QK^T before applying softmax.
1561+
Default to 1 / sqrt(headdim).
15621562
softcap: float. Anything > 0 activates softcapping attention.
15631563
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
15641564
which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1575,17 +1575,17 @@ def flash_dmattn_varlen_func(
15751575
pattern (negative means that location was dropped, nonnegative means it was kept).
15761576
"""
15771577
return FlashDMAttnVarlenFunc.apply(
1578-
q,
1579-
k,
1580-
v,
1578+
query,
1579+
key,
1580+
value,
15811581
attn_mask,
15821582
attn_bias,
15831583
cu_seqlens_q,
15841584
cu_seqlens_k,
15851585
max_seqlen_q,
15861586
max_seqlen_k,
15871587
dropout_p,
1588-
softmax_scale,
1588+
scale,
15891589
is_causal,
15901590
softcap,
15911591
deterministic,

flash_dmattn/flash_dmattn_triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,15 +1052,15 @@ def _flash_attn_backward(
10521052

10531053
class FlashDMAttnFunc(torch.autograd.Function):
10541054
@staticmethod
1055-
def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, softmax_scale=None, is_causal=False):
1055+
def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None):
10561056
"""
10571057
query: (batch_size, seqlen_q, nheads, headdim)
10581058
key: (batch_size, seqlen_k, nheads, headdim)
10591059
value: (batch_size, seqlen_k, nheads, headdim)
10601060
attn_mask: optional, (batch, nheads, seqlen_q, seqlen_k)
10611061
attn_bias: optional, (batch, nheads, seqlen_q, seqlen_k)
1062-
softmax_scale: float, scaling factor for attention scores
10631062
is_causal: bool, whether to apply causal masking
1063+
softmax_scale: float, scaling factor for attention scores
10641064
"""
10651065
batch, seqlen_q, nheads, _ = query.shape
10661066
_, seqlen_k, _, _ = key.shape
@@ -1111,5 +1111,5 @@ def backward(ctx, do):
11111111
return dq, dk, dv, None, dbias, None, None
11121112

11131113

1114-
def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, scale=None, is_causal=False):
1115-
return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, scale, is_causal)
1114+
def triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, is_causal=False, scale=None):
1115+
return FlashDMAttnFunc.apply(query, key, value, attn_mask, attn_bias, is_causal, scale)

0 commit comments

Comments
 (0)