2121from ucm .shared .metrics import ucmmonitor
2222from ucm .shared .metrics .observability import UCMStatsLogger
2323from ucm .store .factory import UcmConnectorFactory
24- from ucm .store .ucmstore_v1 import UcmKVStoreBaseV1 , Task
24+ from ucm .store .ucmstore_v1 import Task , UcmKVStoreBaseV1
2525from ucm .utils import Config
2626
2727if TYPE_CHECKING :
@@ -394,20 +394,25 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
394394 return int (chunk )
395395 return None
396396
397- def _get_tensors (self , vllm_block_id : int ) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
397+ def _get_tensors (
398+ self , vllm_block_id : int
399+ ) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
398400 """
399401 GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
400402 MLA: one layer shape is (num_blocks, block_size, head_size)
401403 """
402404 k_tensors , v_tensors = [], []
403405 for _ , kv_layer in self .kv_caches .items ():
404- k_tensors .append (kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ])
406+ k_tensors .append (
407+ kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ]
408+ )
405409 if not self .is_mla :
406410 v_tensors .append (kv_layer [1 ][vllm_block_id ])
407411 return k_tensors , v_tensors
408412
409-
410- def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [bytes ]) -> Tuple [List [bytes ], List [int ], List [List [torch .Tensor ]]]:
413+ def _generate_task (
414+ self , vllm_block_ids : List [int ], ucm_block_ids : List [bytes ]
415+ ) -> Tuple [List [bytes ], List [int ], List [List [torch .Tensor ]]]:
411416 """
412417 GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
413418 MLA: one layer shape is (num_blocks, block_size, head_size)
@@ -416,7 +421,7 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[bytes])
416421 for i , vllm_block_id in enumerate (vllm_block_ids ):
417422 k_tensors , v_tensors = self ._get_tensors (vllm_block_id )
418423 block_ids .append (ucm_block_ids [i ])
419- tensors .append (k_tensors + v_tensors )
424+ tensors .append (k_tensors + v_tensors )
420425 shard_indexs .append (0 )
421426
422427 return block_ids , shard_indexs , tensors
0 commit comments