Skip to content

Commit fac15c0

Browse files
committed
add scheduler side connector
1 parent 9affb6c commit fac15c0

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
117117
self.k_store: UcmKVStoreBaseV1
118118
self.v_store: Optional[UcmKVStoreBaseV1] = None
119119

120-
if role == KVConnectorRole.SCHEDULER:
121-
self.request_hasher = RequestHasher(vllm_config, 0)
122-
else:
123-
self.request_hasher = RequestHasher(vllm_config, self.global_rank)
124-
125120
# save block info, avoid hash request twice, and track them until request finished
126121
self.requests_meta: dict[str, RequestMeta] = {}
127122

128123
ucm_config = Config(vllm_config.kv_transfer_config)
129124
self.launch_config = ucm_config.get_config()
130-
125+
logger.info(f"self.launch_config: {self.launch_config}")
126+
self.connector_configs = self.launch_config.get("ucm_connectors", [])
127+
assert len(self.connector_configs) > 0, "no storage connector name in config."
131128
self.load_only_first_rank: bool = (
132129
self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla
133130
)
@@ -137,7 +134,27 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
137134
self.broadcast_fn = self.group_coordinator.broadcast
138135
self.broadcast_stream = torch.cuda.Stream()
139136

140-
logger.info(f"self.launch_config: {self.launch_config}")
137+
name = self.connector_configs[0].get("ucm_connector_name")
138+
config = self.connector_configs[0].get("ucm_connector_config") or {}
139+
storage_backends = [
140+
path for path in config["storage_backends"].split(":") if path
141+
]
142+
self.k_storage_backends = [os.path.join(p, "k") for p in storage_backends]
143+
self.v_storage_backends = [os.path.join(p, "v") for p in storage_backends]
144+
os.makedirs(self.k_storage_backends[0], exist_ok=True)
145+
os.makedirs(self.v_storage_backends[0], exist_ok=True)
146+
logger.info(
147+
f"Created subdirectories: {self.k_storage_backends}, {self.v_storage_backends}"
148+
)
149+
150+
if role == KVConnectorRole.SCHEDULER:
151+
self.request_hasher = RequestHasher(vllm_config, 0)
152+
# init scheduler-size connector
153+
config["storage_backends"] = ":".join(self.k_storage_backends)
154+
config["role"] = "scheduler"
155+
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
156+
else:
157+
self.request_hasher = RequestHasher(vllm_config, self.global_rank)
141158

142159
self.metrics_config = self.launch_config.get("metrics_config_path", "")
143160
if self.metrics_config:
@@ -195,30 +212,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
195212
self.is_mla = False
196213
self.is_dsa = True
197214

215+
# init work-side connector
198216
# When handling the GQA case, we will separately dump the k_cache and v_cache.
199-
connector_configs = self.launch_config.get("ucm_connectors", [])
200-
assert len(connector_configs) > 0, "no storage connector name in config."
201-
202-
name = connector_configs[0].get("ucm_connector_name")
203-
config = connector_configs[0].get("ucm_connector_config") or {}
217+
name = self.connector_configs[0].get("ucm_connector_name")
218+
config = self.connector_configs[0].get("ucm_connector_config") or {}
204219
config["device"] = self.local_rank
205-
config["role"] = (
206-
"scheduler" if self._role == KVConnectorRole.SCHEDULER else "worker"
207-
)
220+
config["role"] = "worker"
208221
if len(sample_kv_layer) == 2:
209-
storage_backends = config["storage_backends"]
210-
k_dir = os.path.join(storage_backends, "k")
211-
v_dir = os.path.join(storage_backends, "v")
212-
os.makedirs(k_dir, exist_ok=True)
213-
os.makedirs(v_dir, exist_ok=True)
214-
logger.info(f"Created subdirectories: {k_dir}, {v_dir}")
215-
216222
k_io_size = (
217223
sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size()
218224
)
219225
config["io_size"] = k_io_size
220226
config["kv_block_size"] = k_io_size * self.num_layers
221-
config["storage_backends"] = k_dir
227+
config["storage_backends"] = ":".join(self.k_storage_backends)
222228
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
223229
logger.info("init UCConnectorImpl, k_connector: %s", name)
224230
logger.info(
@@ -232,7 +238,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
232238
)
233239
config["io_size"] = v_io_size
234240
config["kv_block_size"] = v_io_size * self.num_layers
235-
config["storage_backends"] = v_dir
241+
config["storage_backends"] = ":".join(self.v_storage_backends)
236242
self.v_store = UcmConnectorFactoryV1.create_connector(name, config)
237243
logger.info("init UCConnectorImpl, v_connector: %s", name)
238244
logger.info(
@@ -245,6 +251,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
245251
k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size()
246252
config["io_size"] = k_io_size
247253
config["kv_block_size"] = k_io_size * self.num_layers
254+
config["storage_backends"] = ":".join(self.k_storage_backends)
248255
self.k_store = UcmConnectorFactoryV1.create_connector(name, config)
249256
logger.info("init UCConnectorImpl, k_connector: %s", name)
250257
logger.info(

ucm/store/pcstore/pcstore_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, config: Dict):
4343
storage_backends = [
4444
path for path in config["storage_backends"].split(":") if path
4545
]
46-
block_size = int(config["kv_block_size"])
46+
block_size = config.get("kv_block_size", 33554432)
4747
transfer_enable = True if config["role"] == "worker" else False
4848
param = ucmpcstore.PcStore.Config(storage_backends, block_size, transfer_enable)
4949
if transfer_enable:

0 commit comments

Comments
 (0)