Skip to content

Commit 4b0d330

Browse files
committed
adapt store v1
1 parent 523bbc4 commit 4b0d330

File tree

8 files changed

+286
-205
lines changed

8 files changed

+286
-205
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 38 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pickle
55
import time
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Callable, List, Optional
7+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
88

99
import torch
1010
from vllm.config import VllmConfig
@@ -21,7 +21,7 @@
2121
from ucm.shared.metrics import ucmmonitor
2222
from ucm.shared.metrics.observability import UCMStatsLogger
2323
from ucm.store.factory import UcmConnectorFactory
24-
from ucm.store.ucmstore import Task, UcmKVStoreBase
24+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1, Task
2525
from ucm.utils import Config
2626

2727
if TYPE_CHECKING:
@@ -35,7 +35,7 @@
3535

3636
@dataclass
3737
class RequestMeta:
38-
ucm_block_ids: list[str] = field(default_factory=list)
38+
ucm_block_ids: list[bytes] = field(default_factory=list)
3939
hbm_hit_block_num: int = 0
4040
# local_computed_block + external_computed_block
4141
total_hit_block_num: int = 0
@@ -47,9 +47,9 @@ class RequestMeta:
4747
@dataclass
4848
class RequestDispatchMeta:
4949
load_block_ids: tuple[
50-
list[str], list[int]
50+
list[bytes], list[int]
5151
] # [0] mean ucm_block_ids, [1] means vllm_block_ids
52-
dump_block_ids: tuple[list[str], list[int]]
52+
dump_block_ids: tuple[list[bytes], list[int]]
5353

5454

5555
@dataclass
@@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id):
6969
if RequestHasher._SEED_HASH is None:
7070
RequestHasher._SEED_HASH = self("UCM_HASH_SEED")
7171

72-
def __call__(self, input_data) -> int:
73-
if isinstance(input_data, str):
74-
input_bytes = input_data.encode("utf-8")
72+
def __call__(self, input_data) -> bytes:
73+
if isinstance(input_data, bytes):
74+
input_bytes = input_data
7575
else:
7676
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
7777

7878
h = hashlib.md5(self.meta_bytes + input_bytes)
79-
return int.from_bytes(h.digest(), byteorder="big")
79+
return h.digest()
8080

8181

8282
class UCMDirectConnector(KVConnectorBase_V1):
@@ -110,9 +110,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
110110

111111
if self.local_rank >= 0:
112112
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
113-
self._layer_offset_cache = {}
114113

115-
self.store: UcmKVStoreBase
114+
self.store: UcmKVStoreBaseV1
116115

117116
if role == KVConnectorRole.SCHEDULER:
118117
self.request_hasher = RequestHasher(vllm_config, 0)
@@ -188,7 +187,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
188187
# invlalid block ids due to load errors
189188
self._invalid_block_ids: set[int] = set()
190189

191-
def generate_hash(self, block_size: int, request: "Request") -> list[str]:
190+
def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
192191
token_ids = request.all_token_ids
193192

194193
ret = []
@@ -205,7 +204,7 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]:
205204
(parent_block_hash_value, block_token_ids_tuple)
206205
)
207206
parent_block_hash_value = hash_value
208-
ret.append(str(hash_value))
207+
ret.append(hash_value)
209208

210209
return ret
211210

@@ -395,70 +394,32 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
395394
return int(chunk)
396395
return None
397396

398-
def _precompute_layer_offsets(self):
399-
if not self.kv_caches:
400-
return
401-
402-
sample_kv_layer = next(iter(self.kv_caches.values()))
403-
elem_size = sample_kv_layer[0].element_size()
404-
block_data_size = (
405-
sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel()
406-
) * elem_size
407-
layer_data_size = block_data_size if self.is_mla else block_data_size * 2
408-
409-
# precompute all layers offset
410-
for layer_name, _ in self.kv_caches.items():
411-
layer_id = self._extract_layer_index(layer_name)
412-
assert layer_id is not None
413-
k_offset = layer_data_size * layer_id
414-
v_offset = k_offset + block_data_size if not self.is_mla else 0
415-
self._layer_offset_cache[layer_name] = (k_offset, v_offset)
416-
417-
def _get_tensor_and_offset(
418-
self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str
419-
) -> tuple[list[torch.Tensor], list[int]]:
397+
def _get_tensors(self, vllm_block_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
420398
"""
421399
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
422400
MLA: one layer shape is (num_blocks, block_size, head_size)
423401
"""
424-
k_tensors, k_offsets = [], []
425-
v_tensors, v_offsets = [], []
426-
k_offset, v_offset = self._layer_offset_cache[layer_name]
427-
428-
for vllm_block_id in vllm_block_ids:
429-
k_tensors.append(
430-
kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id]
431-
)
432-
k_offsets.append(k_offset)
402+
k_tensors, v_tensors = [], []
403+
for _, kv_layer in self.kv_caches.items():
404+
k_tensors.append(kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id])
433405
if not self.is_mla:
434406
v_tensors.append(kv_layer[1][vllm_block_id])
435-
v_offsets.append(v_offset)
436-
return k_tensors + v_tensors, k_offsets + v_offsets
437-
438-
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
439-
if not self._layer_offset_cache:
440-
self._precompute_layer_offsets()
441-
442-
num_layers = len(self.kv_caches)
443-
num_blocks_per_layer = len(vllm_block_ids)
444-
num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
445-
dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
446-
ucm_offsets = [0] * (num_layers * num_tensors_per_layer)
447-
448-
idx = 0
449-
for layer_name, one_layer_kv_cache in self.kv_caches.items():
450-
tensors, offsets = self._get_tensor_and_offset(
451-
vllm_block_ids, one_layer_kv_cache, layer_name
452-
)
453-
dst_tensor_addr[idx : idx + len(tensors)] = tensors
454-
ucm_offsets[idx : idx + len(offsets)] = offsets
455-
idx += len(tensors)
407+
return k_tensors, v_tensors
456408

457-
repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
458-
ucm_total_block_ids = ucm_block_ids * repeat_times
459409

460-
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
461-
return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
410+
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
411+
"""
412+
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
413+
MLA: one layer shape is (num_blocks, block_size, head_size)
414+
"""
415+
block_ids, shard_indexs, tensors = [], [], []
416+
for i, vllm_block_id in enumerate(vllm_block_ids):
417+
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
418+
block_ids.append(ucm_block_ids[i])
419+
tensors.append(k_tensors+v_tensors)
420+
shard_indexs.append(0)
421+
422+
return block_ids, shard_indexs, tensors
462423

463424
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
464425
rec_tensor: torch.Tensor = None
@@ -501,17 +462,17 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
501462
ucm_block_ids, vllm_block_ids = request.load_block_ids
502463
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
503464
for i, ucm_block_id in enumerate(ucm_block_ids):
504-
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
505-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
465+
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
466+
block_ids, shard_indexs, tensors = self._generate_task(
506467
vllm_block_ids, ucm_block_ids
507468
)
508469
if self.global_rank == 0 or not self.load_only_first_rank:
509470
request_to_task[request_id] = self.store.load(
510-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
471+
block_ids, shard_indexs, tensors
511472
)
512473
else:
513474
request_to_task[request_id] = None
514-
req_broadcast_addr[request_id] = dst_tensor_addr
475+
req_broadcast_addr[request_id] = tensors
515476

516477
for request_id, task in request_to_task.items():
517478
# TODO error handling
@@ -568,7 +529,6 @@ def wait_for_save(self) -> None:
568529
assert isinstance(metadata, UCMConnectorMetadata)
569530

570531
request_to_task: dict[str, Task] = {}
571-
request_to_blocks: dict[str, list[str]] = {}
572532
is_save = False
573533
num_saved_block = 0
574534
num_saved_request = 0
@@ -583,36 +543,16 @@ def wait_for_save(self) -> None:
583543
ucm_block_ids, vllm_block_ids = request.dump_block_ids
584544
if self.global_rank != 0:
585545
for i, ucm_block_id in enumerate(ucm_block_ids):
586-
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
587-
rets = self.store.create(ucm_block_ids)
588-
end = 0
589-
for i, ret in enumerate(rets):
590-
if ret != 0:
591-
logger.error(
592-
f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}"
593-
)
594-
break
595-
end += 1
596-
597-
if end == 0:
598-
continue
599-
ucm_block_ids = ucm_block_ids[:end]
600-
vllm_block_ids = vllm_block_ids[:end]
601-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
546+
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
547+
block_ids, shard_indexs, tensors = self._generate_task(
602548
vllm_block_ids, ucm_block_ids
603549
)
604550
request_to_task[request_id] = self.store.dump(
605-
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
551+
block_ids, shard_indexs, tensors
606552
)
607-
request_to_blocks[request_id] = ucm_block_ids
608553

609554
for request_id, task in request_to_task.items():
610-
ucm_block_ids = request_to_blocks[request_id]
611-
if self.store.wait(task) == 0:
612-
self.store.commit(ucm_block_ids, True)
613-
else:
614-
logger.error(f"request {request_id} dump kv cache failed.")
615-
self.store.commit(ucm_block_ids, False)
555+
self.store.wait(task)
616556
save_end_time = time.perf_counter() * 1000
617557
save_speed = (
618558
num_saved_block

ucm/store/factory.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,37 @@
2727

2828
from ucm.logger import init_logger
2929
from ucm.store.ucmstore import UcmKVStoreBase
30+
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1
3031

3132
logger = init_logger(__name__)
3233

3334

3435
class UcmConnectorFactory:
35-
_registry: dict[str, Callable[[], type[UcmKVStoreBase]]] = {}
36+
_registry: dict[str, Callable[[], type[UcmKVStoreBaseV1]]] = {}
3637

3738
@classmethod
3839
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
3940
"""Register a connector with a lazy-loading module and class name."""
4041
if name in cls._registry:
4142
raise ValueError(f"Connector '{name}' is already registered.")
4243

43-
def loader() -> type[UcmKVStoreBase]:
44+
def loader() -> type[UcmKVStoreBaseV1]:
4445
module = importlib.import_module(module_path)
4546
return getattr(module, class_name)
4647

4748
cls._registry[name] = loader
4849

4950
@classmethod
50-
def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBase:
51+
def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBaseV1:
5152
if connector_name in cls._registry:
5253
connector_cls = cls._registry[connector_name]()
5354
else:
5455
raise ValueError(f"Unsupported connector type: {connector_name}")
55-
assert issubclass(connector_cls, UcmKVStoreBase)
56+
assert issubclass(connector_cls, UcmKVStoreBaseV1)
5657
logger.info("Creating connector with name: %s", connector_name)
5758
return connector_cls(config)
5859

5960

60-
UcmConnectorFactory.register_connector(
61-
"UcmDramStore", "ucm.store.dramstore.dramstore_connector", "UcmDramStore"
62-
)
63-
UcmConnectorFactory.register_connector(
64-
"UcmNfsStore", "ucm.store.nfsstore.nfsstore_connector", "UcmNfsStore"
65-
)
6661
UcmConnectorFactory.register_connector(
6762
"UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore"
68-
)
69-
UcmConnectorFactory.register_connector(
70-
"UcmMooncakeStore",
71-
"ucm.store.mooncakestore.mooncake_connector",
72-
"UcmMooncakeStore",
73-
)
63+
)

ucm/store/pcstore/cc/domain/trans/trans_task.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ class TransTask {
5858
: id{NextId()}, type{std::move(type)}, startTp{NowTp()}, brief_{std::move(brief)}
5959
{
6060
}
61-
void Append(const std::string& block, const uintptr_t address)
61+
void Append(const std::string& block, const std::vector<uintptr_t>& addresses)
6262
{
63-
grouped_[block].push_back(address);
64-
number_++;
63+
grouped_[block] = addresses;
64+
number_ += addresses.size();
6565
}
6666
auto Str() const noexcept { return fmt::format("{},{},{}", id, brief_, number_); }
6767
size_t GroupNumber() const { return grouped_.size(); }

ucm/store/pcstore/cpy/pcstore.py.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ class PcStorePy : public PcStore {
7171
auto blockId = blockIds.begin();
7272
auto address = addresses.begin();
7373
while ((blockId != blockIds.end()) && (address != addresses.end())) {
74-
task.Append(blockId->cast<std::string>(), address->cast<uintptr_t>());
74+
std::string id = blockId->cast<py::bytes>();
75+
std::vector<uintptr_t> addrs = address->cast<std::vector<uintptr_t>>();
76+
task.Append(id, addrs);
7577
blockId++;
7678
address++;
7779
}

0 commit comments

Comments
 (0)