@@ -117,17 +117,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
117117 self .k_store : UcmKVStoreBaseV1
118118 self .v_store : Optional [UcmKVStoreBaseV1 ] = None
119119
120- if role == KVConnectorRole .SCHEDULER :
121- self .request_hasher = RequestHasher (vllm_config , 0 )
122- else :
123- self .request_hasher = RequestHasher (vllm_config , self .global_rank )
124-
125120 # save block info, avoid hash request twice, and track them until request finished
126121 self .requests_meta : dict [str , RequestMeta ] = {}
127122
128123 ucm_config = Config (vllm_config .kv_transfer_config )
129124 self .launch_config = ucm_config .get_config ()
130-
125+ logger .info (f"self.launch_config: { self .launch_config } " )
126+ self .connector_configs = self .launch_config .get ("ucm_connectors" , [])
127+ assert len (self .connector_configs ) > 0 , "no storage connector name in config."
131128 self .load_only_first_rank : bool = (
132129 self .launch_config .get ("load_only_first_rank" , self .is_mla ) and self .is_mla
133130 )
@@ -137,7 +134,27 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
137134 self .broadcast_fn = self .group_coordinator .broadcast
138135 self .broadcast_stream = torch .cuda .Stream ()
139136
140- logger .info (f"self.launch_config: { self .launch_config } " )
137+ name = self .connector_configs [0 ].get ("ucm_connector_name" )
138+ config = self .connector_configs [0 ].get ("ucm_connector_config" ) or {}
139+ storage_backends = [
140+ path for path in config ["storage_backends" ].split (":" ) if path
141+ ]
142+ self .k_storage_backends = [os .path .join (p , "k" ) for p in storage_backends ]
143+ self .v_storage_backends = [os .path .join (p , "v" ) for p in storage_backends ]
144+ os .makedirs (self .k_storage_backends [0 ], exist_ok = True )
145+ os .makedirs (self .v_storage_backends [0 ], exist_ok = True )
146+ logger .info (
147+ f"Created subdirectories: { self .k_storage_backends } , { self .v_storage_backends } "
148+ )
149+
150+ if role == KVConnectorRole .SCHEDULER :
151+ self .request_hasher = RequestHasher (vllm_config , 0 )
152+ # init scheduler-size connector
153+ config ["storage_backends" ] = ":" .join (self .k_storage_backends )
154+ config ["role" ] = "scheduler"
155+ self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
156+ else :
157+ self .request_hasher = RequestHasher (vllm_config , self .global_rank )
141158
142159 self .metrics_config = self .launch_config .get ("metrics_config_path" , "" )
143160 if self .metrics_config :
@@ -195,30 +212,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
195212 self .is_mla = False
196213 self .is_dsa = True
197214
215+ # init work-side connector
198216 # When handling the GQA case, we will separately dump the k_cache and v_cache.
199- connector_configs = self .launch_config .get ("ucm_connectors" , [])
200- assert len (connector_configs ) > 0 , "no storage connector name in config."
201-
202- name = connector_configs [0 ].get ("ucm_connector_name" )
203- config = connector_configs [0 ].get ("ucm_connector_config" ) or {}
217+ name = self .connector_configs [0 ].get ("ucm_connector_name" )
218+ config = self .connector_configs [0 ].get ("ucm_connector_config" ) or {}
204219 config ["device" ] = self .local_rank
205- config ["role" ] = (
206- "scheduler" if self ._role == KVConnectorRole .SCHEDULER else "worker"
207- )
220+ config ["role" ] = "worker"
208221 if len (sample_kv_layer ) == 2 :
209- storage_backends = config ["storage_backends" ]
210- k_dir = os .path .join (storage_backends , "k" )
211- v_dir = os .path .join (storage_backends , "v" )
212- os .makedirs (k_dir , exist_ok = True )
213- os .makedirs (v_dir , exist_ok = True )
214- logger .info (f"Created subdirectories: { k_dir } , { v_dir } " )
215-
216222 k_io_size = (
217223 sample_kv_layer [0 ][0 ].numel () * sample_kv_layer [0 ][0 ].element_size ()
218224 )
219225 config ["io_size" ] = k_io_size
220226 config ["kv_block_size" ] = k_io_size * self .num_layers
221- config ["storage_backends" ] = k_dir
227+ config ["storage_backends" ] = ":" . join ( self . k_storage_backends )
222228 self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
223229 logger .info ("init UCConnectorImpl, k_connector: %s" , name )
224230 logger .info (
@@ -232,7 +238,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
232238 )
233239 config ["io_size" ] = v_io_size
234240 config ["kv_block_size" ] = v_io_size * self .num_layers
235- config ["storage_backends" ] = v_dir
241+ config ["storage_backends" ] = ":" . join ( self . v_storage_backends )
236242 self .v_store = UcmConnectorFactoryV1 .create_connector (name , config )
237243 logger .info ("init UCConnectorImpl, v_connector: %s" , name )
238244 logger .info (
@@ -245,6 +251,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
245251 k_io_size = sample_kv_layer [0 ].numel () * sample_kv_layer [0 ].element_size ()
246252 config ["io_size" ] = k_io_size
247253 config ["kv_block_size" ] = k_io_size * self .num_layers
254+ config ["storage_backends" ] = ":" .join (self .k_storage_backends )
248255 self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
249256 logger .info ("init UCConnectorImpl, k_connector: %s" , name )
250257 logger .info (
0 commit comments