Skip to content

Commit 3e5e5ad

Browse files
committed
Fix wrong var names and data access for client, removed unnecessary try excepts in api.py, update client tests accordingly
1 parent 8e2f8e6 commit 3e5e5ad

File tree

6 files changed

+108
-181
lines changed

6 files changed

+108
-181
lines changed

tests/vec_inf/client/test_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def test_wait_until_ready():
113113
with patch.object(VecInfClient, "get_status") as mock_status:
114114
# First call returns LAUNCHING, second call returns READY
115115
status1 = MagicMock()
116-
status1.status = ModelStatus.LAUNCHING
116+
status1.server_status = ModelStatus.LAUNCHING
117117

118118
status2 = MagicMock()
119-
status2.status = ModelStatus.READY
119+
status2.server_status = ModelStatus.READY
120120
status2.base_url = "http://gpu123:8080/v1"
121121

122122
mock_status.side_effect = [status1, status2]
@@ -125,6 +125,6 @@ def test_wait_until_ready():
125125
client = VecInfClient()
126126
result = client.wait_until_ready("12345678", timeout_seconds=5)
127127

128-
assert result.status == ModelStatus.READY
128+
assert result.server_status == ModelStatus.READY
129129
assert result.base_url == "http://gpu123:8080/v1"
130130
assert mock_status.call_count == 2

tests/vec_inf/client/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ def test_model_info_creation():
1010
family="test-family",
1111
variant="test-variant",
1212
type=ModelType.LLM,
13-
config={"num_gpus": 1},
13+
config={"gpus_per_node": 1},
1414
)
1515

1616
assert model.name == "test-model"
1717
assert model.family == "test-family"
1818
assert model.variant == "test-variant"
1919
assert model.type == ModelType.LLM
20-
assert model.config["num_gpus"] == 1
20+
assert model.config["gpus_per_node"] == 1
2121

2222

2323
def test_model_info_optional_fields():
@@ -40,7 +40,7 @@ def test_launch_options_default_values():
4040
"""Test LaunchOptions with default values."""
4141
options = LaunchOptions()
4242

43-
assert options.num_gpus is None
43+
assert options.gpus_per_node is None
4444
assert options.partition is None
4545
assert options.data_type is None
4646
assert options.num_nodes is None

vec_inf/client/_helper.py

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def _get_model_configuration(self) -> ModelConfig:
7474
)
7575

7676
if not model_weights_parent_dir:
77-
raise ValueError(
78-
f"Could not determine model_weights_parent_dir and '{self.model_name}' not found in configuration"
79-
)
77+
raise ModelNotFoundError("Could not determine model weights parent directory")
8078

8179
model_weights_path = Path(model_weights_parent_dir, self.model_name)
8280

@@ -266,53 +264,47 @@ def _get_base_status_data(self) -> StatusResponse:
266264
def _check_model_health(self) -> None:
267265
"""Check model health and update status accordingly."""
268266
status, status_code = utils.model_health_check(
269-
cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
267+
self.status_info.model_name, self.slurm_job_id, self.log_dir
270268
)
271269
if status == ModelStatus.READY:
272-
self.status_info["base_url"] = utils.get_base_url(
273-
cast(str, self.status_info["model_name"]),
270+
self.status_info.base_url = utils.get_base_url(
271+
self.status_info.model_name,
274272
self.slurm_job_id,
275273
self.log_dir,
276274
)
277-
self.status_info["server_status"] = status
275+
self.status_info.server_status = status
278276
else:
279-
self.status_info["server_status"], self.status_info["failed_reason"] = (
280-
status,
281-
cast(str, status_code),
282-
)
277+
self.status_info.server_status = status
278+
self.status_info.failed_reason = cast(str, status_code)
283279

284280
def _process_running_state(self) -> None:
285281
"""Process RUNNING job state and check server status."""
286282
server_status = utils.is_server_running(
287-
cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
283+
self.status_info.model_name, self.slurm_job_id, self.log_dir
288284
)
289285

290286
if isinstance(server_status, tuple):
291-
self.status_info["server_status"], self.status_info["failed_reason"] = (
292-
server_status
293-
)
287+
self.status_info.server_status, self.status_info.failed_reason = server_status
294288
return
295289

296290
if server_status == "RUNNING":
297291
self._check_model_health()
298292
else:
299-
self.status_info["server_status"] = server_status
293+
self.status_info.server_status = server_status
300294

301295
def _process_pending_state(self) -> None:
302296
"""Process PENDING job state."""
303297
try:
304-
self.status_info["pending_reason"] = self.output.split(" ")[10].split("=")[
305-
1
306-
]
307-
self.status_info["server_status"] = ModelStatus.PENDING
298+
self.status_info.pending_reason = self.output.split(" ")[10].split("=")[1]
299+
self.status_info.server_status = ModelStatus.PENDING
308300
except IndexError:
309-
self.status_info["pending_reason"] = "Unknown pending reason"
301+
self.status_info.pending_reason = "Unknown pending reason"
310302

311303
def process_model_status(self) -> StatusResponse:
312304
"""Process different job states and update status information."""
313-
if self.status_info["job_state"] == ModelStatus.PENDING:
305+
if self.status_info.job_state == ModelStatus.PENDING:
314306
self._process_pending_state()
315-
elif self.status_info["job_state"] == "RUNNING":
307+
elif self.status_info.job_state == "RUNNING":
316308
self._process_running_state()
317309

318310
return self.status_info
@@ -360,7 +352,7 @@ def _build_metrics_url(self) -> str:
360352
def _check_prefix_caching(self) -> bool:
361353
"""Check if prefix caching is enabled."""
362354
job_json = utils.read_slurm_log(
363-
cast(str, self.status_info["model_name"]),
355+
self.status_info.model_name,
364356
self.slurm_job_id,
365357
"json",
366358
self.log_dir,
@@ -369,6 +361,43 @@ def _check_prefix_caching(self) -> bool:
369361
return False
370362
return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False))
371363

364+
def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
365+
"""Parse metrics with latency count and sum."""
366+
key_metrics = {
367+
"vllm:prompt_tokens_total": "total_prompt_tokens",
368+
"vllm:generation_tokens_total": "total_generation_tokens",
369+
"vllm:e2e_request_latency_seconds_sum": "request_latency_sum",
370+
"vllm:e2e_request_latency_seconds_count": "request_latency_count",
371+
"vllm:request_queue_time_seconds_sum": "queue_time_sum",
372+
"vllm:request_success_total": "successful_requests_total",
373+
"vllm:num_requests_running": "requests_running",
374+
"vllm:num_requests_waiting": "requests_waiting",
375+
"vllm:num_requests_swapped": "requests_swapped",
376+
"vllm:gpu_cache_usage_perc": "gpu_cache_usage",
377+
"vllm:cpu_cache_usage_perc": "cpu_cache_usage",
378+
}
379+
380+
if self.enabled_prefix_caching:
381+
key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
382+
key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
383+
384+
parsed: dict[str, float] = {}
385+
for line in metrics_text.split("\n"):
386+
if line.startswith("#") or not line.strip():
387+
continue
388+
389+
parts = line.split()
390+
if len(parts) < 2:
391+
continue
392+
393+
metric_name = parts[0].split("{")[0]
394+
if metric_name in key_metrics:
395+
try:
396+
parsed[key_metrics[metric_name]] = float(parts[1])
397+
except (ValueError, IndexError):
398+
continue
399+
return parsed
400+
372401
def fetch_metrics(self) -> Union[dict[str, float], str]:
373402
"""Fetch metrics from the endpoint."""
374403
try:
@@ -443,43 +472,6 @@ def fetch_metrics(self) -> Union[dict[str, float], str]:
443472
except requests.RequestException as e:
444473
return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}"
445474

446-
def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
447-
"""Parse metrics with latency count and sum."""
448-
key_metrics = {
449-
"vllm:prompt_tokens_total": "total_prompt_tokens",
450-
"vllm:generation_tokens_total": "total_generation_tokens",
451-
"vllm:e2e_request_latency_seconds_sum": "request_latency_sum",
452-
"vllm:e2e_request_latency_seconds_count": "request_latency_count",
453-
"vllm:request_queue_time_seconds_sum": "queue_time_sum",
454-
"vllm:request_success_total": "successful_requests_total",
455-
"vllm:num_requests_running": "requests_running",
456-
"vllm:num_requests_waiting": "requests_waiting",
457-
"vllm:num_requests_swapped": "requests_swapped",
458-
"vllm:gpu_cache_usage_perc": "gpu_cache_usage",
459-
"vllm:cpu_cache_usage_perc": "cpu_cache_usage",
460-
}
461-
462-
if self.enabled_prefix_caching:
463-
key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
464-
key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
465-
466-
parsed: dict[str, float] = {}
467-
for line in metrics_text.split("\n"):
468-
if line.startswith("#") or not line.strip():
469-
continue
470-
471-
parts = line.split()
472-
if len(parts) < 2:
473-
continue
474-
475-
metric_name = parts[0].split("{")[0]
476-
if metric_name in key_metrics:
477-
try:
478-
parsed[key_metrics[metric_name]] = float(parts[1])
479-
except (ValueError, IndexError):
480-
continue
481-
return parsed
482-
483475

484476
class ModelRegistry:
485477
"""Class for handling model listing and configuration management."""

vec_inf/client/_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class LaunchOptions:
7878
max_num_batched_tokens: Optional[int] = None
7979
partition: Optional[str] = None
8080
num_nodes: Optional[int] = None
81-
num_gpus: Optional[int] = None
81+
gpus_per_node: Optional[int] = None
8282
qos: Optional[str] = None
8383
time: Optional[str] = None
8484
vocab_size: Optional[int] = None
@@ -104,7 +104,7 @@ class LaunchOptionsDict(TypedDict):
104104
max_num_batched_tokens: NotRequired[Optional[int]]
105105
partition: NotRequired[Optional[str]]
106106
num_nodes: NotRequired[Optional[int]]
107-
num_gpus: NotRequired[Optional[int]]
107+
gpus_per_node: NotRequired[Optional[int]]
108108
qos: NotRequired[Optional[str]]
109109
time: NotRequired[Optional[str]]
110110
vocab_size: NotRequired[Optional[int]]

vec_inf/client/_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import subprocess
6+
import warnings
67
from pathlib import Path
78
from typing import Any, Optional, Union, cast
89

@@ -151,8 +152,9 @@ def load_config() -> list[ModelConfig]:
151152
else:
152153
config.setdefault("models", {})[name] = data
153154
else:
154-
print(
155-
f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}"
155+
warnings.warn(
156+
f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}", UserWarning,
157+
stacklevel=2
156158
)
157159

158160
return [
@@ -161,12 +163,6 @@ def load_config() -> list[ModelConfig]:
161163
]
162164

163165

164-
def shutdown_model(slurm_job_id: int) -> None:
165-
"""Shutdown a running model on the cluster."""
166-
shutdown_cmd = f"scancel {slurm_job_id}"
167-
run_bash_command(shutdown_cmd)
168-
169-
170166
def parse_launch_output(output: str) -> tuple[str, dict[str, str]]:
171167
"""Parse output from model launch command.
172168

0 commit comments

Comments
 (0)