1- import hashlib
21import itertools
3- import os
4- import pickle
5- import time
62from dataclasses import dataclass , field
73from enum import Enum , auto
8- from typing import TYPE_CHECKING , Callable , List , Optional , Self , Tuple
4+ from typing import TYPE_CHECKING , List , Self , Tuple
95
106import torch
117from vllm .config import VllmConfig
128from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
13- KVConnectorBase_V1 ,
149 KVConnectorMetadata ,
1510 KVConnectorRole ,
1611)
17- from vllm .distributed .parallel_state import get_tp_group , get_world_group
18- from vllm .platforms import current_platform
1912from vllm .v1 .core .sched .output import SchedulerOutput
2013from vllm .v1 .request import Request
2114
2821)
2922from ucm .logger import init_logger
3023from ucm .shared .metrics import ucmmonitor
31- from ucm .shared .metrics .observability import UCMStatsLogger
3224from ucm .sparse .blend .blockwise_rope import block_wise_rope_forward
33- from ucm .sparse .kvstar .multistep import ReqStage
34- from ucm .store .factory import UcmConnectorFactory
35- from ucm .store .ucmstore import Task , UcmKVStoreBase
36- from ucm .utils import Config
3725
3826if TYPE_CHECKING :
39- from vllm .attention .backends .abstract import AttentionMetadata
40- from vllm .forward_context import ForwardContext
4127 from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
4228
4329logger = init_logger (__name__ )
@@ -82,7 +68,7 @@ def hits_vllm_blk_ids(self) -> List[int]:
8268 def hits_chunk_blks_hash (self ) -> List [str ]:
8369 return list (itertools .compress (self .chunk_blks_hash , self .store_hits ))
8470
85- def merge_chunk (self , temp_chunk_meta : Self ):
71+ def merge_chunk (self , temp_chunk_meta : Self ) -> None :
8672 # current we use a fix pattern(end with a fix token id) to recognize the text token chunk
8773 # in some special situation, one text chunk maybe split as multi text chunk, so we should merge them into one
8874 self .chunk_tokens_len += temp_chunk_meta .chunk_tokens_len
@@ -107,10 +93,10 @@ class BlendStage(Enum):
10793 BUILD_PREFIX_CACHE = auto ()
10894 CACHE_BLEND = auto ()
10995
110- def is_blend_cache (self ):
96+ def is_blend_cache (self ) -> bool :
11197 return self == BlendStage .CACHE_BLEND
11298
113- def is_prefix_cache (self ):
99+ def is_prefix_cache (self ) -> bool :
114100 return self == BlendStage .BUILD_PREFIX_CACHE
115101
116102
@@ -137,10 +123,7 @@ class UCMBlendConnectorMetadata(UCMConnectorMetadata):
137123
138124class UCMBlendConnector (UCMDirectConnector ):
139125 """
140- This Connector means overlap:
141- load l0 -> forward l0 -> save l0
142- load l1 -> forward l1 -> save l1
143- load l2 -> forward l2 -> save l2
126+ This Connector process chunk hash and prefix cache
144127 """
145128
146129 def __init__ (self , vllm_config : "VllmConfig" , role : KVConnectorRole ):
@@ -265,7 +248,7 @@ def _get_req_chunk_hit(
265248 prefix_block_hashes : List [str ],
266249 req_chunks_meta : List [ChunkMetaData ],
267250 req_chunks_hashes : List [str ],
268- ):
251+ ) -> Tuple [ int , int ] :
269252
270253 # first perform prefix cache lookup
271254 pc_lookup_results = self .store .lookup (prefix_block_hashes )
@@ -312,7 +295,7 @@ def _generate_blend_dispatch_meta(
312295 ----------------------------------------------------------------------------------------------------------
313296 | LOAD | DUMP |
314297 ----------------------------------------------------------------------------------------------------------
315- | REUSE | RECOMPUTE |
298+ | REUSE | RECOMPUTE |
316299 ----------------------------------------------------------------------------------------------------------
317300
318301
@@ -362,7 +345,7 @@ def _generate_blend_dispatch_meta(
362345 req_meta .chunks_meta ,
363346 )
364347
365- def _post_process_chunk_cache (self , k_cache , vllm_ids , positions ):
348+ def _post_process_chunk_cache (self , k_cache , vllm_ids , positions ) -> None :
366349 """
367350 post process loaded chunk kcache
368351 """
@@ -371,7 +354,7 @@ def _post_process_chunk_cache(self, k_cache, vllm_ids, positions):
371354 # triton kernl for block-wise delta rope
372355 block_wise_rope_forward (k_cache , vllm_ids , positions , self .cos_sin_cache )
373356
374- def _register_cos_sin_cache (self , model : "Model" ):
357+ def _register_cos_sin_cache (self , model : "Model" ) -> None :
375358 try :
376359 rotary_emb = model .model .layers [0 ].self_attn .rotary_emb
377360 self .cos_sin_cache = rotary_emb .cos_sin_cache
0 commit comments