Skip to content

Commit 7836ab7

Browse files
committed
[Feat]Support broadcast on ascend
1 parent 523bbc4 commit 7836ab7

File tree

1 file changed

+57
-28
lines changed

1 file changed

+57
-28
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,17 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9999

100100
if current_platform.is_cuda_alike():
101101
logger.info("CUDA device is available.")
102-
torch_dev = torch
102+
self.torch_dev = torch.cuda
103103
dev_name = "cuda"
104104
elif current_platform.device_type == "npu":
105105
logger.info("NPU device is available.")
106-
torch_dev = torch.npu
106+
self.torch_dev = torch.npu
107107
dev_name = "npu"
108108
else:
109109
raise RuntimeError("Unsupported device platform for UCMDirectConnector.")
110110

111111
if self.local_rank >= 0:
112-
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
112+
self.device = torch.device(f"{dev_name}:{self.local_rank}")
113113
self._layer_offset_cache = {}
114114

115115
self.store: UcmKVStoreBase
@@ -132,7 +132,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
132132
if role == KVConnectorRole.WORKER:
133133
self.group_coordinator = get_tp_group()
134134
self.broadcast_fn = self.group_coordinator.broadcast
135-
self.broadcast_stream = torch.cuda.Stream()
135+
self.broadcast_stream = self.torch_dev.Stream()
136+
self._broadcast_buffer = None
137+
self._broadcast_buffer_size = 0
136138

137139
logger.info(f"self.launch_config: {self.launch_config}")
138140
connector_configs = self.launch_config.get("ucm_connectors", [])
@@ -179,12 +181,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
179181
)
180182
self.monitor = ucmmonitor.StatsMonitor.get_instance()
181183

182-
self.synchronize = (
183-
torch.cuda.synchronize
184-
if current_platform.is_cuda_alike()
185-
else torch.npu.synchronize
186-
)
187-
188184
# invlalid block ids due to load errors
189185
self._invalid_block_ids: set[int] = set()
190186

@@ -433,7 +429,7 @@ def _get_tensor_and_offset(
433429
if not self.is_mla:
434430
v_tensors.append(kv_layer[1][vllm_block_id])
435431
v_offsets.append(v_offset)
436-
return k_tensors + v_tensors, k_offsets + v_offsets
432+
return (k_tensors, v_tensors), k_offsets + v_offsets
437433

438434
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
439435
if not self._layer_offset_cache:
@@ -443,41 +439,62 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
443439
num_blocks_per_layer = len(vllm_block_ids)
444440
num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
445441
dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
442+
k_tensor_addr = [None] * num_blocks_per_layer
443+
v_tensor_addr = [None] * num_blocks_per_layer
446444
ucm_offsets = [0] * (num_layers * num_tensors_per_layer)
447445

448446
idx = 0
447+
kv_idx = 0
449448
for layer_name, one_layer_kv_cache in self.kv_caches.items():
450-
tensors, offsets = self._get_tensor_and_offset(
449+
(k_tensors, v_tensors), offsets = self._get_tensor_and_offset(
451450
vllm_block_ids, one_layer_kv_cache, layer_name
452451
)
452+
tensors = k_tensors + v_tensors
453+
k_tensor_addr[kv_idx : kv_idx + len(k_tensors)] = k_tensors
454+
if v_tensors:
455+
v_tensor_addr[kv_idx : kv_idx + len(v_tensors)] = v_tensors
453456
dst_tensor_addr[idx : idx + len(tensors)] = tensors
454457
ucm_offsets[idx : idx + len(offsets)] = offsets
455458
idx += len(tensors)
459+
kv_idx += len(k_tensors)
456460

457461
repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
458462
ucm_total_block_ids = ucm_block_ids * repeat_times
459463

460464
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
461-
return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
465+
return (
466+
ucm_total_block_ids,
467+
ucm_offsets,
468+
dst_tensor_addr,
469+
(k_tensor_addr, v_tensor_addr),
470+
)
471+
472+
def _ensure_buffer(self, total_numel: int):
473+
if self._broadcast_buffer is None or self._broadcast_buffer_size < total_numel:
474+
self._broadcast_buffer = torch.empty(
475+
total_numel,
476+
dtype=self.kv_cache_dtype,
477+
device=self.device,
478+
)
479+
self._broadcast_buffer_size = total_numel
462480

463481
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
464482
rec_tensor: torch.Tensor = None
465-
with torch.cuda.stream(self.broadcast_stream):
466-
# TODO support broadcast when PP
483+
total_numel = len(dst_tensor_addr) * dst_tensor_addr[0].numel()
484+
with self.torch_dev.stream(self.broadcast_stream):
467485
if self.global_rank == 0:
468486
tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0)
469487
self.broadcast_fn(tensor_to_broadcast, 0)
470488
else:
471489
shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
472-
# TODO create earlier
473-
rec_tensor = torch.empty(
474-
shape, dtype=self.kv_cache_dtype, device=self.device
475-
)
490+
self._ensure_buffer(total_numel)
491+
rec_tensor = self._broadcast_buffer[:total_numel].view(shape)
476492
self.broadcast_fn(rec_tensor, 0)
477493
self.broadcast_stream.synchronize()
494+
478495
if self.global_rank != 0 and rec_tensor is not None:
479-
for i, tensor in enumerate(dst_tensor_addr):
480-
tensor.copy_(rec_tensor[i])
496+
rec_tensor_list = list(torch.unbind(rec_tensor, dim=0))
497+
torch._foreach_copy_(dst_tensor_addr, rec_tensor_list)
481498

482499
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
483500
metadata = self._get_connector_metadata()
@@ -502,16 +519,19 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
502519
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
503520
for i, ucm_block_id in enumerate(ucm_block_ids):
504521
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
505-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
506-
vllm_block_ids, ucm_block_ids
507-
)
522+
(
523+
ucm_total_block_ids,
524+
ucm_offsets,
525+
dst_tensor_addr,
526+
(k_tensor_addr, v_tensor_addr),
527+
) = self._generate_task(vllm_block_ids, ucm_block_ids)
508528
if self.global_rank == 0 or not self.load_only_first_rank:
509529
request_to_task[request_id] = self.store.load(
510530
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
511531
)
512532
else:
513533
request_to_task[request_id] = None
514-
req_broadcast_addr[request_id] = dst_tensor_addr
534+
req_broadcast_addr[request_id] = (k_tensor_addr, v_tensor_addr)
515535

516536
for request_id, task in request_to_task.items():
517537
# TODO error handling
@@ -522,7 +542,16 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
522542
)
523543
logger.error(f"request {request_id} load kv cache failed.")
524544
if self.load_only_first_rank:
525-
self._broadcast(req_broadcast_addr[request_id])
545+
if self.is_mla:
546+
self._broadcast(req_broadcast_addr[request_id][0])
547+
else:
548+
for kv_addrs in req_broadcast_addr[request_id]:
549+
self._broadcast(kv_addrs)
550+
if not self.is_dsa:
551+
logger.warning(
552+
"For best performance, do not load only first rank in non-mla models"
553+
)
554+
526555
load_end_time = time.perf_counter() * 1000
527556
load_speed = (
528557
num_loaded_block
@@ -562,7 +591,7 @@ def wait_for_save(self) -> None:
562591
if self.metrics_config or current_platform.device_type == "npu":
563592
# When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise
564593
# This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions.
565-
self.synchronize()
594+
self.torch_dev.synchronize()
566595

567596
metadata = self._get_connector_metadata()
568597
assert isinstance(metadata, UCMConnectorMetadata)
@@ -598,7 +627,7 @@ def wait_for_save(self) -> None:
598627
continue
599628
ucm_block_ids = ucm_block_ids[:end]
600629
vllm_block_ids = vllm_block_ids[:end]
601-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
630+
ucm_total_block_ids, ucm_offsets, dst_tensor_addr, _ = self._generate_task(
602631
vllm_block_ids, ucm_block_ids
603632
)
604633
request_to_task[request_id] = self.store.dump(

0 commit comments

Comments
 (0)