Skip to content

Commit bddff0f

Browse files
committed
Makes attention parameters optional with defaults
Improves flexibility by making attn_mask, attn_bias, is_causal, and scale parameters optional with sensible defaults. Creates default attention mask and bias tensors when not provided, sets causal attention to true by default, and calculates scale from head dimension when not specified. Adds proper null checking before tensor slicing operations to prevent errors when optional parameters are None.
1 parent 9fa7885 commit bddff0f

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

flash_dmattn/flash_dmattn_flex.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional, Tuple
2+
import math
23
import torch
34
from torch.nn.attention.flex_attention import create_block_mask
45
from transformers.integrations.flex_attention import compile_friendly_flex_attention
@@ -8,17 +9,29 @@ def flex_attention_forward(
89
query: torch.Tensor,
910
key: torch.Tensor,
1011
value: torch.Tensor,
11-
attn_mask: torch.Tensor,
12-
attn_bias: torch.Tensor,
13-
is_causal: bool = True,
12+
attn_mask: Optional[torch.Tensor] = None,
13+
attn_bias: Optional[torch.Tensor] = None,
14+
is_causal: Optional[bool] = None,
1415
scale: Optional[float] = None,
1516
**kwargs,
1617
) -> Tuple[torch.Tensor, torch.Tensor]:
18+
batch, seqlen_q, nheads, dhead = query.shape
19+
_, seqlen_k, _, _ = key.shape
1720
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]
1821
key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
1922
value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
20-
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
21-
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
23+
if attn_mask is not None:
24+
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
25+
else:
26+
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
27+
if attn_bias is not None:
28+
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
29+
else:
30+
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
31+
if is_causal is None:
32+
is_causal = True
33+
if scale is None:
34+
scale = 1.0 / math.sqrt(dhead)
2235

2336
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
2437
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]

0 commit comments

Comments
 (0)