diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index e9189941..ee19bc82 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -16,7 +16,15 @@ ucm_connectors: load_only_first_rank: false -metrics_config_path: "/vllm-workspace/metrics_config.yaml" +# Enable UCM metrics; metrics can be viewed online via Grafana and Prometheus +# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml" + +# UCM operation recording configuration, whether to write UCM dump/load logs to a file +record_config: + enable: false + log_path: "/workspace/ucm_ops.log" + flush_size: 10 + flush_interval: 5.0 # Sparse attention configuration # Format 1: Dictionary format (for methods like ESA, KvComp) @@ -33,5 +41,4 @@ metrics_config_path: "/vllm-workspace/metrics_config.yaml" # Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) # use_layerwise: true -# hit_ratio: 0.9 - +# hit_ratio: 0.9 \ No newline at end of file diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index f4b1f4d3..271d4b4b 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -19,6 +19,7 @@ from vllm.v1.request import Request from ucm.logger import init_logger +from ucm.profiling.profiler import Profiler from ucm.shared.metrics import ucmmonitor from ucm.shared.metrics.observability import UCMStatsLogger from ucm.store.factory import UcmConnectorFactory @@ -171,7 +172,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if self.metrics_config: self.stats_logger = UCMStatsLogger( vllm_config.model_config.served_model_name, - self.rank, + self.global_rank, self.metrics_config, ) self.monitor = ucmmonitor.StatsMonitor.get_instance() @@ -181,6 +182,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): else torch.npu.synchronize ) + self.profiler = Profiler(self.launch_config, self.block_size, self.local_rank) + def generate_hash(self, block_size: int, request: "Request") -> list[str]: token_ids = request.all_token_ids @@ -507,6 +510,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: request_to_task[request_id] = self.store.load( ucm_total_block_ids, ucm_offsets, dst_tensor_addr ) + self.profiler.log_operation( + { + "op_type": "load", + "blocks": ucm_block_ids, + } + ) else: request_to_task[request_id] = None req_broadcast_addr[request_id] = dst_tensor_addr @@ -598,6 +607,12 @@ def wait_for_save(self) -> None: ucm_total_block_ids, ucm_offsets, dst_tensor_addr ) request_to_blocks[request_id] = ucm_block_ids + self.profiler.log_operation( + { + "op_type": "dump", + "blocks": ucm_block_ids, + } + ) for request_id, task in request_to_task.items(): ucm_block_ids = request_to_blocks[request_id] diff --git a/ucm/profiling/profiler.py b/ucm/profiling/profiler.py new file mode 100644 index 00000000..b6dea171 --- /dev/null +++ b/ucm/profiling/profiler.py @@ -0,0 +1,86 @@ +import json +import queue +import threading +import time +from typing import Any + +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +class Profiler: + def __init__( + self, + launch_config: dict[str, Any], + block_size: int, + rank: int, + ) -> None: + self.block_size = block_size + self.rank = rank + + self.record_config = launch_config.get("record_config", {}) + self.enable_record: bool = ( + self.record_config.get("enable", False) and self.rank == 0 + ) + if self.enable_record: + self.write_thread = threading.Thread( + target=self._async_record_loop, daemon=True + ) + self.write_thread.start() + + def log_operation(self, operation_data: dict[str, Any]) -> None: + """Record operation log (non-blocking)""" + if not self.enable_record: + return + + default_data = { + "timestamp": time.time(), + "op_type": "None", + "block_size": self.block_size, + } + log_entry = {**default_data, **operation_data} + + try: + self.log_queue.put_nowait(log_entry) + except queue.Full: + logger.error( + f"Log queue is full, dropping one log: {log_entry.get('request_id')}" + ) + + def _async_record_loop(self): + self.log_queue = queue.Queue(maxsize=10000) # Max cache: 10000 entries + log_path = self.record_config.get( + "log_path", "/vllm-workspace/ucm_logs/ucm_ops.log" + ) + flush_size = self.record_config.get("flush_size", 100) + flush_interval = self.record_config.get("flush_interval", 5.0) + batch_buffer = [] + last_flush_time = time.time() + while True: + try: + # Get log from queue (1 second timeout) + is_flush = False + current_time = time.time() + log_entry = self.log_queue.get(timeout=1.0) + batch_buffer.append(log_entry) + + # Flush if conditions are met + if ( + len(batch_buffer) >= flush_size + or (current_time - last_flush_time) >= flush_interval + ): + is_flush = True + last_flush_time = current_time + self.log_queue.task_done() + except queue.Empty: + if (current_time - last_flush_time) >= flush_interval: + last_flush_time = current_time + except Exception as e: + logger.error(f"Log thread exception: {str(e)}") + + if is_flush: + with open(log_path, "a", encoding="utf-8") as f: + for log_entry in batch_buffer: + f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") + batch_buffer.clear()