Skip to content

Commit 9affb6c

Browse files
committed
remove init_kv_caches in start_load_kv
1 parent 49b0443 commit 9affb6c

File tree

1 file changed

+2
-30
lines changed

1 file changed

+2
-30
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,16 @@ def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
181181
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
182182
self.kv_caches = kv_caches
183183
sample_kv_layer = next(iter(self.kv_caches.values()))
184+
if self.kv_cache_dtype is None:
185+
self.kv_cache_dtype = sample_kv_layer[0].dtype
184186
if isinstance(sample_kv_layer, torch.Tensor):
185187
logger.info(f"kv cache shape {sample_kv_layer.shape}")
186-
if self.kv_cache_dtype is None:
187-
self.kv_cache_dtype = sample_kv_layer.dtype
188188
elif isinstance(sample_kv_layer, Tuple):
189189
# Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple
190190
# [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)]
191191
# Currently, we treat it as GQA, and use is_dsa to mark it
192192
for i, tensor in enumerate(sample_kv_layer):
193193
logger.info(f"kv cache shape {i}: {tensor.shape}")
194-
if self.kv_cache_dtype is None:
195-
self.kv_cache_dtype = sample_kv_layer[0].dtype
196194
if self.is_mla:
197195
self.is_mla = False
198196
self.is_dsa = True
@@ -408,30 +406,6 @@ def build_connector_meta(
408406

409407
return UCMConnectorMetadata(requests_dispatch_meta)
410408

411-
def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
412-
if len(self.kv_caches) > 0:
413-
return
414-
for layer_name in forward_context.no_compile_layers:
415-
attn_layer = forward_context.no_compile_layers[layer_name]
416-
if not hasattr(attn_layer, "kv_cache"):
417-
continue
418-
419-
if layer_name not in self.kv_caches:
420-
self.kv_caches[layer_name] = attn_layer.kv_cache[
421-
forward_context.virtual_engine
422-
]
423-
# Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to
424-
# (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim).
425-
# Currently, we treat it as GQA, and use is_dsa to mark it,
426-
# which works but leads to space inefficiency.
427-
# TODO: Optimize this to avoid unnecessary space usage.
428-
sample_kv_layer = next(iter(self.kv_caches.values()))
429-
if self.is_mla and len(sample_kv_layer) == 2:
430-
self.is_mla = False
431-
self.is_dsa = True
432-
if self.kv_cache_dtype is None:
433-
self.kv_cache_dtype = sample_kv_layer[0].dtype
434-
435409
@staticmethod
436410
def _extract_layer_index(layer_name: str) -> Optional[int]:
437411
"""
@@ -496,8 +470,6 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
496470
metadata = self._get_connector_metadata()
497471
assert isinstance(metadata, UCMConnectorMetadata)
498472

499-
self._init_kv_caches_from_forward_context(forward_context)
500-
501473
request_to_task: dict[str, Optional[List[Task]]] = {}
502474
req_broadcast_addr = {}
503475
is_load = False

0 commit comments

Comments
 (0)