Skip to content

Commit 11a0862

Browse files
committed
Adds dynamic mask helpers
Introduces mask utilities for top-k and relu masking to support flash sparse attention. Enables optional block smoothing to stabilize dynamic sparsity patterns.
1 parent 6bb896f commit 11a0862

File tree

1 file changed

+240
-0
lines changed

1 file changed

+240
-0
lines changed

flash_sparse_attn/utils/mask.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
19+
20+
def topk_indices(
21+
attention_bias: torch.Tensor,
22+
window_size: int,
23+
**kwargs,
24+
) -> torch.Tensor:
25+
r"""
26+
This function generates top-k indices based on the attention bias.
27+
28+
Args:
29+
attention_bias (torch.Tensor): The attention bias tensor of
30+
(batch_size, num_kv_heads, key_len).
31+
window_size (int): The number of top elements to consider for the mask.
32+
**kwargs: Additional keyword arguments.
33+
34+
Returns:
35+
topk_indices (Tensor): The top-k indices tensor of shape
36+
(batch_size, num_kv_heads, window_size).
37+
"""
38+
attention_bias = attention_bias.detach()
39+
topk_indices = torch.topk(
40+
attention_bias,
41+
window_size, dim=-1, largest=True, sorted=False
42+
).indices
43+
topk_indices = torch.sort(topk_indices, dim=-1).values
44+
return topk_indices
45+
46+
47+
def block_smooth(
48+
attention_mask: torch.Tensor,
49+
key_len: int,
50+
block_size: int,
51+
):
52+
if block_size <= 0:
53+
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
54+
55+
if block_size > 1:
56+
full_len = (key_len // block_size) * block_size
57+
58+
if full_len:
59+
block_view = attention_mask[..., :full_len]
60+
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
61+
blocks = block_view.view(*block_shape)
62+
block_counts = blocks.sum(dim=-1).to(torch.int64)
63+
block_keep = (block_counts * 2) > block_size
64+
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
65+
66+
if key_len > full_len:
67+
tail_slice = attention_mask[..., full_len:]
68+
tail_len = tail_slice.shape[-1]
69+
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)
70+
tail_keep = (tail_counts * 2) > tail_len
71+
tail_slice.copy_(tail_keep.expand_as(tail_slice))
72+
73+
return attention_mask
74+
75+
76+
def topk_mask(
77+
attention_bias: torch.Tensor,
78+
attention_mask: Optional[torch.Tensor],
79+
window_size: int,
80+
min_dtype: float,
81+
block_size: Optional[int] = None,
82+
**kwargs,
83+
):
84+
r"""
85+
This function generates a dynamic mask based on the top-k attention bias.
86+
87+
Args:
88+
attention_bias (torch.Tensor): The attention bias tensor of shape
89+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
90+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
91+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
92+
window_size (int): The number of top elements to consider for the mask.
93+
min_dtype (float): The minimum value to use for masking.
94+
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
95+
resulting mask along the key dimension.
96+
97+
Returns:
98+
attention_mask (Tensor): The attention mask tensor of shape
99+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
100+
"""
101+
102+
attention_bias = attention_bias.detach()
103+
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
104+
topk_values, topk_indices = torch.topk(
105+
attention_bias,
106+
window_size, dim=-1, largest=True, sorted=False
107+
)
108+
attention_mask = torch.zeros_like(
109+
attention_bias, dtype=torch.bool, device=attention_bias.device
110+
).scatter_(-1, topk_indices, topk_values != min_dtype)
111+
112+
if block_size is not None and block_size > 1:
113+
key_len = attention_mask.shape[-1]
114+
attention_mask = block_smooth(
115+
attention_mask=attention_mask,
116+
key_len=key_len,
117+
block_size=block_size
118+
)
119+
120+
return attention_mask
121+
122+
123+
def relu_mask(
124+
attention_bias: torch.Tensor,
125+
attention_mask: Optional[torch.Tensor],
126+
min_dtype: float,
127+
block_size: Optional[int] = None,
128+
**kwargs
129+
):
130+
r"""
131+
This function generates a dynamic mask based on the ReLU of attention bias.
132+
133+
Args:
134+
attention_bias (torch.Tensor): The attention bias tensor of shape
135+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
136+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
137+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
138+
min_dtype (float): The minimum value to use for masking.
139+
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
140+
resulting mask along the key dimension.
141+
142+
Returns:
143+
attention_mask (Tensor): The attention mask tensor of shape
144+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
145+
"""
146+
147+
attention_bias = attention_bias.detach()
148+
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
149+
attention_mask = attention_bias > 0
150+
151+
if block_size is not None and block_size > 1:
152+
key_len = attention_mask.shape[-1]
153+
attention_mask = block_smooth(
154+
attention_mask=attention_mask,
155+
key_len=key_len,
156+
block_size=block_size
157+
)
158+
159+
return attention_mask
160+
161+
162+
163+
def create_mask(
164+
attention_bias: torch.Tensor,
165+
attention_mask: Optional[torch.Tensor],
166+
batch_size: int,
167+
query_len: int,
168+
key_len: int,
169+
window_size: int,
170+
min_dtype: float,
171+
block_size: Optional[int] = None,
172+
type: str = "topk",
173+
) -> torch.Tensor:
174+
r"""
175+
This function creates a mask tensor for Flash Sparse Attention.
176+
177+
If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias.
178+
179+
Args:
180+
attention_bias (torch.Tensor): The attention bias tensor of shape
181+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
182+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
183+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
184+
batch_size (int): The batch size.
185+
query_len (int): The sequence length of the query.
186+
key_len (int): The sequence length of the key.
187+
window_size (int): The number of top elements to consider for the attention mask.
188+
min_dtype (float): The minimum value to use for masking.
189+
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
190+
type (str): The type of mask to create. Options are "topk" and "relu".
191+
192+
Returns:
193+
attention (Tensor): The attention mask tensor of shape
194+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
195+
"""
196+
197+
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
198+
if attention_mask is not None and attention_mask.dim() == 2:
199+
if attention_mask.shape[-1] == key_len:
200+
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
201+
elif attention_mask.shape[-1] == query_len:
202+
pad_len = key_len - query_len
203+
if pad_len > 0:
204+
pad_mask = torch.ones(
205+
(batch_size, 1, 1, pad_len),
206+
dtype=torch.bool,
207+
device=attention_mask.device,
208+
)
209+
attention_mask = torch.cat(
210+
[pad_mask, attention_mask.view(batch_size, 1, 1, query_len)],
211+
dim=-1,
212+
)
213+
else:
214+
attention_mask = attention_mask.view(batch_size, 1, 1, query_len)
215+
else:
216+
raise ValueError(
217+
f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}."
218+
)
219+
220+
# Generate dynamic mask based on attention_bias and attention_mask
221+
if type == "topk":
222+
attention_mask = topk_mask(
223+
attention_bias=attention_bias,
224+
attention_mask=attention_mask,
225+
window_size=window_size,
226+
min_dtype=min_dtype,
227+
block_size=block_size,
228+
)
229+
elif type == "relu":
230+
attention_mask = relu_mask(
231+
attention_bias=attention_bias,
232+
attention_mask=attention_mask,
233+
window_size=window_size,
234+
min_dtype=min_dtype,
235+
block_size=block_size,
236+
)
237+
else:
238+
raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.")
239+
240+
return attention_mask

0 commit comments

Comments
 (0)