Skip to content

Commit 922e7dc

Browse files
committed
modify local_rank_size
1 parent fac15c0 commit 922e7dc

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
9898
self.num_layers = self._vllm_config.model_config.get_num_layers(
9999
self._vllm_config.parallel_config
100100
)
101+
self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size
101102
self.kv_cache_dtype: torch.dtype = None
102103

103104
if current_platform.is_cuda_alike():
@@ -218,6 +219,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
218219
config = self.connector_configs[0].get("ucm_connector_config") or {}
219220
config["device"] = self.local_rank
220221
config["role"] = "worker"
222+
config["local_rank_size"] = self.tp_size if self.is_mla or self.is_dsa else 1
221223
if len(sample_kv_layer) == 2:
222224
k_io_size = (
223225
sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size()

ucm/store/pcstore/pcstore_connector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def __init__(self, config: Dict):
5252
param.transferIoDirect = config.get("use_direct", False)
5353
param.transferStreamNumber = config.get("stream_number", 8)
5454
param.transferBufferNumber = config.get("buffer_number", 4096)
55-
param.transferLocalRankSize = config.get("local_rank_size", 1)
56-
param.transferScatterGatherEnable = config.get("use_scatter_gatter", False)
55+
param.transferScatterGatherEnable = config.get("use_scatter_gatter", True)
5756
ret = self.store.Setup(param)
5857
if ret != 0:
5958
msg = f"Failed to initialize ucmpcstore, errcode: {ret}."

0 commit comments

Comments
 (0)