@@ -181,18 +181,16 @@ def generate_hash(self, block_size: int, request: "Request") -> list[bytes]:
181181 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
182182 self .kv_caches = kv_caches
183183 sample_kv_layer = next (iter (self .kv_caches .values ()))
184+ if self .kv_cache_dtype is None :
185+ self .kv_cache_dtype = sample_kv_layer [0 ].dtype
184186 if isinstance (sample_kv_layer , torch .Tensor ):
185187 logger .info (f"kv cache shape { sample_kv_layer .shape } " )
186- if self .kv_cache_dtype is None :
187- self .kv_cache_dtype = sample_kv_layer .dtype
188188 elif isinstance (sample_kv_layer , Tuple ):
189189 # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple
190190 # [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)]
191191 # Currently, we treat it as GQA, and use is_dsa to mark it
192192 for i , tensor in enumerate (sample_kv_layer ):
193193 logger .info (f"kv cache shape { i } : { tensor .shape } " )
194- if self .kv_cache_dtype is None :
195- self .kv_cache_dtype = sample_kv_layer [0 ].dtype
196194 if self .is_mla :
197195 self .is_mla = False
198196 self .is_dsa = True
@@ -408,30 +406,6 @@ def build_connector_meta(
408406
409407 return UCMConnectorMetadata (requests_dispatch_meta )
410408
411- def _init_kv_caches_from_forward_context (self , forward_context : "ForwardContext" ):
412- if len (self .kv_caches ) > 0 :
413- return
414- for layer_name in forward_context .no_compile_layers :
415- attn_layer = forward_context .no_compile_layers [layer_name ]
416- if not hasattr (attn_layer , "kv_cache" ):
417- continue
418-
419- if layer_name not in self .kv_caches :
420- self .kv_caches [layer_name ] = attn_layer .kv_cache [
421- forward_context .virtual_engine
422- ]
423- # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to
424- # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim).
425- # Currently, we treat it as GQA, and use is_dsa to mark it,
426- # which works but leads to space inefficiency.
427- # TODO: Optimize this to avoid unnecessary space usage.
428- sample_kv_layer = next (iter (self .kv_caches .values ()))
429- if self .is_mla and len (sample_kv_layer ) == 2 :
430- self .is_mla = False
431- self .is_dsa = True
432- if self .kv_cache_dtype is None :
433- self .kv_cache_dtype = sample_kv_layer [0 ].dtype
434-
435409 @staticmethod
436410 def _extract_layer_index (layer_name : str ) -> Optional [int ]:
437411 """
@@ -496,8 +470,6 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
496470 metadata = self ._get_connector_metadata ()
497471 assert isinstance (metadata , UCMConnectorMetadata )
498472
499- self ._init_kv_caches_from_forward_context (forward_context )
500-
501473 request_to_task : dict [str , Optional [List [Task ]]] = {}
502474 req_broadcast_addr = {}
503475 is_load = False
0 commit comments