Skip to content

Commit e02668c

Browse files
committed
Adds Flex flash-sparse attention hook
Introduces a Flex Attention forward path that constructs causal block masks, normalizes mask and bias defaults, and applies compile-friendly kernel options to ease sparse Flash workloads.
1 parent 71add00 commit e02668c

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Optional, Tuple
2+
import math
3+
import torch
4+
from torch.nn.attention.flex_attention import create_block_mask
5+
from transformers.integrations.flex_attention import compile_friendly_flex_attention
6+
7+
8+
def flex_attention_forward(
9+
query: torch.Tensor,
10+
key: torch.Tensor,
11+
value: torch.Tensor,
12+
attn_mask: Optional[torch.Tensor] = None,
13+
attn_bias: Optional[torch.Tensor] = None,
14+
is_causal: Optional[bool] = None,
15+
softmax_scale: Optional[float] = None,
16+
**kwargs,
17+
) -> Tuple[torch.Tensor, torch.Tensor]:
18+
batch, seqlen_q, nheads, dhead = query.shape
19+
_, seqlen_k, _, _ = key.shape
20+
query = query.transpose(1, 2).contiguous() # [B, H, Q_LEN, D]
21+
key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
22+
value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
23+
if attn_mask is not None:
24+
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
25+
else:
26+
attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
27+
if attn_bias is not None:
28+
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
29+
else:
30+
attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype)
31+
if is_causal is None:
32+
is_causal = True
33+
if softmax_scale is None:
34+
softmax_scale = 1.0 / math.sqrt(dhead)
35+
36+
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
37+
score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx]
38+
return score
39+
40+
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
41+
# It looks like you're attempting to use a Tensor in some data-dependent control flow.
42+
# We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
43+
# return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0
44+
return q_idx >= kv_idx
45+
46+
block_mask = create_block_mask(
47+
mask_mod=causal_mask_mod,
48+
B=query.shape[0],
49+
H=None,
50+
Q_LEN=query.shape[2],
51+
KV_LEN=key.shape[2],
52+
device=query.device,
53+
_compile=True,
54+
)
55+
56+
kernel_options = {
57+
"BLOCK_M": 64,
58+
"BLOCK_N": 64,
59+
"BLOCK_DMODEL": 32,
60+
"num_stages": 1,
61+
"num_warps": 8,
62+
}
63+
attn_output = compile_friendly_flex_attention(
64+
query,
65+
key,
66+
value,
67+
score_mod=score_mod,
68+
block_mask=block_mask if is_causal else None,
69+
scale=softmax_scale,
70+
kernel_options=kernel_options,
71+
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
72+
# For simplification, we thus always return it as no additional computations are introduced.
73+
return_lse=False,
74+
training=False,
75+
)
76+
attn_output = attn_output.transpose(1, 2).contiguous()
77+
78+
return attn_output
79+
80+
flex_sparse_attn_func = flex_attention_forward

0 commit comments

Comments
 (0)