@@ -99,17 +99,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9999
100100 if current_platform .is_cuda_alike ():
101101 logger .info ("CUDA device is available." )
102- torch_dev = torch
102+ self . torch_dev = torch . cuda
103103 dev_name = "cuda"
104104 elif current_platform .device_type == "npu" :
105105 logger .info ("NPU device is available." )
106- torch_dev = torch .npu
106+ self . torch_dev = torch .npu
107107 dev_name = "npu"
108108 else :
109109 raise RuntimeError ("Unsupported device platform for UCMDirectConnector." )
110110
111111 if self .local_rank >= 0 :
112- self .device = torch_dev .device (f"{ dev_name } :{ self .local_rank } " )
112+ self .device = torch .device (f"{ dev_name } :{ self .local_rank } " )
113113 self ._layer_offset_cache = {}
114114
115115 self .store : UcmKVStoreBase
@@ -132,7 +132,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
132132 if role == KVConnectorRole .WORKER :
133133 self .group_coordinator = get_tp_group ()
134134 self .broadcast_fn = self .group_coordinator .broadcast
135- self .broadcast_stream = torch .cuda .Stream ()
135+ self .broadcast_stream = self .torch_dev .Stream ()
136+ self ._broadcast_buffer = None
137+ self ._broadcast_buffer_size = 0
136138
137139 logger .info (f"self.launch_config: { self .launch_config } " )
138140 connector_configs = self .launch_config .get ("ucm_connectors" , [])
@@ -179,12 +181,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
179181 )
180182 self .monitor = ucmmonitor .StatsMonitor .get_instance ()
181183
182- self .synchronize = (
183- torch .cuda .synchronize
184- if current_platform .is_cuda_alike ()
185- else torch .npu .synchronize
186- )
187-
188184 # invlalid block ids due to load errors
189185 self ._invalid_block_ids : set [int ] = set ()
190186
@@ -433,7 +429,7 @@ def _get_tensor_and_offset(
433429 if not self .is_mla :
434430 v_tensors .append (kv_layer [1 ][vllm_block_id ])
435431 v_offsets .append (v_offset )
436- return k_tensors + v_tensors , k_offsets + v_offsets
432+ return ( k_tensors , v_tensors ) , k_offsets + v_offsets
437433
438434 def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [str ]):
439435 if not self ._layer_offset_cache :
@@ -443,41 +439,62 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
443439 num_blocks_per_layer = len (vllm_block_ids )
444440 num_tensors_per_layer = num_blocks_per_layer * (1 if self .is_mla else 2 )
445441 dst_tensor_addr = [None ] * (num_layers * num_tensors_per_layer )
442+ k_tensor_addr = [None ] * num_blocks_per_layer
443+ v_tensor_addr = [None ] * num_blocks_per_layer
446444 ucm_offsets = [0 ] * (num_layers * num_tensors_per_layer )
447445
448446 idx = 0
447+ kv_idx = 0
449448 for layer_name , one_layer_kv_cache in self .kv_caches .items ():
450- tensors , offsets = self ._get_tensor_and_offset (
449+ ( k_tensors , v_tensors ) , offsets = self ._get_tensor_and_offset (
451450 vllm_block_ids , one_layer_kv_cache , layer_name
452451 )
452+ tensors = k_tensors + v_tensors
453+ k_tensor_addr [kv_idx : kv_idx + len (k_tensors )] = k_tensors
454+ if v_tensors :
455+ v_tensor_addr [kv_idx : kv_idx + len (v_tensors )] = v_tensors
453456 dst_tensor_addr [idx : idx + len (tensors )] = tensors
454457 ucm_offsets [idx : idx + len (offsets )] = offsets
455458 idx += len (tensors )
459+ kv_idx += len (k_tensors )
456460
457461 repeat_times = len (self .kv_caches ) * (1 if self .is_mla else 2 )
458462 ucm_total_block_ids = ucm_block_ids * repeat_times
459463
460464 assert len (ucm_total_block_ids ) == len (ucm_offsets ) == len (dst_tensor_addr )
461- return ucm_total_block_ids , ucm_offsets , dst_tensor_addr
465+ return (
466+ ucm_total_block_ids ,
467+ ucm_offsets ,
468+ dst_tensor_addr ,
469+ (k_tensor_addr , v_tensor_addr ),
470+ )
471+
472+ def _ensure_buffer (self , total_numel : int ):
473+ if self ._broadcast_buffer is None or self ._broadcast_buffer_size < total_numel :
474+ self ._broadcast_buffer = torch .empty (
475+ total_numel ,
476+ dtype = self .kv_cache_dtype ,
477+ device = self .device ,
478+ )
479+ self ._broadcast_buffer_size = total_numel
462480
463481 def _broadcast (self , dst_tensor_addr : list [torch .Tensor ]):
464482 rec_tensor : torch .Tensor = None
465- with torch . cuda . stream ( self . broadcast_stream ):
466- # TODO support broadcast when PP
483+ total_numel = len ( dst_tensor_addr ) * dst_tensor_addr [ 0 ]. numel ()
484+ with self . torch_dev . stream ( self . broadcast_stream ):
467485 if self .global_rank == 0 :
468486 tensor_to_broadcast = torch .stack (dst_tensor_addr , dim = 0 )
469487 self .broadcast_fn (tensor_to_broadcast , 0 )
470488 else :
471489 shape = (len (dst_tensor_addr ),) + dst_tensor_addr [0 ].shape
472- # TODO create earlier
473- rec_tensor = torch .empty (
474- shape , dtype = self .kv_cache_dtype , device = self .device
475- )
490+ self ._ensure_buffer (total_numel )
491+ rec_tensor = self ._broadcast_buffer [:total_numel ].view (shape )
476492 self .broadcast_fn (rec_tensor , 0 )
477493 self .broadcast_stream .synchronize ()
494+
478495 if self .global_rank != 0 and rec_tensor is not None :
479- for i , tensor in enumerate ( dst_tensor_addr ):
480- tensor . copy_ ( rec_tensor [ i ] )
496+ rec_tensor_list = list ( torch . unbind ( rec_tensor , dim = 0 ))
497+ torch . _foreach_copy_ ( dst_tensor_addr , rec_tensor_list )
481498
482499 def start_load_kv (self , forward_context : "ForwardContext" , ** kwargs ) -> None :
483500 metadata = self ._get_connector_metadata ()
@@ -502,16 +519,19 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
502519 if self .global_rank != 0 and not self .is_mla and not self .is_dsa :
503520 for i , ucm_block_id in enumerate (ucm_block_ids ):
504521 ucm_block_ids [i ] = str (self .request_hasher (ucm_block_id ))
505- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
506- vllm_block_ids , ucm_block_ids
507- )
522+ (
523+ ucm_total_block_ids ,
524+ ucm_offsets ,
525+ dst_tensor_addr ,
526+ (k_tensor_addr , v_tensor_addr ),
527+ ) = self ._generate_task (vllm_block_ids , ucm_block_ids )
508528 if self .global_rank == 0 or not self .load_only_first_rank :
509529 request_to_task [request_id ] = self .store .load (
510530 ucm_total_block_ids , ucm_offsets , dst_tensor_addr
511531 )
512532 else :
513533 request_to_task [request_id ] = None
514- req_broadcast_addr [request_id ] = dst_tensor_addr
534+ req_broadcast_addr [request_id ] = ( k_tensor_addr , v_tensor_addr )
515535
516536 for request_id , task in request_to_task .items ():
517537 # TODO error handling
@@ -522,7 +542,16 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
522542 )
523543 logger .error (f"request { request_id } load kv cache failed." )
524544 if self .load_only_first_rank :
525- self ._broadcast (req_broadcast_addr [request_id ])
545+ if self .is_mla :
546+ self ._broadcast (req_broadcast_addr [request_id ][0 ])
547+ else :
548+ for kv_addrs in req_broadcast_addr [request_id ]:
549+ self ._broadcast (kv_addrs )
550+ if not self .is_dsa :
551+ logger .warning (
552+ "For best performance, do not load only first rank in non-mla models"
553+ )
554+
526555 load_end_time = time .perf_counter () * 1000
527556 load_speed = (
528557 num_loaded_block
@@ -562,7 +591,7 @@ def wait_for_save(self) -> None:
562591 if self .metrics_config or current_platform .device_type == "npu" :
563592 # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise
564593 # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions.
565- self .synchronize ()
594+ self .torch_dev . synchronize ()
566595
567596 metadata = self ._get_connector_metadata ()
568597 assert isinstance (metadata , UCMConnectorMetadata )
@@ -598,7 +627,7 @@ def wait_for_save(self) -> None:
598627 continue
599628 ucm_block_ids = ucm_block_ids [:end ]
600629 vllm_block_ids = vllm_block_ids [:end ]
601- ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
630+ ucm_total_block_ids , ucm_offsets , dst_tensor_addr , _ = self ._generate_task (
602631 vllm_block_ids , ucm_block_ids
603632 )
604633 request_to_task [request_id ] = self .store .dump (
0 commit comments