@@ -1151,8 +1151,8 @@ def flash_dmattn_qkvpacked_func(
11511151 attn_mask : Optional [torch .Tensor ] = None ,
11521152 attn_bias : Optional [torch .Tensor ] = None ,
11531153 dropout_p : Optional [float ] = None ,
1154- softmax_scale : Optional [float ] = None ,
11551154 is_causal : Optional [bool ] = None ,
1155+ scale : Optional [float ] = None ,
11561156 softcap : Optional [float ] = None ,
11571157 deterministic : Optional [bool ] = None ,
11581158 return_attn_probs : Optional [bool ] = None ,
@@ -1174,9 +1174,9 @@ def flash_dmattn_qkvpacked_func(
11741174 attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores.
11751175 If None, no bias is applied.
11761176 dropout_p: float. Dropout probability.
1177- softmax_scale: float. The scaling of QK^T before applying softmax.
1178- Default to 1 / sqrt(headdim).
11791177 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1178+ scale: float. The scaling of QK^T before applying softmax.
1179+ Default to 1 / sqrt(headdim).
11801180 softcap: float. Anything > 0 activates softcapping attention.
11811181 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
11821182 which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1197,7 +1197,7 @@ def flash_dmattn_qkvpacked_func(
11971197 attn_mask ,
11981198 attn_bias ,
11991199 dropout_p ,
1200- softmax_scale ,
1200+ scale ,
12011201 is_causal ,
12021202 softcap ,
12031203 deterministic ,
@@ -1212,7 +1212,7 @@ def flash_dmattn_kvpacked_func(
12121212 attn_mask : Optional [torch .Tensor ] = None ,
12131213 attn_bias : Optional [torch .Tensor ] = None ,
12141214 dropout_p : Optional [float ] = None ,
1215- softmax_scale : Optional [float ] = None ,
1215+ scale : Optional [float ] = None ,
12161216 is_causal : Optional [bool ] = None ,
12171217 softcap : Optional [float ] = None ,
12181218 deterministic : Optional [bool ] = None ,
@@ -1247,9 +1247,9 @@ def flash_dmattn_kvpacked_func(
12471247 attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
12481248 If None, no bias is applied.
12491249 dropout_p: float. Dropout probability.
1250- softmax_scale: float. The scaling of QK^T before applying softmax.
1251- Default to 1 / sqrt(headdim).
12521250 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1251+ scale: float. The scaling of QK^T before applying softmax.
1252+ Default to 1 / sqrt(headdim).
12531253 softcap: float. Anything > 0 activates softcapping attention.
12541254 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
12551255 which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1271,7 +1271,7 @@ def flash_dmattn_kvpacked_func(
12711271 attn_mask ,
12721272 attn_bias ,
12731273 dropout_p ,
1274- softmax_scale ,
1274+ scale ,
12751275 is_causal ,
12761276 softcap ,
12771277 deterministic ,
@@ -1281,13 +1281,13 @@ def flash_dmattn_kvpacked_func(
12811281
12821282
12831283def flash_dmattn_func (
1284- q : torch .Tensor ,
1285- k : torch .Tensor ,
1286- v : torch .Tensor ,
1284+ query : torch .Tensor ,
1285+ key : torch .Tensor ,
1286+ value : torch .Tensor ,
12871287 attn_mask : Optional [torch .Tensor ] = None ,
12881288 attn_bias : Optional [torch .Tensor ] = None ,
12891289 dropout_p : Optional [float ] = None ,
1290- softmax_scale : Optional [float ] = None ,
1290+ scale : Optional [float ] = None ,
12911291 is_causal : Optional [bool ] = None ,
12921292 softcap : Optional [float ] = None ,
12931293 deterministic : Optional [bool ] = None ,
@@ -1312,17 +1312,17 @@ def flash_dmattn_func(
13121312 If the row of the mask is all zero, the output will be zero.
13131313
13141314 Arguments:
1315- q : (batch_size, seqlen, nheads, headdim)
1316- k : (batch_size, seqlen, nheads_k, headdim)
1317- v : (batch_size, seqlen, nheads_k, headdim)
1315+ query : (batch_size, seqlen, nheads, headdim)
1316+ key : (batch_size, seqlen, nheads_k, headdim)
1317+ value : (batch_size, seqlen, nheads_k, headdim)
13181318 attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
13191319 If None, no mask is applied.
13201320 attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
13211321 If None, no bias is applied.
13221322 dropout_p: float. Dropout probability.
1323- softmax_scale: float. The scaling of QK^T before applying softmax.
1324- Default to 1 / sqrt(headdim).
13251323 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1324+ scale: float. The scaling of QK^T before applying softmax.
1325+ Default to 1 / sqrt(headdim).
13261326 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
13271327 which is slightly slower and uses more memory. The forward pass is always deterministic.
13281328 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
@@ -1338,13 +1338,13 @@ def flash_dmattn_func(
13381338 pattern (negative means that location was dropped, nonnegative means it was kept).
13391339 """
13401340 return FlashDMAttnFunc .apply (
1341- q ,
1342- k ,
1343- v ,
1341+ query ,
1342+ key ,
1343+ value ,
13441344 attn_mask ,
13451345 attn_bias ,
13461346 dropout_p ,
1347- softmax_scale ,
1347+ scale ,
13481348 is_causal ,
13491349 softcap ,
13501350 deterministic ,
@@ -1360,7 +1360,7 @@ def flash_dmattn_varlen_qkvpacked_func(
13601360 cu_seqlens : torch .Tensor = None ,
13611361 max_seqlen : int = None ,
13621362 dropout_p : Optional [float ] = None ,
1363- softmax_scale : Optional [float ] = None ,
1363+ scale : Optional [float ] = None ,
13641364 is_causal : Optional [bool ] = None ,
13651365 softcap : Optional [float ] = None ,
13661366 deterministic : Optional [bool ] = None ,
@@ -1383,9 +1383,9 @@ def flash_dmattn_varlen_qkvpacked_func(
13831383 of the sequences in the batch, used to index into qkv.
13841384 max_seqlen: int. Maximum sequence length in the batch.
13851385 dropout_p: float. Dropout probability.
1386- softmax_scale: float. The scaling of QK^T before applying softmax.
1387- Default to 1 / sqrt(headdim).
13881386 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1387+ scale: float. The scaling of QK^T before applying softmax.
1388+ Default to 1 / sqrt(headdim).
13891389 softcap: float. Anything > 0 activates softcapping attention.
13901390 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
13911391 which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1408,7 +1408,7 @@ def flash_dmattn_varlen_qkvpacked_func(
14081408 cu_seqlens ,
14091409 max_seqlen ,
14101410 dropout_p ,
1411- softmax_scale ,
1411+ scale ,
14121412 is_causal ,
14131413 softcap ,
14141414 deterministic ,
@@ -1427,7 +1427,7 @@ def flash_dmattn_varlen_kvpacked_func(
14271427 max_seqlen_q : int = None ,
14281428 max_seqlen_k : int = None ,
14291429 dropout_p : Optional [float ] = None ,
1430- softmax_scale : Optional [float ] = None ,
1430+ scale : Optional [float ] = None ,
14311431 is_causal : Optional [bool ] = None ,
14321432 softcap : Optional [float ] = None ,
14331433 deterministic : Optional [bool ] = None ,
@@ -1468,9 +1468,9 @@ def flash_dmattn_varlen_kvpacked_func(
14681468 max_seqlen_q: int. Maximum query sequence length in the batch.
14691469 max_seqlen_k: int. Maximum key sequence length in the batch.
14701470 dropout_p: float. Dropout probability.
1471- softmax_scale: float. The scaling of QK^T before applying softmax.
1472- Default to 1 / sqrt(headdim).
14731471 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1472+ scale: float. The scaling of QK^T before applying softmax.
1473+ Default to 1 / sqrt(headdim).
14741474 softcap: float. Anything > 0 activates softcapping attention.
14751475 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
14761476 which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1496,7 +1496,7 @@ def flash_dmattn_varlen_kvpacked_func(
14961496 max_seqlen_q ,
14971497 max_seqlen_k ,
14981498 dropout_p ,
1499- softmax_scale ,
1499+ scale ,
15001500 is_causal ,
15011501 softcap ,
15021502 deterministic ,
@@ -1506,17 +1506,17 @@ def flash_dmattn_varlen_kvpacked_func(
15061506
15071507
15081508def flash_dmattn_varlen_func (
1509- q : torch .Tensor ,
1510- k : torch .Tensor ,
1511- v : torch .Tensor ,
1509+ query : torch .Tensor ,
1510+ key : torch .Tensor ,
1511+ value : torch .Tensor ,
15121512 attn_mask : Optional [torch .Tensor ] = None ,
15131513 attn_bias : Optional [torch .Tensor ] = None ,
15141514 cu_seqlens_q : torch .Tensor = None ,
15151515 cu_seqlens_k : torch .Tensor = None ,
15161516 max_seqlen_q : int = None ,
15171517 max_seqlen_k : int = None ,
15181518 dropout_p : Optional [float ] = None ,
1519- softmax_scale : Optional [float ] = None ,
1519+ scale : Optional [float ] = None ,
15201520 is_causal : Optional [bool ] = None ,
15211521 softcap : Optional [float ] = None ,
15221522 deterministic : Optional [bool ] = None ,
@@ -1542,9 +1542,9 @@ def flash_dmattn_varlen_func(
15421542 If the row of the mask is all zero, the output will be zero.
15431543
15441544 Arguments:
1545- q : (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1546- k : (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1547- v : (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1545+ query : (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1546+ key : (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1547+ value : (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
15481548 attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
15491549 If None, no mask is applied.
15501550 attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
@@ -1556,9 +1556,9 @@ def flash_dmattn_varlen_func(
15561556 max_seqlen_q: int. Maximum query sequence length in the batch.
15571557 max_seqlen_k: int. Maximum key sequence length in the batch.
15581558 dropout_p: float. Dropout probability.
1559- softmax_scale: float. The scaling of QK^T before applying softmax.
1560- Default to 1 / sqrt(headdim).
15611559 is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1560+ scale: float. The scaling of QK^T before applying softmax.
1561+ Default to 1 / sqrt(headdim).
15621562 softcap: float. Anything > 0 activates softcapping attention.
15631563 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
15641564 which is slightly slower and uses more memory. The forward pass is always deterministic.
@@ -1575,17 +1575,17 @@ def flash_dmattn_varlen_func(
15751575 pattern (negative means that location was dropped, nonnegative means it was kept).
15761576 """
15771577 return FlashDMAttnVarlenFunc .apply (
1578- q ,
1579- k ,
1580- v ,
1578+ query ,
1579+ key ,
1580+ value ,
15811581 attn_mask ,
15821582 attn_bias ,
15831583 cu_seqlens_q ,
15841584 cu_seqlens_k ,
15851585 max_seqlen_q ,
15861586 max_seqlen_k ,
15871587 dropout_p ,
1588- softmax_scale ,
1588+ scale ,
15891589 is_causal ,
15901590 softcap ,
15911591 deterministic ,
0 commit comments