Skip to content

Commit 49b0443

Browse files
committed
split store to k_store and v_store
1 parent e451e5b commit 49b0443

File tree

1 file changed

+126
-55
lines changed

1 file changed

+126
-55
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 126 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9595
self.block_size = self._vllm_config.cache_config.block_size
9696
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
9797
self.is_dsa = False
98+
self.num_layers = self._vllm_config.model_config.get_num_layers(
99+
self._vllm_config.parallel_config
100+
)
98101
self.kv_cache_dtype: torch.dtype = None
99102

100103
if current_platform.is_cuda_alike():
@@ -111,7 +114,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
111114
if self.local_rank >= 0:
112115
self.device = torch_dev.device(f"{dev_name}:{self.local_rank}")
113116

114-
self.store: UcmKVStoreBaseV1
117+
self.k_store: UcmKVStoreBaseV1
118+
self.v_store: Optional[UcmKVStoreBaseV1] = None
115119

116120
if role == KVConnectorRole.SCHEDULER:
117121
self.request_hasher = RequestHasher(vllm_config, 0)
@@ -134,40 +138,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
134138
self.broadcast_stream = torch.cuda.Stream()
135139

136140
logger.info(f"self.launch_config: {self.launch_config}")
137-
connector_configs = self.launch_config.get("ucm_connectors", [])
138-
assert len(connector_configs) > 0, "no storage connector name in config."
139-
140-
name = connector_configs[0].get("ucm_connector_name")
141-
config = connector_configs[0].get("ucm_connector_config") or {}
142-
config["device"] = self.local_rank
143-
config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
144-
element_size = vllm_config.model_config.dtype.itemsize
145-
single_head_dim = vllm_config.model_config.get_head_size()
146-
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
147-
vllm_config.parallel_config
148-
)
149-
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
150-
num_layers = vllm_config.model_config.get_num_layers(
151-
vllm_config.parallel_config
152-
)
153-
block_size_per_layer = self.block_size * element_size * single_head_dim
154-
config["kv_block_size"] = (
155-
block_size_per_layer
156-
* num_layers
157-
* (1 if self.is_mla else num_head_per_tp * 2)
158-
)
159-
config["io_size"] = block_size_per_layer * (
160-
1 if self.is_mla else num_head_per_tp
161-
)
162-
self.store = UcmConnectorFactoryV1.create_connector(name, config)
163-
self.block_data_size = config["kv_block_size"]
164-
165-
logger.info("init UCConnectorImpl, connector: %s", name)
166-
logger.info(
167-
"single file size = %d MB, io_size = %d KB,",
168-
config["kv_block_size"] / 1024 / 1024,
169-
config["io_size"] / 1024,
170-
)
171141

172142
self.metrics_config = self.launch_config.get("metrics_config_path", "")
173143
if self.metrics_config:
@@ -208,6 +178,84 @@ def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
208178

209179
return ret
210180

181+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
182+
self.kv_caches = kv_caches
183+
sample_kv_layer = next(iter(self.kv_caches.values()))
184+
if isinstance(sample_kv_layer, torch.Tensor):
185+
logger.info(f"kv cache shape {sample_kv_layer.shape}")
186+
if self.kv_cache_dtype is None:
187+
self.kv_cache_dtype = sample_kv_layer.dtype
188+
elif isinstance(sample_kv_layer, Tuple):
189+
# Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple
190+
# [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)]
191+
# Currently, we treat it as GQA, and use is_dsa to mark it
192+
for i, tensor in enumerate(sample_kv_layer):
193+
logger.info(f"kv cache shape {i}: {tensor.shape}")
194+
if self.kv_cache_dtype is None:
195+
self.kv_cache_dtype = sample_kv_layer[0].dtype
196+
if self.is_mla:
197+
self.is_mla = False
198+
self.is_dsa = True
199+
200+
# When handling the GQA case, we will separately dump the k_cache and v_cache.
201+
connector_configs = self.launch_config.get("ucm_connectors", [])
202+
assert len(connector_configs) > 0, "no storage connector name in config."
203+
204+
name = connector_configs[0].get("ucm_connector_name")
205+
config = connector_configs[0].get("ucm_connector_config") or {}
206+
config["device"] = self.local_rank
207+
config["role"] = (
208+
"scheduler" if self._role == KVConnectorRole.SCHEDULER else "worker"
209+
)
210+
if len(sample_kv_layer) == 2:
211+
storage_backends = config["storage_backends"]
212+
k_dir = os.path.join(storage_backends, "k")
213+
v_dir = os.path.join(storage_backends, "v")
214+
os.makedirs(k_dir, exist_ok=True)
215+
os.makedirs(v_dir, exist_ok=True)
216+
logger.info(f"Created subdirectories: {k_dir}, {v_dir}")
217+
218+
k_io_size = (
219+
sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size()
220+
)
221+
config["io_size"] = k_io_size
222+
config["kv_block_size"] = k_io_size * self.num_layers
223+
config["storage_backends"] = k_dir
224+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
225+
logger.info("init UCConnectorImpl, k_connector: %s", name)
226+
logger.info(
227+
"single file size = %d MB, io_size = %d KB,",
228+
config["kv_block_size"] / 1024 / 1024,
229+
config["io_size"] / 1024,
230+
)
231+
232+
v_io_size = (
233+
sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size()
234+
)
235+
config["io_size"] = v_io_size
236+
config["kv_block_size"] = v_io_size * self.num_layers
237+
config["storage_backends"] = v_dir
238+
self.v_store = UcmConnectorFactoryV1.create_connector(name, config)
239+
logger.info("init UCConnectorImpl, v_connector: %s", name)
240+
logger.info(
241+
"single file size = %d MB, io_size = %d KB,",
242+
config["kv_block_size"] / 1024 / 1024,
243+
config["io_size"] / 1024,
244+
)
245+
self.block_data_size = (k_io_size + v_io_size) * self.num_layers
246+
else:
247+
k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size()
248+
config["io_size"] = k_io_size
249+
config["kv_block_size"] = k_io_size * self.num_layers
250+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
251+
logger.info("init UCConnectorImpl, k_connector: %s", name)
252+
logger.info(
253+
"single file size = %d MB, io_size = %d KB,",
254+
config["kv_block_size"] / 1024 / 1024,
255+
config["io_size"] / 1024,
256+
)
257+
self.block_data_size = k_io_size * self.num_layers
258+
211259
def get_num_new_matched_tokens(
212260
self,
213261
request: "Request",
@@ -222,7 +270,7 @@ def get_num_new_matched_tokens(
222270
if not external_block_ids:
223271
return 0, False
224272

225-
lookup_results = self.store.lookup(external_block_ids)
273+
lookup_results = self.k_store.lookup(external_block_ids)
226274
external_hit_blocks = 0
227275
for i, hit in enumerate(lookup_results):
228276
if not hit:
@@ -412,15 +460,18 @@ def _get_tensors(
412460

413461
def _generate_task(
414462
self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]
415-
) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
416-
block_ids, shard_indexs, tensors = [], [], []
463+
) -> Tuple[
464+
List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]]
465+
]:
466+
block_ids, shard_indexs, total_k_tensors, total_v_tensors = [], [], [], []
417467
for i, vllm_block_id in enumerate(vllm_block_ids):
418468
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
419469
block_ids.append(ucm_block_ids[i])
420-
tensors.append(k_tensors + v_tensors)
470+
total_k_tensors.append(k_tensors)
471+
total_v_tensors.append(v_tensors)
421472
shard_indexs.append(0)
422473

423-
return block_ids, shard_indexs, tensors
474+
return block_ids, shard_indexs, total_k_tensors, total_v_tensors
424475

425476
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
426477
rec_tensor: torch.Tensor = None
@@ -447,7 +498,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
447498

448499
self._init_kv_caches_from_forward_context(forward_context)
449500

450-
request_to_task: dict[str, Optional[Task]] = {}
501+
request_to_task: dict[str, Optional[List[Task]]] = {}
451502
req_broadcast_addr = {}
452503
is_load = False
453504
num_loaded_block = 0
@@ -464,22 +515,28 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
464515
if self.global_rank != 0 and not self.is_mla and not self.is_dsa:
465516
for i, ucm_block_id in enumerate(ucm_block_ids):
466517
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
467-
block_ids, shard_indexs, tensors = self._generate_task(
518+
block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task(
468519
vllm_block_ids, ucm_block_ids
469520
)
470521
if self.global_rank == 0 or not self.load_only_first_rank:
471-
request_to_task[request_id] = self.store.load(
472-
block_ids, shard_indexs, tensors
473-
)
522+
k_task = self.k_store.load(block_ids, shard_indexs, k_tensors)
523+
request_to_task[request_id] = [k_task]
524+
if v_tensors and self.v_store:
525+
v_task = self.v_store.load(block_ids, shard_indexs, v_tensors)
526+
request_to_task[request_id].append(v_task)
474527
else:
475528
request_to_task[request_id] = None
476-
req_broadcast_addr[request_id] = [t for row in tensors for t in row]
529+
req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [
530+
t for row in v_tensors for t in row
531+
]
477532

478-
for request_id, task in request_to_task.items():
533+
for request_id, tasks in request_to_task.items():
479534
# TODO error handling
480535
if self.global_rank == 0 or not self.load_only_first_rank:
481536
try:
482-
self.store.wait(task)
537+
self.k_store.wait(tasks[0])
538+
if len(tasks) > 1 and self.v_store:
539+
self.v_store.wait(tasks[1])
483540
except RuntimeError as e:
484541
logger.error("request {request_id} load kv cache failed.:", e)
485542
self._invalid_block_ids.update(
@@ -531,7 +588,7 @@ def wait_for_save(self) -> None:
531588
metadata = self._get_connector_metadata()
532589
assert isinstance(metadata, UCMConnectorMetadata)
533590

534-
request_to_task: dict[str, Task] = {}
591+
request_to_task: dict[str, List[Task]] = {}
535592
is_save = False
536593
num_saved_block = 0
537594
num_saved_request = 0
@@ -547,16 +604,20 @@ def wait_for_save(self) -> None:
547604
if self.global_rank != 0:
548605
for i, ucm_block_id in enumerate(ucm_block_ids):
549606
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
550-
block_ids, shard_indexs, tensors = self._generate_task(
607+
block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task(
551608
vllm_block_ids, ucm_block_ids
552609
)
553-
request_to_task[request_id] = self.store.dump(
554-
block_ids, shard_indexs, tensors
555-
)
610+
k_task = self.k_store.dump(block_ids, shard_indexs, k_tensors)
611+
request_to_task[request_id] = [k_task]
612+
if v_tensors and self.v_store:
613+
v_task = self.v_store.dump(block_ids, shard_indexs, v_tensors)
614+
request_to_task[request_id].append(v_task)
556615

557-
for request_id, task in request_to_task.items():
616+
for request_id, tasks in request_to_task.items():
558617
try:
559-
self.store.wait(task)
618+
self.k_store.wait(tasks[0])
619+
if len(tasks) > 1 and self.v_store:
620+
self.v_store.wait(tasks[1])
560621
except RuntimeError as e:
561622
logger.error("request {request_id} dump kv cache failed.:", e)
562623
save_end_time = time.perf_counter() * 1000
@@ -739,6 +800,16 @@ def update_state_after_alloc(
739800
"""
740801
self.connector.update_state_after_alloc(request, blocks, num_external_tokens)
741802

803+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
804+
"""
805+
Initialize with the KV caches. Useful for pre-registering the
806+
KV Caches in the KVConnector (e.g. for NIXL).
807+
808+
Args: kv_caches:
809+
dictionary of layer names, kv cache
810+
"""
811+
self.connector.register_kv_caches(kv_caches)
812+
742813
def build_connector_meta(
743814
self, scheduler_output: SchedulerOutput
744815
) -> KVConnectorMetadata:

0 commit comments

Comments
 (0)