Skip to content

Commit 0402b39

Browse files
committed
Adds flash sparse attention wrapper
Supports future HF integration by routing calls through flash sparse attention logic and normalizing autocast, causal, and dtype handling
1 parent 186c725 commit 0402b39

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from .modeling_flash_sparse_attention_utils import _flash_sparse_attention_forward
6+
from transformers.utils import logging
7+
8+
9+
logger = logging.get_logger(__name__)
10+
11+
12+
def flash_sparse_attention_forward(
13+
module: torch.nn.Module,
14+
query: torch.Tensor,
15+
key: torch.Tensor,
16+
value: torch.Tensor,
17+
attention_mask: Optional[torch.Tensor],
18+
attention_bias: Optional[torch.Tensor],
19+
scaling: Optional[float] = None,
20+
window_size: Optional[int] = None,
21+
softcap: Optional[float] = None,
22+
**kwargs,
23+
) -> tuple[torch.Tensor, None]:
24+
"""
25+
A wrapper around the _flash_sparse_attention_forward function to be used in
26+
the FlashSparseAttention class from HuggingFace Transformers.
27+
28+
Args:
29+
module (torch.nn.Module): The attention module.
30+
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
31+
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32+
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
33+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
35+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
36+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
37+
scaling (Optional[float]): The scaling factor for the attention scores.
38+
window_size (Optional[int]): The size of the window to keep.
39+
softcap (Optional[float]): The softcap value for the attention scores.
40+
**kwargs: Additional keyword arguments.
41+
Includes:
42+
- is_causal (bool): Whether to apply a causal mask.
43+
- layer_idx (int): The index of the layer (for logging purposes).
44+
- implementation (str): The implementation to use ("flash_sparse_attn" or None).
45+
46+
Returns:
47+
tuple[torch.Tensor, None]: The output tensor of shape (batch_size, seq_len, num_heads, head_dim)
48+
and None (for compatibility with other attention implementations).
49+
"""
50+
51+
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
52+
logger.warning_once(
53+
"`flash_sparse_attention` does not support `output_attentions=True` or `head_mask`."
54+
" Please set your attention to `eager` if you want any of these features."
55+
)
56+
57+
# This is before the transpose
58+
query_len = query.shape[2]
59+
key_len = key.shape[2]
60+
61+
if any(dim == 0 for dim in query.shape):
62+
raise ValueError(
63+
"Tensor query has shape with a zero dimension.\n"
64+
"FlashSparseAttention does not support inputs with dim=0.\n"
65+
"Please check your input shapes or use SDPA instead."
66+
)
67+
68+
# FSA uses non-transposed inputs
69+
query = query.transpose(1, 2)
70+
key = key.transpose(1, 2)
71+
value = value.transpose(1, 2)
72+
73+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
74+
# therefore the input hidden states gets silently casted in float32. Hence, we need
75+
# cast them back in the correct dtype just to be sure everything works as expected.
76+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
77+
# in fp32. (usually our RMSNorm modules handle it correctly)
78+
target_dtype = None
79+
if query.dtype == torch.float32:
80+
if torch.is_autocast_enabled():
81+
target_dtype = torch.get_autocast_gpu_dtype()
82+
# Handle the case where the model is quantized
83+
elif hasattr(module.config, "_pre_quantization_dtype"):
84+
target_dtype = module.config._pre_quantization_dtype
85+
else:
86+
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
87+
88+
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
89+
is_causal = kwargs.pop("is_causal", None)
90+
if is_causal is None:
91+
is_causal = module.is_causal
92+
93+
attn_output = _flash_sparse_attention_forward(
94+
query,
95+
key,
96+
value,
97+
attention_mask,
98+
attention_bias,
99+
query_length=query_len,
100+
key_length=key_len,
101+
is_causal=is_causal,
102+
softmax_scale=scaling,
103+
softcap=softcap,
104+
window_size=window_size,
105+
target_dtype=target_dtype,
106+
implementation="flash_sparse_attn",
107+
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
108+
**kwargs,
109+
)
110+
111+
return attn_output, None

0 commit comments

Comments
 (0)