Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):

if current_platform.is_cuda_alike():
logger.info("CUDA device is available.")
torch_dev = torch
self.torch_dev = torch.cuda
dev_name = "cuda"
elif current_platform.device_type == "npu":
logger.info("NPU device is available.")
torch_dev = torch.npu
self.torch_dev = torch.npu
dev_name = "npu"
else:
raise RuntimeError("Unsupported device platform for UCMDirectConnector.")

if self.local_rank >= 0:
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
self.device = torch.device(f"{dev_name}:{self.local_rank}")
self._layer_offset_cache = {}

self.store: UcmKVStoreBase
Expand All @@ -132,7 +132,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
if role == KVConnectorRole.WORKER:
self.group_coordinator = get_tp_group()
self.broadcast_fn = self.group_coordinator.broadcast
self.broadcast_stream = torch.cuda.Stream()
self.broadcast_stream = self.torch_dev.Stream()
self._broadcast_buffer = None
self._broadcast_buffer_size = 0

logger.info(f"self.launch_config: {self.launch_config}")
connector_configs = self.launch_config.get("ucm_connectors", [])
Expand Down Expand Up @@ -179,12 +181,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
)
self.monitor = ucmmonitor.StatsMonitor.get_instance()

self.synchronize = (
torch.cuda.synchronize
if current_platform.is_cuda_alike()
else torch.npu.synchronize
)

# invlalid block ids due to load errors
self._invalid_block_ids: set[int] = set()

Expand Down Expand Up @@ -433,7 +429,7 @@ def _get_tensor_and_offset(
if not self.is_mla:
v_tensors.append(kv_layer[1][vllm_block_id])
v_offsets.append(v_offset)
return k_tensors + v_tensors, k_offsets + v_offsets
return (k_tensors, v_tensors), k_offsets + v_offsets

def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
if not self._layer_offset_cache:
Expand All @@ -443,41 +439,62 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
num_blocks_per_layer = len(vllm_block_ids)
num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
k_tensor_addr = [None] * num_blocks_per_layer
v_tensor_addr = [None] * num_blocks_per_layer
ucm_offsets = [0] * (num_layers * num_tensors_per_layer)

idx = 0
kv_idx = 0
for layer_name, one_layer_kv_cache in self.kv_caches.items():
tensors, offsets = self._get_tensor_and_offset(
(k_tensors, v_tensors), offsets = self._get_tensor_and_offset(
vllm_block_ids, one_layer_kv_cache, layer_name
)
tensors = k_tensors + v_tensors
k_tensor_addr[kv_idx : kv_idx + len(k_tensors)] = k_tensors
if v_tensors:
v_tensor_addr[kv_idx : kv_idx + len(v_tensors)] = v_tensors
dst_tensor_addr[idx : idx + len(tensors)] = tensors
ucm_offsets[idx : idx + len(offsets)] = offsets
idx += len(tensors)
kv_idx += len(k_tensors)

repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
ucm_total_block_ids = ucm_block_ids * repeat_times

assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
return (
ucm_total_block_ids,
ucm_offsets,
dst_tensor_addr,
(k_tensor_addr, v_tensor_addr),
)

def _ensure_buffer(self, total_numel: int):
if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel:
self._broadcast_buffer = torch.empty(
total_numel,
dtype=self.kv_cache_dtype,
device=self.device,
)
self._broadcast_buffer_size = total_numel

def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
rec_tensor: torch.Tensor = None
with torch.cuda.stream(self.broadcast_stream):
# TODO support broadcast when PP
total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel()
with self.torch_dev.stream(self.broadcast_stream):
if self.global_rank == 0:
tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0)
self.broadcast_fn(tensor_to_broadcast, 0)
else:
shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
# TODO create earlier
rec_tensor = torch.empty(
shape, dtype=self.kv_cache_dtype, device=self.device
)
self._ensure_buffer(total_numel)
rec_tensor = self._broadcast_buffer[:total_numel].view(shape)
self.broadcast_fn(rec_tensor, 0)
self.broadcast_stream.synchronize()

if self.global_rank != 0 and rec_tensor is not None:
for i, tensor in enumerate(dst_tensor_addr):
tensor.copy_(rec_tensor[i])
rec_tensor_list = list(torch.unbind(rec_tensor, dim=0))
torch._foreach_copy_(dst_tensor_addr, rec_tensor_list)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
metadata = self._get_connector_metadata()
Expand All @@ -502,16 +519,19 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
for i, ucm_block_id in enumerate(ucm_block_ids):
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
vllm_block_ids, ucm_block_ids
)
(
ucm_total_block_ids,
ucm_offsets,
dst_tensor_addr,
(k_tensor_addr, v_tensor_addr),
) = self._generate_task(vllm_block_ids, ucm_block_ids)
if self.global_rank == 0 or not self.load_only_first_rank:
request_to_task[request_id] = self.store.load(
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
)
else:
request_to_task[request_id] = None
req_broadcast_addr[request_id] = dst_tensor_addr
req_broadcast_addr[request_id] = (k_tensor_addr, v_tensor_addr)

for request_id, task in request_to_task.items():
# TODO error handling
Expand All @@ -522,7 +542,16 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
)
logger.error(f"request {request_id} load kv cache failed.")
if self.load_only_first_rank:
self._broadcast(req_broadcast_addr[request_id])
if self.is_mla:
self._broadcast(req_broadcast_addr[request_id][0])
else:
for kv_addrs in req_broadcast_addr[request_id]:
self._broadcast(kv_addrs)
if not self.is_dsa:
logger.warning(
"For best performance, do not load only first rank in non-mla models"
)

load_end_time = time.perf_counter() * 1000
load_speed = (
num_loaded_block
Expand Down Expand Up @@ -562,7 +591,7 @@ def wait_for_save(self) -> None:
if self.metrics_config or current_platform.device_type == "npu":
# When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise
# This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions.
self.synchronize()
self.torch_dev.synchronize()

metadata = self._get_connector_metadata()
assert isinstance(metadata, UCMConnectorMetadata)
Expand Down Expand Up @@ -598,7 +627,7 @@ def wait_for_save(self) -> None:
continue
ucm_block_ids = ucm_block_ids[:end]
vllm_block_ids = vllm_block_ids[:end]
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
ucm_total_block_ids, ucm_offsets, dst_tensor_addr, _ = self._generate_task(
vllm_block_ids, ucm_block_ids
)
request_to_task[request_id] = self.store.dump(
Expand Down