diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 66216a255..c7755a59f 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -99,17 +99,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") - torch_dev = torch + self.torch_dev = torch.cuda dev_name = "cuda" elif current_platform.device_type == "npu": logger.info("NPU device is available.") - torch_dev = torch.npu + self.torch_dev = torch.npu dev_name = "npu" else: raise RuntimeError("Unsupported device platform for UCMDirectConnector.") if self.local_rank >= 0: - self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") + self.device = torch.device(f"{dev_name}:{self.local_rank}") self._layer_offset_cache = {} self.store: UcmKVStoreBase @@ -132,7 +132,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if role == KVConnectorRole.WORKER: self.group_coordinator = get_tp_group() self.broadcast_fn = self.group_coordinator.broadcast - self.broadcast_stream = torch.cuda.Stream() + self.broadcast_stream = self.torch_dev.Stream() + self._broadcast_buffer = None + self._broadcast_buffer_size = 0 logger.info(f"self.launch_config: {self.launch_config}") connector_configs = self.launch_config.get("ucm_connectors", []) @@ -179,12 +181,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.monitor = ucmmonitor.StatsMonitor.get_instance() - self.synchronize = ( - torch.cuda.synchronize - if current_platform.is_cuda_alike() - else torch.npu.synchronize - ) - # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() @@ -433,7 +429,7 @@ def _get_tensor_and_offset( if not self.is_mla: v_tensors.append(kv_layer[1][vllm_block_id]) v_offsets.append(v_offset) - return k_tensors + v_tensors, k_offsets + v_offsets + return (k_tensors, v_tensors), k_offsets + v_offsets def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): if not self._layer_offset_cache: @@ -443,41 +439,62 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): num_blocks_per_layer = len(vllm_block_ids) num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2) dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer) + k_tensor_addr = [None] * num_blocks_per_layer + v_tensor_addr = [None] * num_blocks_per_layer ucm_offsets = [0] * (num_layers * num_tensors_per_layer) idx = 0 + kv_idx = 0 for layer_name, one_layer_kv_cache in self.kv_caches.items(): - tensors, offsets = self._get_tensor_and_offset( + (k_tensors, v_tensors), offsets = self._get_tensor_and_offset( vllm_block_ids, one_layer_kv_cache, layer_name ) + tensors = k_tensors + v_tensors + k_tensor_addr[kv_idx : kv_idx + len(k_tensors)] = k_tensors + if v_tensors: + v_tensor_addr[kv_idx : kv_idx + len(v_tensors)] = v_tensors dst_tensor_addr[idx : idx + len(tensors)] = tensors ucm_offsets[idx : idx + len(offsets)] = offsets idx += len(tensors) + kv_idx += len(k_tensors) repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2) ucm_total_block_ids = ucm_block_ids * repeat_times assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) - return ucm_total_block_ids, ucm_offsets, dst_tensor_addr + return ( + ucm_total_block_ids, + ucm_offsets, + dst_tensor_addr, + (k_tensor_addr, v_tensor_addr), + ) + + def _ensure_buffer(self, total_numel: int): + if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel: + self._broadcast_buffer = torch.empty( + total_numel, + dtype=self.kv_cache_dtype, + device=self.device, + ) + self._broadcast_buffer_size = total_numel def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): rec_tensor: torch.Tensor = None - with torch.cuda.stream(self.broadcast_stream): - # TODO support broadcast when PP + total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel() + with self.torch_dev.stream(self.broadcast_stream): if self.global_rank == 0: tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0) self.broadcast_fn(tensor_to_broadcast, 0) else: shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape - # TODO create earlier - rec_tensor = torch.empty( - shape, dtype=self.kv_cache_dtype, device=self.device - ) + self._ensure_buffer(total_numel) + rec_tensor = self._broadcast_buffer[:total_numel].view(shape) self.broadcast_fn(rec_tensor, 0) self.broadcast_stream.synchronize() + if self.global_rank != 0 and rec_tensor is not None: - for i, tensor in enumerate(dst_tensor_addr): - tensor.copy_(rec_tensor[i]) + rec_tensor_list = list(torch.unbind(rec_tensor, dim=0)) + torch._foreach_copy_(dst_tensor_addr, rec_tensor_list) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() @@ -502,16 +519,19 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if self.global_rank != 0 and not self.is_mla and not self.is_dsa: for i, ucm_block_id in enumerate(ucm_block_ids): ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( - vllm_block_ids, ucm_block_ids - ) + ( + ucm_total_block_ids, + ucm_offsets, + dst_tensor_addr, + (k_tensor_addr, v_tensor_addr), + ) = self._generate_task(vllm_block_ids, ucm_block_ids) if self.global_rank == 0 or not self.load_only_first_rank: request_to_task[request_id] = self.store.load( ucm_total_block_ids, ucm_offsets, dst_tensor_addr ) else: request_to_task[request_id] = None - req_broadcast_addr[request_id] = dst_tensor_addr + req_broadcast_addr[request_id] = (k_tensor_addr, v_tensor_addr) for request_id, task in request_to_task.items(): # TODO error handling @@ -522,7 +542,16 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: ) logger.error(f"request {request_id} load kv cache failed.") if self.load_only_first_rank: - self._broadcast(req_broadcast_addr[request_id]) + if self.is_mla: + self._broadcast(req_broadcast_addr[request_id][0]) + else: + for kv_addrs in req_broadcast_addr[request_id]: + self._broadcast(kv_addrs) + if not self.is_dsa: + logger.warning( + "For best performance, do not load only first rank in non-mla models" + ) + load_end_time = time.perf_counter() * 1000 load_speed = ( num_loaded_block @@ -562,7 +591,7 @@ def wait_for_save(self) -> None: if self.metrics_config or current_platform.device_type == "npu": # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions. - self.synchronize() + self.torch_dev.synchronize() metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) @@ -598,7 +627,7 @@ def wait_for_save(self) -> None: continue ucm_block_ids = ucm_block_ids[:end] vllm_block_ids = vllm_block_ids[:end] - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + ucm_total_block_ids, ucm_offsets, dst_tensor_addr, _ = self._generate_task( vllm_block_ids, ucm_block_ids ) request_to_task[request_id] = self.store.dump(