44import pickle
55import time
66from dataclasses import dataclass , field
7- from typing import TYPE_CHECKING , Callable , List , Optional
7+ from typing import TYPE_CHECKING , Callable , List , Optional , Tuple
88
99import torch
1010from vllm .config import VllmConfig
2020from ucm .logger import init_logger
2121from ucm .shared .metrics import ucmmonitor
2222from ucm .shared .metrics .observability import UCMStatsLogger
23- from ucm .store .factory import UcmConnectorFactory
24- from ucm .store .ucmstore import Task , UcmKVStoreBase
23+ from ucm .store .factory_v1 import UcmConnectorFactoryV1
24+ from ucm .store .ucmstore_v1 import Task , UcmKVStoreBaseV1
2525from ucm .utils import Config
2626
2727if TYPE_CHECKING :
3535
3636@dataclass
3737class RequestMeta :
38- ucm_block_ids : list [str ] = field (default_factory = list )
38+ ucm_block_ids : list [bytes ] = field (default_factory = list )
3939 hbm_hit_block_num : int = 0
4040 # local_computed_block + external_computed_block
4141 total_hit_block_num : int = 0
@@ -47,9 +47,9 @@ class RequestMeta:
4747@dataclass
4848class RequestDispatchMeta :
4949 load_block_ids : tuple [
50- list [str ], list [int ]
50+ list [bytes ], list [int ]
5151 ] # [0] mean ucm_block_ids, [1] means vllm_block_ids
52- dump_block_ids : tuple [list [str ], list [int ]]
52+ dump_block_ids : tuple [list [bytes ], list [int ]]
5353
5454
5555@dataclass
@@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id):
6969 if RequestHasher ._SEED_HASH is None :
7070 RequestHasher ._SEED_HASH = self ("UCM_HASH_SEED" )
7171
72- def __call__ (self , input_data ) -> int :
73- if isinstance (input_data , str ):
74- input_bytes = input_data . encode ( "utf-8" )
72+ def __call__ (self , input_data ) -> bytes :
73+ if isinstance (input_data , bytes ):
74+ input_bytes = input_data
7575 else :
7676 input_bytes = pickle .dumps (input_data , protocol = pickle .HIGHEST_PROTOCOL )
7777
7878 h = hashlib .md5 (self .meta_bytes + input_bytes )
79- return int . from_bytes ( h .digest (), byteorder = "big" )
79+ return h .digest ()
8080
8181
8282class UCMDirectConnector (KVConnectorBase_V1 ):
@@ -110,9 +110,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
110110
111111 if self .local_rank >= 0 :
112112 self .device = torch_dev .device (f"{ dev_name } :{ self .local_rank } " )
113- self ._layer_offset_cache = {}
114113
115- self .store : UcmKVStoreBase
114+ self .store : UcmKVStoreBaseV1
116115
117116 if role == KVConnectorRole .SCHEDULER :
118117 self .request_hasher = RequestHasher (vllm_config , 0 )
@@ -160,7 +159,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
160159 config ["io_size" ] = block_size_per_layer * (
161160 1 if self .is_mla else num_head_per_tp
162161 )
163- self .store = UcmConnectorFactory .create_connector (name , config )
162+ self .store = UcmConnectorFactoryV1 .create_connector (name , config )
164163 self .block_data_size = config ["kv_block_size" ]
165164
166165 logger .info ("init UCConnectorImpl, connector: %s" , name )
@@ -188,7 +187,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
188187 # invlalid block ids due to load errors
189188 self ._invalid_block_ids : set [int ] = set ()
190189
191- def generate_hash (self , block_size : int , request : "Request" ) -> list [str ]:
190+ def generate_hash (self , block_size : int , request : "Request" ) -> list [bytes ]:
192191 token_ids = request .all_token_ids
193192
194193 ret = []
@@ -205,7 +204,7 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]:
205204 (parent_block_hash_value , block_token_ids_tuple )
206205 )
207206 parent_block_hash_value = hash_value
208- ret .append (str ( hash_value ) )
207+ ret .append (hash_value )
209208
210209 return ret
211210
@@ -395,70 +394,37 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
395394 return int (chunk )
396395 return None
397396
398- def _precompute_layer_offsets (self ):
399- if not self .kv_caches :
400- return
401-
402- sample_kv_layer = next (iter (self .kv_caches .values ()))
403- elem_size = sample_kv_layer [0 ].element_size ()
404- block_data_size = (
405- sample_kv_layer [0 ].numel () if self .is_mla else sample_kv_layer [0 ][0 ].numel ()
406- ) * elem_size
407- layer_data_size = block_data_size if self .is_mla else block_data_size * 2
408-
409- # precompute all layers offset
410- for layer_name , _ in self .kv_caches .items ():
411- layer_id = self ._extract_layer_index (layer_name )
412- assert layer_id is not None
413- k_offset = layer_data_size * layer_id
414- v_offset = k_offset + block_data_size if not self .is_mla else 0
415- self ._layer_offset_cache [layer_name ] = (k_offset , v_offset )
416-
417- def _get_tensor_and_offset (
418- self , vllm_block_ids : list [int ], kv_layer : torch .Tensor , layer_name : str
419- ) -> tuple [list [torch .Tensor ], list [int ]]:
397+ def _get_tensors (
398+ self , vllm_block_id : int
399+ ) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
420400 """
421401 GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
422402 MLA: one layer shape is (num_blocks, block_size, head_size)
423403 """
424- k_tensors , k_offsets = [], []
425- v_tensors , v_offsets = [], []
426- k_offset , v_offset = self ._layer_offset_cache [layer_name ]
427-
428- for vllm_block_id in vllm_block_ids :
404+ k_tensors , v_tensors = [], []
405+ for _ , kv_layer in self .kv_caches .items ():
429406 k_tensors .append (
430407 kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ]
431408 )
432- k_offsets .append (k_offset )
433409 if not self .is_mla :
434410 v_tensors .append (kv_layer [1 ][vllm_block_id ])
435- v_offsets .append (v_offset )
436- return k_tensors + v_tensors , k_offsets + v_offsets
437-
438- def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [str ]):
439- if not self ._layer_offset_cache :
440- self ._precompute_layer_offsets ()
441-
442- num_layers = len (self .kv_caches )
443- num_blocks_per_layer = len (vllm_block_ids )
444- num_tensors_per_layer = num_blocks_per_layer * (1 if self .is_mla else 2 )
445- dst_tensor_addr = [None ] * (num_layers * num_tensors_per_layer )
446- ucm_offsets = [0 ] * (num_layers * num_tensors_per_layer )
447-
448- idx = 0
449- for layer_name , one_layer_kv_cache in self .kv_caches .items ():
450- tensors , offsets = self ._get_tensor_and_offset (
451- vllm_block_ids , one_layer_kv_cache , layer_name
452- )
453- dst_tensor_addr [idx : idx + len (tensors )] = tensors
454- ucm_offsets [idx : idx + len (offsets )] = offsets
455- idx += len (tensors )
411+ return k_tensors , v_tensors
456412
457- repeat_times = len (self .kv_caches ) * (1 if self .is_mla else 2 )
458- ucm_total_block_ids = ucm_block_ids * repeat_times
413+ def _generate_task (
414+ self , vllm_block_ids : List [int ], ucm_block_ids : List [bytes ]
415+ ) -> Tuple [List [bytes ], List [int ], List [List [torch .Tensor ]]]:
416+ """
417+ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
418+ MLA: one layer shape is (num_blocks, block_size, head_size)
419+ """
420+ block_ids , shard_indexs , tensors = [], [], []
421+ for i , vllm_block_id in enumerate (vllm_block_ids ):
422+ k_tensors , v_tensors = self ._get_tensors (vllm_block_id )
423+ block_ids .append (ucm_block_ids [i ])
424+ tensors .append (k_tensors + v_tensors )
425+ shard_indexs .append (0 )
459426
460- assert len (ucm_total_block_ids ) == len (ucm_offsets ) == len (dst_tensor_addr )
461- return ucm_total_block_ids , ucm_offsets , dst_tensor_addr
427+ return block_ids , shard_indexs , tensors
462428
463429 def _broadcast (self , dst_tensor_addr : list [torch .Tensor ]):
464430 rec_tensor : torch .Tensor = None
@@ -501,26 +467,28 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
501467 ucm_block_ids , vllm_block_ids = request .load_block_ids
502468 if self .global_rank != 0 and not self .is_mla and not self .is_dsa :
503469 for i , ucm_block_id in enumerate (ucm_block_ids ):
504- ucm_block_ids [i ] = str ( self .request_hasher (ucm_block_id ) )
505- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
470+ ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
471+ block_ids , shard_indexs , tensors = self ._generate_task (
506472 vllm_block_ids , ucm_block_ids
507473 )
508474 if self .global_rank == 0 or not self .load_only_first_rank :
509475 request_to_task [request_id ] = self .store .load (
510- ucm_total_block_ids , ucm_offsets , dst_tensor_addr
476+ block_ids , shard_indexs , tensors
511477 )
512478 else :
513479 request_to_task [request_id ] = None
514- req_broadcast_addr [request_id ] = dst_tensor_addr
480+ req_broadcast_addr [request_id ] = [ t for row in tensors for t in row ]
515481
516482 for request_id , task in request_to_task .items ():
517483 # TODO error handling
518484 if self .global_rank == 0 or not self .load_only_first_rank :
519- if self .store .wait (task ) != 0 :
485+ try :
486+ self .store .wait (task )
487+ except RuntimeError as e :
488+ logger .error ("request {request_id} load kv cache failed.:" , e )
520489 self ._invalid_block_ids .update (
521490 metadata .request_meta [request_id ].load_block_ids [1 ]
522491 )
523- logger .error (f"request { request_id } load kv cache failed." )
524492 if self .load_only_first_rank :
525493 self ._broadcast (req_broadcast_addr [request_id ])
526494 load_end_time = time .perf_counter () * 1000
@@ -568,7 +536,6 @@ def wait_for_save(self) -> None:
568536 assert isinstance (metadata , UCMConnectorMetadata )
569537
570538 request_to_task : dict [str , Task ] = {}
571- request_to_blocks : dict [str , list [str ]] = {}
572539 is_save = False
573540 num_saved_block = 0
574541 num_saved_request = 0
@@ -583,36 +550,19 @@ def wait_for_save(self) -> None:
583550 ucm_block_ids , vllm_block_ids = request .dump_block_ids
584551 if self .global_rank != 0 :
585552 for i , ucm_block_id in enumerate (ucm_block_ids ):
586- ucm_block_ids [i ] = str (self .request_hasher (ucm_block_id ))
587- rets = self .store .create (ucm_block_ids )
588- end = 0
589- for i , ret in enumerate (rets ):
590- if ret != 0 :
591- logger .error (
592- f"create blocks for { request_id } failed, block index: { i } , ret code: { ret } "
593- )
594- break
595- end += 1
596-
597- if end == 0 :
598- continue
599- ucm_block_ids = ucm_block_ids [:end ]
600- vllm_block_ids = vllm_block_ids [:end ]
601- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
553+ ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
554+ block_ids , shard_indexs , tensors = self ._generate_task (
602555 vllm_block_ids , ucm_block_ids
603556 )
604557 request_to_task [request_id ] = self .store .dump (
605- ucm_total_block_ids , ucm_offsets , dst_tensor_addr
558+ block_ids , shard_indexs , tensors
606559 )
607- request_to_blocks [request_id ] = ucm_block_ids
608560
609561 for request_id , task in request_to_task .items ():
610- ucm_block_ids = request_to_blocks [request_id ]
611- if self .store .wait (task ) == 0 :
612- self .store .commit (ucm_block_ids , True )
613- else :
614- logger .error (f"request { request_id } dump kv cache failed." )
615- self .store .commit (ucm_block_ids , False )
562+ try :
563+ self .store .wait (task )
564+ except RuntimeError as e :
565+ logger .error ("request {request_id} dump kv cache failed.:" , e )
616566 save_end_time = time .perf_counter () * 1000
617567 save_speed = (
618568 num_saved_block
0 commit comments