@@ -95,6 +95,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9595 self .block_size = self ._vllm_config .cache_config .block_size
9696 self .is_mla = self ._vllm_config .model_config .is_deepseek_mla
9797 self .is_dsa = False
98+ self .num_layers = self ._vllm_config .model_config .get_num_layers (
99+ self ._vllm_config .parallel_config
100+ )
98101 self .kv_cache_dtype : torch .dtype = None
99102
100103 if current_platform .is_cuda_alike ():
@@ -111,7 +114,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
111114 if self .local_rank >= 0 :
112115 self .device = torch_dev .device (f"{ dev_name } :{ self .local_rank } " )
113116
114- self .store : UcmKVStoreBaseV1
117+ self .k_store : UcmKVStoreBaseV1
118+ self .v_store : Optional [UcmKVStoreBaseV1 ] = None
115119
116120 if role == KVConnectorRole .SCHEDULER :
117121 self .request_hasher = RequestHasher (vllm_config , 0 )
@@ -134,40 +138,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
134138 self .broadcast_stream = torch .cuda .Stream ()
135139
136140 logger .info (f"self.launch_config: { self .launch_config } " )
137- connector_configs = self .launch_config .get ("ucm_connectors" , [])
138- assert len (connector_configs ) > 0 , "no storage connector name in config."
139-
140- name = connector_configs [0 ].get ("ucm_connector_name" )
141- config = connector_configs [0 ].get ("ucm_connector_config" ) or {}
142- config ["device" ] = self .local_rank
143- config ["role" ] = "scheduler" if role == KVConnectorRole .SCHEDULER else "worker"
144- element_size = vllm_config .model_config .dtype .itemsize
145- single_head_dim = vllm_config .model_config .get_head_size ()
146- num_head_per_tp = vllm_config .model_config .get_num_kv_heads (
147- vllm_config .parallel_config
148- )
149- total_tp_size = vllm_config .parallel_config .tensor_parallel_size
150- num_layers = vllm_config .model_config .get_num_layers (
151- vllm_config .parallel_config
152- )
153- block_size_per_layer = self .block_size * element_size * single_head_dim
154- config ["kv_block_size" ] = (
155- block_size_per_layer
156- * num_layers
157- * (1 if self .is_mla else num_head_per_tp * 2 )
158- )
159- config ["io_size" ] = block_size_per_layer * (
160- 1 if self .is_mla else num_head_per_tp
161- )
162- self .store = UcmConnectorFactoryV1 .create_connector (name , config )
163- self .block_data_size = config ["kv_block_size" ]
164-
165- logger .info ("init UCConnectorImpl, connector: %s" , name )
166- logger .info (
167- "single file size = %d MB, io_size = %d KB," ,
168- config ["kv_block_size" ] / 1024 / 1024 ,
169- config ["io_size" ] / 1024 ,
170- )
171141
172142 self .metrics_config = self .launch_config .get ("metrics_config_path" , "" )
173143 if self .metrics_config :
@@ -208,6 +178,84 @@ def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
208178
209179 return ret
210180
181+ def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
182+ self .kv_caches = kv_caches
183+ sample_kv_layer = next (iter (self .kv_caches .values ()))
184+ if isinstance (sample_kv_layer , torch .Tensor ):
185+ logger .info (f"kv cache shape { sample_kv_layer .shape } " )
186+ if self .kv_cache_dtype is None :
187+ self .kv_cache_dtype = sample_kv_layer .dtype
188+ elif isinstance (sample_kv_layer , Tuple ):
189+ # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple
190+ # [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)]
191+ # Currently, we treat it as GQA, and use is_dsa to mark it
192+ for i , tensor in enumerate (sample_kv_layer ):
193+ logger .info (f"kv cache shape { i } : { tensor .shape } " )
194+ if self .kv_cache_dtype is None :
195+ self .kv_cache_dtype = sample_kv_layer [0 ].dtype
196+ if self .is_mla :
197+ self .is_mla = False
198+ self .is_dsa = True
199+
200+ # When handling the GQA case, we will separately dump the k_cache and v_cache.
201+ connector_configs = self .launch_config .get ("ucm_connectors" , [])
202+ assert len (connector_configs ) > 0 , "no storage connector name in config."
203+
204+ name = connector_configs [0 ].get ("ucm_connector_name" )
205+ config = connector_configs [0 ].get ("ucm_connector_config" ) or {}
206+ config ["device" ] = self .local_rank
207+ config ["role" ] = (
208+ "scheduler" if self ._role == KVConnectorRole .SCHEDULER else "worker"
209+ )
210+ if len (sample_kv_layer ) == 2 :
211+ storage_backends = config ["storage_backends" ]
212+ k_dir = os .path .join (storage_backends , "k" )
213+ v_dir = os .path .join (storage_backends , "v" )
214+ os .makedirs (k_dir , exist_ok = True )
215+ os .makedirs (v_dir , exist_ok = True )
216+ logger .info (f"Created subdirectories: { k_dir } , { v_dir } " )
217+
218+ k_io_size = (
219+ sample_kv_layer [0 ][0 ].numel () * sample_kv_layer [0 ][0 ].element_size ()
220+ )
221+ config ["io_size" ] = k_io_size
222+ config ["kv_block_size" ] = k_io_size * self .num_layers
223+ config ["storage_backends" ] = k_dir
224+ self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
225+ logger .info ("init UCConnectorImpl, k_connector: %s" , name )
226+ logger .info (
227+ "single file size = %d MB, io_size = %d KB," ,
228+ config ["kv_block_size" ] / 1024 / 1024 ,
229+ config ["io_size" ] / 1024 ,
230+ )
231+
232+ v_io_size = (
233+ sample_kv_layer [1 ][0 ].numel () * sample_kv_layer [1 ][0 ].element_size ()
234+ )
235+ config ["io_size" ] = v_io_size
236+ config ["kv_block_size" ] = v_io_size * self .num_layers
237+ config ["storage_backends" ] = v_dir
238+ self .v_store = UcmConnectorFactoryV1 .create_connector (name , config )
239+ logger .info ("init UCConnectorImpl, v_connector: %s" , name )
240+ logger .info (
241+ "single file size = %d MB, io_size = %d KB," ,
242+ config ["kv_block_size" ] / 1024 / 1024 ,
243+ config ["io_size" ] / 1024 ,
244+ )
245+ self .block_data_size = (k_io_size + v_io_size ) * self .num_layers
246+ else :
247+ k_io_size = sample_kv_layer [0 ].numel () * sample_kv_layer [0 ].element_size ()
248+ config ["io_size" ] = k_io_size
249+ config ["kv_block_size" ] = k_io_size * self .num_layers
250+ self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
251+ logger .info ("init UCConnectorImpl, k_connector: %s" , name )
252+ logger .info (
253+ "single file size = %d MB, io_size = %d KB," ,
254+ config ["kv_block_size" ] / 1024 / 1024 ,
255+ config ["io_size" ] / 1024 ,
256+ )
257+ self .block_data_size = k_io_size * self .num_layers
258+
211259 def get_num_new_matched_tokens (
212260 self ,
213261 request : "Request" ,
@@ -222,7 +270,7 @@ def get_num_new_matched_tokens(
222270 if not external_block_ids :
223271 return 0 , False
224272
225- lookup_results = self .store .lookup (external_block_ids )
273+ lookup_results = self .k_store .lookup (external_block_ids )
226274 external_hit_blocks = 0
227275 for i , hit in enumerate (lookup_results ):
228276 if not hit :
@@ -412,15 +460,18 @@ def _get_tensors(
412460
413461 def _generate_task (
414462 self , vllm_block_ids : List [int ], ucm_block_ids : List [bytes ]
415- ) -> Tuple [List [bytes ], List [int ], List [List [torch .Tensor ]]]:
416- block_ids , shard_indexs , tensors = [], [], []
463+ ) -> Tuple [
464+ List [bytes ], List [int ], List [List [torch .Tensor ]], List [List [torch .Tensor ]]
465+ ]:
466+ block_ids , shard_indexs , total_k_tensors , total_v_tensors = [], [], [], []
417467 for i , vllm_block_id in enumerate (vllm_block_ids ):
418468 k_tensors , v_tensors = self ._get_tensors (vllm_block_id )
419469 block_ids .append (ucm_block_ids [i ])
420- tensors .append (k_tensors + v_tensors )
470+ total_k_tensors .append (k_tensors )
471+ total_v_tensors .append (v_tensors )
421472 shard_indexs .append (0 )
422473
423- return block_ids , shard_indexs , tensors
474+ return block_ids , shard_indexs , total_k_tensors , total_v_tensors
424475
425476 def _broadcast (self , dst_tensor_addr : list [torch .Tensor ]):
426477 rec_tensor : torch .Tensor = None
@@ -447,7 +498,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
447498
448499 self ._init_kv_caches_from_forward_context (forward_context )
449500
450- request_to_task : dict [str , Optional [Task ]] = {}
501+ request_to_task : dict [str , Optional [List [ Task ] ]] = {}
451502 req_broadcast_addr = {}
452503 is_load = False
453504 num_loaded_block = 0
@@ -464,22 +515,28 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
464515 if self .global_rank != 0 and not self .is_mla and not self .is_dsa :
465516 for i , ucm_block_id in enumerate (ucm_block_ids ):
466517 ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
467- block_ids , shard_indexs , tensors = self ._generate_task (
518+ block_ids , shard_indexs , k_tensors , v_tensors = self ._generate_task (
468519 vllm_block_ids , ucm_block_ids
469520 )
470521 if self .global_rank == 0 or not self .load_only_first_rank :
471- request_to_task [request_id ] = self .store .load (
472- block_ids , shard_indexs , tensors
473- )
522+ k_task = self .k_store .load (block_ids , shard_indexs , k_tensors )
523+ request_to_task [request_id ] = [k_task ]
524+ if v_tensors and self .v_store :
525+ v_task = self .v_store .load (block_ids , shard_indexs , v_tensors )
526+ request_to_task [request_id ].append (v_task )
474527 else :
475528 request_to_task [request_id ] = None
476- req_broadcast_addr [request_id ] = [t for row in tensors for t in row ]
529+ req_broadcast_addr [request_id ] = [t for row in k_tensors for t in row ] + [
530+ t for row in v_tensors for t in row
531+ ]
477532
478- for request_id , task in request_to_task .items ():
533+ for request_id , tasks in request_to_task .items ():
479534 # TODO error handling
480535 if self .global_rank == 0 or not self .load_only_first_rank :
481536 try :
482- self .store .wait (task )
537+ self .k_store .wait (tasks [0 ])
538+ if len (tasks ) > 1 and self .v_store :
539+ self .v_store .wait (tasks [1 ])
483540 except RuntimeError as e :
484541 logger .error ("request {request_id} load kv cache failed.:" , e )
485542 self ._invalid_block_ids .update (
@@ -531,7 +588,7 @@ def wait_for_save(self) -> None:
531588 metadata = self ._get_connector_metadata ()
532589 assert isinstance (metadata , UCMConnectorMetadata )
533590
534- request_to_task : dict [str , Task ] = {}
591+ request_to_task : dict [str , List [ Task ] ] = {}
535592 is_save = False
536593 num_saved_block = 0
537594 num_saved_request = 0
@@ -547,16 +604,20 @@ def wait_for_save(self) -> None:
547604 if self .global_rank != 0 :
548605 for i , ucm_block_id in enumerate (ucm_block_ids ):
549606 ucm_block_ids [i ] = self .request_hasher (ucm_block_id )
550- block_ids , shard_indexs , tensors = self ._generate_task (
607+ block_ids , shard_indexs , k_tensors , v_tensors = self ._generate_task (
551608 vllm_block_ids , ucm_block_ids
552609 )
553- request_to_task [request_id ] = self .store .dump (
554- block_ids , shard_indexs , tensors
555- )
610+ k_task = self .k_store .dump (block_ids , shard_indexs , k_tensors )
611+ request_to_task [request_id ] = [k_task ]
612+ if v_tensors and self .v_store :
613+ v_task = self .v_store .dump (block_ids , shard_indexs , v_tensors )
614+ request_to_task [request_id ].append (v_task )
556615
557- for request_id , task in request_to_task .items ():
616+ for request_id , tasks in request_to_task .items ():
558617 try :
559- self .store .wait (task )
618+ self .k_store .wait (tasks [0 ])
619+ if len (tasks ) > 1 and self .v_store :
620+ self .v_store .wait (tasks [1 ])
560621 except RuntimeError as e :
561622 logger .error ("request {request_id} dump kv cache failed.:" , e )
562623 save_end_time = time .perf_counter () * 1000
@@ -739,6 +800,16 @@ def update_state_after_alloc(
739800 """
740801 self .connector .update_state_after_alloc (request , blocks , num_external_tokens )
741802
803+ def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
804+ """
805+ Initialize with the KV caches. Useful for pre-registering the
806+ KV Caches in the KVConnector (e.g. for NIXL).
807+
808+ Args: kv_caches:
809+ dictionary of layer names, kv cache
810+ """
811+ self .connector .register_kv_caches (kv_caches )
812+
742813 def build_connector_meta (
743814 self , scheduler_output : SchedulerOutput
744815 ) -> KVConnectorMetadata :
0 commit comments