diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 66216a255..ee8558e84 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -4,7 +4,7 @@ import pickle import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import torch from vllm.config import VllmConfig @@ -20,8 +20,8 @@ from ucm.logger import init_logger from ucm.shared.metrics import ucmmonitor from ucm.shared.metrics.observability import UCMStatsLogger -from ucm.store.factory import UcmConnectorFactory -from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.store.factory_v1 import UcmConnectorFactoryV1 +from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1 from ucm.utils import Config if TYPE_CHECKING: @@ -35,7 +35,7 @@ @dataclass class RequestMeta: - ucm_block_ids: list[str] = field(default_factory=list) + ucm_block_ids: list[bytes] = field(default_factory=list) hbm_hit_block_num: int = 0 # local_computed_block + external_computed_block total_hit_block_num: int = 0 @@ -47,9 +47,9 @@ class RequestMeta: @dataclass class RequestDispatchMeta: load_block_ids: tuple[ - list[str], list[int] + list[bytes], list[int] ] # [0] mean ucm_block_ids, [1] means vllm_block_ids - dump_block_ids: tuple[list[str], list[int]] + dump_block_ids: tuple[list[bytes], list[int]] @dataclass @@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id): if RequestHasher._SEED_HASH is None: RequestHasher._SEED_HASH = self("UCM_HASH_SEED") - def __call__(self, input_data) -> int: - if isinstance(input_data, str): - input_bytes = input_data.encode("utf-8") + def __call__(self, input_data) -> bytes: + if isinstance(input_data, bytes): + input_bytes = input_data else: input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL) h = hashlib.md5(self.meta_bytes + input_bytes) - return int.from_bytes(h.digest(), byteorder="big") + return h.digest() class UCMDirectConnector(KVConnectorBase_V1): @@ -95,6 +95,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla self.is_dsa = False + self.num_layers = self._vllm_config.model_config.get_num_layers( + self._vllm_config.parallel_config + ) + self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size self.kv_cache_dtype: torch.dtype = None if current_platform.is_cuda_alike(): @@ -110,21 +114,19 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if self.local_rank >= 0: self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") - self._layer_offset_cache = {} - - self.store: UcmKVStoreBase - if role == KVConnectorRole.SCHEDULER: - self.request_hasher = RequestHasher(vllm_config, 0) - else: - self.request_hasher = RequestHasher(vllm_config, self.global_rank) + self.k_store: UcmKVStoreBaseV1 + self.v_store: Optional[UcmKVStoreBaseV1] = None # save block info, avoid hash request twice, and track them until request finished self.requests_meta: dict[str, RequestMeta] = {} ucm_config = Config(vllm_config.kv_transfer_config) + self.engine_id = vllm_config.kv_transfer_config.engine_id self.launch_config = ucm_config.get_config() - + logger.info(f"self.launch_config: {self.launch_config}") + self.connector_configs = self.launch_config.get("ucm_connectors", []) + assert len(self.connector_configs) > 0, "no storage connector name in config." self.load_only_first_rank: bool = ( self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla ) @@ -134,42 +136,28 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.broadcast_fn = self.group_coordinator.broadcast self.broadcast_stream = torch.cuda.Stream() - logger.info(f"self.launch_config: {self.launch_config}") - connector_configs = self.launch_config.get("ucm_connectors", []) - assert len(connector_configs) > 0, "no storage connector name in config." - - name = connector_configs[0].get("ucm_connector_name") - config = connector_configs[0].get("ucm_connector_config") or {} - config["device"] = self.local_rank - config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" - element_size = vllm_config.model_config.dtype.itemsize - single_head_dim = vllm_config.model_config.get_head_size() - num_head_per_tp = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - total_tp_size = vllm_config.parallel_config.tensor_parallel_size - num_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config - ) - block_size_per_layer = self.block_size * element_size * single_head_dim - config["kv_block_size"] = ( - block_size_per_layer - * num_layers - * (1 if self.is_mla else num_head_per_tp * 2) - ) - config["io_size"] = block_size_per_layer * ( - 1 if self.is_mla else num_head_per_tp - ) - self.store = UcmConnectorFactory.create_connector(name, config) - self.block_data_size = config["kv_block_size"] - - logger.info("init UCConnectorImpl, connector: %s", name) + name = self.connector_configs[0].get("ucm_connector_name") + config = self.connector_configs[0].get("ucm_connector_config") or {} + storage_backends = [ + path for path in config["storage_backends"].split(":") if path + ] + self.k_storage_backends = [os.path.join(p, "k") for p in storage_backends] + self.v_storage_backends = [os.path.join(p, "v") for p in storage_backends] + os.makedirs(self.k_storage_backends[0], exist_ok=True) + os.makedirs(self.v_storage_backends[0], exist_ok=True) logger.info( - "single file size = %d MB, io_size = %d KB,", - config["kv_block_size"] / 1024 / 1024, - config["io_size"] / 1024, + f"Created subdirectories: {self.k_storage_backends}, {self.v_storage_backends}" ) + if role == KVConnectorRole.SCHEDULER: + self.request_hasher = RequestHasher(vllm_config, 0) + # init scheduler-size connector + config["storage_backends"] = ":".join(self.k_storage_backends) + config["role"] = "scheduler" + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + else: + self.request_hasher = RequestHasher(vllm_config, self.global_rank) + self.metrics_config = self.launch_config.get("metrics_config_path", "") if self.metrics_config: self.stats_logger = UCMStatsLogger( @@ -188,7 +176,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() - def generate_hash(self, block_size: int, request: "Request") -> list[str]: + def generate_hash(self, block_size: int, request: "Request") -> list[bytes]: token_ids = request.all_token_ids ret = [] @@ -205,10 +193,81 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]: (parent_block_hash_value, block_token_ids_tuple) ) parent_block_hash_value = hash_value - ret.append(str(hash_value)) + ret.append(hash_value) return ret + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.kv_caches = kv_caches + sample_kv_layer = next(iter(self.kv_caches.values())) + if self.kv_cache_dtype is None: + self.kv_cache_dtype = sample_kv_layer[0].dtype + if isinstance(sample_kv_layer, torch.Tensor): + logger.info(f"kv cache shape {sample_kv_layer.shape}") + elif isinstance(sample_kv_layer, Tuple): + # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple + # [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)] + # Currently, we treat it as GQA, and use is_dsa to mark it + for i, tensor in enumerate(sample_kv_layer): + logger.info(f"kv cache shape {i}: {tensor.shape}") + if self.is_mla: + self.is_mla = False + self.is_dsa = True + logger.info(f"use mla: {self.is_mla}, use dsa: {self.is_dsa}") + + # init work-side connector + # When handling the GQA case, we will separately dump the k_cache and v_cache. + name = self.connector_configs[0].get("ucm_connector_name") + config = self.connector_configs[0].get("ucm_connector_config") or {} + config["device"] = self.local_rank + config["role"] = "worker" + config["local_rank_size"] = self.tp_size if self.is_mla or self.is_dsa else 1 + if len(sample_kv_layer) == 2: + k_io_size = ( + sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size() + ) + config["io_size"] = k_io_size + config["kv_block_size"] = k_io_size * self.num_layers + config["storage_backends"] = ":".join(self.k_storage_backends) + config["unique_id"] = self.engine_id + "k" + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, k_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + + v_io_size = ( + sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size() + ) + config["io_size"] = v_io_size + config["kv_block_size"] = v_io_size * self.num_layers + config["storage_backends"] = ":".join(self.v_storage_backends) + config["unique_id"] = self.engine_id + "v" + self.v_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, v_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + self.block_data_size = (k_io_size + v_io_size) * self.num_layers + else: + k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size() + config["io_size"] = k_io_size + config["kv_block_size"] = k_io_size * self.num_layers + config["storage_backends"] = ":".join(self.k_storage_backends) + config["unique_id"] = self.engine_id + "k" + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, k_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + self.block_data_size = k_io_size * self.num_layers + def get_num_new_matched_tokens( self, request: "Request", @@ -223,7 +282,7 @@ def get_num_new_matched_tokens( if not external_block_ids: return 0, False - lookup_results = self.store.lookup(external_block_ids) + lookup_results = self.k_store.lookup(external_block_ids) external_hit_blocks = 0 for i, hit in enumerate(lookup_results): if not hit: @@ -361,30 +420,6 @@ def build_connector_meta( return UCMConnectorMetadata(requests_dispatch_meta) - def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): - if len(self.kv_caches) > 0: - return - for layer_name in forward_context.no_compile_layers: - attn_layer = forward_context.no_compile_layers[layer_name] - if not hasattr(attn_layer, "kv_cache"): - continue - - if layer_name not in self.kv_caches: - self.kv_caches[layer_name] = attn_layer.kv_cache[ - forward_context.virtual_engine - ] - # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to - # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim). - # Currently, we treat it as GQA, and use is_dsa to mark it, - # which works but leads to space inefficiency. - # TODO: Optimize this to avoid unnecessary space usage. - sample_kv_layer = next(iter(self.kv_caches.values())) - if self.is_mla and len(sample_kv_layer) == 2: - self.is_mla = False - self.is_dsa = True - if self.kv_cache_dtype is None: - self.kv_cache_dtype = sample_kv_layer[0].dtype - @staticmethod def _extract_layer_index(layer_name: str) -> Optional[int]: """ @@ -395,70 +430,36 @@ def _extract_layer_index(layer_name: str) -> Optional[int]: return int(chunk) return None - def _precompute_layer_offsets(self): - if not self.kv_caches: - return - - sample_kv_layer = next(iter(self.kv_caches.values())) - elem_size = sample_kv_layer[0].element_size() - block_data_size = ( - sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel() - ) * elem_size - layer_data_size = block_data_size if self.is_mla else block_data_size * 2 - - # precompute all layers offset - for layer_name, _ in self.kv_caches.items(): - layer_id = self._extract_layer_index(layer_name) - assert layer_id is not None - k_offset = layer_data_size * layer_id - v_offset = k_offset + block_data_size if not self.is_mla else 0 - self._layer_offset_cache[layer_name] = (k_offset, v_offset) - - def _get_tensor_and_offset( - self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str - ) -> tuple[list[torch.Tensor], list[int]]: + def _get_tensors( + self, vllm_block_id: int + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size) MLA: one layer shape is (num_blocks, block_size, head_size) """ - k_tensors, k_offsets = [], [] - v_tensors, v_offsets = [], [] - k_offset, v_offset = self._layer_offset_cache[layer_name] - - for vllm_block_id in vllm_block_ids: + k_tensors, v_tensors = [], [] + for _, kv_layer in self.kv_caches.items(): k_tensors.append( kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id] ) - k_offsets.append(k_offset) if not self.is_mla: v_tensors.append(kv_layer[1][vllm_block_id]) - v_offsets.append(v_offset) - return k_tensors + v_tensors, k_offsets + v_offsets - - def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): - if not self._layer_offset_cache: - self._precompute_layer_offsets() - - num_layers = len(self.kv_caches) - num_blocks_per_layer = len(vllm_block_ids) - num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2) - dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer) - ucm_offsets = [0] * (num_layers * num_tensors_per_layer) - - idx = 0 - for layer_name, one_layer_kv_cache in self.kv_caches.items(): - tensors, offsets = self._get_tensor_and_offset( - vllm_block_ids, one_layer_kv_cache, layer_name - ) - dst_tensor_addr[idx : idx + len(tensors)] = tensors - ucm_offsets[idx : idx + len(offsets)] = offsets - idx += len(tensors) - - repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2) - ucm_total_block_ids = ucm_block_ids * repeat_times - - assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) - return ucm_total_block_ids, ucm_offsets, dst_tensor_addr + return k_tensors, v_tensors + + def _generate_task( + self, vllm_block_ids: List[int], ucm_block_ids: List[bytes] + ) -> Tuple[ + List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]] + ]: + block_ids, shard_indexs, total_k_tensors, total_v_tensors = [], [], [], [] + for i, vllm_block_id in enumerate(vllm_block_ids): + k_tensors, v_tensors = self._get_tensors(vllm_block_id) + block_ids.append(ucm_block_ids[i]) + total_k_tensors.append(k_tensors) + total_v_tensors.append(v_tensors) + shard_indexs.append(0) + + return block_ids, shard_indexs, total_k_tensors, total_v_tensors def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): rec_tensor: torch.Tensor = None @@ -483,9 +484,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - self._init_kv_caches_from_forward_context(forward_context) - - request_to_task: dict[str, Optional[Task]] = {} + request_to_task: dict[str, Optional[List[Task]]] = {} req_broadcast_addr = {} is_load = False num_loaded_block = 0 @@ -501,26 +500,34 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: ucm_block_ids, vllm_block_ids = request.load_block_ids if self.global_rank != 0 and not self.is_mla and not self.is_dsa: for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task( vllm_block_ids, ucm_block_ids ) if self.global_rank == 0 or not self.load_only_first_rank: - request_to_task[request_id] = self.store.load( - ucm_total_block_ids, ucm_offsets, dst_tensor_addr - ) + k_task = self.k_store.load(block_ids, shard_indexs, k_tensors) + request_to_task[request_id] = [k_task] + if v_tensors and self.v_store: + v_task = self.v_store.load(block_ids, shard_indexs, v_tensors) + request_to_task[request_id].append(v_task) else: request_to_task[request_id] = None - req_broadcast_addr[request_id] = dst_tensor_addr + req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [ + t for row in v_tensors for t in row + ] - for request_id, task in request_to_task.items(): + for request_id, tasks in request_to_task.items(): # TODO error handling if self.global_rank == 0 or not self.load_only_first_rank: - if self.store.wait(task) != 0: + try: + self.k_store.wait(tasks[0]) + if len(tasks) > 1 and self.v_store: + self.v_store.wait(tasks[1]) + except RuntimeError as e: + logger.error("request {request_id} load kv cache failed.:", e) self._invalid_block_ids.update( metadata.request_meta[request_id].load_block_ids[1] ) - logger.error(f"request {request_id} load kv cache failed.") if self.load_only_first_rank: self._broadcast(req_broadcast_addr[request_id]) load_end_time = time.perf_counter() * 1000 @@ -567,8 +574,7 @@ def wait_for_save(self) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - request_to_task: dict[str, Task] = {} - request_to_blocks: dict[str, list[str]] = {} + request_to_task: dict[str, List[Task]] = {} is_save = False num_saved_block = 0 num_saved_request = 0 @@ -583,36 +589,23 @@ def wait_for_save(self) -> None: ucm_block_ids, vllm_block_ids = request.dump_block_ids if self.global_rank != 0: for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) - rets = self.store.create(ucm_block_ids) - end = 0 - for i, ret in enumerate(rets): - if ret != 0: - logger.error( - f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}" - ) - break - end += 1 - - if end == 0: - continue - ucm_block_ids = ucm_block_ids[:end] - vllm_block_ids = vllm_block_ids[:end] - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task( vllm_block_ids, ucm_block_ids ) - request_to_task[request_id] = self.store.dump( - ucm_total_block_ids, ucm_offsets, dst_tensor_addr - ) - request_to_blocks[request_id] = ucm_block_ids - - for request_id, task in request_to_task.items(): - ucm_block_ids = request_to_blocks[request_id] - if self.store.wait(task) == 0: - self.store.commit(ucm_block_ids, True) - else: - logger.error(f"request {request_id} dump kv cache failed.") - self.store.commit(ucm_block_ids, False) + k_task = self.k_store.dump(block_ids, shard_indexs, k_tensors) + request_to_task[request_id] = [k_task] + if v_tensors and self.v_store: + v_task = self.v_store.dump(block_ids, shard_indexs, v_tensors) + request_to_task[request_id].append(v_task) + + for request_id, tasks in request_to_task.items(): + try: + self.k_store.wait(tasks[0]) + if len(tasks) > 1 and self.v_store: + self.v_store.wait(tasks[1]) + except RuntimeError as e: + logger.error("request {request_id} dump kv cache failed.:", e) save_end_time = time.perf_counter() * 1000 save_speed = ( num_saved_block @@ -793,6 +786,16 @@ def update_state_after_alloc( """ self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + self.connector.register_kv_caches(kv_caches) + def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: diff --git a/ucm/store/pcstore/pcstore_connector_v1.py b/ucm/store/pcstore/pcstore_connector_v1.py index d6fcac1bd..3cf964fea 100644 --- a/ucm/store/pcstore/pcstore_connector_v1.py +++ b/ucm/store/pcstore/pcstore_connector_v1.py @@ -44,7 +44,7 @@ def __init__(self, config: Dict): storage_backends = [ path for path in config["storage_backends"].split(":") if path ] - block_size = int(config["kv_block_size"]) + block_size = config.get("kv_block_size", 33554432) transfer_enable = True if config["role"] == "worker" else False param = ucmpcstore.PcStore.Config(storage_backends, block_size, transfer_enable) if transfer_enable: