|
| 1 | +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | + |
| 7 | +def index_first_axis(tensor, indices): |
| 8 | + """ |
| 9 | + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, |
| 10 | + after flattening the first two dimensions of the tensor. |
| 11 | + """ |
| 12 | + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first |
| 13 | + # two dimensions to get (total_tokens, ...) before indexing. |
| 14 | + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) |
| 15 | + return reshaped_tensor[indices] |
| 16 | + |
| 17 | + |
| 18 | +def unpad_input(hidden_states, attention_mask, unused_mask=None): |
| 19 | + """ |
| 20 | + Arguments: |
| 21 | + hidden_states: (batch, seqlen, ...) |
| 22 | + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| 23 | + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| 24 | +
|
| 25 | + Return: |
| 26 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| 27 | + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| 28 | + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| 29 | + max_seqlen_in_batch: int |
| 30 | + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| 31 | + """ |
| 32 | + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| 33 | + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) |
| 34 | + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 35 | + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| 36 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 37 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| 38 | + |
| 39 | + return ( |
| 40 | + index_first_axis(hidden_states, indices), |
| 41 | + indices, |
| 42 | + cu_seqlens, |
| 43 | + max_seqlen_in_batch, |
| 44 | + used_seqlens_in_batch, |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def pad_input(hidden_states, indices, batch, seqlen): |
| 49 | + """ |
| 50 | + Arguments: |
| 51 | + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 52 | + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| 53 | + batch: int, batch size for the padded sequence. |
| 54 | + seqlen: int, maximum sequence length for the padded sequence. |
| 55 | +
|
| 56 | + Return: |
| 57 | + hidden_states: (batch, seqlen, ...) |
| 58 | + """ |
| 59 | + dim = hidden_states.shape[1:] |
| 60 | + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| 61 | + output[indices] = hidden_states |
| 62 | + return output.view(batch, seqlen, *dim) |
| 63 | + |
| 64 | + |
| 65 | +def get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 66 | + """ |
| 67 | + Retrieves indexing data required to repad unpadded (ragged) tensors. |
| 68 | +
|
| 69 | + Arguments: |
| 70 | + attention_mask (`torch.Tensor`): |
| 71 | + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| 72 | +
|
| 73 | + Return: |
| 74 | + indices (`torch.Tensor`): |
| 75 | + The indices of non-masked tokens from the flattened input sequence. |
| 76 | + cu_seqlens (`torch.Tensor`): |
| 77 | + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| 78 | + max_seqlen_in_batch (`int`): |
| 79 | + Maximum sequence length in batch. |
| 80 | + """ |
| 81 | + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 82 | + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| 83 | + # NOTE: Similar to the `.item()` in prepare_fdma_from_position_ids, with torch compile, |
| 84 | + # this might cause a graph break |
| 85 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 86 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| 87 | + return ( |
| 88 | + indices, |
| 89 | + cu_seqlens, |
| 90 | + max_seqlen_in_batch, |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +def upad_input( |
| 95 | + query_layer: torch.Tensor, |
| 96 | + key_layer: torch.Tensor, |
| 97 | + value_layer: torch.Tensor, |
| 98 | + attention_mask: torch.Tensor, |
| 99 | + query_length: int, |
| 100 | + unpad_input_func, |
| 101 | +): |
| 102 | + """ |
| 103 | + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. |
| 104 | + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary |
| 105 | + tensors for query, key, value tensors. |
| 106 | +
|
| 107 | + Arguments: |
| 108 | + query_layer (`torch.Tensor`): |
| 109 | + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). |
| 110 | + key_layer (`torch.Tensor`): |
| 111 | + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). |
| 112 | + value_layer (`torch.Tensor`): |
| 113 | + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). |
| 114 | + attention_mask (`torch.Tensor`): |
| 115 | + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| 116 | + query_length (`int`): |
| 117 | + Target length. |
| 118 | + unpad_input_func: |
| 119 | + The function to use for unpadding the input tensors. |
| 120 | +
|
| 121 | + Return: |
| 122 | + query_layer (`torch.Tensor`): |
| 123 | + Query state without padding. Shape: (total_target_length, num_heads, head_dim). |
| 124 | + key_layer (`torch.Tensor`): |
| 125 | + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). |
| 126 | + value_layer (`torch.Tensor`): |
| 127 | + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). |
| 128 | + indices_q (`torch.Tensor`): |
| 129 | + The indices of non-masked tokens from the flattened input target sequence. |
| 130 | + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): |
| 131 | + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| 132 | + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): |
| 133 | + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). |
| 134 | + """ |
| 135 | + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) |
| 136 | + |
| 137 | + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage |
| 138 | + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores |
| 139 | + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): |
| 140 | + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] |
| 141 | + |
| 142 | + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| 143 | + |
| 144 | + key_layer = index_first_axis(key_layer, indices_k) |
| 145 | + value_layer = index_first_axis(value_layer, indices_k) |
| 146 | + if query_length == kv_seq_len: |
| 147 | + query_layer = index_first_axis(query_layer, indices_k) |
| 148 | + cu_seqlens_q = cu_seqlens_k |
| 149 | + max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| 150 | + indices_q = indices_k |
| 151 | + elif query_length == 1: |
| 152 | + max_seqlen_in_batch_q = 1 |
| 153 | + cu_seqlens_q = torch.arange( |
| 154 | + batch_size + 1, dtype=torch.int32, device=query_layer.device |
| 155 | + ) # There is a memcpy here, that is very bad. |
| 156 | + indices_q = cu_seqlens_q[:-1] |
| 157 | + query_layer = query_layer.squeeze(1) |
| 158 | + else: |
| 159 | + # The -q_len: slice assumes left padding. |
| 160 | + attention_mask = attention_mask[:, -query_length:] |
| 161 | + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) |
| 162 | + |
| 163 | + return ( |
| 164 | + query_layer, |
| 165 | + key_layer, |
| 166 | + value_layer, |
| 167 | + indices_q, |
| 168 | + (cu_seqlens_q, cu_seqlens_k), |
| 169 | + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| 170 | + ) |
0 commit comments