Skip to content

Commit 5a37f34

Browse files
author
wuhuxiao
committed
clean code
1 parent f982864 commit 5a37f34

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

docs/source/user-guide/sparse-attention/cacheblend.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ CacheBlend reduces TTFT by 2.2 ~ 3.3× and increases throughput by 2.8 ~ 5× und
2929

3030
### Native Block-Wise Chunk KV Cache Dump, Load, PostProcess and Recompute
3131
1. **🔐 Chunk Hash Encoding**: Similar as prefix hash encoder, hash all blocks in each chunk from the same hash meta beginning.
32-
2. **⚡ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully resued cache and then conduct chunk cache lookup to fetch the candidate cache for blending.
32+
2. **⚡ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully reused cache and then conduct chunk cache lookup to fetch the candidate cache for blending.
3333
3. **🎯 Delta-Rope PostProcess**: Rectify loaded chunk cache according to their position in the new request.
34-
3. **🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage
35-
4. **🚀 Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module spare the prefill tokens not only in attention stage but also in ffn, layer stage.
34+
3. **🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
35+
4. **🚀 Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module sparse the prefill tokens not only in attention stage but also in ffn, layer stage.
3636

3737
## 🚀 Quick Start
3838

@@ -49,7 +49,7 @@ python <ucm-repo>/examples/offline_inference_blend.py
4949
```
5050

5151
### Basic Usage
52-
Similr to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_method` to be `Blend` and specify meta config, as shown below.
52+
Similar to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_method` to be `Blend` and specify meta config, as shown below.
5353

5454
```python
5555
...

ucm/integration/vllm/blend_connector.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
@dataclass
3333
class ChunkMetaData:
3434
# [start, start + len)
35-
start_idx_in_req: int
35+
start_token_dix: int
3636
chunk_tokens_len: int
3737

38-
start_idx_in_req_blks: int
38+
start_blk_idx: int
3939
chunk_blks_len: int
4040

4141
cached_start_position: int
@@ -45,20 +45,20 @@ class ChunkMetaData:
4545
store_hits: List[bool] = field(default_factory=list)
4646

4747
@property
48-
def end_idx_in_req(self) -> int:
49-
return self.start_idx_in_req + self.chunk_tokens_len
48+
def end_token_dix(self) -> int:
49+
return self.start_token_dix + self.chunk_tokens_len
5050

5151
@property
52-
def end_idx_in_req_blks(self) -> int:
53-
return self.start_idx_in_req_blks + self.chunk_blks_len
52+
def end_blk_idx(self) -> int:
53+
return self.start_blk_idx + self.chunk_blks_len
5454

5555
@property
5656
def cached_end_position(self) -> int:
5757
return self.cached_start_position + self.chunk_tokens_len
5858

5959
@property
6060
def position_offset(self) -> int:
61-
return self.start_idx_in_req - self.cached_start_position
61+
return self.start_token_dix - self.cached_start_position
6262

6363
@property
6464
def hits_vllm_blk_ids(self) -> List[int]:
@@ -77,10 +77,10 @@ def merge_chunk(self, temp_chunk_meta: Self) -> None:
7777

7878
def update_meta_partial_pc(self, num_pc_part_blks: int, block_size: int) -> None:
7979
if num_pc_part_blks > 0:
80-
self.start_idx_in_req += num_pc_part_blks * block_size
80+
self.start_token_dix += num_pc_part_blks * block_size
8181
self.chunk_tokens_len -= num_pc_part_blks * block_size
8282

83-
self.start_idx_in_req_blks += num_pc_part_blks
83+
self.start_blk_idx += num_pc_part_blks
8484
self.chunk_blks_len -= num_pc_part_blks
8585

8686
self.chunk_blks_hash = self.chunk_blks_hash[num_pc_part_blks:]
@@ -211,9 +211,9 @@ def _process_req(self, all_token_ids: List[int]):
211211
chunk_tokens_len = chunk_blks_len * self.block_size
212212

213213
rag_chunk_meta = ChunkMetaData(
214-
start_idx_in_req=start_token_dix,
214+
start_token_dix=start_token_dix,
215215
chunk_tokens_len=chunk_tokens_len,
216-
start_idx_in_req_blks=start_blk_idx,
216+
start_blk_idx=start_blk_idx,
217217
chunk_blks_len=chunk_blks_len,
218218
chunk_blks_hash=chunk_blks_hash,
219219
cached_start_position=0,
@@ -271,7 +271,7 @@ def _get_req_chunk_hit(
271271
# for cache blend
272272
for i, chunk_meta in enumerate(req_chunks_meta):
273273
chunk_meta.store_hits = chunk_lookup_results[
274-
chunk_meta.start_idx_in_req_blks : chunk_meta.end_idx_in_req_blks
274+
chunk_meta.start_blk_idx : chunk_meta.end_blk_idx
275275
]
276276
first_chunk_meta = req_chunks_meta[0]
277277
first_chunk_meta.update_meta_partial_pc(pc_hit_blocks, self.block_size)
@@ -324,7 +324,7 @@ def _generate_blend_dispatch_meta(
324324
# just need to load, in future we may create a multi-chunk hash to dump and reuse the blended cache
325325
for chunk_meta in req_meta.chunks_meta:
326326
chunk_meta.vllm_blk_ids = vllm_block_ids[
327-
chunk_meta.start_idx_in_req_blks : chunk_meta.end_idx_in_req_blks
327+
chunk_meta.start_blk_idx : chunk_meta.end_blk_idx
328328
]
329329
load_ucm_block_ids.extend(chunk_meta.hits_chunk_blks_hash)
330330
load_vllm_block_ids.extend(chunk_meta.hits_vllm_blk_ids)

ucm/sparse/blend/blend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,8 @@ def add_request(
8585
hit_mask.extend(meta.store_hits)
8686
reqMeta = ReqMeta(
8787
req_idx=req_idx_batch,
88-
prefix_len=chunks_meta[0].start_idx_in_req,
89-
prefix_blk_len=get_num_blks(
90-
chunks_meta[0].start_idx_in_req, block_size
91-
),
88+
prefix_len=chunks_meta[0].start_token_dix,
89+
prefix_blk_len=get_num_blks(chunks_meta[0].start_token_dix, block_size),
9290
chunks_len=len(hit_mask) * block_size,
9391
chunks_blk_len=len(hit_mask),
9492
chunk_hit_mask=hit_mask,

0 commit comments

Comments
 (0)