Skip to content

Commit 7c5142e

Browse files
committed
adapt store v1
1 parent 523bbc4 commit 7c5142e

File tree

9 files changed

+364
-191
lines changed

9 files changed

+364
-191
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 49 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pickle
55
import time
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Callable, List, Optional
7+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
88

99
import torch
1010
from vllm.config import VllmConfig
@@ -20,8 +20,8 @@
2020
from ucm.logger import init_logger
2121
from ucm.shared.metrics import ucmmonitor
2222
from ucm.shared.metrics.observability import UCMStatsLogger
23-
from ucm.store.factory import UcmConnectorFactory
24-
from ucm.store.ucmstore import Task, UcmKVStoreBase
23+
from ucm.store.factory_v1 import UcmConnectorFactoryV1
24+
from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1
2525
from ucm.utils import Config
2626

2727
if TYPE_CHECKING:
@@ -35,7 +35,7 @@
3535

3636
@dataclass
3737
class RequestMeta:
38-
ucm_block_ids: list[str] = field(default_factory=list)
38+
ucm_block_ids: list[bytes] = field(default_factory=list)
3939
hbm_hit_block_num: int = 0
4040
# local_computed_block + external_computed_block
4141
total_hit_block_num: int = 0
@@ -47,9 +47,9 @@ class RequestMeta:
4747
@dataclass
4848
class RequestDispatchMeta:
4949
load_block_ids: tuple[
50-
list[str], list[int]
50+
list[bytes], list[int]
5151
] # [0] mean ucm_block_ids, [1] means vllm_block_ids
52-
dump_block_ids: tuple[list[str], list[int]]
52+
dump_block_ids: tuple[list[bytes], list[int]]
5353

5454

5555
@dataclass
@@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id):
6969
if RequestHasher._SEED_HASH is None:
7070
RequestHasher._SEED_HASH = self("UCM_HASH_SEED")
7171

72-
def __call__(self, input_data) -> int:
73-
if isinstance(input_data, str):
74-
input_bytes = input_data.encode("utf-8")
72+
def __call__(self, input_data) -> bytes:
73+
if isinstance(input_data, bytes):
74+
input_bytes = input_data
7575
else:
7676
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
7777

7878
h = hashlib.md5(self.meta_bytes + input_bytes)
79-
return int.from_bytes(h.digest(), byteorder="big")
79+
return h.digest()
8080

8181

8282
class UCMDirectConnector(KVConnectorBase_V1):
@@ -110,9 +110,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
110110

111111
if self.local_rank >= 0:
112112
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
113-
self._layer_offset_cache = {}
114113

115-
self.store: UcmKVStoreBase
114+
self.store: UcmKVStoreBaseV1
116115

117116
if role == KVConnectorRole.SCHEDULER:
118117
self.request_hasher = RequestHasher(vllm_config, 0)
@@ -160,7 +159,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
160159
config["io_size"] = block_size_per_layer * (
161160
1 if self.is_mla else num_head_per_tp
162161
)
163-
self.store = UcmConnectorFactory.create_connector(name, config)
162+
self.store = UcmConnectorFactoryV1.create_connector(name, config)
164163
self.block_data_size = config["kv_block_size"]
165164

166165
logger.info("init UCConnectorImpl, connector: %s", name)
@@ -188,7 +187,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
188187
# invlalid block ids due to load errors
189188
self._invalid_block_ids: set[int] = set()
190189

191-
def generate_hash(self, block_size: int, request: "Request") -> list[str]:
190+
def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
192191
token_ids = request.all_token_ids
193192

194193
ret = []
@@ -205,7 +204,7 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]:
205204
(parent_block_hash_value, block_token_ids_tuple)
206205
)
207206
parent_block_hash_value = hash_value
208-
ret.append(str(hash_value))
207+
ret.append(hash_value)
209208

210209
return ret
211210

@@ -395,70 +394,37 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
395394
return int(chunk)
396395
return None
397396

398-
def _precompute_layer_offsets(self):
399-
if not self.kv_caches:
400-
return
401-
402-
sample_kv_layer = next(iter(self.kv_caches.values()))
403-
elem_size = sample_kv_layer[0].element_size()
404-
block_data_size = (
405-
sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel()
406-
) * elem_size
407-
layer_data_size = block_data_size if self.is_mla else block_data_size * 2
408-
409-
# precompute all layers offset
410-
for layer_name, _ in self.kv_caches.items():
411-
layer_id = self._extract_layer_index(layer_name)
412-
assert layer_id is not None
413-
k_offset = layer_data_size * layer_id
414-
v_offset = k_offset + block_data_size if not self.is_mla else 0
415-
self._layer_offset_cache[layer_name] = (k_offset, v_offset)
416-
417-
def _get_tensor_and_offset(
418-
self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str
419-
) -> tuple[list[torch.Tensor], list[int]]:
397+
def _get_tensors(
398+
self, vllm_block_id: int
399+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
420400
"""
421401
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
422402
MLA: one layer shape is (num_blocks, block_size, head_size)
423403
"""
424-
k_tensors, k_offsets = [], []
425-
v_tensors, v_offsets = [], []
426-
k_offset, v_offset = self._layer_offset_cache[layer_name]
427-
428-
for vllm_block_id in vllm_block_ids:
404+
k_tensors, v_tensors = [], []
405+
for _, kv_layer in self.kv_caches.items():
429406
k_tensors.append(
430407
kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id]
431408
)
432-
k_offsets.append(k_offset)
433409
if not self.is_mla:
434410
v_tensors.append(kv_layer[1][vllm_block_id])
435-
v_offsets.append(v_offset)
436-
return k_tensors + v_tensors, k_offsets + v_offsets
437-
438-
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
439-
if not self._layer_offset_cache:
440-
self._precompute_layer_offsets()
441-
442-
num_layers = len(self.kv_caches)
443-
num_blocks_per_layer = len(vllm_block_ids)
444-
num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
445-
dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
446-
ucm_offsets = [0] * (num_layers * num_tensors_per_layer)
447-
448-
idx = 0
449-
for layer_name, one_layer_kv_cache in self.kv_caches.items():
450-
tensors, offsets = self._get_tensor_and_offset(
451-
vllm_block_ids, one_layer_kv_cache, layer_name
452-
)
453-
dst_tensor_addr[idx : idx + len(tensors)] = tensors
454-
ucm_offsets[idx : idx + len(offsets)] = offsets
455-
idx += len(tensors)
411+
return k_tensors, v_tensors
456412

457-
repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
458-
ucm_total_block_ids = ucm_block_ids * repeat_times
413+
def _generate_task(
414+
self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]
415+
) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
416+
"""
417+
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
418+
MLA: one layer shape is (num_blocks, block_size, head_size)
419+
"""
420+
block_ids, shard_indexs, tensors = [], [], []
421+
for i, vllm_block_id in enumerate(vllm_block_ids):
422+
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
423+
block_ids.append(ucm_block_ids[i])
424+
tensors.append(k_tensors + v_tensors)
425+
shard_indexs.append(0)
459426

460-
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
461-
return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
427+
return block_ids, shard_indexs, tensors
462428

463429
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
464430
rec_tensor: torch.Tensor = None
@@ -501,26 +467,28 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
501467
ucm_block_ids, vllm_block_ids = request.load_block_ids
502468
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
503469
for i, ucm_block_id in enumerate(ucm_block_ids):
504-
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
505-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
470+
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
471+
block_ids, shard_indexs, tensors = self._generate_task(
506472
vllm_block_ids, ucm_block_ids
507473
)
508474
if self.global_rank == 0 or not self.load_only_first_rank:
509475
request_to_task[request_id] = self.store.load(
510-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
476+
block_ids, shard_indexs, tensors
511477
)
512478
else:
513479
request_to_task[request_id] = None
514-
req_broadcast_addr[request_id] = dst_tensor_addr
480+
req_broadcast_addr[request_id] = [t for row in tensors for t in row]
515481

516482
for request_id, task in request_to_task.items():
517483
# TODO error handling
518484
if self.global_rank == 0 or not self.load_only_first_rank:
519-
if self.store.wait(task) != 0:
485+
try:
486+
self.store.wait(task)
487+
except RuntimeError as e:
488+
logger.error("request {request_id} load kv cache failed.:", e)
520489
self._invalid_block_ids.update(
521490
metadata.request_meta[request_id].load_block_ids[1]
522491
)
523-
logger.error(f"request {request_id} load kv cache failed.")
524492
if self.load_only_first_rank:
525493
self._broadcast(req_broadcast_addr[request_id])
526494
load_end_time = time.perf_counter() * 1000
@@ -568,7 +536,6 @@ def wait_for_save(self) -> None:
568536
assert isinstance(metadata, UCMConnectorMetadata)
569537

570538
request_to_task: dict[str, Task] = {}
571-
request_to_blocks: dict[str, list[str]] = {}
572539
is_save = False
573540
num_saved_block = 0
574541
num_saved_request = 0
@@ -583,36 +550,19 @@ def wait_for_save(self) -> None:
583550
ucm_block_ids, vllm_block_ids = request.dump_block_ids
584551
if self.global_rank != 0:
585552
for i, ucm_block_id in enumerate(ucm_block_ids):
586-
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
587-
rets = self.store.create(ucm_block_ids)
588-
end = 0
589-
for i, ret in enumerate(rets):
590-
if ret != 0:
591-
logger.error(
592-
f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}"
593-
)
594-
break
595-
end += 1
596-
597-
if end == 0:
598-
continue
599-
ucm_block_ids = ucm_block_ids[:end]
600-
vllm_block_ids = vllm_block_ids[:end]
601-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
553+
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
554+
block_ids, shard_indexs, tensors = self._generate_task(
602555
vllm_block_ids, ucm_block_ids
603556
)
604557
request_to_task[request_id] = self.store.dump(
605-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
558+
block_ids, shard_indexs, tensors
606559
)
607-
request_to_blocks[request_id] = ucm_block_ids
608560

609561
for request_id, task in request_to_task.items():
610-
ucm_block_ids = request_to_blocks[request_id]
611-
if self.store.wait(task) == 0:
612-
self.store.commit(ucm_block_ids, True)
613-
else:
614-
logger.error(f"request {request_id} dump kv cache failed.")
615-
self.store.commit(ucm_block_ids, False)
562+
try:
563+
self.store.wait(task)
564+
except RuntimeError as e:
565+
logger.error("request {request_id} dump kv cache failed.:", e)
616566
save_end_time = time.perf_counter() * 1000
617567
save_speed = (
618568
num_saved_block

ucm/store/factory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBase:
6363
UcmConnectorFactory.register_connector(
6464
"UcmNfsStore", "ucm.store.nfsstore.nfsstore_connector", "UcmNfsStore"
6565
)
66-
UcmConnectorFactory.register_connector(
67-
"UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore"
68-
)
6966
UcmConnectorFactory.register_connector(
7067
"UcmMooncakeStore",
7168
"ucm.store.mooncakestore.mooncake_connector",

ucm/store/factory_v1.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# MIT License
3+
#
4+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
5+
#
6+
# Permission is hereby granted, free of charge, to any person obtaining a copy
7+
# of this software and associated documentation files (the "Software"), to deal
8+
# in the Software without restriction, including without limitation the rights
9+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
# copies of the Software, and to permit persons to whom the Software is
11+
# furnished to do so, subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be included in all
14+
# copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
#
24+
25+
import importlib
26+
from typing import Callable
27+
28+
from ucm.logger import init_logger
29+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1
30+
31+
logger = init_logger(__name__)
32+
33+
34+
class UcmConnectorFactoryV1:
35+
_registry: dict[str, Callable[[], type[UcmKVStoreBaseV1]]] = {}
36+
37+
@classmethod
38+
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
39+
"""Register a connector with a lazy-loading module and class name."""
40+
if name in cls._registry:
41+
raise ValueError(f"Connector '{name}' is already registered.")
42+
43+
def loader() -> type[UcmKVStoreBaseV1]:
44+
module = importlib.import_module(module_path)
45+
return getattr(module, class_name)
46+
47+
cls._registry[name] = loader
48+
49+
@classmethod
50+
def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBaseV1:
51+
if connector_name in cls._registry:
52+
connector_cls = cls._registry[connector_name]()
53+
else:
54+
raise ValueError(f"Unsupported connector type: {connector_name}")
55+
assert issubclass(connector_cls, UcmKVStoreBaseV1)
56+
logger.info("Creating connector with name: %s", connector_name)
57+
return connector_cls(config)
58+
59+
UcmConnectorFactoryV1.register_connector(
60+
"UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore"
61+
)

ucm/store/pcstore/cc/domain/trans/trans_task.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ class TransTask {
5858
: id{NextId()}, type{std::move(type)}, startTp{NowTp()}, brief_{std::move(brief)}
5959
{
6060
}
61-
void Append(const std::string& block, const uintptr_t address)
61+
void Append(const std::string& block, const std::vector<uintptr_t>& addresses)
6262
{
63-
grouped_[block].push_back(address);
64-
number_++;
63+
grouped_[block] = addresses;
64+
number_ += addresses.size();
6565
}
6666
auto Str() const noexcept { return fmt::format("{},{},{}", id, brief_, number_); }
6767
size_t GroupNumber() const { return grouped_.size(); }

ucm/store/pcstore/cpy/pcstore.py.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ class PcStorePy : public PcStore {
7171
auto blockId = blockIds.begin();
7272
auto address = addresses.begin();
7373
while ((blockId != blockIds.end()) && (address != addresses.end())) {
74-
task.Append(blockId->cast<std::string>(), address->cast<uintptr_t>());
74+
std::string id = blockId->cast<py::bytes>();
75+
std::vector<uintptr_t> addrs = address->cast<std::vector<uintptr_t>>();
76+
task.Append(id, addrs);
7577
blockId++;
7678
address++;
7779
}

0 commit comments

Comments
 (0)