Skip to content

Commit 3250fca

Browse files
committed
fix code style
1 parent 4b0d330 commit 3250fca

File tree

6 files changed

+30
-13
lines changed

6 files changed

+30
-13
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ucm.shared.metrics import ucmmonitor
2222
from ucm.shared.metrics.observability import UCMStatsLogger
2323
from ucm.store.factory import UcmConnectorFactory
24-
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1, Task
24+
from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1
2525
from ucm.utils import Config
2626

2727
if TYPE_CHECKING:
@@ -394,20 +394,25 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
394394
return int(chunk)
395395
return None
396396

397-
def _get_tensors(self, vllm_block_id: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
397+
def _get_tensors(
398+
self, vllm_block_id: int
399+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
398400
"""
399401
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
400402
MLA: one layer shape is (num_blocks, block_size, head_size)
401403
"""
402404
k_tensors, v_tensors = [], []
403405
for _, kv_layer in self.kv_caches.items():
404-
k_tensors.append(kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id])
406+
k_tensors.append(
407+
kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id]
408+
)
405409
if not self.is_mla:
406410
v_tensors.append(kv_layer[1][vllm_block_id])
407411
return k_tensors, v_tensors
408412

409-
410-
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
413+
def _generate_task(
414+
self, vllm_block_ids: List[int], ucm_block_ids: List[bytes]
415+
) -> Tuple[List[bytes], List[int], List[List[torch.Tensor]]]:
411416
"""
412417
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
413418
MLA: one layer shape is (num_blocks, block_size, head_size)
@@ -416,7 +421,7 @@ def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[bytes])
416421
for i, vllm_block_id in enumerate(vllm_block_ids):
417422
k_tensors, v_tensors = self._get_tensors(vllm_block_id)
418423
block_ids.append(ucm_block_ids[i])
419-
tensors.append(k_tensors+v_tensors)
424+
tensors.append(k_tensors + v_tensors)
420425
shard_indexs.append(0)
421426

422427
return block_ids, shard_indexs, tensors

ucm/store/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBaseV1
6060

6161
UcmConnectorFactory.register_connector(
6262
"UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore"
63-
)
63+
)

ucm/store/pcstore/cc/domain/trans/trans_task.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class TransTask {
6161
void Append(const std::string& block, const std::vector<uintptr_t>& addresses)
6262
{
6363
grouped_[block] = addresses;
64-
number_ += addresses.size();
64+
number_ += addresses.size();
6565
}
6666
auto Str() const noexcept { return fmt::format("{},{},{}", id, brief_, number_); }
6767
size_t GroupNumber() const { return grouped_.size(); }

ucm/store/pcstore/pcstore_connector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,20 @@ def prefetch(self, block_ids: List[bytes]) -> None:
6969
pass
7070

7171
def load(
72-
self, block_ids: List[bytes], shard_index: List[int], dst_tensor: List[List[torch.Tensor]]
72+
self,
73+
block_ids: List[bytes],
74+
shard_index: List[int],
75+
dst_tensor: List[List[torch.Tensor]],
7376
) -> Task:
7477
dst_tensor_ptrs = [[t.data_ptr() for t in tensors] for tensors in dst_tensor]
7578
task_id = self.store.LoadToDevice(block_ids, dst_tensor_ptrs)
7679
return PcTask(task_id=task_id)
7780

7881
def dump(
79-
self, block_ids: List[bytes], shard_index: List[int], src_tensor: List[List[torch.Tensor]]
82+
self,
83+
block_ids: List[bytes],
84+
shard_index: List[int],
85+
src_tensor: List[List[torch.Tensor]],
8086
) -> Task:
8187
src_tensor_ptrs = [[t.data_ptr() for t in tensors] for tensors in src_tensor]
8288
task_id = self.store.DumpFromDevice(block_ids, src_tensor_ptrs)

ucm/store/test/e2e/pcstore_embed.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,18 @@ def make_buffers(
6262
return hashes, tensors
6363

6464

65-
def embed(store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]):
65+
def embed(
66+
store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
67+
):
6668
shard_index = [0] * len(hashes)
6769
task = store.dump(hashes, shard_index, tensors)
6870
assert task.task_id > 0
6971
store.wait(task)
7072

7173

72-
def fetch(store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]):
74+
def fetch(
75+
store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
76+
):
7377
founds = store.lookup(hashes)
7478
for found in founds:
7579
assert found

ucm/store/test/e2e/pcstore_fetch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def make_buffers(device_id, batch_size, block_dim, block_len, block_layer):
7070
return tensors
7171

7272

73-
def fetch(store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]):
73+
def fetch(
74+
store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]]
75+
):
7476
founds = store.lookup(hashes)
7577
for found in founds:
7678
assert found

0 commit comments

Comments
 (0)