Skip to content

Commit a3658e6

Browse files
authored
[Feat] UCM supports recovery form load failure (#477)
1 parent 79d250d commit a3658e6

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
184184
else torch.npu.synchronize
185185
)
186186

187+
# invlalid block ids due to load errors
188+
self._invalid_block_ids: set[int] = set()
189+
187190
def generate_hash(self, block_size: int, request: "Request") -> list[str]:
188191
token_ids = request.all_token_ids
189192

@@ -513,6 +516,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
513516
# TODO error handling
514517
if self.global_rank == 0 or not self.load_only_first_rank:
515518
if self.store.wait(task) != 0:
519+
self._invalid_block_ids.update(
520+
metadata.request_meta[request_id].load_block_ids[1]
521+
)
516522
logger.error(f"request {request_id} load kv cache failed.")
517523
if self.load_only_first_rank:
518524
self._broadcast(req_broadcast_addr[request_id])
@@ -626,6 +632,18 @@ def wait_for_save(self) -> None:
626632
def clear_connector_metadata(self) -> None:
627633
super().clear_connector_metadata()
628634

635+
def get_block_ids_with_load_errors(self) -> set[int]:
636+
"""
637+
Get the set of block IDs that failed to load.
638+
639+
Returns:
640+
Set of block IDs that encountered load errors.
641+
Empty set if no load errors occurred.
642+
"""
643+
res = self._invalid_block_ids
644+
self._invalid_block_ids = set()
645+
return res
646+
629647

630648
class UCMLayerWiseConnector(UCMDirectConnector):
631649
"""
@@ -866,3 +884,13 @@ def clear_connector_metadata(self) -> None:
866884
after the model execution.
867885
"""
868886
self.connector.clear_connector_metadata()
887+
888+
def get_block_ids_with_load_errors(self) -> set[int]:
889+
"""
890+
Get the set of block IDs that failed to load.
891+
892+
Returns:
893+
Set of block IDs that encountered load errors.
894+
Empty set if no load errors occurred.
895+
"""
896+
return self.connector.get_block_ids_with_load_errors()

0 commit comments

Comments
 (0)