Skip to content

Commit 598d018

Browse files
authored
[Feature] Add performance and evaluation testing tools using the pytest framework (#462)
Add eval and performance test
1 parent 24e6bfa commit 598d018

File tree

11 files changed

+2407
-0
lines changed

11 files changed

+2407
-0
lines changed

test/common/uc_eval/task.py

Lines changed: 409 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import functools
2+
import importlib
3+
from abc import ABC, abstractmethod
4+
from typing import Any, Dict, List, Optional
5+
6+
import numpy as np
7+
from common.uc_eval.utils.data_class import (
8+
EvalConfig,
9+
LatencyStatistics,
10+
MultiTurnDialogRecord,
11+
RequestRecord,
12+
)
13+
from common.uc_eval.utils.utils import get_logger
14+
from tqdm import tqdm
15+
16+
logger = get_logger()
17+
MS_SCALE = 1000
18+
# the max wave rate for stable perf
19+
MAX_WAVE_RATE = 0.05
20+
21+
22+
def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any:
23+
"""create object based on class name"""
24+
modname, qualname_separator, qualname = object_ref.partition(":")
25+
obj = importlib.import_module(modname)
26+
if qualname_separator:
27+
for attr in qualname.split("."):
28+
obj = getattr(obj, attr)
29+
return functools.partial(obj, *args, **kwargs)
30+
31+
32+
class BenchmarkBase(ABC):
33+
def __init__(self, eval_config: Optional[EvalConfig], stable_perf: bool = False):
34+
self.eval_config = eval_config
35+
self.stable_perf = stable_perf
36+
37+
def get_success_request(self, data: List[RequestRecord | MultiTurnDialogRecord]):
38+
"""
39+
Get the successful request from the record
40+
"""
41+
success_request = []
42+
for request in data:
43+
if request.is_success:
44+
success_request.append(request)
45+
if len(success_request) == 0:
46+
logger.warning(f"No success request found, please check the result")
47+
return success_request
48+
49+
def result_to_column_dict(
50+
self, data: List[RequestRecord | MultiTurnDialogRecord]
51+
) -> Dict[str, List[Any]]:
52+
"""
53+
format: list[dict] ---> dict[list]
54+
"""
55+
if not data:
56+
return {}
57+
keys = list(data[0].to_dict().keys())
58+
result = {key: [] for key in keys}
59+
for item in data:
60+
for key in keys:
61+
result[key].append(item.to_dict()[key])
62+
return result
63+
64+
@abstractmethod
65+
def perf_show(self, records: Any, parallel_num: int = 1):
66+
raise NotImplementedError
67+
68+
69+
class EvaluatorBenchmark(BenchmarkBase):
70+
def __init__(self, eval_config: EvalConfig):
71+
super().__init__(eval_config=eval_config)
72+
self.metric_method = eval_config.metrics
73+
self.eval_class = eval_config.eval_class
74+
75+
def perf_show(
76+
self,
77+
record_list: List[RequestRecord | MultiTurnDialogRecord],
78+
parallel_num: int,
79+
):
80+
logger.info(f"Begin calculate metrics...")
81+
success_request = self.get_success_request(record_list)
82+
eval_cls = make_object(self.eval_class)(success_request)
83+
latency = LatencyStatistics()
84+
metric_result = eval_cls.calculate_metric(self.metric_method)
85+
latency.metric_dict = metric_result
86+
match_record_list = eval_cls.record_list
87+
88+
return latency, match_record_list
89+
90+
91+
class PerformanceBenchmark(BenchmarkBase):
92+
def __init__(self, stable_perf: bool):
93+
super().__init__(stable_perf)
94+
self.stable_perf = stable_perf
95+
self.stable_work_time = [0, 0]
96+
97+
def perf_show(
98+
self,
99+
record_list: List[RequestRecord | MultiTurnDialogRecord],
100+
parallel_num: int,
101+
) -> LatencyStatistics:
102+
logger.info(f"Begin calculate latency...")
103+
success_request = self.get_success_request(record_list)
104+
request_record_dict = self.result_to_column_dict(success_request)
105+
if self.stable_perf:
106+
request_ids = self._get_stable_request_id(request_record_dict, parallel_num)
107+
else:
108+
request_ids = request_record_dict.get("request_id")
109+
records = [record for record in record_list if record.request_id in request_ids]
110+
perf_result = self._get_performance_data(records)
111+
return perf_result
112+
113+
def _get_performance_data(
114+
self, record_list: List[RequestRecord | MultiTurnDialogRecord]
115+
) -> LatencyStatistics:
116+
"""
117+
After all requests are completed, get the performance data
118+
"""
119+
if len(record_list) == 0:
120+
logger.warning(f"there is no request_id in the record_list, please check")
121+
latency = LatencyStatistics()
122+
record_dict = self.result_to_column_dict(record_list)
123+
124+
e2e_latency_all = (
125+
max(record_dict["end_time"]) - min(record_dict["start_time"])
126+
) * MS_SCALE
127+
latency.e2e_latency_all = round(e2e_latency_all, 2)
128+
logger.debug("All request latencies: %.4f ms", e2e_latency_all)
129+
130+
total_output_tokens = sum(record_dict["output_tokens"])
131+
output_token_throughput = total_output_tokens / e2e_latency_all * MS_SCALE
132+
latency.output_token_throughput = round(output_token_throughput, 2)
133+
logger.debug(
134+
"Total output token throughput: %.4f tokens/s", output_token_throughput
135+
)
136+
137+
throughputs = []
138+
for tokens, cost in zip(record_dict["output_tokens"], record_dict["req_cost"]):
139+
if cost > 0:
140+
throughputs.append(tokens / cost)
141+
if throughputs:
142+
token_throughput_per_request = np.mean(throughputs).item()
143+
latency.token_throughput_per_request = round(
144+
token_throughput_per_request, 2
145+
)
146+
logger.debug(
147+
"Average per-request throughput: %.4f tokens/s",
148+
token_throughput_per_request,
149+
)
150+
else:
151+
logger.warning("No valid requests for throughput calculation")
152+
153+
prefill_latency_list = [record_dict["prefill_latency"]]
154+
p50_prefill_latency = np.percentile(prefill_latency_list, 50).item() * MS_SCALE
155+
latency.p50_prefill_latency = round(p50_prefill_latency, 2)
156+
logger.debug("Time to First token latency P50: %.4f ms", p50_prefill_latency)
157+
158+
p90_prefill_latency = np.percentile(prefill_latency_list, 90).item() * MS_SCALE
159+
latency.p90_prefill_latency = round(p90_prefill_latency, 2)
160+
logger.debug("Time to First token latency TP90: %.4f ms", p90_prefill_latency)
161+
162+
p99_prefill_latency = np.percentile(prefill_latency_list, 99).item() * MS_SCALE
163+
latency.p99_prefill_latency = round(p99_prefill_latency, 2)
164+
logger.debug("Time to First token latency TP99: %.4f ms", p99_prefill_latency)
165+
166+
max_prefill_latency = np.max(prefill_latency_list).item() * MS_SCALE
167+
latency.max_prefill_latency = round(max_prefill_latency, 2)
168+
logger.debug(
169+
"Maximum time to first token latency: %.4f ms", max_prefill_latency
170+
)
171+
172+
avg_prefill_latency = np.mean(prefill_latency_list).item() * MS_SCALE
173+
latency.avg_prefill_latency = round(avg_prefill_latency, 2)
174+
logger.debug(
175+
"Average time to first token latency: %.4f ms", avg_prefill_latency
176+
)
177+
178+
decode_latency_list = []
179+
for tbt_latency in record_dict["tbt_latency"]:
180+
decode_latency_list.append(tbt_latency)
181+
182+
p50_decode_latency = np.percentile(decode_latency_list, 50).item() * MS_SCALE
183+
latency.p50_decode_latency = round(p50_decode_latency, 2)
184+
logger.debug("Tokens Per Second latency TP50: %.4f ms", p50_decode_latency)
185+
186+
p90_decode_latency = np.percentile(decode_latency_list, 90).item() * MS_SCALE
187+
latency.p90_decode_latency = round(p90_decode_latency, 2)
188+
logger.debug("Tokens Per Second latency TP90: %.4f ms", p90_decode_latency)
189+
190+
p99_decode_latency = np.percentile(decode_latency_list, 99).item() * MS_SCALE
191+
latency.p99_decode_latency = round(p99_decode_latency, 2)
192+
logger.debug("Tokens Per Second latency TP99: %.4f ms", p99_decode_latency)
193+
194+
max_decode_latency = np.max(decode_latency_list).item() * MS_SCALE
195+
latency.max_decode_latency = round(max_decode_latency, 2)
196+
logger.debug("Maximum tokens per second latency: %.4f ms", max_decode_latency)
197+
198+
avg_decode_latency = np.mean(decode_latency_list).item() * MS_SCALE
199+
latency.avg_decode_latency = round(avg_decode_latency, 2)
200+
logger.debug("Average tokens per second latency: %.4f ms", avg_decode_latency)
201+
202+
return latency
203+
204+
def _get_stable_request_id(
205+
self, result: Dict[str, List[Any]], target_concurrency: int
206+
):
207+
"""
208+
Get steady-state request ids via start_time vs. end_time delta
209+
"""
210+
# the number of concurrent requests at each request start and end
211+
request_num = len(result.get("request_id", []))
212+
concurrent_levels = [0] * 2 * request_num
213+
request_events = []
214+
for idx in range(request_num):
215+
request_events.append(
216+
{
217+
"request_id": result.get("request_id", [])[idx],
218+
"event_type": "start",
219+
"timestamp": result.get("start_time", [])[idx],
220+
}
221+
)
222+
request_events.append(
223+
{
224+
"request_id": result.get("request_id", [])[idx],
225+
"event_type": "end",
226+
"timestamp": result.get("end_time", [])[idx],
227+
}
228+
)
229+
sorted_events = sorted(request_events, key=lambda x: x["timestamp"])
230+
stable_stage_requests = []
231+
logger.info("Start calculating stable request id")
232+
used_request_num = 0
233+
for idx, item in enumerate(
234+
tqdm(sorted_events, desc="search stable request id")
235+
):
236+
if item["event_type"] == "start":
237+
used_request_num += 1
238+
concurrent_levels[idx] = (
239+
concurrent_levels[idx - 1] + 1 if idx > 0 else 1
240+
)
241+
else:
242+
concurrent_levels[idx] = concurrent_levels[idx - 1] - 1
243+
if (
244+
item["event_type"] == "start"
245+
and concurrent_levels[idx] == target_concurrency
246+
):
247+
stable_stage_requests.append(item["request_id"])
248+
if len(stable_stage_requests) == 2:
249+
self.stable_work_time[0] = item["timestamp"]
250+
elif (
251+
item["event_type"] == "start"
252+
and concurrent_levels[idx]
253+
>= int(target_concurrency * (1 - MAX_WAVE_RATE))
254+
and len(stable_stage_requests) > 2
255+
):
256+
stable_stage_requests.append(item["request_id"])
257+
elif used_request_num == request_num and item["event_type"] == "end":
258+
self.stable_work_time[1] = item["timestamp"]
259+
break
260+
elif (
261+
len(stable_stage_requests) > 1
262+
and item["event_type"] == "end"
263+
and concurrent_levels[idx]
264+
< int(target_concurrency * (1 - MAX_WAVE_RATE))
265+
):
266+
self.stable_work_time[1] = item["timestamp"]
267+
break
268+
269+
if len(stable_stage_requests) > 1:
270+
# ignore first request
271+
stable_stage_requests.pop(0)
272+
if len(stable_stage_requests) == 0:
273+
logger.error("cannot find stable stage, please check your settings")
274+
raise ValueError("cannot find stable stage, please check your settings")
275+
logger.info(f"stable request id list: {stable_stage_requests=}")
276+
return stable_stage_requests

0 commit comments

Comments
 (0)