@@ -184,6 +184,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
184184 else torch .npu .synchronize
185185 )
186186
187+ # invlalid block ids due to load errors
188+ self ._invalid_block_ids : set [int ] = set ()
189+
187190 def generate_hash (self , block_size : int , request : "Request" ) -> list [str ]:
188191 token_ids = request .all_token_ids
189192
@@ -513,6 +516,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
513516 # TODO error handling
514517 if self .global_rank == 0 or not self .load_only_first_rank :
515518 if self .store .wait (task ) != 0 :
519+ self ._invalid_block_ids .update (
520+ metadata .request_meta [request_id ].load_block_ids [1 ]
521+ )
516522 logger .error (f"request { request_id } load kv cache failed." )
517523 if self .load_only_first_rank :
518524 self ._broadcast (req_broadcast_addr [request_id ])
@@ -626,6 +632,18 @@ def wait_for_save(self) -> None:
626632 def clear_connector_metadata (self ) -> None :
627633 super ().clear_connector_metadata ()
628634
635+ def get_block_ids_with_load_errors (self ) -> set [int ]:
636+ """
637+ Get the set of block IDs that failed to load.
638+
639+ Returns:
640+ Set of block IDs that encountered load errors.
641+ Empty set if no load errors occurred.
642+ """
643+ res = self ._invalid_block_ids
644+ self ._invalid_block_ids = set ()
645+ return res
646+
629647
630648class UCMLayerWiseConnector (UCMDirectConnector ):
631649 """
@@ -866,3 +884,13 @@ def clear_connector_metadata(self) -> None:
866884 after the model execution.
867885 """
868886 self .connector .clear_connector_metadata ()
887+
888+ def get_block_ids_with_load_errors (self ) -> set [int ]:
889+ """
890+ Get the set of block IDs that failed to load.
891+
892+ Returns:
893+ Set of block IDs that encountered load errors.
894+ Empty set if no load errors occurred.
895+ """
896+ return self .connector .get_block_ids_with_load_errors ()
0 commit comments