|
| 1 | +# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Optional, TypedDict |
| 16 | +import torch |
| 17 | +from torch.nn import functional as F |
| 18 | +from .import_utils import is_flash_dmattn_available |
| 19 | + |
| 20 | +from transformers.utils import logging |
| 21 | + |
| 22 | + |
| 23 | +logger = logging.get_logger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
| 27 | + reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) |
| 28 | + return reshaped[indices] |
| 29 | + |
| 30 | + |
| 31 | +def _fdma_unpad_input(hidden_states, attention_mask, unused_mask=None): |
| 32 | + """ |
| 33 | + FDMA-compatible unpad_input function. |
| 34 | + Arguments: |
| 35 | + hidden_states: (batch, seqlen, ...) |
| 36 | + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| 37 | + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| 38 | + Return: |
| 39 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| 40 | + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| 41 | + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| 42 | + max_seqlen_in_batch: int |
| 43 | + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| 44 | + """ |
| 45 | + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| 46 | + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) |
| 47 | + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 48 | + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| 49 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 50 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| 51 | + |
| 52 | + return ( |
| 53 | + _index_first_axis(hidden_states, indices), |
| 54 | + indices, |
| 55 | + cu_seqlens, |
| 56 | + max_seqlen_in_batch, |
| 57 | + used_seqlens_in_batch, |
| 58 | + ) |
| 59 | + |
| 60 | + |
| 61 | +def _fdma_pad_input(hidden_states, indices, batch, seqlen): |
| 62 | + """ |
| 63 | + FDMA-compatible pad_input function. |
| 64 | + Arguments: |
| 65 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 66 | + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| 67 | + batch: int, batch size for the padded sequence. |
| 68 | + seqlen: int, maximum sequence length for the padded sequence. |
| 69 | + Return: |
| 70 | + hidden_states: (batch, seqlen, ...) |
| 71 | + """ |
| 72 | + dim = hidden_states.shape[1:] |
| 73 | + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| 74 | + output[indices] = hidden_states |
| 75 | + return output.view(batch, seqlen, *dim) |
| 76 | + |
| 77 | + |
| 78 | +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 79 | + """ |
| 80 | + Retrieves indexing data required to repad unpadded (ragged) tensors. |
| 81 | + Arguments: |
| 82 | + attention_mask (`torch.Tensor`): |
| 83 | + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| 84 | + Return: |
| 85 | + indices (`torch.Tensor`): |
| 86 | + The indices of non-masked tokens from the flattened input sequence. |
| 87 | + cu_seqlens (`torch.Tensor`): |
| 88 | + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| 89 | + max_seqlen_in_batch (`int`): |
| 90 | + Maximum sequence length in batch. |
| 91 | + """ |
| 92 | + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 93 | + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| 94 | + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, |
| 95 | + # this might cause a graph break |
| 96 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 97 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| 98 | + return ( |
| 99 | + indices, |
| 100 | + cu_seqlens, |
| 101 | + max_seqlen_in_batch, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def _upad_input( |
| 106 | + query_layer: torch.Tensor, |
| 107 | + key_layer: torch.Tensor, |
| 108 | + value_layer: torch.Tensor, |
| 109 | + bias_layer: torch.Tensor, |
| 110 | + attention_mask: torch.Tensor, |
| 111 | + query_length: int, |
| 112 | + unpad_input_func, |
| 113 | +): |
| 114 | + """ |
| 115 | + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. |
| 116 | + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary |
| 117 | + tensors for query, key, value tensors. |
| 118 | + Arguments: |
| 119 | + query_layer (`torch.Tensor`): |
| 120 | + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). |
| 121 | + key_layer (`torch.Tensor`): |
| 122 | + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). |
| 123 | + value_layer (`torch.Tensor`): |
| 124 | + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). |
| 125 | + bias_layer (`torch.Tensor`): |
| 126 | + Attention bias tensor of shape (batch_size, num_key_value_heads, query_length, kv_seq_len). |
| 127 | + attention_mask (`torch.Tensor`): |
| 128 | + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| 129 | + query_length (`int`): |
| 130 | + Target length. |
| 131 | + unpad_input_func: |
| 132 | + The function to use for unpadding the input tensors. |
| 133 | + Return: |
| 134 | + query_layer (`torch.Tensor`): |
| 135 | + Query state without padding. Shape: (total_target_length, num_heads, head_dim). |
| 136 | + key_layer (`torch.Tensor`): |
| 137 | + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). |
| 138 | + value_layer (`torch.Tensor`): |
| 139 | + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). |
| 140 | + bias_layer (`torch.Tensor`): |
| 141 | + Attention bias tensor without padding. Shape: (total_target_length, num_key_value_heads, query_length, kv_seq_len). |
| 142 | + indices_q (`torch.Tensor`): |
| 143 | + The indices of non-masked tokens from the flattened input target sequence. |
| 144 | + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): |
| 145 | + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| 146 | + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): |
| 147 | + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). |
| 148 | + """ |
| 149 | + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| 150 | + |
| 151 | + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage |
| 152 | + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores |
| 153 | + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): |
| 154 | + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] |
| 155 | + |
| 156 | + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| 157 | + key_lens_per_batch = attention_mask.sum(-1) |
| 158 | + |
| 159 | + key_layer = _index_first_axis(key_layer, indices_k) |
| 160 | + value_layer = _index_first_axis(value_layer, indices_k) |
| 161 | + |
| 162 | + if query_length == kv_seq_len: |
| 163 | + query_layer = _index_first_axis(query_layer, indices_k) |
| 164 | + cu_seqlens_q = cu_seqlens_k |
| 165 | + max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| 166 | + indices_q = indices_k |
| 167 | + |
| 168 | + query_mask = attention_mask |
| 169 | + bias_view = bias_layer |
| 170 | + elif query_length == 1: |
| 171 | + max_seqlen_in_batch_q = 1 |
| 172 | + cu_seqlens_q = torch.arange( |
| 173 | + batch_size + 1, dtype=torch.int32, device=query_layer.device |
| 174 | + ) # There is a memcpy here, that is very bad. |
| 175 | + indices_q = cu_seqlens_q[:-1] |
| 176 | + query_layer = query_layer.squeeze(1) |
| 177 | + |
| 178 | + query_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) |
| 179 | + bias_view = bias_layer[:, :, :1, :] |
| 180 | + else: |
| 181 | + # The -q_len: slice assumes left padding. |
| 182 | + attention_mask = attention_mask[:, -query_length:] |
| 183 | + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) |
| 184 | + |
| 185 | + query_mask = attention_mask[:, -query_length:] |
| 186 | + bias_view = bias_layer[:, :, -query_length:, :] |
| 187 | + |
| 188 | + b_idx_q, pos_in_q = torch.nonzero(query_mask, as_tuple=True) |
| 189 | + bias_layer = bias_view[b_idx_q, :, pos_in_q, :] |
| 190 | + row_key_lens = key_lens_per_batch[b_idx_q] |
| 191 | + bias_layer = bias_layer[:, :, :max_seqlen_in_batch_k] |
| 192 | + col_idx = torch.arange(max_seqlen_in_batch_k, device=query_layer.device).view(1, 1, max_seqlen_in_batch_k) |
| 193 | + valid_cols = col_idx < row_key_lens.view(-1, 1, 1) |
| 194 | + bias_layer = bias_layer * valid_cols |
| 195 | + |
| 196 | + return ( |
| 197 | + query_layer, |
| 198 | + key_layer, |
| 199 | + value_layer, |
| 200 | + bias_layer, |
| 201 | + indices_q, |
| 202 | + (cu_seqlens_q, cu_seqlens_k), |
| 203 | + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| 204 | + ) |
| 205 | + |
| 206 | + |
| 207 | +def fdma_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): |
| 208 | + if target_dtype and q.dtype == torch.float32: |
| 209 | + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") |
| 210 | + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) |
| 211 | + return q, k, v |
| 212 | + |
| 213 | + |
| 214 | +def _lazy_imports(impl: Optional[str]): |
| 215 | + # returns funcs and pad/unpad based on impl |
| 216 | + is_fdma = is_flash_dmattn_available() |
| 217 | + |
| 218 | + if impl == "flash_dmattn" or (impl is None and is_fdma): |
| 219 | + pad_input, unpad_input = _fdma_pad_input, _fdma_unpad_input |
| 220 | + from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func |
| 221 | + return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input |
| 222 | + |
| 223 | + else: |
| 224 | + pad_input, unpad_input = _fdma_pad_input, _fdma_unpad_input |
| 225 | + return ( |
| 226 | + getattr(impl, "flash_dmattn_func", None), |
| 227 | + getattr(impl, "flash_dmattn_varlen_func", None), |
| 228 | + pad_input, |
| 229 | + unpad_input, |
| 230 | + ) |
| 231 | + |
| 232 | + |
| 233 | +class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): |
| 234 | + cumulative_seqlens_q: Optional[torch.LongTensor] |
| 235 | + cumulative_seqlens_k: Optional[torch.LongTensor] |
| 236 | + |
| 237 | + |
| 238 | +def _flash_dynamic_mask_attention_forward( |
| 239 | + query_states: torch.Tensor, |
| 240 | + key_states: torch.Tensor, |
| 241 | + value_states: torch.Tensor, |
| 242 | + attention_mask: Optional[torch.Tensor], |
| 243 | + attention_bias: Optional[torch.Tensor], |
| 244 | + query_length: int, |
| 245 | + is_causal: bool, |
| 246 | + softmax_scale: Optional[float] = None, |
| 247 | + softcap: Optional[float] = None, |
| 248 | + deterministic: Optional[bool] = None, |
| 249 | + target_dtype: Optional[torch.dtype] = None, |
| 250 | + implementation: Optional[str] = None, |
| 251 | + **kwargs, |
| 252 | +): |
| 253 | + |
| 254 | + if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn")): |
| 255 | + flash_fn, flash_varlen_fn, pad_fn, unpad_fn = _lazy_imports(implementation) |
| 256 | + globals()["_flash_fn"] = flash_fn |
| 257 | + globals()["_flash_varlen_fn"] = flash_varlen_fn |
| 258 | + globals()["_pad_fn"] = pad_fn |
| 259 | + globals()["_unpad_fn"] = unpad_fn |
| 260 | + else: |
| 261 | + flash_fn = globals()["_flash_fn"] |
| 262 | + flash_varlen_fn = globals()["_flash_varlen_fn"] |
| 263 | + pad_fn = globals()["_pad_fn"] |
| 264 | + unpad_fn = globals()["_unpad_fn"] |
| 265 | + |
| 266 | + is_causal = is_causal and not query_length == 1 |
| 267 | + flash_kwargs = {} |
| 268 | + if deterministic is not None: |
| 269 | + flash_kwargs["deterministic"] = deterministic |
| 270 | + if softcap is not None: |
| 271 | + flash_kwargs["softcap"] = softcap |
| 272 | + query_states, key_states, value_states = fdma_peft_integration_check( |
| 273 | + query_states, key_states, value_states, target_dtype |
| 274 | + ) |
| 275 | + if attention_mask is not None: |
| 276 | + q, k, v, bias, idx, (cu_q, cu_k), (mq, mk) = _upad_input( |
| 277 | + query_states, key_states, value_states, attention_bias, attention_mask, query_length, _fdma_unpad_input |
| 278 | + ) |
| 279 | + if "mps" in str(q.device): |
| 280 | + cu_k = cu_k.clone() |
| 281 | + out_unpad = flash_varlen_fn( |
| 282 | + query=q, |
| 283 | + key=k, |
| 284 | + value=v, |
| 285 | + attn_bias=bias, |
| 286 | + cu_seqlens_q=cu_q.to(torch.int32), |
| 287 | + cu_seqlens_k=cu_k.to(torch.int32), |
| 288 | + max_seqlen_q=mq, |
| 289 | + max_seqlen_k=mk, |
| 290 | + scale=softmax_scale, |
| 291 | + is_causal=is_causal, |
| 292 | + ) |
| 293 | + if isinstance(out_unpad, tuple): |
| 294 | + out_unpad = out_unpad[0] |
| 295 | + out = _fdma_pad_input(out_unpad, idx, query_states.shape[0], query_length) |
| 296 | + else: |
| 297 | + out = flash_fn( |
| 298 | + query_states, key_states, value_states, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal |
| 299 | + ) |
| 300 | + |
| 301 | + return out[0] if isinstance(out, tuple) else out |
0 commit comments