diff --git a/test/common/uc_eval/task.py b/test/common/uc_eval/task.py new file mode 100644 index 00000000..eae5e3de --- /dev/null +++ b/test/common/uc_eval/task.py @@ -0,0 +1,409 @@ +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + +from common.uc_eval.utils.config_loader import ConfigLoader, TaskFactory +from common.uc_eval.utils.data_class import ( + BenchmarkModeType, + EvalConfig, + LatencyStatistics, + ModelConfig, + MultiTurnDialogRecord, + PerfConfig, + RequestRecord, + SynthericParams, +) +from common.uc_eval.utils.utils import FileUtil, PathUtil, get_current_time, get_logger + +MS_SCALE = 1000 +BAD_COMPLETION_TOKENS_THR = 20 +logger = get_logger() +PERF_CSV_HEADER = [ + "Test Time", + "Total Cases", + "Parallel Num", + "Prefix Cache", + "Total Latency(ms)", + "E2E TPS(tokens/s)", + "Per Request TPS(tokens/s)", + "TTFT P50(ms)", + "TTFT P90(ms)", + "TTFT P99(ms)", + "MAX TTFT(ms)", + "Average TTFT(ms)", + "TBT P50(ms)", + "TBT P90(ms)", + "TBT P99(ms)", + "TBT MAX(ms)", + "TBT Average(ms)", +] + +SYNC_PERF_CSV_HEADER = [ + "Test Time", + "Parallel Num", + "Input Tokens", + "Output Tokens", + "Parallel Num", + "Prefix Cache", + "Hit Rate", + "Total Latency(ms)", + "E2E TPS(tokens/s)", + "Per Request TPS(tokens/s)", + "TTFT P50(ms)", + "TTFT P90(ms)", + "TTFT P99(ms)", + "MAX TTFT(ms)", + "Average TTFT(ms)", + "TBT P50(ms)", + "TBT P90(ms)", + "TBT P99(ms)", + "TBT MAX(ms)", + "TBT Average(ms)", +] + +CASE_PERF_CSV_HEADER = [ + "Test Time", + "Prefix Cache", + "Total Cases", + "Current Case", + "Case Name", + "Input Tokens", + "Output Tokens", + "Latency(ms)", + "TTFT(ms)", + "TBT(ms)", +] + +CASE_EVAL_CSV_HEADER = [ + "Test Time", + "Prefix Cache", + "Total Cases", + "Current Case", + "Case Name", + "Input Tokens", + "Output Tokens", + "Input Text", + "Question", + "Expected Output", + "Real Output", + "Is Match", + "Match Class", +] + + +class BaseTask(ABC): + def __init__( + self, + model_config: ModelConfig, + perf_config: PerfConfig = None, + eval_config: EvalConfig = None, + save_to_excel: bool = True, + file_save_path: str = None, + ): + ConfigLoader(model_config, perf_config, eval_config) + self.current_time = get_current_time() + self.model_config = model_config + self.perf_config = perf_config + self.eval_config = eval_config + common_config = perf_config if perf_config else eval_config + self.data_type = common_config.data_type + self.parallel_num = common_config.parallel_num + self.enable_prefix_cache = common_config.enable_prefix_cache + self.benchmark_mode = common_config.benchmark_mode + self.save_to_excel = save_to_excel + self.file_save_path = PathUtil.get_datasets_dir_path(file_save_path).joinpath( + self.benchmark_mode, f"{self.data_type}_latency.xlsx" + ) + + self.dataset, self.client, self.benchmark = TaskFactory.create_task( + model_config, perf_config, eval_config + ) + + def run(self): + logger.info("-----------------------------------------------------------") + logger.info( + f"Begin test, the data type: {self.data_type}, the benchmark mode: {self.benchmark_mode}" + ) + latency_results, case_len = self.process() + result_to_pytest = self.pytest_result(latency_results, case_len) + return result_to_pytest + + @abstractmethod + def process(self) -> Any: + raise NotImplementedError + + def pytest_result( + self, records: Union[LatencyStatistics, List[Dict]], case_len: int + ): + if isinstance(records, list): + # If records is a list, it indicates the result of SyntheticPerfTask which has been processed before + return records + + data_dict = self.update_single_record(records, case_len) + data = list(data_dict.values()) + if self.perf_config and self.save_to_excel: + logger.info( + f"Begin save latency data to excel, file name: {self.file_save_path}" + ) + FileUtil.save_excel( + self.file_save_path, [data], PERF_CSV_HEADER, "Overall Performance" + ) + return data_dict + + def update_single_record(self, record: LatencyStatistics, case_len: int): + logger.info(f"There are {case_len} cases to save to the database.") + data_dict = { + "current_time": self.current_time, + "total_case_num": case_len, + "parallel_num": self.parallel_num, + "enable_prefix_cache": self.enable_prefix_cache, + } + record_dict = record.to_dict() + metric_key = list(record_dict.keys())[-1] + latency_key = list(record_dict.keys())[:-1] + if self.perf_config: + data_dict.update({k: record_dict[k] for k in latency_key}) + else: + data_dict.update({metric_key: record_dict[metric_key]}) + + return data_dict + + def save_perf_cases_excel( + self, records: List[RequestRecord | MultiTurnDialogRecord] + ): + save_data = [] + common_columns = [self.current_time, self.enable_prefix_cache] + for idx, record in enumerate(records): + if isinstance(record, MultiTurnDialogRecord): + columns = common_columns + [record.total_turns, record.turn_id] + elif isinstance(record, RequestRecord): + columns = common_columns + [len(records), idx] + columns += [ + record.case_name, + record.input_tokens, + record.output_tokens, + round(record.req_cost * MS_SCALE, 3), + round(record.prefill_latency * MS_SCALE, 3), + round(record.tbt_latency * MS_SCALE, 3), + ] + save_data.append(columns) + FileUtil.save_excel( + self.file_save_path, + save_data, + CASE_PERF_CSV_HEADER, + "Single Case Performance", + ) + + def save_eval_cases_excel( + self, records: List[RequestRecord | MultiTurnDialogRecord], match_cls: str + ): + save_data = [] + common_columns = [self.current_time, self.enable_prefix_cache] + for idx, record in enumerate(records): + if isinstance(record, MultiTurnDialogRecord): + columns = common_columns + [record.total_turns, record.turn_id] + elif isinstance(record, RequestRecord): + columns = common_columns + [len(records), idx] + columns += [ + record.case_name, + record.input_tokens, + record.output_tokens, + record.input_data, + record.question, + record.expected_output, + record.output_data, + record.is_match, + match_cls, + ] + save_data.append(columns) + FileUtil.save_excel( + self.file_save_path, + save_data, + CASE_EVAL_CSV_HEADER, + "Single Case Evaluation", + ) + + +class SyntheticPerfTask(BaseTask): + def __init__( + self, + model_config: ModelConfig, + perf_config: PerfConfig, + file_save_path: str, + stable_rate: int = 5, + ): + super().__init__( + model_config=model_config, + perf_config=perf_config, + file_save_path=file_save_path, + ) + self.enable_clear_hbm = model_config.enable_clear_hbm + self.prompt_tokens = perf_config.prompt_tokens + self.output_tokens = perf_config.output_tokens + self.prefix_cache_num = perf_config.prefix_cache_num + self.prompt_seed = 0 if self.enable_prefix_cache else -1 + self.stable_perf = self.benchmark_mode == BenchmarkModeType.STABLE_PREF + self.stable_rate = stable_rate + + def process(self): + result = [] + for parallel_num in self.parallel_num: + for idx in range(len(self.prompt_tokens)): + syntheric_params = SynthericParams() + syntheric_params.parallel_num = parallel_num + if self.stable_perf: + syntheric_params.parallel_num *= self.stable_rate + if self.enable_prefix_cache: + syntheric_params.seeds = [ + self.prompt_seed + i + for i in range(syntheric_params.parallel_num) + ] + self.prompt_seed += syntheric_params.parallel_num + else: + syntheric_params.seeds = [ + self.prompt_seed + ] * syntheric_params.parallel_num + syntheric_params.prompt_tokens = self.prompt_tokens[idx] + syntheric_params.prefix_cache_tokens = ( + int(self.prefix_cache_num[idx] * syntheric_params.prompt_tokens) + if self.enable_prefix_cache + else 0 + ) + logger.info( + f"Performance benchmark running with: enable prefix cache: ({self.enable_prefix_cache}), {syntheric_params=}" + ) + if self.enable_prefix_cache and self.prefix_cache_num[idx] > 0: + logger.info(f"Begin build kvcache...") + input_data = self.dataset.prepare_data(syntheric_params) + self.client.handle_requests_with_pool( + input_data, parallel_num, BAD_COMPLETION_TOKENS_THR + ) + logger.info( + "To ensure thal all kvcache is offload2ssd, sleep for 10 seconds" + ) + time.sleep(10) + + if self.enable_clear_hbm: + self.client.clear_hbm() + + logger.info(f"Begin post cases...") + input_data = self.dataset.prepare_data(syntheric_params) + records: List[RequestRecord] = self.client.handle_requests_with_pool( + input_data, parallel_num, self.output_tokens[idx] + ) + latency_statistics: LatencyStatistics = self.benchmark.perf_show( + records, parallel_num + ) + # Make sure to store the data after each test is completed, to prevent data loss after a request fails + data_dict = { + "current_time": self.current_time, + "total_case_num": syntheric_params.parallel_num, + "input_tokens": self.prompt_tokens[idx], + "output_tokens": self.output_tokens[idx], + "parallel_num": parallel_num, + "enable_prefix_cache": self.enable_prefix_cache, + "prefix_cache_num": ( + self.prefix_cache_num[idx] if self.enable_prefix_cache else 0 + ), + } + latency_dict = latency_statistics.to_dict() + data_dict.update(dict(list(latency_dict.items())[:-1])) + data = data_dict.values() + if self.save_to_excel: + logger.info( + f"Begin save latency data to excel, file name: {self.file_save_path}" + ) + FileUtil.save_excel( + self.file_save_path, + [data], + SYNC_PERF_CSV_HEADER, + "Overall Performance", + ) + + result.append(data_dict) + + return result, len(result) + + +class MultiTurnDialogPerfTask(BaseTask): + def __init__( + self, model_config: ModelConfig, perf_config: PerfConfig, file_save_path: str + ): + super().__init__( + model_config=model_config, + perf_config=perf_config, + file_save_path=file_save_path, + ) + self.dataset_file_path = perf_config.dataset_file_path + + def process(self): + cases = self.dataset.prepare_data(self.dataset_file_path) + records: List[List[MultiTurnDialogRecord]] = ( + self.client.handle_requests_with_pool(cases, self.parallel_num) + ) + for record in records: + self.save_perf_cases_excel(record) + all_records = [r for record in records for r in record] + latency_statistics = self.benchmark.perf_show(all_records, self.parallel_num) + return latency_statistics, len(records) + + +class DocQaPerfTask(BaseTask): + def __init__( + self, model_config: ModelConfig, perf_config: PerfConfig, file_save_path: str + ): + super().__init__( + model_config=model_config, + perf_config=perf_config, + file_save_path=file_save_path, + ) + self.dataset_file_path = perf_config.dataset_file_path + self.max_tokens = model_config.payload.get("max_tokens") + + def process(self): + cases_list = self.dataset.prepare_data(self.dataset_file_path) + if self.enable_prefix_cache: + logger.info("Begin build kvcache...") + self.client.handle_requests_with_pool( + cases_list, self.parallel_num, BAD_COMPLETION_TOKENS_THR + ) + + logger.info("Begin post cases...") + records: List[RequestRecord] = self.client.handle_requests_with_pool( + cases_list, self.parallel_num, self.max_tokens + ) + self.save_perf_cases_excel(records) + latency_statistics = self.benchmark.perf_show(records, self.parallel_num) + return latency_statistics, len(records) + + +class DocQaEvalTask(BaseTask): + def __init__( + self, model_config: ModelConfig, eval_config: EvalConfig, file_save_path: str + ): + super().__init__( + model_config=model_config, + eval_config=eval_config, + file_save_path=file_save_path, + ) + self.dataset_file_path = eval_config.dataset_file_path + self.max_tokens = model_config.payload.get("max_tokens") + self.eval_cls = eval_config.eval_class + + def process(self): + cases_list = self.dataset.prepare_data(self.dataset_file_path) + if self.enable_prefix_cache: + logger.info("Begin build kvcache...") + self.client.handle_requests_with_pool( + cases_list, self.parallel_num, BAD_COMPLETION_TOKENS_THR + ) + + logger.info("Begin post cases...") + records: List[RequestRecord] = self.client.handle_requests_with_pool( + cases_list, self.parallel_num, self.max_tokens + ) + metric_result, match_record_list = self.benchmark.perf_show( + records, self.parallel_num + ) + self.save_eval_cases_excel(match_record_list, self.eval_cls) + return metric_result, len(records) diff --git a/test/common/uc_eval/utils/benchmark.py b/test/common/uc_eval/utils/benchmark.py new file mode 100644 index 00000000..996829df --- /dev/null +++ b/test/common/uc_eval/utils/benchmark.py @@ -0,0 +1,276 @@ +import functools +import importlib +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import numpy as np +from common.uc_eval.utils.data_class import ( + EvalConfig, + LatencyStatistics, + MultiTurnDialogRecord, + RequestRecord, +) +from common.uc_eval.utils.utils import get_logger +from tqdm import tqdm + +logger = get_logger() +MS_SCALE = 1000 +# the max wave rate for stable perf +MAX_WAVE_RATE = 0.05 + + +def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any: + """create object based on class name""" + modname, qualname_separator, qualname = object_ref.partition(":") + obj = importlib.import_module(modname) + if qualname_separator: + for attr in qualname.split("."): + obj = getattr(obj, attr) + return functools.partial(obj, *args, **kwargs) + + +class BenchmarkBase(ABC): + def __init__(self, eval_config: Optional[EvalConfig], stable_perf: bool = False): + self.eval_config = eval_config + self.stable_perf = stable_perf + + def get_success_request(self, data: List[RequestRecord | MultiTurnDialogRecord]): + """ + Get the successful request from the record + """ + success_request = [] + for request in data: + if request.is_success: + success_request.append(request) + if len(success_request) == 0: + logger.warning(f"No success request found, please check the result") + return success_request + + def result_to_column_dict( + self, data: List[RequestRecord | MultiTurnDialogRecord] + ) -> Dict[str, List[Any]]: + """ + format: list[dict] ---> dict[list] + """ + if not data: + return {} + keys = list(data[0].to_dict().keys()) + result = {key: [] for key in keys} + for item in data: + for key in keys: + result[key].append(item.to_dict()[key]) + return result + + @abstractmethod + def perf_show(self, records: Any, parallel_num: int = 1): + raise NotImplementedError + + +class EvaluatorBenchmark(BenchmarkBase): + def __init__(self, eval_config: EvalConfig): + super().__init__(eval_config=eval_config) + self.metric_method = eval_config.metrics + self.eval_class = eval_config.eval_class + + def perf_show( + self, + record_list: List[RequestRecord | MultiTurnDialogRecord], + parallel_num: int, + ): + logger.info(f"Begin calculate metrics...") + success_request = self.get_success_request(record_list) + eval_cls = make_object(self.eval_class)(success_request) + latency = LatencyStatistics() + metric_result = eval_cls.calculate_metric(self.metric_method) + latency.metric_dict = metric_result + match_record_list = eval_cls.record_list + + return latency, match_record_list + + +class PerformanceBenchmark(BenchmarkBase): + def __init__(self, stable_perf: bool): + super().__init__(stable_perf) + self.stable_perf = stable_perf + self.stable_work_time = [0, 0] + + def perf_show( + self, + record_list: List[RequestRecord | MultiTurnDialogRecord], + parallel_num: int, + ) -> LatencyStatistics: + logger.info(f"Begin calculate latency...") + success_request = self.get_success_request(record_list) + request_record_dict = self.result_to_column_dict(success_request) + if self.stable_perf: + request_ids = self._get_stable_request_id(request_record_dict, parallel_num) + else: + request_ids = request_record_dict.get("request_id") + records = [record for record in record_list if record.request_id in request_ids] + perf_result = self._get_performance_data(records) + return perf_result + + def _get_performance_data( + self, record_list: List[RequestRecord | MultiTurnDialogRecord] + ) -> LatencyStatistics: + """ + After all requests are completed, get the performance data + """ + if len(record_list) == 0: + logger.warning(f"there is no request_id in the record_list, please check") + latency = LatencyStatistics() + record_dict = self.result_to_column_dict(record_list) + + e2e_latency_all = ( + max(record_dict["end_time"]) - min(record_dict["start_time"]) + ) * MS_SCALE + latency.e2e_latency_all = round(e2e_latency_all, 2) + logger.debug("All request latencies: %.4f ms", e2e_latency_all) + + total_output_tokens = sum(record_dict["output_tokens"]) + output_token_throughput = total_output_tokens / e2e_latency_all * MS_SCALE + latency.output_token_throughput = round(output_token_throughput, 2) + logger.debug( + "Total output token throughput: %.4f tokens/s", output_token_throughput + ) + + throughputs = [] + for tokens, cost in zip(record_dict["output_tokens"], record_dict["req_cost"]): + if cost > 0: + throughputs.append(tokens / cost) + if throughputs: + token_throughput_per_request = np.mean(throughputs).item() + latency.token_throughput_per_request = round( + token_throughput_per_request, 2 + ) + logger.debug( + "Average per-request throughput: %.4f tokens/s", + token_throughput_per_request, + ) + else: + logger.warning("No valid requests for throughput calculation") + + prefill_latency_list = [record_dict["prefill_latency"]] + p50_prefill_latency = np.percentile(prefill_latency_list, 50).item() * MS_SCALE + latency.p50_prefill_latency = round(p50_prefill_latency, 2) + logger.debug("Time to First token latency P50: %.4f ms", p50_prefill_latency) + + p90_prefill_latency = np.percentile(prefill_latency_list, 90).item() * MS_SCALE + latency.p90_prefill_latency = round(p90_prefill_latency, 2) + logger.debug("Time to First token latency TP90: %.4f ms", p90_prefill_latency) + + p99_prefill_latency = np.percentile(prefill_latency_list, 99).item() * MS_SCALE + latency.p99_prefill_latency = round(p99_prefill_latency, 2) + logger.debug("Time to First token latency TP99: %.4f ms", p99_prefill_latency) + + max_prefill_latency = np.max(prefill_latency_list).item() * MS_SCALE + latency.max_prefill_latency = round(max_prefill_latency, 2) + logger.debug( + "Maximum time to first token latency: %.4f ms", max_prefill_latency + ) + + avg_prefill_latency = np.mean(prefill_latency_list).item() * MS_SCALE + latency.avg_prefill_latency = round(avg_prefill_latency, 2) + logger.debug( + "Average time to first token latency: %.4f ms", avg_prefill_latency + ) + + decode_latency_list = [] + for tbt_latency in record_dict["tbt_latency"]: + decode_latency_list.append(tbt_latency) + + p50_decode_latency = np.percentile(decode_latency_list, 50).item() * MS_SCALE + latency.p50_decode_latency = round(p50_decode_latency, 2) + logger.debug("Tokens Per Second latency TP50: %.4f ms", p50_decode_latency) + + p90_decode_latency = np.percentile(decode_latency_list, 90).item() * MS_SCALE + latency.p90_decode_latency = round(p90_decode_latency, 2) + logger.debug("Tokens Per Second latency TP90: %.4f ms", p90_decode_latency) + + p99_decode_latency = np.percentile(decode_latency_list, 99).item() * MS_SCALE + latency.p99_decode_latency = round(p99_decode_latency, 2) + logger.debug("Tokens Per Second latency TP99: %.4f ms", p99_decode_latency) + + max_decode_latency = np.max(decode_latency_list).item() * MS_SCALE + latency.max_decode_latency = round(max_decode_latency, 2) + logger.debug("Maximum tokens per second latency: %.4f ms", max_decode_latency) + + avg_decode_latency = np.mean(decode_latency_list).item() * MS_SCALE + latency.avg_decode_latency = round(avg_decode_latency, 2) + logger.debug("Average tokens per second latency: %.4f ms", avg_decode_latency) + + return latency + + def _get_stable_request_id( + self, result: Dict[str, List[Any]], target_concurrency: int + ): + """ + Get steady-state request ids via start_time vs. end_time delta + """ + # the number of concurrent requests at each request start and end + request_num = len(result.get("request_id", [])) + concurrent_levels = [0] * 2 * request_num + request_events = [] + for idx in range(request_num): + request_events.append( + { + "request_id": result.get("request_id", [])[idx], + "event_type": "start", + "timestamp": result.get("start_time", [])[idx], + } + ) + request_events.append( + { + "request_id": result.get("request_id", [])[idx], + "event_type": "end", + "timestamp": result.get("end_time", [])[idx], + } + ) + sorted_events = sorted(request_events, key=lambda x: x["timestamp"]) + stable_stage_requests = [] + logger.info("Start calculating stable request id") + used_request_num = 0 + for idx, item in enumerate( + tqdm(sorted_events, desc="search stable request id") + ): + if item["event_type"] == "start": + used_request_num += 1 + concurrent_levels[idx] = ( + concurrent_levels[idx - 1] + 1 if idx > 0 else 1 + ) + else: + concurrent_levels[idx] = concurrent_levels[idx - 1] - 1 + if ( + item["event_type"] == "start" + and concurrent_levels[idx] == target_concurrency + ): + stable_stage_requests.append(item["request_id"]) + if len(stable_stage_requests) == 2: + self.stable_work_time[0] = item["timestamp"] + elif ( + item["event_type"] == "start" + and concurrent_levels[idx] + >= int(target_concurrency * (1 - MAX_WAVE_RATE)) + and len(stable_stage_requests) > 2 + ): + stable_stage_requests.append(item["request_id"]) + elif used_request_num == request_num and item["event_type"] == "end": + self.stable_work_time[1] = item["timestamp"] + break + elif ( + len(stable_stage_requests) > 1 + and item["event_type"] == "end" + and concurrent_levels[idx] + < int(target_concurrency * (1 - MAX_WAVE_RATE)) + ): + self.stable_work_time[1] = item["timestamp"] + break + + if len(stable_stage_requests) > 1: + # ignore first request + stable_stage_requests.pop(0) + if len(stable_stage_requests) == 0: + logger.error("cannot find stable stage, please check your settings") + raise ValueError("cannot find stable stage, please check your settings") + logger.info(f"stable request id list: {stable_stage_requests=}") + return stable_stage_requests diff --git a/test/common/uc_eval/utils/client.py b/test/common/uc_eval/utils/client.py new file mode 100644 index 00000000..82c509e2 --- /dev/null +++ b/test/common/uc_eval/utils/client.py @@ -0,0 +1,490 @@ +import concurrent.futures +import copy +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests +from common.uc_eval.utils.data_class import ( + ModelConfig, + MultiTurnDialogRecord, + RequestRecord, +) +from common.uc_eval.utils.utils import PathUtil, get_logger +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer +from typing_extensions import override + +logger = get_logger() +TIMEOUT = 6000 +HEADERS = {"User-Agent": "Benchmark Client", "Content-Type": "application/json"} +CHUNK_SIZE = 2**16 + + +def _excute_with_pool( + task_func: callable, + process_func: callable, + tasks: List, + parallel_num: int, + desc: str = "Processing Requests", +) -> List[RequestRecord | MultiTurnDialogRecord]: + record_results: List[RequestRecord | MultiTurnDialogRecord] = [] + if parallel_num > len(tasks): + logger.error( + f"The number of requests: {len(tasks)} is less than parallel_num: {parallel_num}, please check..." + ) + raise ValueError( + f"The number of requests: {len(tasks)} is less than parallel_num: {parallel_num}, please check..." + ) + logger.info(f"Start to send {len(tasks)} requests to server...") + with ThreadPoolExecutor(max_workers=parallel_num) as executor: + futures = [executor.submit(task_func, task) for task in tasks] + + with tqdm(total=len(futures), desc=desc, mininterval=0.5) as pbar: + for future in concurrent.futures.as_completed(futures): + try: + pbar.update(1) + result = process_func(future.result()) + record_results.append(result) + pbar.set_postfix( + { + "Completed": len(record_results), + "Pending": len(futures) - pbar.n, + } + ) + except Exception as e: + pbar.update(1) + logger.error(f"Requested failed: {str(e)}") + raise Exception(f"Requested failed: {str(e)}") + return record_results + + +class BaseClient: + def __init__( + self, + config: ModelConfig, + stream: bool = False, + **kwargs, + ): + self.ip_ports = config.ip_ports + self.url = f"http://{self.ip_ports}/v1/chat/completions" + self.served_model_name = config.served_model_name + tokenizer_path = PathUtil.get_datasets_dir_path(config.tokenizer_path) + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + tokenizer_path + ) + self.session = requests.Session() + self.payload = config.payload + self.stream = stream + if self.stream: + self.payload.update( + {"stream": True, "ignore_eos": True, "temperature": 0.0} + ) + else: + self.payload.update( + {"stream": False, "ignore_eos": False, "temperature": 0.0} + ) + + def handle_requests_with_pool( + self, prompt_list: List, parallel_num: int, max_tokens: int + ) -> List[RequestRecord]: + return _excute_with_pool( + task_func=lambda prompt: self.send_request(prompt, max_tokens), + process_func=self.update_request_record, + tasks=prompt_list, + parallel_num=parallel_num, + ) + + def send_request(self, prompt, max_tokens) -> List[RequestRecord]: + """ + update payload and send request + """ + payload = self._update_payload(prompt, max_tokens) + record = self._create_record(prompt) + if self.stream: + record = self.do_stream_request(payload, record) + else: + record = self.do_request(payload, record) + return record + + def _update_payload(self, prompt, max_tokens) -> Dict: + """ + update request payload + """ + payload = copy.deepcopy(self.payload) + payload.update({"model": self.served_model_name}) + # If payload already has default max_tokens, the input max_tokens will be set to 0 + if max_tokens > 0: + payload.update({"max_tokens": max_tokens}) + if isinstance(prompt, str): + message = [{"role": "user", "content": prompt}] + if isinstance(prompt, list): + # Multi-turn conversation - prompt already contains full message history. + # No need to update messages as they are already properly formatted + message = prompt + payload.update({"messages": message}) + + return payload + + def _create_record(self, prompt): + # If the prompt is not a dict, it must be a list of dicts for multi-turn dialogue. + if isinstance(prompt, str): + record = RequestRecord(input_data=prompt) + else: + record = RequestRecord(input_data=str(prompt)) + + return record + + def update_request_record( + self, records: Union[RequestRecord, List[RequestRecord]] + ) -> Union[RequestRecord, List[RequestRecord]]: + """ + Get the number of input and output tokens for each request record + """ + if not records: + logger.warning("No records to update, please check...") + if isinstance(records, RequestRecord): + single_record = records + records = [single_record] + else: + single_record = None + + for record in records: + record.input_tokens = len(self.tokenizer.tokenize(record.input_data)) + record.output_tokens = len(self.tokenizer.tokenize(record.output_data)) + record.tbt_list = record.tbt_list[2:] if record.tbt_list else [] + record.tbt_latency = ( + sum(record.tbt_list) / len(record.tbt_list) if record.tbt_list else 0 + ) + + return records[0] if single_record is not None else records + + def _requset(self, payload): + response = None + try: + response = self.session.post( + self.url, + headers=HEADERS, + json=payload, + timeout=TIMEOUT, + stream=self.stream, + ) + response.raise_for_status() + return response + except Exception as err: + raise self._handle_request_error(err) + + def do_request(self, payload: Dict, record: RequestRecord) -> RequestRecord: + record.start_time = time.time() + + response = self._requset(payload) + result = json.loads(response.text) + request_id = result.get("id", "request_id not found") + output = self._get_message_from_response(result) + + record.request_id = request_id + record.output_data = output + record.is_success = True + record.end_time = time.time() + record.req_cost = record.end_time - record.start_time + return record + + def _get_message_from_response(self, response) -> str: + message = response.get("choices", [])[0].get("message", {}) + output = "" + if message.get("content", "") is not None: + output += message.get("content", "") + elif message.get("reasoning_content", "") is not None: + output += message.get("reasoning_content", "") + return output + + def do_stream_request(self, payload: Dict, record: RequestRecord) -> RequestRecord: + while True: + all_chunks = [] + first_token = True + last_chunk = None + timeout_finish_reason = False + cur_time = last_time = time.perf_counter() + record.start_time = last_time + response = self._requset(payload) + for chunk in response.iter_content(chunk_size=CHUNK_SIZE): + all_chunks.append(chunk) + if len(chunk.strip()) == 0: + continue + last_chunk = chunk + cur_time = time.perf_counter() + time_diff = cur_time - last_time + if first_token: + record.prefill_latency = time_diff + first_token = False + else: + record.tbt_list.append(time_diff) + last_time = cur_time + chunk_output = chunk[5:].strip().decode("utf-8") + + # when the MindIE engine side timeout, it will return timeout information + if chunk.startswith(b"Engine callback timeout"): + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="Engine callback timeout", + ) + record.output_data = "TIMEOUT" + return record + if "[DONE]" in chunk_output: + logger.debug(f"Finished chunk: {chunk_output=}") + continue + output = self._get_message_from_stream_response( + json.loads(chunk_output) + ) + if record.request_id == "": + record.request_id = json.loads(chunk_output).get( + "id", "request_id not found" + ) + record.output_data += output + + # when the uc-vllm request timeout, finish_reason == "length" and the final output is empty + finish_reason = ( + json.loads(chunk_output) + .get("choices", [])[0] + .get("finish_reason", "") + ) + if finish_reason == "length": + timeout_finish_reason = True + + # handle the last chunk + if last_chunk.startswith(b"data:"): + chunk_output = last_chunk[5:].strip().decode("utf-8") + else: + chunk_output = last_chunk.strip().strip().decode("utf-8").rstrip("\0") + # while the last chunk meets the following conditions, the request is finished successfully + if "[DONE]" in chunk_output: + break + else: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="request failed, please retry!!!", + ) + break + # while the request is done, we need to check the content to see if the request is successful + if record.output_data == "": + if timeout_finish_reason: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="vllm server scheduling timeout, please check", + ) + return record + else: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="the request returned an empty message, which may be an unknown error on the engine side. Please check the specific reason!", + ) + return record + record.is_success = True + record.end_time = time.perf_counter() + record.req_cost = record.end_time - record.start_time + logger.debug(f"{record.request_id} finished, cost: {record.req_cost:.2f}s") + return record + + def _get_message_from_stream_response(self, response) -> str: + message = response.get("choices", [])[0].get("delta", {}) + output = "" + if message.get("content", "") is not None: + output += message.get("content", "") + elif message.get("reasoning_content", "") is not None: + output += message.get("reasoning_content", "") + return output + + def clear_hbm(self) -> bool: + """ + The API is used to clear HBM. It is available only when the serving backend is VLLM. + """ + os.environ["NO_PROXY"] = "127.0.0.1, localhost, local, .local" + logger.info("Begin to clear HBM") + headers = {"Content-Type": "application/json"} + payload = {} + url = f"http://{self.ip_ports}/reset_prefix_cache" + try: + response = requests.post( + url, json=payload, headers=headers, timeout=TIMEOUT + ) + response.raise_for_status() + except Exception as err: + raise self._handle_request_error(err) + time.sleep(5) + logger.info("Clear HBM success") + return True + + def _handle_request_error(self, err: Exception) -> Exception: + """ + Used to handle request errors + """ + if isinstance(err, requests.exceptions.ConnectionError): + logger.error(f"Cannot connect to {self.url}, please check your network") + return ConnectionError(f"Cannot connect to {self.url}") + elif isinstance(err, requests.exceptions.Timeout): + logger.error("The request timed out, please check your server status") + return TimeoutError( + "The request timed out, please check your server status" + ) + elif isinstance(err, requests.exceptions.HTTPError): + status_code = err.response.status_code + if status_code == 404: + logger.error( + f"The requested resource does not exist, or the served model name is incorrect" + ) + else: + logger.error(f"HTTP error, status code: {status_code}") + return Exception(f"HTTP error, status code: {status_code}, err: {err}") + else: + logger.error(f"Other error: {err}") + return Exception(f"Other error: {err}") + + @staticmethod + def _print_request_info(**kwargs): + """print request info when the request is failed""" + for key, value in kwargs.items(): + value = ( + json.dumps(value, ensure_ascii=False) + if isinstance(value, dict) + else value + ) + logger.error(f"{key} => {value}") + + +class MultiDialogClient(BaseClient): + def __init__(self, config: ModelConfig, stream: bool, **kwargs): + super().__init__(config, stream, **kwargs) + self.uuid = uuid.uuid4().hex + self.enable_prefix_cache = kwargs.get("enable_prefix_cache", False) + + @override + def handle_requests_with_pool( + self, + cases: List[List[Union[str, Dict]]], + parallel_num: int, + max_tokens: int = -1, + ) -> List[List[MultiTurnDialogRecord]]: + return _excute_with_pool( + task_func=lambda case: self._send_multi_request(case, max_tokens), + process_func=self.update_request_record, + tasks=cases, + parallel_num=parallel_num, + ) + + def _send_multi_request( + self, case: List[Union[str, Dict]], max_tokens: int = -1 + ) -> List[MultiTurnDialogRecord]: + case_name, dialog = case + history, conv_record = [], [] + conversion = dialog["conversations"] + turns = self._convert_conversation_2_turns(conversion, 2) + for i, turn in enumerate(turns): + in_content, reply = turn[0]["content"], turn[1]["content"] + # Update payload, then send request + prompt = self._update_request_body(history, in_content) + record: RequestRecord = self.send_request(prompt, max_tokens) + record.case_name = case_name + history = self._update_history(history, in_content, reply) + multi_turn_record: MultiTurnDialogRecord = ( + self._update_multi_turn_request_record(record, len(turns), i) + ) + conv_record.append(multi_turn_record) + return conv_record + + def _update_multi_turn_request_record( + self, record: RequestRecord, total_turns: int, turn_id: int + ) -> MultiTurnDialogRecord: + """ + Update multi-tuen dialogue request record + """ + request_record = MultiTurnDialogRecord() + request_record.__dict__.update(record.__dict__) + request_record.total_turns = total_turns + request_record.turn_id = turn_id + return request_record + + @staticmethod + def _convert_conversation_2_turns(conversion_list: list, chunk_size: int): + """ + Convert conversation list to turns + """ + if chunk_size < 0: + raise ValueError(f"the chunk size {chunk_size} must be greater than 0") + num_full_chunks = len(conversion_list) // chunk_size + return [ + conversion_list[i * chunk_size : (i + 1) * chunk_size] + for i in range(num_full_chunks) + ] + + def _update_request_body(self, history: Optional[List[Dict]], in_content: str): + """ + Multi turn dialogue request body + """ + history = copy.deepcopy(history) + if history and self.enable_prefix_cache: + # To make sure the prefix cache is unique + history[0]["content"] = f"uuid: [{self.uuid}]" + history[0]["content"] + if history and not self.enable_prefix_cache: + history[0]["content"] = ( + f"uuid: [{uuid.uuid4().hex}]" + history[0]["content"] + ) + + message = history + [{"role": "user", "content": in_content}] + return message + + @staticmethod + def _update_history( + history: Optional[List[Dict]], in_content: str, out_content: str + ) -> List[Dict]: + """ + Update conversation history + """ + history.append({"role": "user", "content": in_content}) + history.append({"role": "assistant", "content": out_content}) + return history + + +class DocQaClient(BaseClient): + def __init__(self, config: ModelConfig, stream: bool, **kwargs): + super().__init__(config, stream, **kwargs) + + @override + def handle_requests_with_pool( + self, cases: List[Union[str, str, str]], parallel_num: int, max_tokens: int = -1 + ) -> List[List[MultiTurnDialogRecord]]: + return _excute_with_pool( + task_func=lambda case: self.send_qa_request(case, max_tokens), + process_func=self.update_request_record, + tasks=cases, + parallel_num=parallel_num, + ) + + def send_qa_request( + self, case: Union[str, str, str, str], max_tokens: int = -1 + ) -> RequestRecord: + case_name, context, question, answer = case + prompt = context + question + record: RequestRecord = self.send_request(prompt, max_tokens) + record.case_name = case_name + record.question = question + record.expected_output = answer + return record diff --git a/test/common/uc_eval/utils/config_loader.py b/test/common/uc_eval/utils/config_loader.py new file mode 100644 index 00000000..89ef9e47 --- /dev/null +++ b/test/common/uc_eval/utils/config_loader.py @@ -0,0 +1,201 @@ +import dataclasses +import json +from typing import Optional, Tuple + +from common.uc_eval.utils.benchmark import ( + BenchmarkBase, + EvaluatorBenchmark, + PerformanceBenchmark, +) +from common.uc_eval.utils.client import BaseClient, DocQaClient, MultiDialogClient +from common.uc_eval.utils.data_class import ( + BenchmarkModeType, + DatasetType, + EvalConfig, + ModelConfig, + PerfConfig, +) +from common.uc_eval.utils.dataloader import ( + BaseDataset, + DocQADataset, + MultiTurnDialogueDataset, + SyntheticDataset, +) +from common.uc_eval.utils.utils import get_logger + +logger = get_logger() + + +class ConfigLoader: + def __init__( + self, + model_config: ModelConfig, + perf_config: PerfConfig = None, + eval_config: EvalConfig = None, + ): + + self.model_config = model_config + self.perf_config = perf_config + self.eval_config = eval_config + self._valid_config() + + def _valid_config(self) -> bool: + logger.info("Validating config...") + if self.perf_config is not None and self.eval_config is not None: + raise ValueError( + "perf_config and eval_config are mutually exclusive, one must be None." + ) + if self.perf_config is None and self.eval_config is None: + raise ValueError( + "At least one of perf_config or eval_config must be provided." + ) + + result = self._valid_model_config() and ( + self._valid_perf_config() + if self.perf_config is not None + else self._valid_eval_config() + ) + logger.info("Complete validation...") + return result + + def _valid_model_config(self) -> bool: + payload = self.model_config.payload + if isinstance(payload, str): + try: + self.model_config.payload = json.loads(payload) + except Exception as e: + raise ValueError(f"Invalid payload JSON format: {e}") + + empty_fields = [] + field_names = [field.name for field in dataclasses.fields(ModelConfig)] + for field_name in field_names: + value = getattr(self.model_config, field_name) + if value is None or (isinstance(value, str) and not value.strip()): + empty_fields.append(field_name) + + if empty_fields: + raise ValueError( + f"The following model config fields can't be empty: {', '.join(empty_fields)}" + ) + + return True + + def _valid_perf_config(self) -> bool: + data_type = self.perf_config.data_type + benchmark_mode = self.perf_config.benchmark_mode + if benchmark_mode not in [ + BenchmarkModeType.DEFAULT_PERF, + BenchmarkModeType.STABLE_PREF, + ]: + raise ValueError( + f"Invalid benchmark mode: {benchmark_mode}. Valid modes are: {BenchmarkModeType.DEFAULT_PERF}, {BenchmarkModeType.STABLE_PREF}" + ) + prompt_fields = ["prompt_tokens", "output_tokens"] + ( + ["prefix_cache_num"] if self.perf_config.enable_prefix_cache else [] + ) + if data_type == DatasetType.SYNTHETIC: + invalid_fields = [] + for field in prompt_fields: + value = getattr(self.perf_config, field) + if not isinstance(value, list) or not value: + invalid_fields.append(field) + if invalid_fields: + raise ValueError( + f"The following dataset config fields must be non-empty list for synthetic data: {', '.join(invalid_fields)}" + ) + + length = { + field: len(getattr(self.perf_config, field)) for field in prompt_fields + } + if len(set(length.values())) > 1: + raise ValueError( + f"The following dataset config is not matched: {', '.join(length.keys())}" + ) + else: + if self.perf_config.dataset_file_path is None: + raise ValueError( + f"dataset_file_path is required for {data_type} data type" + ) + if not isinstance(self.perf_config.parallel_num, int): + raise TypeError( + f"parallel_num must be an integer for {data_type} data type" + ) + not_empty_fields = [ + field for field in prompt_fields if getattr(self.perf_config, field) + ] + if not_empty_fields: + raise ValueError( + f"The following dataset fields should be None for {data_type} data type: {not_empty_fields}" + ) + + return True + + def _valid_eval_config(self) -> bool: + data_type = self.eval_config.data_type + dataset_file_path = self.eval_config.dataset_file_path + benchmark_mode = self.eval_config.benchmark_mode + parallem_num = self.eval_config.parallel_num + eval_cls = self.eval_config.eval_class + metrics = self.eval_config.metrics + if benchmark_mode != BenchmarkModeType.EVAL: + raise ValueError( + f"Invalid benchmark mode: {benchmark_mode}. Valid modes are: {BenchmarkModeType.EVAL}" + ) + if data_type == DatasetType.SYNTHETIC or dataset_file_path is None: + raise ValueError( + f"Invalid dataset type: {data_type} or Invalid dataset file path: {dataset_file_path}" + ) + if not isinstance(parallem_num, int): + raise TypeError( + f"parallel_num must be an integer for {data_type} data type" + ) + if not metrics or not eval_cls: + raise ValueError( + f"metrics and eval_class must be provided for {data_type} data type" + ) + + return True + + +class TaskFactory: + _dataset: BaseDataset = { + DatasetType.SYNTHETIC: SyntheticDataset, + DatasetType.MULTI_DIALOGUE: MultiTurnDialogueDataset, + DatasetType.DOC_QA: DocQADataset, + } + _client: BaseClient = { + DatasetType.SYNTHETIC: BaseClient, + DatasetType.MULTI_DIALOGUE: MultiDialogClient, + DatasetType.DOC_QA: DocQaClient, + } + _benchmark: BenchmarkBase = { + BenchmarkModeType.EVAL: EvaluatorBenchmark, + BenchmarkModeType.STABLE_PREF: PerformanceBenchmark, + BenchmarkModeType.DEFAULT_PERF: PerformanceBenchmark, + } + + @classmethod + def create_task( + cls, + model_config: ModelConfig, + perf_config: Optional[PerfConfig], + eval_config: Optional[EvalConfig], + ) -> Tuple[BaseDataset, BaseClient, BenchmarkBase]: + stream = False + data_type = (perf_config or eval_config).data_type + tokenizer_path = model_config.tokenizer_path + benchmark_mode = (perf_config or eval_config).benchmark_mode + stable = benchmark_mode == BenchmarkModeType.STABLE_PREF + if benchmark_mode in [ + BenchmarkModeType.STABLE_PREF, + BenchmarkModeType.DEFAULT_PERF, + ]: + stream = True + client_kwargs = {} + if data_type is DatasetType.MULTI_DIALOGUE: + client_kwargs["enable_prefix_cache"] = perf_config.enable_prefix_cache + return ( + cls._dataset[data_type](tokenizer_path), + cls._client[data_type](model_config, stream, **client_kwargs), + cls._benchmark[benchmark_mode](stable if perf_config else eval_config), + ) diff --git a/test/common/uc_eval/utils/data_class.py b/test/common/uc_eval/utils/data_class.py new file mode 100644 index 00000000..7cdbcc47 --- /dev/null +++ b/test/common/uc_eval/utils/data_class.py @@ -0,0 +1,167 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + + +class DatasetType(str, Enum): + """ + The dataset type of uc_eval, including synthetic, multi-turn dialogue, and document-QA. + """ + + SYNTHETIC = "synthetic" + MULTI_DIALOGUE = "multi_turn_dialogue" + DOC_QA = "doc_qa" + + +class BenchmarkModeType(str, Enum): + """ + The benchmark mode of uc_eval, including evaluate, stable-perf, and default-perf. + """ + + EVAL = "evaluate" + STABLE_PREF = "stable-perf" + DEFAULT_PERF = "default-perf" + + +@dataclass +class ModelConfig: + ip_ports: str = "" + tokenizer_path: str = "" + served_model_name: str = "" + enable_clear_hbm: bool = False + payload: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class EvalConfig: + data_type: str = "" + dataset_file_path: str = "" + enable_prefix_cache: str = False + parallel_num: int = 1 + benchmark_mode: str = "evaluate" + metrics: Optional[List[str]] = field(default_factory=list) + eval_class: Optional[str] = None + + +@dataclass +class PerfConfig: + data_type: str = "" + dataset_file_path: str = "" + enable_prefix_cache: bool = False + parallel_num: int | List[int] = 1 + prompt_tokens: List[int] = field(default_factory=list) + output_tokens: List[int] = field(default_factory=list) + prefix_cache_num: List[float] = field(default_factory=list) + benchmark_mode: str = "" + + +@dataclass +class SynthericParams: + """ + The parameters for synthetic dataset + """ + + parallel_num: int = -1 + # The number of tokens for total prompts + prompt_tokens: int = -1 + # The number of tokens for prefix cache + prefix_cache_tokens: int = -1 + # List of seeds, to ensure the prefix cache is consistent between warmup and inference + seeds: list[int] = field(default_factory=list) + + def to_dict(self): + return vars(self) + + +@dataclass +class RequestRecord: + """ + The record for single request + """ + + case_name: str = "" + request_id: str = "" + input_data: Optional[str] = "" + input_tokens: int = 0 + # The real output + output_data: str = "" + output_tokens: int = 0 + # The expected output + expected_output: str = "" + # The question of the request + question: str = "" + start_time: float = 0.0 + end_time: float = 0.0 + # The cost of the request + req_cost: float = 0.0 + # Time to first token, cost of the prefill + prefill_latency: float = 0.0 + # Time between tokens + tbt_list: list[float] = field(default_factory=list) + # Average latency of the tbt_list + tbt_latency: float = 0.0 + # Whether the request is successful + is_success: bool = False + # whether the output_data matches the expected output + is_match: bool = False + + def to_dict(self): + return vars(self) + + +@dataclass +class MultiTurnDialogRecord(RequestRecord): + """ + The record for multi-turn dialogue request + """ + + # The total turn of the conversation + total_turns: int = -1 + # The current turn of the dialog + turn_id: int = -1 + # The input content of this dialog, which deletes the history information + in_content: str = "" + # If this request belongs to QA dialog + is_qa: bool = False + + def to_dict(self): + return vars(self) + + +@dataclass +class LatencyStatistics: + """ + the latency statistics of all requests + """ + + # The total latency of all requests(ms) + e2e_latency_all: float = -1 + # The end to end average throughput(tokens/s) + output_token_throughput: float = -1 + # The average throughput of all requests(tokens/s) + token_throughput_per_request: float = -1 + # The TP50 latency of time to first tokens(ms) + p50_prefill_latency: float = -1 + # The TP90 latency of time to first tokens(ms) + p90_prefill_latency: float = -1 + # The TP99 latency of time to first tokens(ms) + p99_prefill_latency: float = -1 + # The max latency of time to first tokens(ms) + max_prefill_latency: float = -1 + # The average latency of time to first tokens(ms) + avg_prefill_latency: float = -1 + # The TP50 latency of decoder latency(ms) + p50_decode_latency: float = -1 + # The TP90 latency of decoder latency(ms) + p90_decode_latency: float = -1 + # The TP99 latency of decoder latency(ms) + p99_decode_latency: float = -1 + # The max latency of decoder latency(ms) + max_decode_latency: float = -1 + # The average latency of decoder latency(ms) + avg_decode_latency: float = -1 + # The metrics + metric_dict: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self): + return vars(self) diff --git a/test/common/uc_eval/utils/dataloader.py b/test/common/uc_eval/utils/dataloader.py new file mode 100644 index 00000000..3903d683 --- /dev/null +++ b/test/common/uc_eval/utils/dataloader.py @@ -0,0 +1,214 @@ +import json +import random +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + +import numpy as np +from common.uc_eval.utils.data_class import SynthericParams +from common.uc_eval.utils.utils import PathUtil, get_logger +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer + +logger = get_logger() +EPOCH_NUM = 10 + + +class BaseDataset(ABC): + def __init__( + self, + tokenizer_path: str = None, + ): + tokenizer_path = PathUtil.get_datasets_dir_path(tokenizer_path) + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + tokenizer_path + ) + + @abstractmethod + def prepare_data(self, param: Any): + raise NotImplementedError + + +class SyntheticDataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, syntheric_params: SynthericParams) -> list[str]: + prompt_list = [] + for parallel_num in tqdm( + range(syntheric_params.parallel_num), + desc="Generate synthetic data", + unit="prompt", + ): + random_prompt_len = max( + 0, syntheric_params.prompt_tokens - syntheric_params.prefix_cache_tokens + ) + random_prompt = self.generate_random_str(random_prompt_len, time.time_ns()) + if syntheric_params.prefix_cache_tokens > 0: + pc_prompt = self.generate_random_str( + syntheric_params.prefix_cache_tokens, + syntheric_params.seeds[parallel_num], + ) + else: + pc_prompt = "" + final_prompt = pc_prompt + random_prompt + prompt_list.append(final_prompt) + return prompt_list + + def generate_random_str(self, length: int, seed: int) -> str: + """ + Sample random tokens from the tokenizer using a seed. + Use timestamp when cache hit is not required; otherwise use an incrementing seed. + """ + if length <= 0: + return "" + vocab_size = self.tokenizer.vocab_size + random.seed(seed) + ids_list = random.choices(range(vocab_size // 4, vocab_size // 3), k=length) + ids = np.array(ids_list) + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + logger.debug( + f"len(completion_token_ids[0]) = {len(completion_token_ids[0])}, length = {length}" + ) + + epoch = EPOCH_NUM + while len(completion_token_ids[0]) != length and epoch > 0: + epoch -= 1 + while len(completion_token_ids[0]) > length: + diff = len(completion_token_ids[0]) - length + now_length = ids.shape[0] - diff + ids = ids[:now_length] + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + + while len(completion_token_ids[0]) < length: + diff = length - len(completion_token_ids[0]) + diff_ids_list = random.choices( + range(vocab_size // 4, vocab_size // 3), k=diff + ) + diff_ids = np.array(diff_ids_list) + ids = np.append(ids, diff_ids) + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + + if len(completion_token_ids[0]) != length: + logger.warning( + "The length of completion token ids is not equal to the length of input token ids" + ) + logger.warning( + f"Generate tokens, target: {length}, actual: {len(completion_token_ids[0])}" + ) + + return text + + +class MultiTurnDialogueDataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, dataset_file_path) -> List[List[Union[str, Dict]]]: + """ + Load a JSON file containing multi-turn dialogue dataset paths. + :param file_path: JSON file listing multi-turn dialogue dataset paths to traverse. + the multi-turn dataset format: {"kimi": [{"conversion": [{"role": "user", "content": "xxx"}, ...], "qa": [{"question": "xxx", "answer": "xxx"}, ...]}]} + """ + cases = [] + # the path of multiturndialog.json + json_path = PathUtil.get_datasets_dir_path(dataset_file_path) + mtd_data: dict = self.load_json_file(json_path) + for dataset_name, files_list in mtd_data.items(): + for file_name in files_list: + case_path = PathUtil.get_dirname(json_path).joinpath( + dataset_name, file_name + ) + if case_path.exists(): + dialogues = self.load_json_file(case_path) + cases.extend(self.process_single_case_file(dialogues)) + else: + logger.warning( + f"JSON file {case_path} does not exist, please check the file path" + ) + if len(cases) == 0: + logger.warning( + f"The file {json_path} does not contain multi-turn dialogue data" + ) + return cases + + def process_single_case_file(self, dialogues: dict) -> List[List[Union[str, Dict]]]: + cases = [] + for dialogue_name, dialogue_data in dialogues.items(): + for i, dialog in enumerate(dialogue_data): + dialog_tokens = len( + self.tokenizer.tokenize(str(dialog["conversations"])) + ) + logger.info( + f"Current dialogue {dialogue_name}-{i} token count: {dialog_tokens}" + ) + cases.append([f"{dialogue_name}-{i}", dialog]) + return cases + + def load_json_file(self, file_path): + try: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data + except FileNotFoundError: + logger.error(f"JSON file not found: {file_path}") + raise FileNotFoundError(f"JSON file not found: {file_path}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in file {file_path}: {e}") + raise ValueError(f"Invalid JSON format in file {file_path}: {e}") + except Exception as e: + logger.error(f"Unexpected error while loading JSON file {file_path}: {e}") + raise ValueError(f"Failed to load JSON file {file_path}: {e}") + + +class DocQADataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, dataset_file_path) -> List[Union[str, str, str]]: + cases_list = [] + case_data = self._load_jsonl_file(dataset_file_path) + for case in case_data: + context = case.get("context") + question = case.get("question") + answer = case.get("answers") + case_name = case.get("dataset") + "_" + case.get("_id") + cases_list.append([case_name, context, question, answer]) + return cases_list + + def _load_jsonl_file(self, file_path: str) -> List[Dict[str, Any]]: + """ + Load a JSONL file containing doc_qa data + :param file_path: Path to the jsonl file + :return: List of doc_qa data + """ + case_data = [] + try: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + # In doc_qa, one line per sample; each sample contains: question, context, answer, etc. + json_line = json.loads(line) + extracted_data = { + "question": json_line.get("input", None), + "context": json_line.get("context", None), + "answers": json_line.get("answers", None), + "length": json_line.get("length", None), + "dataset": json_line.get("dataset", None), + "language": json_line.get("language", None), + "all_classes": json_line.get("all_classes", None), + "_id": json_line.get("_id", None), + } + case_data.append(extracted_data) + return case_data + except FileNotFoundError: + logger.error(f"JSONL file not found: {file_path}") + raise FileNotFoundError(f"JSONL file not found: {file_path}") + except json.JSONDecodeError as e: + logger.error(f"JSONL decode error in file {file_path}: {e}") + raise ValueError(f"Invalid JSONL format in file {file_path}: {e}") + except Exception as e: + logger.error(f"Unexpected error while loading JSONL file {file_path}: {e}") + raise ValueError(f"Failed to load JSONL file {file_path}: {e}") diff --git a/test/common/uc_eval/utils/metric.py b/test/common/uc_eval/utils/metric.py new file mode 100644 index 00000000..b6b0b800 --- /dev/null +++ b/test/common/uc_eval/utils/metric.py @@ -0,0 +1,223 @@ +import random +import re +import string +from abc import ABC, abstractmethod +from collections import Counter +from pathlib import Path +from typing import Callable, List, Optional, Union + +import jieba +import numpy as np +from common.uc_eval.utils.data_class import MultiTurnDialogRecord, RequestRecord + +stopwords_path = Path(__file__).parent.joinpath("stopwords.txt") +STOPWORDS: List[str] = [ + line.strip() for line in stopwords_path.open("r", encoding="utf-8").readlines() +] + + +def normalize_text(text: str) -> str: + # Remove punctuation (CJK full-width & ASCII) + pattern = r"[\u3000-\u303F\uFF00-\uFFEF" + re.escape(string.punctuation) + "]" + text = re.sub(pattern, "", text) + # Segment with jieba (precise mode) and lowercase + words = jieba.lcut(text) + words = [word.strip().lower() for word in words] + # Drop stop-words + filtered_words = [word for word in words if word not in STOPWORDS and word != ""] + text = " ".join(map(str, filtered_words)) + return text + + +class MetricClass(ABC): + def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]): + self.record_list = record_list + self.ACCURACY_METIRC_FUNCTION_MAP: dict[str, Callable] = { + "accuracy": self.get_accuracy, + "bootstrap-accuracy": self.get_bootstrap_accuracy_std, + "f1-score": self.get_f1_score, + } + + def calculate_metric(self, metric_names: List[str]): + record_match_num = 0 + for record in self.record_list: + expected_output, real_output = self.get_normalize_text(record) + + if self.match(expected_output, real_output): + record.is_match = True + record_match_num += 1 + + metric_dict = {} + for metric in metric_names: + metric_function = self.ACCURACY_METIRC_FUNCTION_MAP[metric] + metric_dict[metric] = metric_function(self.record_list) + + return metric_dict + + def get_normalize_text(self, record: Union[RequestRecord, MultiTurnDialogRecord]): + expected_output = record.expected_output + real_output = record.output_data + if isinstance(expected_output, tuple): + expected_output = list(expected_output) + elif not isinstance(expected_output, list): + expected_output = [expected_output] + + expected_output = [normalize_text(output) for output in expected_output] + real_output = normalize_text(real_output) + + return expected_output, real_output + + @abstractmethod + def match( + self, + expected_output: Union[str, List[str], tuple[str]], + real_output: str, + **kwargs + ): + pass + + def get_accuracy( + self, record_list: List[RequestRecord | MultiTurnDialogRecord] + ) -> float: + record_total = len(record_list) + match_num = sum(record.is_match for record in record_list) + return match_num / record_total if record_total != 0 else float("nan") + + def get_bootstrap_accuracy_std( + self, + record_list: List[RequestRecord | MultiTurnDialogRecord], + num_samples: int = 1000, + ): + """ + Compute standard deviation of accuracy using the Bootstrap method. + """ + if not record_list: + return float("nan") + + vals = [record.is_match for record in record_list] + return np.std( + [np.mean(random.sample(vals, len(vals) // 2)) for _ in range(num_samples)] + ).item() + + def get_f1_score( + self, + record_list: List[RequestRecord | MultiTurnDialogRecord], + ): + f1_score = [] + for record in record_list: + expected_output, real_output = self.get_normalize_text(record) + f1_score.append(self._f1_score(expected_output, real_output)) + return np.mean(f1_score).item() + + def _f1_score(self, expected_output: List[str], real_output: str) -> float: + max_f1_score = 0 + for output in expected_output: + common = Counter(output.split()) & Counter(real_output.split()) + num_same = sum(common.values()) + if num_same != 0: + precision = 1.0 * num_same / len(output.split()) + recall = 1.0 * num_same / len(real_output.split()) + f1 = (2 * precision * recall) / (precision + recall) + max_f1_score = max(max_f1_score, f1) + return max_f1_score + + +class Match(MetricClass): + def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]): + super().__init__(record_list) + + def match( + self, + expected_output: List[str], + real_output: str, + separator: Callable[[str], bool] = None, + options: Optional[list[str]] = None, + ) -> bool: + """ + Exact match: expected and picked must be identical + :param expected_output: the answer from dataset + :param real_output: actual output generated by model + :param separator: separator function to prevent partial matches + :param options: optional list of matching options; for multiple-choice questions, options must be present + """ + if options is None: + options = expected_output + + picked = None + for option in options: + if not real_output.startswith(option): + continue + if ( + separator is not None + and len(real_output) > len(options) + and not separator(real_output[len(option)]) + ): + continue + picked = option + break + + match = picked in expected_output + return match + + +class Includes(MetricClass): + def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]): + super().__init__(record_list) + + def match( + self, + expected_output: List[str], + real_output: str, + ) -> bool: + """ + Match succeeds if any part expected_output is found in real_output + :param expected_output: the answer from dataset + :param real_output: actual output generated by model + """ + for output in expected_output: + if real_output.rfind(output) != -1: + return True + return False + + +class FuzzyMatch(MetricClass): + def __init__(self, record_list: List[RequestRecord | MultiTurnDialogRecord]): + super().__init__(record_list) + + def match( + self, + expected_output: List[str], + real_output: str, + strategy: str = "substring", + threshold: float = 0.8, + ) -> bool: + """ + Fuzzy matching + :param expected_output: the answer from dataset + :param real_output: actual output generated by model + :param strategy: matching strategy, currently supports substring and jaccard + :param threshold: similarity threshold for jaccard strategy + """ + return any( + self._single_match(expected, real_output, strategy, threshold) + for expected in expected_output + ) + + def _single_match( + self, + expected: str, + real: str, + strategy: str = "substring", + threshold: float = 0.8, + ) -> bool: + if strategy == "substring": + return expected in real or real in expected + else: + set_exp, set_real = set(expected.split()), set(real.split()) + if not set_exp and not set_real: + return True + if not set_exp or not set_real: + return False + inter = len(set_exp & set_real) + union = len(set_exp | set_real) + return (inter / union) >= threshold diff --git a/test/common/uc_eval/utils/utils.py b/test/common/uc_eval/utils/utils.py new file mode 100644 index 00000000..589a0f11 --- /dev/null +++ b/test/common/uc_eval/utils/utils.py @@ -0,0 +1,264 @@ +import logging +import logging.handlers +import math +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Union + +import pandas as pd +from transformers import AutoConfig, AutoTokenizer + +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_current_time() -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + + +class PathUtil(object): + + @staticmethod + def get_dirname(file_path: str | Path): + return Path(os.path.dirname(file_path)) + + @staticmethod + def get_root_dir_path() -> Path: + root_path = Path(current_dir).parent.parent + return root_path + + @staticmethod + def get_other_dir_path(other: str) -> Path: + root_path = PathUtil.get_root_dir_path() + other_path = Path.joinpath(root_path, other) + other_path.mkdir(parents=True, exist_ok=True) + return other_path + + @staticmethod + def _default_datasets_path() -> Path: + return PathUtil.get_other_dir_path("UC-Eval-datasets") + + @staticmethod + def get_datasets_dir_path(in_file_path: str) -> Path: + if not in_file_path or in_file_path == "": + return PathUtil._default_datasets_path() + input_path = Path(in_file_path) + if input_path.is_absolute(): + return Path(in_file_path) + else: + return PathUtil.get_other_dir_path(in_file_path) + + +class FileUtil(object): + @staticmethod + def save_excel( + file_path: Path, + data: List[Any], + headers: List[str] = None, + sheet_name: str = "Sheet1", + ): + """ + Write test results to excel, one List[Any] represents one row of data + """ + df = ( + pd.DataFrame(data=data, columns=headers) + if headers + else pd.DataFrame(data=data) + ) + file_path.parent.mkdir(parents=True, exist_ok=True) + if file_path.exists(): + with pd.ExcelWriter( + file_path, mode="a", engine="openpyxl", if_sheet_exists="overlay" + ) as writer: + workbook = writer.book + # If the excel and sheet exist, append write + if sheet_name in workbook.sheetnames: + existing_df = pd.read_excel(file_path, sheet_name=sheet_name) + start_now = existing_df.shape[0] + 1 + df.to_excel( + writer, + sheet_name=sheet_name, + index=False, + startrow=start_now, + header=False if start_now > 0 else True, + ) + else: + # If the excel exists but the sheet does not, create a new sheet and write + df.to_excel( + writer, + sheet_name=sheet_name, + index=False, + header=(headers is not None), + ) + else: + # if the excel does not exist, create a new excel and sheet + with pd.ExcelWriter(file_path, mode="w", engine="openpyxl") as writer: + df.to_excel( + writer, + sheet_name=sheet_name, + index=False, + header=(headers is not None), + ) + + +class LoggerHandler(logging.Logger): + def __init__( + self, name: str, level: int = logging.INFO, log_path: str = None + ) -> None: + super().__init__(name, level) + # format of the log message + fmt = "%(asctime)s.%(msecs)03d %(levelname)s [pid:%(process)d] [%(threadName)s] [tid:%(thread)d] [%(filename)s:%(lineno)d %(funcName)s] %(message)s" + data_fmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(fmt, data_fmt) + + # using file handler to log to file + if log_path is not None: + file_handler = logging.handlers.RotatingFileHandler( + filename=log_path, + maxBytes=1024 * 1024 * 10, + backupCount=20, + delay=True, + encoding="utf-8", + ) + file_handler.setFormatter(formatter) + file_handler.setLevel(self.level) + self.addHandler(file_handler) + + console_handler = logging.StreamHandler(stream=sys.stdout) + console_handler.setFormatter(formatter) + console_handler.setLevel(self.level) + self.addHandler(console_handler) + + def setLevel(self, level) -> None: + super().setLevel(level) + for handler in self.handlers: + handler.setLevel(level) + + +# the global dictionary to store all the logger instances +_logger_instances: Dict[str, LoggerHandler] = {} + + +def get_logger( + name: str = "evals", level: int = logging.INFO, log_file: str = None +) -> logging.Logger: + if name in _logger_instances: + return _logger_instances[name] + + # create a new logger instance + logger = LoggerHandler(name, level, log_file) + _logger_instances[name] = logger + return logger + + +class ModelMemoryCalculator: + def __init__(self, model_path: Union[Path, str]): + if isinstance(model_path, str): + model_path = PathUtil.get_datasets_dir_path(model_path) + self.config = AutoConfig.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.dtype_bytes_map = {"fp16": 2, "bf16": 2, "fp32": 4, "int8": 1} + + def _get_model_info(self): + """ + Get model architecture information + """ + hidden_size = getattr(self.config, "hidden_size", None) + num_layers = getattr(self.config, "num_hidden_layers", None) + num_attention_heads = getattr(self.config, "num_attention_heads", None) + num_kv_heads = getattr(self.config, "num_key_value_heads", num_attention_heads) + qk_rope_head_dim = getattr(self.config, "qk_rope_head_dim", None) + kv_lora_rank = getattr(self.config, "kv_lora_rank", None) + + head_dim = self._calculate_head_dimension( + hidden_size, num_attention_heads, qk_rope_head_dim, kv_lora_rank + ) + + return { + "hidden_size": hidden_size, + "num_layers": num_layers, + "num_attention_heads": num_attention_heads, + "num_kv_heads": num_kv_heads, + "qk_rope_head_dim": qk_rope_head_dim, + "kv_lora_rank": kv_lora_rank, + "head_dim": head_dim, + "model_type": self.config.model_type, + "element_calculate_type": 1 if qk_rope_head_dim and kv_lora_rank else 0, + } + + def _calculate_head_dimension( + self, hidden_size, num_attention_heads, qk_rope_head_dim, kv_lora_rank + ): + """ + Calculate head dimension + """ + # First, check if both qk_rope_head_dim and kv_lora_rank parameters exist; if so, use these two parameters for calculation. + if qk_rope_head_dim is not None and kv_lora_rank is not None: + return qk_rope_head_dim + kv_lora_rank + + # Then, check if there is a head_dim parameter available and use it if present. + head_dim = getattr(self.config, "head_dim", None) + if head_dim is not None: + return head_dim + + # Next, check if both hidden_size and num_attention_heads parameters exist; if so, use these two parameters for calculation. + if hidden_size is not None and num_attention_heads is not None: + if num_attention_heads == 0: + raise ValueError("num_attention_heads cannot be zero") + return hidden_size // num_attention_heads + + # If none of the above exist, raise an error. + raise ValueError( + "Unable to calculate head dimension with current model configuration. " + "Please check if the model configuration contains required parameters." + ) + + def calculate_kv_cache_memory(self, sequence_length, batch_size=1, dtype="fp16"): + """ + Calculate KV Cache memory usage: + For models like DeepSeek-R1: batch_size * sequence_length * num_hidden_layers * head_dim * bytes_per_element + For models like Qwen3-32B: 2 * batch_size * sequence_length * num_hidden_layers * num_kv_heads * head_dim * bytes_per_element + :param sequence_length: Sequence length (number of tokens) + :param batch_size: Batch size + :param dtype: Data type ('fp16', 'bf16', 'fp32', 'int8') + """ + model_info = self._get_model_info() + + # Check required parameters + required_params = ["num_layers", "head_dim"] + ( + [] if model_info["element_calculate_type"] else ["num_attention_heads"] + ) + for param in required_params: + if model_info[param] is None: + raise ValueError(f"Cannot retrieve {param} from configuration file") + + # Round up any input sequence_length to the nearest multiple of 128 + sequence_length = math.ceil(sequence_length / 128) * 128 + bytes_per_element = self.dtype_bytes_map.get(dtype, 2) + + if model_info["element_calculate_type"]: + total_elements = ( + batch_size + * sequence_length + * model_info["num_layers"] + * model_info["head_dim"] + ) + else: + # Use KV heads count from configuration, if not available use attention heads count + num_kv_heads = ( + model_info["num_kv_heads"] or model_info["num_attention_heads"] + ) + total_elements = ( + batch_size + * sequence_length + * model_info["num_layers"] + * num_kv_heads + * model_info["head_dim"] + * 2 # key + value + ) + + memory_bytes = total_elements * bytes_per_element + memory_gb = memory_bytes / (1024**3) + + return total_elements, round(memory_gb, 4) diff --git a/test/config.yaml b/test/config.yaml index 7ac32f48..3e923945 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -9,7 +9,7 @@ reports: database: backup: "results/" - enabled: true + enabled: false host: "127.0.0.1" port: 3306 name: "ucm_pytest" @@ -17,6 +17,14 @@ database: password: "123456" charset: "utf8mb4" +models: + ip_ports: "" + tokenizer_path: "" + served_model_name: "" + payload: '{}' + enable_clear_hbm: false + + # LLM Connection Configuration llm_connection: model: "qwen3" @@ -24,4 +32,4 @@ llm_connection: tokenizer_path: "/home/models/QwQ-32B" stream: true # stream output ignore_eos: true # Ignore the returned terminator - timeout: 180 # request time out \ No newline at end of file + timeout: 180 # request time out diff --git a/test/suites/E2E/test_evaluator.py b/test/suites/E2E/test_evaluator.py new file mode 100644 index 00000000..69c05cae --- /dev/null +++ b/test/suites/E2E/test_evaluator.py @@ -0,0 +1,44 @@ +import dataclasses + +import pytest +from common.capture_utils import export_vars +from common.config_utils import config_utils as config_instance +from common.uc_eval.task import DocQaEvalTask +from common.uc_eval.utils.data_class import EvalConfig, ModelConfig + + +@pytest.fixture(scope="session") +def model_config() -> ModelConfig: + cfg = config_instance.get_config("models") or {} + field_name = [field.name for field in dataclasses.fields(ModelConfig)] + kwargs = {k: v for k, v in cfg.items() if k in field_name and v is not None} + return ModelConfig(**kwargs) + + +doc_qa_eval_cases = [ + pytest.param( + EvalConfig( + data_type="doc_qa", + dataset_file_path="common/uc_eval/datasets/doc_qa/demo.jsonl", + enable_prefix_cache=False, + parallel_num=1, + benchmark_mode="evaluate", + metrics=["accuracy", "bootstrap-accuracy", "f1-score"], + eval_class="common.uc_eval.utils.metric:Includes", + ), + id="doc-qa-complete-recalculate-evaluate", + ) +] + + +@pytest.mark.feature("eval_test") +@pytest.mark.stage(2) +@pytest.mark.parametrize("eval_config", doc_qa_eval_cases) +@export_vars +def test_doc_qa_perf( + eval_config: EvalConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + file_save_path = config_instance.get_config("reports").get("base_dir") + task = DocQaEvalTask(model_config, eval_config, file_save_path) + result = task.run() + return {"_name": request.node.callspec.id, "_data": result} diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py index dbec0318..6f4d7512 100644 --- a/test/suites/E2E/test_uc_performance.py +++ b/test/suites/E2E/test_uc_performance.py @@ -1,6 +1,15 @@ +import dataclasses + import pytest from common.capture_utils import export_vars +from common.config_utils import config_utils as config_instance from common.llmperf.run_inference import inference_results +from common.uc_eval.task import ( + DocQaPerfTask, + MultiTurnDialogPerfTask, + SyntheticPerfTask, +) +from common.uc_eval.utils.data_class import ModelConfig, PerfConfig @pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]]) @@ -156,3 +165,105 @@ def test_performance( print("\n[INFO] All values are greater than 0. Assertion passed!") return {"_name": "llmperf", "_data": value_lists} + + +@pytest.fixture(scope="session") +def model_config() -> ModelConfig: + cfg = config_instance.get_config("models") or {} + field_name = [field.name for field in dataclasses.fields(ModelConfig)] + kwargs = {k: v for k, v in cfg.items() if k in field_name and v is not None} + return ModelConfig(**kwargs) + + +sync_perf_cases = [ + pytest.param( + PerfConfig( + data_type="synthetic", + enable_prefix_cache=False, + parallel_num=[1, 4, 8], + prompt_tokens=[4000, 8000], + output_tokens=[1000, 1000], + benchmark_mode="default-perf", + ), + id="benchmark-complete-recalculate-default-perf", + ), + pytest.param( + PerfConfig( + data_type="synthetic", + enable_prefix_cache=True, + parallel_num=[1, 4, 8], + prompt_tokens=[4000, 8000], + output_tokens=[1000, 1000], + prefix_cache_num=[0.8, 0.8], + benchmark_mode="stable-perf", + ), + id="benchmark-prefix-cache-stable-perf", + ), +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.stage(2) +@pytest.mark.parametrize("perf_config", sync_perf_cases) +@export_vars +def test_sync_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + file_save_path = config_instance.get_config("reports").get("base_dir") + task = SyntheticPerfTask(model_config, perf_config, file_save_path) + result = task.run() + return {"_name": request.node.callspec.id, "_proj": result} + + +multiturn_dialogue_perf_cases = [ + pytest.param( + PerfConfig( + data_type="multi_turn_dialogue", + dataset_file_path="common/uc_eval/datasets/multi_turn_dialogues/multiturndialog.json", + enable_prefix_cache=False, + parallel_num=1, + benchmark_mode="default-perf", + ), + id="multiturn-dialogue-complete-recalculate-default-perf", + ) +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.stage(2) +@pytest.mark.parametrize("perf_config", multiturn_dialogue_perf_cases) +@export_vars +def test_multiturn_dialogue_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + file_save_path = config_instance.get_config("reports").get("base_dir") + task = MultiTurnDialogPerfTask(model_config, perf_config, file_save_path) + result = task.run() + return {"_name": request.node.callspec.id, "_data": result} + + +doc_qa_perf_cases = [ + pytest.param( + PerfConfig( + data_type="doc_qa", + dataset_file_path="common/uc_eval/datasets/doc_qa/demo.jsonl", + enable_prefix_cache=False, + parallel_num=1, + benchmark_mode="default-perf", + ), + id="doc-qa-complete-recalculate-default-perf", + ) +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.stage(2) +@pytest.mark.parametrize("perf_config", doc_qa_perf_cases) +@export_vars +def test_doc_qa_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + file_save_path = config_instance.get_config("reports").get("base_dir") + task = DocQaPerfTask(model_config, perf_config, file_save_path) + result = task.run() + return {"_name": request.node.callspec.id, "_data": result}