1717
1818logger = init_logger (__name__ )
1919
20+ # TPU requires the head size to be a multiple of 128.
21+ TPU_HEAD_SIZE_ALIGNMENT = 128
22+
2023
2124class PallasAttentionBackend (AttentionBackend ):
2225
@@ -43,6 +46,14 @@ def get_kv_cache_shape(
4346 num_kv_heads : int ,
4447 head_size : int ,
4548 ) -> tuple [int , ...]:
49+ padded_head_size = cdiv (
50+ head_size , TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
51+ num_blocks = num_blocks * head_size // padded_head_size
52+ if padded_head_size != head_size :
53+ logger .warning_once (
54+ "head size is padded to %d, and num_blocks is adjusted to %d"
55+ " accordingly" , padded_head_size , num_blocks )
56+ head_size = padded_head_size
4657 return (num_blocks , block_size , num_kv_heads * 2 , head_size )
4758
4859 @staticmethod
@@ -132,8 +143,6 @@ def __init__(
132143 self .kv_sharing_target_layer_name = kv_sharing_target_layer_name
133144
134145 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
135- if head_size % 128 != 0 :
136- raise NotImplementedError ("Head size must be a multiple of 128." )
137146 if alibi_slopes is not None :
138147 raise NotImplementedError ("Alibi slopes is not supported." )
139148 if kv_cache_dtype != "auto" :
@@ -187,6 +196,18 @@ def forward(
187196 assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
188197 num_tokens , hidden_size = query .shape
189198 query = query .view (num_tokens , self .num_heads , self .head_size )
199+ key = key .view (- 1 , self .num_kv_heads , self .head_size )
200+ value = value .view (- 1 , self .num_kv_heads , self .head_size )
201+ if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
202+ padded_head_size = cdiv (
203+ self .head_size ,
204+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
205+ query = torch .nn .functional .pad (
206+ query , (0 , padded_head_size - self .head_size ), value = 0.0 )
207+ key = torch .nn .functional .pad (
208+ key , (0 , padded_head_size - self .head_size ), value = 0.0 )
209+ value = torch .nn .functional .pad (
210+ value , (0 , padded_head_size - self .head_size ), value = 0.0 )
190211
191212 if self .kv_sharing_target_layer_name is None and kv_cache .numel () > 0 :
192213 # Write input keys and values to the KV cache.
@@ -213,6 +234,9 @@ def forward(
213234 soft_cap = self .logits_soft_cap ,
214235 )
215236
237+ if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
238+ output = output [:, :, :self .head_size ]
239+
216240 return output .reshape (num_tokens , hidden_size )
217241
218242
@@ -231,11 +255,8 @@ def write_to_kv_cache(
231255
232256 """
233257 _ , _ , num_combined_kv_heads , head_size = kv_cache .shape
234- num_kv_heads = num_combined_kv_heads // 2
235-
236- key = key .view (- 1 , num_kv_heads , head_size )
237- value = value .view (- 1 , num_kv_heads , head_size )
238-
258+ head_size = cdiv (head_size ,
259+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
239260 kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
240261 head_size )
241262
0 commit comments