Skip to content

Commit 9373b51

Browse files
committed
split store to k_store and v_store
1 parent e451e5b commit 9373b51

File tree

1 file changed

+120
-54
lines changed

1 file changed

+120
-54
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 120 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9595
self.block_size = self._vllm_config.cache_config.block_size
9696
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
9797
self.is_dsa = False
98+
self.num_layers = self._vllm_config.model_config.get_num_layers(
99+
self._vllm_config.parallel_config
100+
)
98101
self.kv_cache_dtype: torch.dtype = None
99102

100103
if current_platform.is_cuda_alike():
@@ -111,7 +114,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
111114
if self.local_rank >= 0:
112115
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
113116

114-
self.store: UcmKVStoreBaseV1
117+
self.k_store: UcmKVStoreBaseV1
118+
self.v_store: Optional[UcmKVStoreBaseV1] = None
115119

116120
if role == KVConnectorRole.SCHEDULER:
117121
self.request_hasher = RequestHasher(vllm_config, 0)
@@ -134,40 +138,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
134138
self.broadcast_stream = torch.cuda.Stream()
135139

136140
logger.info(f"self.launch_config: {self.launch_config}")
137-
connector_configs = self.launch_config.get("ucm_connectors", [])
138-
assert len(connector_configs) > 0, "no storage connector name in config."
139-
140-
name = connector_configs[0].get("ucm_connector_name")
141-
config = connector_configs[0].get("ucm_connector_config") or {}
142-
config["device"] = self.local_rank
143-
config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
144-
element_size = vllm_config.model_config.dtype.itemsize
145-
single_head_dim = vllm_config.model_config.get_head_size()
146-
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
147-
vllm_config.parallel_config
148-
)
149-
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
150-
num_layers = vllm_config.model_config.get_num_layers(
151-
vllm_config.parallel_config
152-
)
153-
block_size_per_layer = self.block_size * element_size * single_head_dim
154-
config["kv_block_size"] = (
155-
block_size_per_layer
156-
* num_layers
157-
* (1 if self.is_mla else num_head_per_tp * 2)
158-
)
159-
config["io_size"] = block_size_per_layer * (
160-
1 if self.is_mla else num_head_per_tp
161-
)
162-
self.store = UcmConnectorFactoryV1.create_connector(name, config)
163-
self.block_data_size = config["kv_block_size"]
164-
165-
logger.info("init UCConnectorImpl, connector: %s", name)
166-
logger.info(
167-
"single file size = %d MB, io_size = %d KB,",
168-
config["kv_block_size"] / 1024 / 1024,
169-
config["io_size"] / 1024,
170-
)
171141

172142
self.metrics_config = self.launch_config.get("metrics_config_path", "")
173143
if self.metrics_config:
@@ -208,6 +178,78 @@ def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
208178

209179
return ret
210180

181+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
182+
self.kv_caches = kv_caches
183+
sample_kv_layer = next(iter(self.kv_caches.values()))
184+
if isinstance(sample_kv_layer, torch.Tensor):
185+
logger.info(f"kv cache shape {sample_kv_layer.shape}")
186+
if self.kv_cache_dtype is None:
187+
self.kv_cache_dtype = sample_kv_layer.dtype
188+
elif isinstance(sample_kv_layer, Tuple):
189+
# Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple
190+
# [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)]
191+
# Currently, we treat it as GQA, and use is_dsa to mark it
192+
for i, tensor in enumerate(sample_kv_layer):
193+
logger.info(f"kv cache shape {i}: {tensor.shape}")
194+
if self.kv_cache_dtype is None:
195+
self.kv_cache_dtype = sample_kv_layer[0].dtype
196+
if self.is_mla:
197+
self.is_mla = False
198+
self.is_dsa = True
199+
200+
# When handling the GQA case, we will separately dump the k_cache and v_cache.
201+
connector_configs = self.launch_config.get("ucm_connectors", [])
202+
assert len(connector_configs) > 0, "no storage connector name in config."
203+
204+
name = connector_configs[0].get("ucm_connector_name")
205+
config = connector_configs[0].get("ucm_connector_config") or {}
206+
config["device"] = self.local_rank
207+
config["role"] = "scheduler" if self._role == KVConnectorRole.SCHEDULER else "worker"
208+
if len(sample_kv_layer) == 2:
209+
storage_backends = config["storage_backends"]
210+
k_dir = os.path.join(storage_backends, "k")
211+
v_dir = os.path.join(storage_backends, "v")
212+
os.makedirs(k_dir, exist_ok=True)
213+
os.makedirs(v_dir, exist_ok=True)
214+
logger.info(f"Created subdirectories: {k_dir}, {v_dir}")
215+
216+
k_io_size = sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size()
217+
config["io_size"] = k_io_size
218+
config["kv_block_size"] = k_io_size * self.num_layers
219+
config["storage_backends"] = k_dir
220+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
221+
logger.info("init UCConnectorImpl, k_connector: %s", name)
222+
logger.info(
223+
"single file size = %d MB, io_size = %d KB,",
224+
config["kv_block_size"] / 1024 / 1024,
225+
config["io_size"] / 1024,
226+
)
227+
228+
v_io_size = sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size()
229+
config["io_size"] = v_io_size
230+
config["kv_block_size"] = v_io_size * self.num_layers
231+
config["storage_backends"] = v_dir
232+
self.v_store = UcmConnectorFactoryV1.create_connector(name, config)
233+
logger.info("init UCConnectorImpl, v_connector: %s", name)
234+
logger.info(
235+
"single file size = %d MB, io_size = %d KB,",
236+
config["kv_block_size"] / 1024 / 1024,
237+
config["io_size"] / 1024,
238+
)
239+
self.block_data_size = (k_io_size + v_io_size) * self.num_layers
240+
else:
241+
k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size()
242+
config["io_size"] = k_io_size
243+
config["kv_block_size"] = k_io_size * self.num_layers
244+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
245+
logger.info("init UCConnectorImpl, k_connector: %s", name)
246+
logger.info(
247+
"single file size = %d MB, io_size = %d KB,",
248+
config["kv_block_size"] / 1024 / 1024,
249+
config["io_size"] / 1024,
250+
)
251+
self.block_data_size = k_io_size * self.num_layers
252+
211253
def get_num_new_matched_tokens(
212254
self,
213255
request: "Request",
@@ -222,7 +264,7 @@ def get_num_new_matched_tokens(
222264
if not external_block_ids:
223265
return 0, False
224266

225-
lookup_results = self.store.lookup(external_block_ids)
267+
lookup_results = self.k_store.lookup(external_block_ids)
226268
external_hit_blocks = 0
227269
for i, hit in enumerate(lookup_results):
228270
if not hit:
@@ -412,15 +454,16 @@ def _get_tensors(
412454

413455
def _generate_task(
414456
self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]
415-
) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
416-
block_ids, shard_indexs, tensors = [], [], []
457+
) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
458+
block_ids, shard_indexs, total_k_tensors, total_v_tensors = [], [], [], []
417459
for i, vllm_block_id in enumerate(vllm_block_ids):
418460
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
419461
block_ids.append(ucm_block_ids[i])
420-
tensors.append(k_tensors + v_tensors)
462+
total_k_tensors.append(k_tensors)
463+
total_v_tensors.append(v_tensors)
421464
shard_indexs.append(0)
422465

423-
return block_ids, shard_indexs, tensors
466+
return block_ids, shard_indexs, total_k_tensors, total_v_tensors
424467

425468
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
426469
rec_tensor: torch.Tensor = None
@@ -447,7 +490,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
447490

448491
self._init_kv_caches_from_forward_context(forward_context)
449492

450-
request_to_task: dict[str, Optional[Task]] = {}
493+
request_to_task: dict[str, Optional[List[Task]]] = {}
451494
req_broadcast_addr = {}
452495
is_load = False
453496
num_loaded_block = 0
@@ -464,22 +507,26 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
464507
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
465508
for i, ucm_block_id in enumerate(ucm_block_ids):
466509
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
467-
block_ids, shard_indexs, tensors = self._generate_task(
510+
block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task(
468511
vllm_block_ids, ucm_block_ids
469512
)
470513
if self.global_rank == 0 or not self.load_only_first_rank:
471-
request_to_task[request_id] = self.store.load(
472-
block_ids, shard_indexs, tensors
473-
)
514+
k_task = self.k_store.load(block_ids, shard_indexs, k_tensors)
515+
request_to_task[request_id] = [k_task]
516+
if v_tensors and self.v_store:
517+
v_task = self.v_store.load(block_ids, shard_indexs, v_tensors)
518+
request_to_task[request_id].append(v_task)
474519
else:
475520
request_to_task[request_id] = None
476-
req_broadcast_addr[request_id] = [t for row in tensors for t in row]
521+
req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [t for row in v_tensors for t in row]
477522

478-
for request_id, task in request_to_task.items():
523+
for request_id, tasks in request_to_task.items():
479524
# TODO error handling
480525
if self.global_rank == 0 or not self.load_only_first_rank:
481526
try:
482-
self.store.wait(task)
527+
self.k_store.wait(tasks[0])
528+
if len(tasks) > 1 and self.v_store:
529+
self.v_store.wait(tasks[1])
483530
except RuntimeError as e:
484531
logger.error("request {request_id} load kv cache failed.:", e)
485532
self._invalid_block_ids.update(
@@ -531,7 +578,7 @@ def wait_for_save(self) -> None:
531578
metadata = self._get_connector_metadata()
532579
assert isinstance(metadata, UCMConnectorMetadata)
533580

534-
request_to_task: dict[str, Task] = {}
581+
request_to_task: dict[str, List[Task]] = {}
535582
is_save = False
536583
num_saved_block = 0
537584
num_saved_request = 0
@@ -547,16 +594,25 @@ def wait_for_save(self) -> None:
547594
if self.global_rank != 0:
548595
for i, ucm_block_id in enumerate(ucm_block_ids):
549596
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
550-
block_ids, shard_indexs, tensors = self._generate_task(
597+
block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task(
551598
vllm_block_ids, ucm_block_ids
552599
)
553-
request_to_task[request_id] = self.store.dump(
554-
block_ids, shard_indexs, tensors
600+
k_task = self.k_store.dump(
601+
block_ids, shard_indexs, k_tensors
555602
)
603+
request_to_task[request_id] = [k_task]
604+
if v_tensors and self.v_store:
605+
v_task = self.v_store.dump(
606+
block_ids, shard_indexs, v_tensors
607+
)
608+
request_to_task[request_id].append(v_task)
609+
556610

557-
for request_id, task in request_to_task.items():
611+
for request_id, tasks in request_to_task.items():
558612
try:
559-
self.store.wait(task)
613+
self.k_store.wait(tasks[0])
614+
if len(tasks) > 1 and self.v_store:
615+
self.v_store.wait(tasks[1])
560616
except RuntimeError as e:
561617
logger.error("request {request_id} dump kv cache failed.:", e)
562618
save_end_time = time.perf_counter() * 1000
@@ -739,6 +795,16 @@ def update_state_after_alloc(
739795
"""
740796
self.connector.update_state_after_alloc(request, blocks, num_external_tokens)
741797

798+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
799+
"""
800+
Initialize with the KV caches. Useful for pre-registering the
801+
KV Caches in the KVConnector (e.g. for NIXL).
802+
803+
Args: kv_caches:
804+
dictionary of layer names, kv cache
805+
"""
806+
self.connector.register_kv_caches(kv_caches)
807+
742808
def build_connector_meta(
743809
self, scheduler_output: SchedulerOutput
744810
) -> KVConnectorMetadata:

0 commit comments

Comments
 (0)