Skip to content

Commit 7eb71dc

Browse files
committed
Adds flash dynamic mask attention utilities module
Implements comprehensive utilities for flash dynamic mask attention operations including tensor padding/unpadding functions, input preprocessing, and attention computation workflows. Provides FDMA-compatible functions for handling variable-length sequences with attention masks and supports both regular and variable-length flash attention variants. Includes PEFT integration checks for dtype compatibility and lazy import mechanism for flexible implementation selection.
1 parent 6e68b61 commit 7eb71dc

File tree

1 file changed

+301
-0
lines changed

1 file changed

+301
-0
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, TypedDict
16+
import torch
17+
from torch.nn import functional as F
18+
from .import_utils import is_flash_dmattn_available
19+
20+
from transformers.utils import logging
21+
22+
23+
logger = logging.get_logger(__name__)
24+
25+
26+
def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
27+
reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
28+
return reshaped[indices]
29+
30+
31+
def _fdma_unpad_input(hidden_states, attention_mask, unused_mask=None):
32+
"""
33+
FDMA-compatible unpad_input function.
34+
Arguments:
35+
hidden_states: (batch, seqlen, ...)
36+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
37+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
38+
Return:
39+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
40+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
41+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
42+
max_seqlen_in_batch: int
43+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
44+
"""
45+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
46+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
47+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
48+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
49+
max_seqlen_in_batch = seqlens_in_batch.max().item()
50+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
51+
52+
return (
53+
_index_first_axis(hidden_states, indices),
54+
indices,
55+
cu_seqlens,
56+
max_seqlen_in_batch,
57+
used_seqlens_in_batch,
58+
)
59+
60+
61+
def _fdma_pad_input(hidden_states, indices, batch, seqlen):
62+
"""
63+
FDMA-compatible pad_input function.
64+
Arguments:
65+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
66+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
67+
batch: int, batch size for the padded sequence.
68+
seqlen: int, maximum sequence length for the padded sequence.
69+
Return:
70+
hidden_states: (batch, seqlen, ...)
71+
"""
72+
dim = hidden_states.shape[1:]
73+
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
74+
output[indices] = hidden_states
75+
return output.view(batch, seqlen, *dim)
76+
77+
78+
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
79+
"""
80+
Retrieves indexing data required to repad unpadded (ragged) tensors.
81+
Arguments:
82+
attention_mask (`torch.Tensor`):
83+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
84+
Return:
85+
indices (`torch.Tensor`):
86+
The indices of non-masked tokens from the flattened input sequence.
87+
cu_seqlens (`torch.Tensor`):
88+
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
89+
max_seqlen_in_batch (`int`):
90+
Maximum sequence length in batch.
91+
"""
92+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
93+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
94+
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
95+
# this might cause a graph break
96+
max_seqlen_in_batch = seqlens_in_batch.max().item()
97+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
98+
return (
99+
indices,
100+
cu_seqlens,
101+
max_seqlen_in_batch,
102+
)
103+
104+
105+
def _upad_input(
106+
query_layer: torch.Tensor,
107+
key_layer: torch.Tensor,
108+
value_layer: torch.Tensor,
109+
bias_layer: torch.Tensor,
110+
attention_mask: torch.Tensor,
111+
query_length: int,
112+
unpad_input_func,
113+
):
114+
"""
115+
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
116+
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
117+
tensors for query, key, value tensors.
118+
Arguments:
119+
query_layer (`torch.Tensor`):
120+
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
121+
key_layer (`torch.Tensor`):
122+
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
123+
value_layer (`torch.Tensor`):
124+
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
125+
bias_layer (`torch.Tensor`):
126+
Attention bias tensor of shape (batch_size, num_key_value_heads, query_length, kv_seq_len).
127+
attention_mask (`torch.Tensor`):
128+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
129+
query_length (`int`):
130+
Target length.
131+
unpad_input_func:
132+
The function to use for unpadding the input tensors.
133+
Return:
134+
query_layer (`torch.Tensor`):
135+
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
136+
key_layer (`torch.Tensor`):
137+
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
138+
value_layer (`torch.Tensor`):
139+
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
140+
bias_layer (`torch.Tensor`):
141+
Attention bias tensor without padding. Shape: (total_target_length, num_key_value_heads, query_length, kv_seq_len).
142+
indices_q (`torch.Tensor`):
143+
The indices of non-masked tokens from the flattened input target sequence.
144+
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
145+
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
146+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
147+
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
148+
"""
149+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
150+
151+
# With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
152+
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
153+
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
154+
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
155+
156+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
157+
key_lens_per_batch = attention_mask.sum(-1)
158+
159+
key_layer = _index_first_axis(key_layer, indices_k)
160+
value_layer = _index_first_axis(value_layer, indices_k)
161+
162+
if query_length == kv_seq_len:
163+
query_layer = _index_first_axis(query_layer, indices_k)
164+
cu_seqlens_q = cu_seqlens_k
165+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
166+
indices_q = indices_k
167+
168+
query_mask = attention_mask
169+
bias_view = bias_layer
170+
elif query_length == 1:
171+
max_seqlen_in_batch_q = 1
172+
cu_seqlens_q = torch.arange(
173+
batch_size + 1, dtype=torch.int32, device=query_layer.device
174+
) # There is a memcpy here, that is very bad.
175+
indices_q = cu_seqlens_q[:-1]
176+
query_layer = query_layer.squeeze(1)
177+
178+
query_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
179+
bias_view = bias_layer[:, :, :1, :]
180+
else:
181+
# The -q_len: slice assumes left padding.
182+
attention_mask = attention_mask[:, -query_length:]
183+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
184+
185+
query_mask = attention_mask[:, -query_length:]
186+
bias_view = bias_layer[:, :, -query_length:, :]
187+
188+
b_idx_q, pos_in_q = torch.nonzero(query_mask, as_tuple=True)
189+
bias_layer = bias_view[b_idx_q, :, pos_in_q, :]
190+
row_key_lens = key_lens_per_batch[b_idx_q]
191+
bias_layer = bias_layer[:, :, :max_seqlen_in_batch_k]
192+
col_idx = torch.arange(max_seqlen_in_batch_k, device=query_layer.device).view(1, 1, max_seqlen_in_batch_k)
193+
valid_cols = col_idx < row_key_lens.view(-1, 1, 1)
194+
bias_layer = bias_layer * valid_cols
195+
196+
return (
197+
query_layer,
198+
key_layer,
199+
value_layer,
200+
bias_layer,
201+
indices_q,
202+
(cu_seqlens_q, cu_seqlens_k),
203+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
204+
)
205+
206+
207+
def fdma_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
208+
if target_dtype and q.dtype == torch.float32:
209+
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
210+
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
211+
return q, k, v
212+
213+
214+
def _lazy_imports(impl: Optional[str]):
215+
# returns funcs and pad/unpad based on impl
216+
is_fdma = is_flash_dmattn_available()
217+
218+
if impl == "flash_dmattn" or (impl is None and is_fdma):
219+
pad_input, unpad_input = _fdma_pad_input, _fdma_unpad_input
220+
from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func
221+
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input
222+
223+
else:
224+
pad_input, unpad_input = _fdma_pad_input, _fdma_unpad_input
225+
return (
226+
getattr(impl, "flash_dmattn_func", None),
227+
getattr(impl, "flash_dmattn_varlen_func", None),
228+
pad_input,
229+
unpad_input,
230+
)
231+
232+
233+
class FlashDynamicMaskAttentionKwargs(TypedDict, total=False):
234+
cumulative_seqlens_q: Optional[torch.LongTensor]
235+
cumulative_seqlens_k: Optional[torch.LongTensor]
236+
237+
238+
def _flash_dynamic_mask_attention_forward(
239+
query_states: torch.Tensor,
240+
key_states: torch.Tensor,
241+
value_states: torch.Tensor,
242+
attention_mask: Optional[torch.Tensor],
243+
attention_bias: Optional[torch.Tensor],
244+
query_length: int,
245+
is_causal: bool,
246+
softmax_scale: Optional[float] = None,
247+
softcap: Optional[float] = None,
248+
deterministic: Optional[bool] = None,
249+
target_dtype: Optional[torch.dtype] = None,
250+
implementation: Optional[str] = None,
251+
**kwargs,
252+
):
253+
254+
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn")):
255+
flash_fn, flash_varlen_fn, pad_fn, unpad_fn = _lazy_imports(implementation)
256+
globals()["_flash_fn"] = flash_fn
257+
globals()["_flash_varlen_fn"] = flash_varlen_fn
258+
globals()["_pad_fn"] = pad_fn
259+
globals()["_unpad_fn"] = unpad_fn
260+
else:
261+
flash_fn = globals()["_flash_fn"]
262+
flash_varlen_fn = globals()["_flash_varlen_fn"]
263+
pad_fn = globals()["_pad_fn"]
264+
unpad_fn = globals()["_unpad_fn"]
265+
266+
is_causal = is_causal and not query_length == 1
267+
flash_kwargs = {}
268+
if deterministic is not None:
269+
flash_kwargs["deterministic"] = deterministic
270+
if softcap is not None:
271+
flash_kwargs["softcap"] = softcap
272+
query_states, key_states, value_states = fdma_peft_integration_check(
273+
query_states, key_states, value_states, target_dtype
274+
)
275+
if attention_mask is not None:
276+
q, k, v, bias, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
277+
query_states, key_states, value_states, attention_bias, attention_mask, query_length, _fdma_unpad_input
278+
)
279+
if "mps" in str(q.device):
280+
cu_k = cu_k.clone()
281+
out_unpad = flash_varlen_fn(
282+
query=q,
283+
key=k,
284+
value=v,
285+
attn_bias=bias,
286+
cu_seqlens_q=cu_q.to(torch.int32),
287+
cu_seqlens_k=cu_k.to(torch.int32),
288+
max_seqlen_q=mq,
289+
max_seqlen_k=mk,
290+
scale=softmax_scale,
291+
is_causal=is_causal,
292+
)
293+
if isinstance(out_unpad, tuple):
294+
out_unpad = out_unpad[0]
295+
out = _fdma_pad_input(out_unpad, idx, query_states.shape[0], query_length)
296+
else:
297+
out = flash_fn(
298+
query_states, key_states, value_states, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal
299+
)
300+
301+
return out[0] if isinstance(out, tuple) else out

0 commit comments

Comments
 (0)