55
66from text_generation_server .models .globals import BLOCK_SIZE
77from text_generation_server .utils .weights import Weights
8- from vllm_hpu_extension import cache_ops
98
109
1110@dataclass
@@ -55,12 +54,12 @@ def __init__(
5554
5655 self .kv_cache = (
5756 torch .zeros (
58- (num_blocks , BLOCK_SIZE , num_heads , head_size ),
57+ (num_blocks * BLOCK_SIZE , num_heads , head_size ),
5958 dtype = dtype ,
6059 device = device ,
6160 ),
6261 torch .zeros (
63- (num_blocks , BLOCK_SIZE , num_heads , head_size ),
62+ (num_blocks * BLOCK_SIZE , num_heads , head_size ),
6463 dtype = dtype ,
6564 device = device ,
6665 ),
@@ -129,7 +128,7 @@ def __init__(
129128 raise ValueError ("torch.float8_e5m2 is not supported in hpu. " )
130129
131130 self .kv_cache = torch .zeros (
132- (num_blocks , BLOCK_SIZE , 1 , head_size ),
131+ (num_blocks * BLOCK_SIZE , 1 , head_size ),
133132 dtype = dtype ,
134133 device = device ,
135134 )
@@ -161,14 +160,11 @@ def store(
161160 ):
162161 """Store the key and value at the given slots."""
163162 ## TODO FP8 kv cache support
164-
165- block_idx = slots // BLOCK_SIZE
166- block_offset = slots % BLOCK_SIZE
167163 if self .kv_cache .dtype == torch .float8_e4m3fn :
168164 key = torch .ops .hpu .cast_to_fp8_v2 (
169165 key , kv_scales .key_scale , False , False , torch .float8_e4m3fn
170166 )[0 ]
171- cache_ops . insert_or_update_cache ( key , self .kv_cache , block_idx , block_offset )
167+ self .kv_cache . index_copy_ ( 0 , slots , key )
172168
173169
174170def paged_reshape_and_cache (
@@ -180,17 +176,15 @@ def paged_reshape_and_cache(
180176 k_scale : torch .Tensor ,
181177 v_scale : torch .Tensor ,
182178):
183- block_idx = slots // BLOCK_SIZE
184- block_offset = slots % BLOCK_SIZE
185179 if key_cache .dtype == torch .float8_e4m3fn :
186180 key = torch .ops .hpu .cast_to_fp8_v2 (
187181 key , k_scale , False , False , torch .float8_e4m3fn
188182 )[0 ]
189183 value = torch .ops .hpu .cast_to_fp8_v2 (
190184 value , v_scale , False , False , torch .float8_e4m3fn
191185 )[0 ]
192- cache_ops . insert_or_update_cache ( key , key_cache , block_idx , block_offset )
193- cache_ops . insert_or_update_cache ( value , value_cache , block_idx , block_offset )
186+ key_cache . index_copy_ ( 0 , slots , key )
187+ value_cache . index_copy_ ( 0 , slots , value )
194188
195189
196190def get_kv_scales (weights : Weights , prefix : str ) -> KVScales :
0 commit comments