Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/ucm_config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ ucm_connectors:

load_only_first_rank: false

metrics_config_path: "/vllm-workspace/metrics_config.yaml"
# Enable UCM metrics; metrics can be viewed online via Grafana and Prometheus
# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"

# UCM operation recording configuration, whether to write UCM dump/load logs to a file
record_config:
enable: false
log_path: "/workspace/ucm_ops.log"
flush_size: 10
flush_interval: 5.0

# Sparse attention configuration
# Format 1: Dictionary format (for methods like ESA, KvComp)
Expand All @@ -33,5 +41,4 @@ metrics_config_path: "/vllm-workspace/metrics_config.yaml"

# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1)
# use_layerwise: true
# hit_ratio: 0.9

# hit_ratio: 0.9
17 changes: 16 additions & 1 deletion ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.v1.request import Request

from ucm.logger import init_logger
from ucm.profiling.profiler import Profiler
from ucm.shared.metrics import ucmmonitor
from ucm.shared.metrics.observability import UCMStatsLogger
from ucm.store.factory import UcmConnectorFactory
Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
if self.metrics_config:
self.stats_logger = UCMStatsLogger(
vllm_config.model_config.served_model_name,
self.rank,
self.global_rank,
self.metrics_config,
)
self.monitor = ucmmonitor.StatsMonitor.get_instance()
Expand All @@ -181,6 +182,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
else torch.npu.synchronize
)

self.profiler = Profiler(self.launch_config, self.block_size, self.local_rank)

def generate_hash(self, block_size: int, request: "Request") -> list[str]:
token_ids = request.all_token_ids

Expand Down Expand Up @@ -507,6 +510,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
request_to_task[request_id] = self.store.load(
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
)
self.profiler.log_operation(
{
"op_type": "load",
"blocks": ucm_block_ids,
}
)
else:
request_to_task[request_id] = None
req_broadcast_addr[request_id] = dst_tensor_addr
Expand Down Expand Up @@ -598,6 +607,12 @@ def wait_for_save(self) -> None:
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
)
request_to_blocks[request_id] = ucm_block_ids
self.profiler.log_operation(
{
"op_type": "dump",
"blocks": ucm_block_ids,
}
)

for request_id, task in request_to_task.items():
ucm_block_ids = request_to_blocks[request_id]
Expand Down
86 changes: 86 additions & 0 deletions ucm/profiling/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import queue
import threading
import time
from typing import Any

from ucm.logger import init_logger

logger = init_logger(__name__)


class Profiler:
def __init__(
self,
launch_config: dict[str, Any],
block_size: int,
rank: int,
) -> None:
self.block_size = block_size
self.rank = rank

self.record_config = launch_config.get("record_config", {})
self.enable_record: bool = (
self.record_config.get("enable", False) and self.rank == 0
)
if self.enable_record:
self.write_thread = threading.Thread(
target=self._async_record_loop, daemon=True
)
self.write_thread.start()

def log_operation(self, operation_data: dict[str, Any]) -> None:
"""Record operation log (non-blocking)"""
if not self.enable_record:
return

default_data = {
"timestamp": time.time(),
"op_type": "None",
"block_size": self.block_size,
}
log_entry = {**default_data, **operation_data}

try:
self.log_queue.put_nowait(log_entry)
except queue.Full:
logger.error(
f"Log queue is full, dropping one log: {log_entry.get('request_id')}"
)

def _async_record_loop(self):
self.log_queue = queue.Queue(maxsize=10000) # Max cache: 10000 entries
log_path = self.record_config.get(
"log_path", "/vllm-workspace/ucm_logs/ucm_ops.log"
)
flush_size = self.record_config.get("flush_size", 100)
flush_interval = self.record_config.get("flush_interval", 5.0)
batch_buffer = []
last_flush_time = time.time()
while True:
try:
# Get log from queue (1 second timeout)
is_flush = False
current_time = time.time()
log_entry = self.log_queue.get(timeout=1.0)
batch_buffer.append(log_entry)

# Flush if conditions are met
if (
len(batch_buffer) >= flush_size
or (current_time - last_flush_time) >= flush_interval
):
is_flush = True
last_flush_time = current_time
self.log_queue.task_done()
except queue.Empty:
if (current_time - last_flush_time) >= flush_interval:
last_flush_time = current_time
except Exception as e:
logger.error(f"Log thread exception: {str(e)}")

if is_flush:
with open(log_path, "a", encoding="utf-8") as f:
for log_entry in batch_buffer:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
batch_buffer.clear()