Skip to content

Commit a848420

Browse files
committed
add scheduler side connector
1 parent 9affb6c commit a848420

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
117117
self.k_store: UcmKVStoreBaseV1
118118
self.v_store: Optional[UcmKVStoreBaseV1] = None
119119

120-
if role == KVConnectorRole.SCHEDULER:
121-
self.request_hasher = RequestHasher(vllm_config, 0)
122-
else:
123-
self.request_hasher = RequestHasher(vllm_config, self.global_rank)
124-
125120
# save block info, avoid hash request twice, and track them until request finished
126121
self.requests_meta: dict[str, RequestMeta] = {}
127122

128123
ucm_config = Config(vllm_config.kv_transfer_config)
129124
self.launch_config = ucm_config.get_config()
130-
125+
logger.info(f"self.launch_config: {self.launch_config}")
126+
self.connector_configs = self.launch_config.get("ucm_connectors", [])
127+
assert len(self.connector_configs) > 0, "no storage connector name in config."
131128
self.load_only_first_rank: bool = (
132129
self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla
133130
)
@@ -136,8 +133,24 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
136133
self.group_coordinator = get_tp_group()
137134
self.broadcast_fn = self.group_coordinator.broadcast
138135
self.broadcast_stream = torch.cuda.Stream()
136+
137+
name = self.connector_configs[0].get("ucm_connector_name")
138+
config = self.connector_configs[0].get("ucm_connector_config") or {}
139+
storage_backends = [path for path in config["storage_backends"].split(":") if path]
140+
self.k_storage_backends = [os.path.join(p, "k") for p in storage_backends]
141+
self.v_storage_backends = [os.path.join(p, "v") for p in storage_backends]
142+
os.makedirs(self.k_storage_backends[0], exist_ok=True)
143+
os.makedirs(self.v_storage_backends[0], exist_ok=True)
144+
logger.info(f"Created subdirectories: {self.k_storage_backends}, {self.v_storage_backends}")
139145

140-
logger.info(f"self.launch_config: {self.launch_config}")
146+
if role == KVConnectorRole.SCHEDULER:
147+
self.request_hasher = RequestHasher(vllm_config, 0)
148+
# init scheduler-size connector
149+
config["storage_backends"] = ":".join(self.k_storage_backends)
150+
config["role"] = "scheduler"
151+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
152+
else:
153+
self.request_hasher = RequestHasher(vllm_config, self.global_rank)
141154

142155
self.metrics_config = self.launch_config.get("metrics_config_path", "")
143156
if self.metrics_config:
@@ -195,30 +208,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
195208
self.is_mla = False
196209
self.is_dsa = True
197210

211+
# init work-side connector
198212
# When handling the GQA case, we will separately dump the k_cache and v_cache.
199-
connector_configs = self.launch_config.get("ucm_connectors", [])
200-
assert len(connector_configs) > 0, "no storage connector name in config."
201-
202-
name = connector_configs[0].get("ucm_connector_name")
203-
config = connector_configs[0].get("ucm_connector_config") or {}
213+
name = self.connector_configs[0].get("ucm_connector_name")
214+
config = self.connector_configs[0].get("ucm_connector_config") or {}
204215
config["device"] = self.local_rank
205-
config["role"] = (
206-
"scheduler" if self._role == KVConnectorRole.SCHEDULER else "worker"
207-
)
216+
config["role"] = "worker"
208217
if len(sample_kv_layer) == 2:
209-
storage_backends = config["storage_backends"]
210-
k_dir = os.path.join(storage_backends, "k")
211-
v_dir = os.path.join(storage_backends, "v")
212-
os.makedirs(k_dir, exist_ok=True)
213-
os.makedirs(v_dir, exist_ok=True)
214-
logger.info(f"Created subdirectories: {k_dir}, {v_dir}")
215-
216218
k_io_size = (
217219
sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size()
218220
)
219221
config["io_size"] = k_io_size
220222
config["kv_block_size"] = k_io_size * self.num_layers
221-
config["storage_backends"] = k_dir
223+
config["storage_backends"] = ":".join(self.k_storage_backends)
222224
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
223225
logger.info("init UCConnectorImpl, k_connector: %s", name)
224226
logger.info(
@@ -232,7 +234,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
232234
)
233235
config["io_size"] = v_io_size
234236
config["kv_block_size"] = v_io_size * self.num_layers
235-
config["storage_backends"] = v_dir
237+
config["storage_backends"] = ":".join(self.v_storage_backends)
236238
self.v_store = UcmConnectorFactoryV1.create_connector(name, config)
237239
logger.info("init UCConnectorImpl, v_connector: %s", name)
238240
logger.info(
@@ -245,6 +247,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
245247
k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size()
246248
config["io_size"] = k_io_size
247249
config["kv_block_size"] = k_io_size * self.num_layers
250+
config["storage_backends"] = ":".join(self.k_storage_backends)
248251
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
249252
logger.info("init UCConnectorImpl, k_connector: %s", name)
250253
logger.info(

ucm/store/pcstore/pcstore_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, config: Dict):
4343
storage_backends = [
4444
path for path in config["storage_backends"].split(":") if path
4545
]
46-
block_size = int(config["kv_block_size"])
46+
block_size = config.get("kv_block_size", 33554432)
4747
transfer_enable = True if config["role"] == "worker" else False
4848
param = ucmpcstore.PcStore.Config(storage_backends, block_size, transfer_enable)
4949
if transfer_enable:

0 commit comments

Comments
 (0)