Skip to content

Commit fd32928

Browse files
wuhuxiaowuhuxiao
andauthored
[Feature] Cache Blend (#467)
* blend ready * clean code * update blend doc * clean code * midify prompt --------- Co-authored-by: wuhuxiao <whu.whx@gmail.com>
1 parent 9a0b3ba commit fd32928

File tree

16 files changed

+1950
-57
lines changed

16 files changed

+1950
-57
lines changed
214 KB
Loading
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# CacheBlend: : Fast Large Language Model Serving for RAG with Cached Knowledge Fusion
2+
<div align="center">
3+
4+
![blend_scheme.jpg](../../_static/images/blend_scheme.jpg)
5+
6+
**🚀 Knowledge Cached Fusion Algorithm | 📄 EuroSys 2025 Paper **
7+
8+
[![License](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/ModelEngine-Group/unified-cache-management/blob/main/LICENSE)
9+
[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org)
10+
11+
</div>
12+
13+
## 🌟 What is CacheBlend?
14+
15+
**CacheBlend** is a cached fusion system that combines multiple pre-computed KV caches, when their corresponding texts
16+
are concatenated in the LLM input. By selectively recomputing the KV cache values of a small fraction of tokens,
17+
CacheBlend reduces TTFT by 2.2 ~ 3.3× and increases throughput by 2.8 ~ 5× under negligible quality drop.
18+
### 🎯 Key Component
19+
20+
- **🔍 Loading Controller**: the Loading Controller orchestrates which KV caches to load, where from, and how much recomputation is needed.
21+
- **⚡ KV Cache Store**: the KV Cache Store manages persistent storage, lookup, and eviction of precomputed KV caches keyed by text-chunk identity.
22+
- **🎛️ Cache Fusor**: the Fusor merges multiple chunk-level caches into one coherent, cross-attention–correct KV cache, using minimal recomputation.
23+
24+
### 🔥 Key Results
25+
- **2.2 ~ 3.3× speedup** of TTFT and **2.8 ~ 5× increase** of throughput for long sequences
26+
- **Preserve High quality** no more than (1% ~ 3%) quality drop compared to full KV recompute
27+
28+
## 🧠 Ucm Implementation
29+
30+
### Native Block-Wise Chunk KV Cache Dump, Load, PostProcess and Recompute
31+
1. **🔐 Chunk Hash Encoding**: Similar as prefix hash encoder, hash all blocks in each chunk from the same hash meta beginning.
32+
2. **⚡ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully reused cache and then conduct chunk cache lookup to fetch the candidate cache for blending.
33+
3. **🎯 Delta-Rope PostProcess**: Rectify loaded chunk cache according to their position in the new request.
34+
3. **🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
35+
4. **🚀 Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module sparse the prefill tokens not only in attention stage but also in ffn, layer stage.
36+
37+
## 🚀 Quick Start
38+
39+
### Installation
40+
41+
Blend is part of the UCM Sparse Attention module. For installation instructions, please refer to the [UCM's top-level README](https://github.com/ModelEngine-Group/unified-cache-management). Once UCM is installed, Blend is naturally supported by running the following example python scripts.
42+
43+
```bash
44+
export ENABLE_SPARSE=TRUE
45+
export DATA_DIR=/home/data/kv_cache
46+
export MODEL_PATH=/home/models/mistralai/Mistral-7B-Instruct-v0.2
47+
export BLEND_DATASET_PATH=/home/datasets/LongBench/data/2wikimqa.jsonl
48+
python <ucm-repo>/examples/offline_inference_blend.py
49+
```
50+
51+
### Basic Usage
52+
Similar to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_method` to be `Blend` and specify meta config, as shown below.
53+
54+
```python
55+
...
56+
ktc = KVTransferConfig(
57+
kv_connector=name,
58+
kv_connector_module_path=module_path,
59+
kv_role="kv_both",
60+
kv_connector_extra_config={
61+
"ucm_connectors": [
62+
{
63+
"ucm_connector_name": "UcmNfsStore",
64+
"ucm_connector_config": {
65+
"storage_backends": data_dir,
66+
"kv_block_size": 33554432,
67+
},
68+
}
69+
],
70+
"load_only_first_rank": False,
71+
"ucm_sparse_config": {
72+
"Blend": {
73+
"chunk_end_token_id": chunk_end_token_id,
74+
"compute_meta": {
75+
"model.layers.1.self_attn.attn": {
76+
"ratio": 0.2,
77+
},
78+
},
79+
}
80+
},
81+
"use_layerwise": True,
82+
},
83+
)
84+
...
85+
```
86+
87+
## 📊 Supported Models
88+
Llama-based models and Qwen-based models now are available
89+
90+
## 🎓 Citation
91+
92+
```bibtex
93+
@inproceedings{yao2025cacheblend,
94+
title={CacheBlend: Fast large language model serving for RAG with cached knowledge fusion},
95+
author={Yao, Jiayi and Li, Hanchen and Liu, Yuhan and Ray, Siddhant and Cheng, Yihua and Zhang, Qizheng and Du, Kuntai and Lu, Shan and Jiang, Junchen},
96+
booktitle={Proceedings of the Twentieth European Conference on Computer Systems},
97+
pages={94--109},
98+
year={2025}
99+
}
100+
```
101+
102+
103+
---
104+
105+
<div align="center">
106+
107+
**🌟 Star [UCM](https://github.com/ModelEngine-Group/unified-cache-management) repository if you find KvComp useful!**
108+
109+
</div>
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
import contextlib
2+
import csv
3+
import json
4+
import os
5+
import random
6+
import re
7+
import time
8+
from dataclasses import asdict
9+
10+
from tqdm import tqdm
11+
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
12+
13+
random.seed(0)
14+
15+
import sys
16+
17+
from transformers import AutoTokenizer
18+
from vllm import LLM, SamplingParams
19+
from vllm.config import KVTransferConfig
20+
from vllm.engine.arg_utils import EngineArgs
21+
from vllm.inputs import TokensPrompt
22+
23+
from ucm.logger import init_logger
24+
25+
logger = init_logger(__name__)
26+
27+
model = ""
28+
data_dir = ""
29+
path_to_dataset = ""
30+
tokenizer = None
31+
# 28705 is the token id for <space> char in llama model
32+
# 151643 is the pad token id in qwen model
33+
chunk_end_token_id = -1
34+
chunk_pad_token_id = -1
35+
block_size = 64
36+
37+
38+
def setup_environment_variables():
39+
os.environ["VLLM_USE_V1"] = "1"
40+
os.environ["PYTHONHASHSEED"] = "123456"
41+
42+
global model, data_dir, path_to_dataset, tokenizer, chunk_end_token_id, chunk_pad_token_id
43+
model = os.getenv("MODEL_PATH", "/home/models/mistralai/Mistral-7B-Instruct-v0.2")
44+
if not os.path.isdir(model):
45+
model = input(
46+
"Enter path to model, e.g./home/models/mistralai/Mistral-7B-Instruct-v0.2: "
47+
)
48+
if not os.path.isdir(model):
49+
print("Exiting. Incorrect model_path")
50+
sys.exit(1)
51+
52+
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
53+
if not os.path.isdir(data_dir):
54+
data_dir = input(
55+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
56+
)
57+
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
58+
if create.lower() == "y":
59+
os.makedirs(data_dir, exist_ok=True)
60+
else:
61+
print("Exiting. Directory not created.")
62+
sys.exit(1)
63+
64+
# now support wikimqa
65+
path_to_dataset = os.getenv(
66+
"BLEND_DATASET_PATH", "/home/data/Longbench/data/2wikimqa.jsonl"
67+
)
68+
if not os.path.isfile(path_to_dataset):
69+
path_to_dataset = input(
70+
"Enter path of one of 2wikimqa dataset in longbench, e.g. /home/data/Longbench/data/2wikimqa.jsonl: "
71+
)
72+
if not os.path.isfile(path_to_dataset):
73+
print("Exiting. Incorrect dataset path")
74+
sys.exit(1)
75+
76+
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
77+
# as for Qwen model, use pad_token_id for padding block
78+
# as for Llama model, current use unk_token for padding block
79+
chunk_pad_token_id = tokenizer.encode("▁", add_special_tokens=False)[0]
80+
chunk_end_token_id = chunk_pad_token_id
81+
82+
if tokenizer.pad_token_id is not None:
83+
chunk_pad_token_id = tokenizer.pad_token_id
84+
chunk_end_token_id = tokenizer.pad_token_id
85+
86+
87+
@contextlib.contextmanager
88+
def build_llm_with_uc(module_path: str, name: str, model: str):
89+
ktc = KVTransferConfig(
90+
kv_connector=name,
91+
kv_connector_module_path=module_path,
92+
kv_role="kv_both",
93+
kv_connector_extra_config={
94+
"ucm_connectors": [
95+
{
96+
"ucm_connector_name": "UcmNfsStore",
97+
"ucm_connector_config": {
98+
"storage_backends": data_dir,
99+
"kv_block_size": 33554432,
100+
},
101+
}
102+
],
103+
"load_only_first_rank": False,
104+
"ucm_sparse_config": {
105+
"Blend": {
106+
"chunk_end_token_id": chunk_end_token_id,
107+
"compute_meta": {
108+
"model.layers.1.self_attn.attn": {
109+
"ratio": 0.2,
110+
},
111+
},
112+
}
113+
},
114+
"use_layerwise": True,
115+
},
116+
)
117+
118+
llm_args = EngineArgs(
119+
model=model,
120+
enforce_eager=True,
121+
kv_transfer_config=ktc,
122+
max_model_len=16384 * 2,
123+
max_num_batched_tokens=16384 * 2,
124+
gpu_memory_utilization=0.8,
125+
block_size=block_size,
126+
enable_prefix_caching=False,
127+
distributed_executor_backend="mp",
128+
tensor_parallel_size=1,
129+
trust_remote_code=True,
130+
)
131+
132+
llm = LLM(**asdict(llm_args))
133+
try:
134+
yield llm
135+
finally:
136+
logger.info("LLM engine is exiting.")
137+
138+
139+
def get_output(
140+
llm: LLM,
141+
prompt,
142+
sampling_params: SamplingParams,
143+
):
144+
start = time.time()
145+
outputs = llm.generate(prompt, sampling_params)
146+
print("-" * 50)
147+
generated_text = None
148+
for output in outputs:
149+
generated_text = output.outputs[0].text
150+
e2e_time = time.time() - start
151+
print("-" * 50)
152+
return e2e_time, generated_text
153+
154+
155+
def pad_rag_chunks(token_ids, block_size, pad_id, end_id):
156+
"""
157+
pad token_ids with pad_id and end up with end_id
158+
"""
159+
# assert pad_id != end_id
160+
remainder = len(token_ids) % block_size
161+
162+
if remainder == 0 and token_ids[-1] in [pad_id, end_id]:
163+
# no need to pad
164+
token_ids[-1] = end_id
165+
return token_ids
166+
167+
pad_len = block_size - remainder - 1
168+
padded = token_ids + [pad_id] * pad_len + [end_id]
169+
return padded
170+
171+
172+
systemPrompt = "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n"
173+
174+
175+
def main():
176+
module_path = "ucm.integration.vllm.blend_connector"
177+
name = "UCMBlendConnector"
178+
179+
setup_environment_variables()
180+
181+
with build_llm_with_uc(module_path, name, model) as llm:
182+
prefill_sampling_params = SamplingParams(
183+
temperature=0.0, top_p=0.95, max_tokens=1
184+
)
185+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=128)
186+
# choose one data row in LongBenchV1 (wikimqa)
187+
assert os.path.isfile(
188+
path_to_dataset
189+
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
190+
with open(path_to_dataset, "r") as f:
191+
lines = f.readlines()
192+
dataset_row = json.loads(lines[0])
193+
194+
passages = re.findall(
195+
r"Passage\s+(\d+):(.*?)(?=Passage\s+\d+:|$)", dataset_row["context"], re.S
196+
)
197+
chunks = [f"Passage {i}:{passages[i][1]}" for i in range(len(passages))]
198+
question = f"\n\nAnswer the question based on the given passages. Answer the question within 5 words. Do NOT repeat the question or output any other words. Question: {dataset_row["input"]}\nAnswer:"
199+
origin_sys_prompt_ids = tokenizer.encode(systemPrompt)
200+
padded_sys_prompt_ids = pad_rag_chunks(
201+
origin_sys_prompt_ids, block_size, chunk_pad_token_id, chunk_end_token_id
202+
)
203+
# 1. sys prompt warm up
204+
print(f"---------------1. sys prompt: warm up---------------")
205+
get_output(
206+
llm,
207+
TokensPrompt(prompt_token_ids=padded_sys_prompt_ids),
208+
prefill_sampling_params,
209+
)
210+
time.sleep(0.5)
211+
212+
padded_contexts_ids = []
213+
padded_prompt_ids = padded_sys_prompt_ids
214+
origin_prompt_ids = origin_sys_prompt_ids
215+
for text_chunk in chunks:
216+
un_pad_ids = tokenizer.encode(text_chunk, add_special_tokens=False)
217+
padded_ids = pad_rag_chunks(
218+
un_pad_ids, block_size, chunk_pad_token_id, chunk_end_token_id
219+
)
220+
padded_prompt_ids = padded_prompt_ids + padded_ids
221+
origin_prompt_ids = origin_prompt_ids + un_pad_ids
222+
padded_contexts_ids.append(padded_ids)
223+
224+
question_ids = tokenizer.encode(question, add_special_tokens=False)
225+
padded_prompt_ids = padded_prompt_ids + question_ids
226+
origin_prompt_ids = origin_prompt_ids + question_ids
227+
228+
print(f"--------------- baseline with no cache blend ---------------")
229+
baseline_time, baseline_gen_text = get_output(
230+
llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params
231+
)
232+
time.sleep(0.5)
233+
234+
print(f"--------------- cache rag chunks ---------------")
235+
llm.generate(
236+
[TokensPrompt(prompt_token_ids=ids) for ids in padded_contexts_ids],
237+
sampling_params,
238+
)
239+
time.sleep(0.5)
240+
241+
print(f"--------------- warm up blend code ---------------")
242+
warm_up_blend_prompt_ids = padded_sys_prompt_ids
243+
for ids in reversed(padded_contexts_ids):
244+
warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + ids
245+
warm_up_blend_prompt_ids = warm_up_blend_prompt_ids + question_ids
246+
llm.generate(
247+
TokensPrompt(prompt_token_ids=warm_up_blend_prompt_ids), sampling_params
248+
)
249+
time.sleep(0.5)
250+
251+
print(f"--------------- cache blend ---------------")
252+
blend_time, blend_gen_text = get_output(
253+
llm, TokensPrompt(prompt_token_ids=padded_prompt_ids), sampling_params
254+
)
255+
time.sleep(0.5)
256+
257+
print(f"--------------- prefix cache ---------------")
258+
pc_time, pc_gen_text = get_output(
259+
llm, TokensPrompt(prompt_token_ids=origin_prompt_ids), sampling_params
260+
)
261+
262+
print(f"Baseline generated text: {baseline_gen_text!r}")
263+
print(f"Baseline generated cost time: {baseline_time:.2f} seconds")
264+
print(f"Blend generated text: {blend_gen_text!r}")
265+
print(f"Blend generated cost time: {blend_time:.2f} seconds")
266+
print(f"Prefix Cache generated text: {pc_gen_text!r}")
267+
print(f"Prefix Cache generated cost time: {pc_time:.2f} seconds")
268+
print(f"Question:{dataset_row['input']}")
269+
print(f"Golden answer:{dataset_row["answers"]}")
270+
271+
272+
if __name__ == "__main__":
273+
main()

0 commit comments

Comments
 (0)