1+ from typing import Optional , Tuple
2+ import math
3+ import torch
4+ from torch .nn .attention .flex_attention import create_block_mask
5+ from transformers .integrations .flex_attention import compile_friendly_flex_attention
6+
7+
8+ def flex_attention_forward (
9+ query : torch .Tensor ,
10+ key : torch .Tensor ,
11+ value : torch .Tensor ,
12+ attn_mask : Optional [torch .Tensor ] = None ,
13+ attn_bias : Optional [torch .Tensor ] = None ,
14+ is_causal : Optional [bool ] = None ,
15+ softmax_scale : Optional [float ] = None ,
16+ ** kwargs ,
17+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
18+ batch , seqlen_q , nheads , dhead = query .shape
19+ _ , seqlen_k , _ , _ = key .shape
20+ query = query .transpose (1 , 2 ).contiguous () # [B, H, Q_LEN, D]
21+ key = key .transpose (1 , 2 ).contiguous () # [B, H, KV_LEN, D]
22+ value = value .transpose (1 , 2 ).contiguous () # [B, H, KV_LEN, D]
23+ if attn_mask is not None :
24+ attn_mask = attn_mask [:, :, :, : key .shape [- 2 ]]
25+ else :
26+ attn_mask = torch .ones ((batch , nheads , seqlen_q , seqlen_k ), device = query .device , dtype = query .dtype )
27+ if attn_bias is not None :
28+ attn_bias = attn_bias [:, :, :, : key .shape [- 2 ]]
29+ else :
30+ attn_bias = torch .zeros ((batch , nheads , seqlen_q , seqlen_k ), device = query .device , dtype = query .dtype )
31+ if is_causal is None :
32+ is_causal = True
33+ if softmax_scale is None :
34+ softmax_scale = 1.0 / math .sqrt (dhead )
35+
36+ def score_mod (score , batch_idx , head_idx , q_idx , kv_idx ):
37+ score = score + attn_bias [batch_idx ][head_idx ][q_idx ][kv_idx ]
38+ return score
39+
40+ def causal_mask_mod (batch_idx , head_idx , q_idx , kv_idx ):
41+ # It looks like you're attempting to use a Tensor in some data-dependent control flow.
42+ # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
43+ # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0
44+ return q_idx >= kv_idx
45+
46+ block_mask = create_block_mask (
47+ mask_mod = causal_mask_mod ,
48+ B = query .shape [0 ],
49+ H = None ,
50+ Q_LEN = query .shape [2 ],
51+ KV_LEN = key .shape [2 ],
52+ device = query .device ,
53+ _compile = True ,
54+ )
55+
56+ kernel_options = {
57+ "BLOCK_M" : 64 ,
58+ "BLOCK_N" : 64 ,
59+ "BLOCK_DMODEL" : 32 ,
60+ "num_stages" : 1 ,
61+ "num_warps" : 8 ,
62+ }
63+ attn_output = compile_friendly_flex_attention (
64+ query ,
65+ key ,
66+ value ,
67+ score_mod = score_mod ,
68+ block_mask = block_mask if is_causal else None ,
69+ scale = softmax_scale ,
70+ kernel_options = kernel_options ,
71+ # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
72+ # For simplification, we thus always return it as no additional computations are introduced.
73+ return_lse = False ,
74+ training = False ,
75+ )
76+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
77+
78+ return attn_output
79+
80+ flex_sparse_attn_func = flex_attention_forward
0 commit comments