44import pickle
55import time
66from dataclasses import dataclass , field
7- from typing import TYPE_CHECKING , Callable , List , Optional
7+ from typing import TYPE_CHECKING , Callable , List , Optional , Tuple
88
99import torch
1010from vllm .config import VllmConfig
2121from ucm .shared .metrics import ucmmonitor
2222from ucm .shared .metrics .observability import UCMStatsLogger
2323from ucm .store .factory import UcmConnectorFactory
24- from ucm .store .ucmstore import Task , UcmKVStoreBase
24+ from ucm .store .ucmstore_v1 import UcmKVStoreBaseV1 , Task
2525from ucm .utils import Config
2626
2727if TYPE_CHECKING :
3535
3636@dataclass
3737class RequestMeta :
38- ucm_block_ids : list [str ] = field (default_factory = list )
38+ ucm_block_ids : list [bytes ] = field (default_factory = list )
3939 hbm_hit_block_num : int = 0
4040 # local_computed_block + external_computed_block
4141 total_hit_block_num : int = 0
@@ -47,9 +47,9 @@ class RequestMeta:
4747@dataclass
4848class RequestDispatchMeta :
4949 load_block_ids : tuple [
50- list [str ], list [int ]
50+ list [bytes ], list [int ]
5151 ] # [0] mean ucm_block_ids, [1] means vllm_block_ids
52- dump_block_ids : tuple [list [str ], list [int ]]
52+ dump_block_ids : tuple [list [bytes ], list [int ]]
5353
5454
5555@dataclass
@@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id):
6969 if RequestHasher ._SEED_HASH is None :
7070 RequestHasher ._SEED_HASH = self ("UCM_HASH_SEED" )
7171
72- def __call__ (self , input_data ) -> int :
73- if isinstance (input_data , str ):
74- input_bytes = input_data . encode ( "utf-8" )
72+ def __call__ (self , input_data ) -> bytes :
73+ if isinstance (input_data , bytes ):
74+ input_bytes = input_data
7575 else :
7676 input_bytes = pickle .dumps (input_data , protocol = pickle .HIGHEST_PROTOCOL )
7777
7878 h = hashlib .md5 (self .meta_bytes + input_bytes )
79- return int . from_bytes ( h .digest (), byteorder = "big" )
79+ return h .digest ()
8080
8181
8282class UCMDirectConnector (KVConnectorBase_V1 ):
@@ -110,9 +110,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
110110
111111 if self .local_rank >= 0 :
112112 self .device = torch_dev .device (f"{ dev_name } :{ self .local_rank } " )
113- self ._layer_offset_cache = {}
114113
115- self .store : UcmKVStoreBase
114+ self .store : UcmKVStoreBaseV1
116115
117116 if role == KVConnectorRole .SCHEDULER :
118117 self .request_hasher = RequestHasher (vllm_config , 0 )
@@ -188,7 +187,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
188187 # invlalid block ids due to load errors
189188 self ._invalid_block_ids : set [int ] = set ()
190189
191- def generate_hash (self , block_size : int , request : "Request" ) -> list [str ]:
190+ def generate_hash (self , block_size : int , request : "Request" ) -> list [bytes ]:
192191 token_ids = request .all_token_ids
193192
194193 ret = []
@@ -205,7 +204,7 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]:
205204 (parent_block_hash_value , block_token_ids_tuple )
206205 )
207206 parent_block_hash_value = hash_value
208- ret .append (str ( hash_value ) )
207+ ret .append (hash_value )
209208
210209 return ret
211210
@@ -395,70 +394,32 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
395394 return int (chunk )
396395 return None
397396
398- def _precompute_layer_offsets (self ):
399- if not self .kv_caches :
400- return
401-
402- sample_kv_layer = next (iter (self .kv_caches .values ()))
403- elem_size = sample_kv_layer [0 ].element_size ()
404- block_data_size = (
405- sample_kv_layer [0 ].numel () if self .is_mla else sample_kv_layer [0 ][0 ].numel ()
406- ) * elem_size
407- layer_data_size = block_data_size if self .is_mla else block_data_size * 2
408-
409- # precompute all layers offset
410- for layer_name , _ in self .kv_caches .items ():
411- layer_id = self ._extract_layer_index (layer_name )
412- assert layer_id is not None
413- k_offset = layer_data_size * layer_id
414- v_offset = k_offset + block_data_size if not self .is_mla else 0
415- self ._layer_offset_cache [layer_name ] = (k_offset , v_offset )
416-
417- def _get_tensor_and_offset (
418- self , vllm_block_ids : list [int ], kv_layer : torch .Tensor , layer_name : str
419- ) -> tuple [list [torch .Tensor ], list [int ]]:
397+ def _get_tensors (self , vllm_block_id : int ) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
420398 """
421399 GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
422400 MLA: one layer shape is (num_blocks, block_size, head_size)
423401 """
424- k_tensors , k_offsets = [], []
425- v_tensors , v_offsets = [], []
426- k_offset , v_offset = self ._layer_offset_cache [layer_name ]
427-
428- for vllm_block_id in vllm_block_ids :
429- k_tensors .append (
430- kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ]
431- )
432- k_offsets .append (k_offset )
402+ k_tensors , v_tensors = [], []
403+ for _ , kv_layer in self .kv_caches .items ():
404+ k_tensors .append (kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ])
433405 if not self .is_mla :
434406 v_tensors .append (kv_layer [1 ][vllm_block_id ])
435- v_offsets .append (v_offset )
436- return k_tensors + v_tensors , k_offsets + v_offsets
437-
438- def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [str ]):
439- if not self ._layer_offset_cache :
440- self ._precompute_layer_offsets ()
441-
442- num_layers = len (self .kv_caches )
443- num_blocks_per_layer = len (vllm_block_ids )
444- num_tensors_per_layer = num_blocks_per_layer * (1 if self .is_mla else 2 )
445- dst_tensor_addr = [None ] * (num_layers * num_tensors_per_layer )
446- ucm_offsets = [0 ] * (num_layers * num_tensors_per_layer )
447-
448- idx = 0
449- for layer_name , one_layer_kv_cache in self .kv_caches .items ():
450- tensors , offsets = self ._get_tensor_and_offset (
451- vllm_block_ids , one_layer_kv_cache , layer_name
452- )
453- dst_tensor_addr [idx : idx + len (tensors )] = tensors
454- ucm_offsets [idx : idx + len (offsets )] = offsets
455- idx += len (tensors )
407+ return k_tensors , v_tensors
456408
457- repeat_times = len (self .kv_caches ) * (1 if self .is_mla else 2 )
458- ucm_total_block_ids = ucm_block_ids * repeat_times
459409
460- assert len (ucm_total_block_ids ) == len (ucm_offsets ) == len (dst_tensor_addr )
461- return ucm_total_block_ids , ucm_offsets , dst_tensor_addr
410+ def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [bytes ]) -> Tuple [List [bytes ], List [int ], List [List [torch .Tensor ]]]:
411+ """
412+ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
413+ MLA: one layer shape is (num_blocks, block_size, head_size)
414+ """
415+ block_ids , shard_indexs , tensors = [], [], []
416+ for i , vllm_block_id in enumerate (vllm_block_ids ):
417+ k_tensors , v_tensors = self ._get_tensors (vllm_block_id )
418+ block_ids .append (ucm_block_ids [i ])
419+ tensors .append (k_tensors + v_tensors )
420+ shard_indexs .append (0 )
421+
422+ return block_ids , shard_indexs , tensors
462423
463424 def _broadcast (self , dst_tensor_addr : list [torch .Tensor ]):
464425 rec_tensor : torch .Tensor = None
@@ -501,17 +462,17 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
501462 ucm_block_ids , vllm_block_ids = request .load_block_ids
502463 if self .global_rank != 0 and not self .is_mla and not self .is_dsa :
503464 for i , ucm_block_id in enumerate (ucm_block_ids ):
504- ucm_block_ids [i ] = str ( self .request_hasher (ucm_block_id ) )
505- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
465+ ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
466+ block_ids , shard_indexs , tensors = self ._generate_task (
506467 vllm_block_ids , ucm_block_ids
507468 )
508469 if self .global_rank == 0 or not self .load_only_first_rank :
509470 request_to_task [request_id ] = self .store .load (
510- ucm_total_block_ids , ucm_offsets , dst_tensor_addr
471+ block_ids , shard_indexs , tensors
511472 )
512473 else :
513474 request_to_task [request_id ] = None
514- req_broadcast_addr [request_id ] = dst_tensor_addr
475+ req_broadcast_addr [request_id ] = tensors
515476
516477 for request_id , task in request_to_task .items ():
517478 # TODO error handling
@@ -568,7 +529,6 @@ def wait_for_save(self) -> None:
568529 assert isinstance (metadata , UCMConnectorMetadata )
569530
570531 request_to_task : dict [str , Task ] = {}
571- request_to_blocks : dict [str , list [str ]] = {}
572532 is_save = False
573533 num_saved_block = 0
574534 num_saved_request = 0
@@ -583,36 +543,16 @@ def wait_for_save(self) -> None:
583543 ucm_block_ids , vllm_block_ids = request .dump_block_ids
584544 if self .global_rank != 0 :
585545 for i , ucm_block_id in enumerate (ucm_block_ids ):
586- ucm_block_ids [i ] = str (self .request_hasher (ucm_block_id ))
587- rets = self .store .create (ucm_block_ids )
588- end = 0
589- for i , ret in enumerate (rets ):
590- if ret != 0 :
591- logger .error (
592- f"create blocks for { request_id } failed, block index: { i } , ret code: { ret } "
593- )
594- break
595- end += 1
596-
597- if end == 0 :
598- continue
599- ucm_block_ids = ucm_block_ids [:end ]
600- vllm_block_ids = vllm_block_ids [:end ]
601- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
546+ ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
547+ block_ids , shard_indexs , tensors = self ._generate_task (
602548 vllm_block_ids , ucm_block_ids
603549 )
604550 request_to_task [request_id ] = self .store .dump (
605- ucm_total_block_ids , ucm_offsets , dst_tensor_addr
551+ block_ids , shard_indexs , tensors
606552 )
607- request_to_blocks [request_id ] = ucm_block_ids
608553
609554 for request_id , task in request_to_task .items ():
610- ucm_block_ids = request_to_blocks [request_id ]
611- if self .store .wait (task ) == 0 :
612- self .store .commit (ucm_block_ids , True )
613- else :
614- logger .error (f"request { request_id } dump kv cache failed." )
615- self .store .commit (ucm_block_ids , False )
555+ self .store .wait (task )
616556 save_end_time = time .perf_counter () * 1000
617557 save_speed = (
618558 num_saved_block
0 commit comments