@@ -511,12 +511,12 @@ def forward(
511511 ):
512512 # qkv is expected to be of shape (total 3, num_heads, head_size)
513513 batch_size = cu_seqlens .numel () - 1
514- _ , num_heads , _ = qkv .shape
514+ total_tokens , num_heads , _ = qkv .shape
515515 is_grad = is_grad_enabled and qkv .requires_grad
516516 if mask is None :
517- mask = torch .ones ((batch_size , num_heads , max_seqlen , max_seqlen ), dtype = qkv .dtype , device = qkv .device )
517+ mask = torch .ones ((total_tokens , num_heads , max_seqlen ), dtype = qkv .dtype , device = qkv .device )
518518 if bias is None :
519- bias = torch .zeros ((batch_size , num_heads , max_seqlen , max_seqlen ), dtype = qkv .dtype , device = qkv .device )
519+ bias = torch .zeros ((total_tokens , num_heads , max_seqlen ), dtype = qkv .dtype , device = qkv .device )
520520 if softmax_scale is None :
521521 softmax_scale = qkv .shape [- 1 ] ** (- 0.5 )
522522 if is_causal is None :
@@ -737,14 +737,15 @@ def forward(
737737 # q is expected to be of shape (total, num_heads, head_size)
738738 # kv is expected to be of shape (total, 2, num_heads, head_size)
739739 batch_size = cu_seqlens_q .numel () - 1
740- _ , num_heads , _ = q .shape
740+ total_q , num_heads , _ = q .shape
741+ _ , _ , num_heads_k , _ = kv .shape
741742 is_grad = is_grad_enabled and any (
742743 x .requires_grad for x in [q , kv ]
743744 )
744745 if mask is None :
745- mask = torch .ones ((batch_size , num_heads , max_seqlen_q , max_seqlen_k ), dtype = q .dtype , device = q .device )
746+ mask = torch .ones ((total_q , num_heads_k , max_seqlen_k ), dtype = q .dtype , device = q .device )
746747 if bias is None :
747- bias = torch .zeros ((batch_size , num_heads , max_seqlen_q , max_seqlen_k ), dtype = q .dtype , device = q .device )
748+ bias = torch .zeros ((total_q , num_heads_k , max_seqlen_k ), dtype = q .dtype , device = q .device )
748749 if softmax_scale is None :
749750 softmax_scale = q .shape [- 1 ] ** (- 0.5 )
750751 if is_causal is None :
@@ -967,14 +968,15 @@ def forward(
967968 ):
968969 # q, k, v are expected to be of shape (total, num_heads, head_size)
969970 batch_size = cu_seqlens_q .numel () - 1
970- _ , num_heads , _ = q .shape
971+ total_q , num_heads , _ = q .shape
972+ _ , num_heads_k , _ = k .shape
971973 is_grad = is_grad_enabled and any (
972974 x .requires_grad for x in [q , k , v ]
973975 )
974976 if mask is None :
975- mask = torch .ones ((batch_size , num_heads , max_seqlen_q , max_seqlen_k ), dtype = q .dtype , device = q .device )
977+ mask = torch .ones ((total_q , num_heads_k , max_seqlen_k ), dtype = q .dtype , device = q .device )
976978 if bias is None :
977- bias = torch .zeros ((batch_size , num_heads , max_seqlen_q , max_seqlen_k ), dtype = q .dtype , device = q .device )
979+ bias = torch .zeros ((total_q , num_heads_k , max_seqlen_k ), dtype = q .dtype , device = q .device )
978980 if softmax_scale is None :
979981 softmax_scale = q .shape [- 1 ] ** (- 0.5 )
980982 if is_causal is None :
@@ -1282,9 +1284,9 @@ def flash_dmattn_varlen_qkvpacked_func(
12821284
12831285 Arguments:
12841286 qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
1285- attn_mask: (batch_size , nheads, seqlen_q, seqlen_k ). Attention mask to apply to the attention scores.
1287+ attn_mask: (total , nheads, max_seqlen ). Attention mask to apply to the attention scores.
12861288 If None, no mask is applied.
1287- attn_bias: (batch_size , nheads, seqlen_q, seqlen_k ). Attention Bias to add to the attention scores.
1289+ attn_bias: (total , nheads, max_seqlen ). Attention Bias to add to the attention scores.
12881290 If None, no bias is applied.
12891291 cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
12901292 of the sequences in the batch, used to index into qkv.
@@ -1360,9 +1362,9 @@ def flash_dmattn_varlen_kvpacked_func(
13601362 Arguments:
13611363 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
13621364 kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1363- attn_mask: (batch_size, nheads, seqlen_q, seqlen_k ). Attention mask to apply to the attention scores.
1365+ attn_mask: (total_q, nheads_k, max_seqlen_k ). Attention mask to apply to the attention scores.
13641366 If None, no mask is applied.
1365- attn_bias: (batch_size, nheads, seqlen_q, seqlen_k ). Attention Bias to add to the attention scores.
1367+ attn_bias: (total_q, nheads_k, max_seqlen_k ). Attention Bias to add to the attention scores.
13661368 If None, no bias is applied.
13671369 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
13681370 of the sequences in the batch, used to index into q.
@@ -1444,9 +1446,9 @@ def flash_dmattn_varlen_func(
14441446 query: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
14451447 key: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
14461448 value: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1447- attn_mask: (batch_size, nheads, seqlen_q, seqlen_k ). Attention mask to apply to the attention scores.
1449+ attn_mask: (total_q, nheads_k, max_seqlen_k ). Attention mask to apply to the attention scores.
14481450 If None, no mask is applied.
1449- attn_bias: (batch_size, nheads, seqlen_q, seqlen_k ). Attention Bias to add to the attention scores.
1451+ attn_bias: (total_q, nheads_k, max_seqlen_k ). Attention Bias to add to the attention scores.
14501452 If None, no bias is applied.
14511453 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
14521454 of the sequences in the batch, used to index into q.
0 commit comments