Skip to content

Commit d6bc98d

Browse files
author
wuhuxiao
committed
clean code
1 parent 0919b19 commit d6bc98d

File tree

7 files changed

+23
-68
lines changed

7 files changed

+23
-68
lines changed

ucm/integration/vllm/blend_connector.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
1-
import hashlib
21
import itertools
3-
import os
4-
import pickle
5-
import time
62
from dataclasses import dataclass, field
73
from enum import Enum, auto
8-
from typing import TYPE_CHECKING, Callable, List, Optional, Self, Tuple
4+
from typing import TYPE_CHECKING, List, Self, Tuple
95

106
import torch
117
from vllm.config import VllmConfig
128
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
13-
KVConnectorBase_V1,
149
KVConnectorMetadata,
1510
KVConnectorRole,
1611
)
17-
from vllm.distributed.parallel_state import get_tp_group, get_world_group
18-
from vllm.platforms import current_platform
1912
from vllm.v1.core.sched.output import SchedulerOutput
2013
from vllm.v1.request import Request
2114

@@ -28,16 +21,9 @@
2821
)
2922
from ucm.logger import init_logger
3023
from ucm.shared.metrics import ucmmonitor
31-
from ucm.shared.metrics.observability import UCMStatsLogger
3224
from ucm.sparse.blend.blockwise_rope import block_wise_rope_forward
33-
from ucm.sparse.kvstar.multistep import ReqStage
34-
from ucm.store.factory import UcmConnectorFactory
35-
from ucm.store.ucmstore import Task, UcmKVStoreBase
36-
from ucm.utils import Config
3725

3826
if TYPE_CHECKING:
39-
from vllm.attention.backends.abstract import AttentionMetadata
40-
from vllm.forward_context import ForwardContext
4127
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
4228

4329
logger = init_logger(__name__)
@@ -82,7 +68,7 @@ def hits_vllm_blk_ids(self) -> List[int]:
8268
def hits_chunk_blks_hash(self) -> List[str]:
8369
return list(itertools.compress(self.chunk_blks_hash, self.store_hits))
8470

85-
def merge_chunk(self, temp_chunk_meta: Self):
71+
def merge_chunk(self, temp_chunk_meta: Self) -> None:
8672
# current we use a fix pattern(end with a fix token id) to recognize the text token chunk
8773
# in some special situation, one text chunk maybe split as multi text chunk, so we should merge them into one
8874
self.chunk_tokens_len += temp_chunk_meta.chunk_tokens_len
@@ -107,10 +93,10 @@ class BlendStage(Enum):
10793
BUILD_PREFIX_CACHE = auto()
10894
CACHE_BLEND = auto()
10995

110-
def is_blend_cache(self):
96+
def is_blend_cache(self) -> bool:
11197
return self == BlendStage.CACHE_BLEND
11298

113-
def is_prefix_cache(self):
99+
def is_prefix_cache(self) -> bool:
114100
return self == BlendStage.BUILD_PREFIX_CACHE
115101

116102

@@ -137,10 +123,7 @@ class UCMBlendConnectorMetadata(UCMConnectorMetadata):
137123

138124
class UCMBlendConnector(UCMDirectConnector):
139125
"""
140-
This Connector means overlap:
141-
load l0 -> forward l0 -> save l0
142-
load l1 -> forward l1 -> save l1
143-
load l2 -> forward l2 -> save l2
126+
This Connector process chunk hash and prefix cache
144127
"""
145128

146129
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
@@ -265,7 +248,7 @@ def _get_req_chunk_hit(
265248
prefix_block_hashes: List[str],
266249
req_chunks_meta: List[ChunkMetaData],
267250
req_chunks_hashes: List[str],
268-
):
251+
) -> Tuple[int, int]:
269252

270253
# first perform prefix cache lookup
271254
pc_lookup_results = self.store.lookup(prefix_block_hashes)
@@ -312,7 +295,7 @@ def _generate_blend_dispatch_meta(
312295
----------------------------------------------------------------------------------------------------------
313296
| LOAD | DUMP |
314297
----------------------------------------------------------------------------------------------------------
315-
| REUSE | RECOMPUTE |
298+
| REUSE | RECOMPUTE |
316299
----------------------------------------------------------------------------------------------------------
317300
318301
@@ -362,7 +345,7 @@ def _generate_blend_dispatch_meta(
362345
req_meta.chunks_meta,
363346
)
364347

365-
def _post_process_chunk_cache(self, k_cache, vllm_ids, positions):
348+
def _post_process_chunk_cache(self, k_cache, vllm_ids, positions) -> None:
366349
"""
367350
post process loaded chunk kcache
368351
"""
@@ -371,7 +354,7 @@ def _post_process_chunk_cache(self, k_cache, vllm_ids, positions):
371354
# triton kernl for block-wise delta rope
372355
block_wise_rope_forward(k_cache, vllm_ids, positions, self.cos_sin_cache)
373356

374-
def _register_cos_sin_cache(self, model: "Model"):
357+
def _register_cos_sin_cache(self, model: "Model") -> None:
375358
try:
376359
rotary_emb = model.model.layers[0].self_attn.rotary_emb
377360
self.cos_sin_cache = rotary_emb.cos_sin_cache

ucm/sparse/blend/README.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

ucm/sparse/blend/blend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313

1414
from vllm.config import VllmConfig
1515
from vllm.forward_context import ForwardContext
16-
from vllm.v1.core.sched.output import SchedulerOutput
1716
from vllm.v1.request import Request
1817

19-
from ucm.integration.vllm.blend_connector import BlendRequestDispatchMeta, ChunkMetaData
18+
from ucm.integration.vllm.blend_connector import BlendRequestDispatchMeta
2019
from ucm.sparse.base import (
2120
INVALID_SLOT,
2221
UcmSparseBase,

ucm/sparse/blend/blockwise_rope.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def _triton_rope_blockwise_kernel(
5959
tl.store(k_ptr + offs + hd // 2, new_k_tile_2, mask=mask)
6060

6161

62-
def block_wise_rope_forward(k_cache, vllm_ids, positions, cos_sin_cache):
62+
def block_wise_rope_forward(
63+
k_cache: torch.Tensor,
64+
vllm_ids: torch.Tensor,
65+
positions: torch.Tensor,
66+
cos_sin_cache: torch.Tensor,
67+
) -> torch.Tensor:
6368
"""
6469
Args:
6570
k_cache: torch.Tensor (total_blocks, seq_len, n_kv_heads, hd), vllm owned.
@@ -96,7 +101,12 @@ def block_wise_rope_forward(k_cache, vllm_ids, positions, cos_sin_cache):
96101
return k_cache
97102

98103

99-
def rope_naive_torch(k_cache, vllm_ids, positions, cos_sin_cache):
104+
def rope_naive_torch(
105+
k_cache: torch.Tensor,
106+
vllm_ids: torch.Tensor,
107+
positions: torch.Tensor,
108+
cos_sin_cache: torch.Tensor,
109+
) -> torch.Tensor:
100110
"""
101111
naive torch implementation for accuracy and perf baseline
102112
Args:

ucm/sparse/blend/utils.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

ucm/sparse/gsa/gsa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ def execute_finished(self, logits_indices: torch.Tensor):
933933
self.prefetch_engine.deal_async_prefetch(
934934
False, self.gsa_metadata, kv_caches, None
935935
)
936+
return logits_indices
936937

937938
def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches):
938939
if all_free_block_ids == None:
@@ -1006,8 +1007,6 @@ def check_transfer_task_done(self) -> bool:
10061007
self.task_load.clear()
10071008
return True
10081009

1009-
return logits_indices
1010-
10111010
def build_sparse_meta(
10121011
self, scheduler_output: SchedulerOutput, requests, input_batch, attn_metadata
10131012
) -> None:

ucm/sparse/state.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def maybe_execute_sparse_layer_begin(
8282
if not has_ucm_sparse():
8383
return positions, hidden_states, residual
8484
ucm_spare = get_ucm_sparse()
85-
# after sparse, n_tokens of source tensor is larger than target
8685
return ucm_spare.layer_begin(positions, hidden_states, residual)
8786

8887

@@ -92,15 +91,13 @@ def maybe_execute_sparse_layer_finished(
9291
if not has_ucm_sparse():
9392
return positions, hidden_states, residual
9493
ucm_spare = get_ucm_sparse()
95-
# after sparse, n_tokens of source tensor is larger than target
9694
return ucm_spare.layer_finished(positions, hidden_states, residual)
9795

9896

9997
def maybe_execute_sparse_ffn_begin(hidden_states: torch.Tensor, residual: torch.Tensor):
10098
if not has_ucm_sparse():
10199
return hidden_states, residual
102100
ucm_spare = get_ucm_sparse()
103-
# after sparse, n_tokens of source tensor is larger than target
104101
return ucm_spare.ffn_begin(hidden_states, residual)
105102

106103

@@ -110,5 +107,4 @@ def maybe_execute_sparse_ffn_finished(
110107
if not has_ucm_sparse():
111108
return hidden_states, residual
112109
ucm_spare = get_ucm_sparse()
113-
# after sparse, n_tokens of source tensor is larger than target
114110
return ucm_spare.ffn_finished(hidden_states, residual)

0 commit comments

Comments
 (0)