@@ -117,17 +117,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
117117 self .k_store : UcmKVStoreBaseV1
118118 self .v_store : Optional [UcmKVStoreBaseV1 ] = None
119119
120- if role == KVConnectorRole .SCHEDULER :
121- self .request_hasher = RequestHasher (vllm_config , 0 )
122- else :
123- self .request_hasher = RequestHasher (vllm_config , self .global_rank )
124-
125120 # save block info, avoid hash request twice, and track them until request finished
126121 self .requests_meta : dict [str , RequestMeta ] = {}
127122
128123 ucm_config = Config (vllm_config .kv_transfer_config )
129124 self .launch_config = ucm_config .get_config ()
130-
125+ logger .info (f"self.launch_config: { self .launch_config } " )
126+ self .connector_configs = self .launch_config .get ("ucm_connectors" , [])
127+ assert len (self .connector_configs ) > 0 , "no storage connector name in config."
131128 self .load_only_first_rank : bool = (
132129 self .launch_config .get ("load_only_first_rank" , self .is_mla ) and self .is_mla
133130 )
@@ -136,8 +133,24 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
136133 self .group_coordinator = get_tp_group ()
137134 self .broadcast_fn = self .group_coordinator .broadcast
138135 self .broadcast_stream = torch .cuda .Stream ()
136+
137+ name = self .connector_configs [0 ].get ("ucm_connector_name" )
138+ config = self .connector_configs [0 ].get ("ucm_connector_config" ) or {}
139+ storage_backends = [path for path in config ["storage_backends" ].split (":" ) if path ]
140+ self .k_storage_backends = [os .path .join (p , "k" ) for p in storage_backends ]
141+ self .v_storage_backends = [os .path .join (p , "v" ) for p in storage_backends ]
142+ os .makedirs (self .k_storage_backends [0 ], exist_ok = True )
143+ os .makedirs (self .v_storage_backends [0 ], exist_ok = True )
144+ logger .info (f"Created subdirectories: { self .k_storage_backends } , { self .v_storage_backends } " )
139145
140- logger .info (f"self.launch_config: { self .launch_config } " )
146+ if role == KVConnectorRole .SCHEDULER :
147+ self .request_hasher = RequestHasher (vllm_config , 0 )
148+ # init scheduler-size connector
149+ config ["storage_backends" ] = ":" .join (self .k_storage_backends )
150+ config ["role" ] = "scheduler"
151+ self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
152+ else :
153+ self .request_hasher = RequestHasher (vllm_config , self .global_rank )
141154
142155 self .metrics_config = self .launch_config .get ("metrics_config_path" , "" )
143156 if self .metrics_config :
@@ -195,30 +208,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
195208 self .is_mla = False
196209 self .is_dsa = True
197210
211+ # init work-side connector
198212 # When handling the GQA case, we will separately dump the k_cache and v_cache.
199- connector_configs = self .launch_config .get ("ucm_connectors" , [])
200- assert len (connector_configs ) > 0 , "no storage connector name in config."
201-
202- name = connector_configs [0 ].get ("ucm_connector_name" )
203- config = connector_configs [0 ].get ("ucm_connector_config" ) or {}
213+ name = self .connector_configs [0 ].get ("ucm_connector_name" )
214+ config = self .connector_configs [0 ].get ("ucm_connector_config" ) or {}
204215 config ["device" ] = self .local_rank
205- config ["role" ] = (
206- "scheduler" if self ._role == KVConnectorRole .SCHEDULER else "worker"
207- )
216+ config ["role" ] = "worker"
208217 if len (sample_kv_layer ) == 2 :
209- storage_backends = config ["storage_backends" ]
210- k_dir = os .path .join (storage_backends , "k" )
211- v_dir = os .path .join (storage_backends , "v" )
212- os .makedirs (k_dir , exist_ok = True )
213- os .makedirs (v_dir , exist_ok = True )
214- logger .info (f"Created subdirectories: { k_dir } , { v_dir } " )
215-
216218 k_io_size = (
217219 sample_kv_layer [0 ][0 ].numel () * sample_kv_layer [0 ][0 ].element_size ()
218220 )
219221 config ["io_size" ] = k_io_size
220222 config ["kv_block_size" ] = k_io_size * self .num_layers
221- config ["storage_backends" ] = k_dir
223+ config ["storage_backends" ] = ":" . join ( self . k_storage_backends )
222224 self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
223225 logger .info ("init UCConnectorImpl, k_connector: %s" , name )
224226 logger .info (
@@ -232,7 +234,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
232234 )
233235 config ["io_size" ] = v_io_size
234236 config ["kv_block_size" ] = v_io_size * self .num_layers
235- config ["storage_backends" ] = v_dir
237+ config ["storage_backends" ] = ":" . join ( self . v_storage_backends )
236238 self .v_store = UcmConnectorFactoryV1 .create_connector (name , config )
237239 logger .info ("init UCConnectorImpl, v_connector: %s" , name )
238240 logger .info (
@@ -245,6 +247,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
245247 k_io_size = sample_kv_layer [0 ].numel () * sample_kv_layer [0 ].element_size ()
246248 config ["io_size" ] = k_io_size
247249 config ["kv_block_size" ] = k_io_size * self .num_layers
250+ config ["storage_backends" ] = ":" .join (self .k_storage_backends )
248251 self .k_store = UcmConnectorFactoryV1 .create_connector (name , config )
249252 logger .info ("init UCConnectorImpl, k_connector: %s" , name )
250253 logger .info (
0 commit comments