Skip to content

Commit c616403

Browse files
CopilotLoserCheems
andcommitted
Fix varlen mask and bias shapes in all varlen functions
Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
1 parent ad22798 commit c616403

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,12 @@ def forward(
511511
):
512512
# qkv is expected to be of shape (total 3, num_heads, head_size)
513513
batch_size = cu_seqlens.numel() - 1
514-
_, num_heads, _ = qkv.shape
514+
total_tokens, num_heads, _ = qkv.shape
515515
is_grad = is_grad_enabled and qkv.requires_grad
516516
if mask is None:
517-
mask = torch.ones((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device)
517+
mask = torch.ones((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device)
518518
if bias is None:
519-
bias = torch.zeros((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device)
519+
bias = torch.zeros((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device)
520520
if softmax_scale is None:
521521
softmax_scale = qkv.shape[-1] ** (-0.5)
522522
if is_causal is None:
@@ -737,14 +737,15 @@ def forward(
737737
# q is expected to be of shape (total, num_heads, head_size)
738738
# kv is expected to be of shape (total, 2, num_heads, head_size)
739739
batch_size = cu_seqlens_q.numel() - 1
740-
_, num_heads, _ = q.shape
740+
total_q, num_heads, _ = q.shape
741+
_, _, num_heads_k, _ = kv.shape
741742
is_grad = is_grad_enabled and any(
742743
x.requires_grad for x in [q, kv]
743744
)
744745
if mask is None:
745-
mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device)
746+
mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device)
746747
if bias is None:
747-
bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device)
748+
bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device)
748749
if softmax_scale is None:
749750
softmax_scale = q.shape[-1] ** (-0.5)
750751
if is_causal is None:
@@ -967,14 +968,15 @@ def forward(
967968
):
968969
# q, k, v are expected to be of shape (total, num_heads, head_size)
969970
batch_size = cu_seqlens_q.numel() - 1
970-
_, num_heads, _ = q.shape
971+
total_q, num_heads, _ = q.shape
972+
_, num_heads_k, _ = k.shape
971973
is_grad = is_grad_enabled and any(
972974
x.requires_grad for x in [q, k, v]
973975
)
974976
if mask is None:
975-
mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device)
977+
mask = torch.ones((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device)
976978
if bias is None:
977-
bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device)
979+
bias = torch.zeros((total_q, num_heads_k, max_seqlen_k), dtype=q.dtype, device=q.device)
978980
if softmax_scale is None:
979981
softmax_scale = q.shape[-1] ** (-0.5)
980982
if is_causal is None:
@@ -1282,9 +1284,9 @@ def flash_dmattn_varlen_qkvpacked_func(
12821284
12831285
Arguments:
12841286
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
1285-
attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
1287+
attn_mask: (total, nheads, max_seqlen). Attention mask to apply to the attention scores.
12861288
If None, no mask is applied.
1287-
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
1289+
attn_bias: (total, nheads, max_seqlen). Attention Bias to add to the attention scores.
12881290
If None, no bias is applied.
12891291
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
12901292
of the sequences in the batch, used to index into qkv.
@@ -1360,9 +1362,9 @@ def flash_dmattn_varlen_kvpacked_func(
13601362
Arguments:
13611363
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
13621364
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1363-
attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
1365+
attn_mask: (total_q, nheads_k, max_seqlen_k). Attention mask to apply to the attention scores.
13641366
If None, no mask is applied.
1365-
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
1367+
attn_bias: (total_q, nheads_k, max_seqlen_k). Attention Bias to add to the attention scores.
13661368
If None, no bias is applied.
13671369
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
13681370
of the sequences in the batch, used to index into q.
@@ -1444,9 +1446,9 @@ def flash_dmattn_varlen_func(
14441446
query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
14451447
key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
14461448
value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1447-
attn_mask: (batch_size, nheads, seqlen_q, seqlen_k). Attention mask to apply to the attention scores.
1449+
attn_mask: (total_q, nheads_k, max_seqlen_k). Attention mask to apply to the attention scores.
14481450
If None, no mask is applied.
1449-
attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores.
1451+
attn_bias: (total_q, nheads_k, max_seqlen_k). Attention Bias to add to the attention scores.
14501452
If None, no bias is applied.
14511453
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
14521454
of the sequences in the batch, used to index into q.

0 commit comments

Comments
 (0)