Skip to content

Commit 8e2f8e6

Browse files
committed
Refactoring client for CLI use
1 parent b392bb0 commit 8e2f8e6

File tree

3 files changed

+39
-58
lines changed

3 files changed

+39
-58
lines changed

vec_inf/client/_helper.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -244,42 +244,26 @@ def _get_raw_status_output(self) -> str:
244244
raise SlurmJobError(f"Error: {stderr}")
245245
return output
246246

247-
def _get_base_status_data(self) -> dict[str, Union[str, None]]:
247+
def _get_base_status_data(self) -> StatusResponse:
248248
"""Extract basic job status information from scontrol output."""
249249
try:
250250
job_name = self.output.split(" ")[1].split("=")[1]
251251
job_state = self.output.split(" ")[9].split("=")[1]
252252
except IndexError:
253-
job_name = ModelStatus.UNAVAILABLE
253+
job_name = "UNAVAILABLE"
254254
job_state = ModelStatus.UNAVAILABLE
255255

256-
return {
257-
"model_name": job_name,
258-
"status": ModelStatus.UNAVAILABLE,
259-
"base_url": ModelStatus.UNAVAILABLE,
260-
"state": job_state,
261-
"pending_reason": None,
262-
"failed_reason": None,
263-
}
264-
265-
def process_model_status(self) -> StatusResponse:
266-
"""Process different job states and update status information."""
267-
if self.status_info["state"] == ModelStatus.PENDING:
268-
self.process_pending_state()
269-
elif self.status_info["state"] == "RUNNING":
270-
self.process_running_state()
271-
272256
return StatusResponse(
273-
slurm_job_id=self.slurm_job_id,
274-
model_name=cast(str, self.status_info["model_name"]),
275-
status=cast(ModelStatus, self.status_info["status"]),
257+
model_name=job_name,
258+
server_status=ModelStatus.UNAVAILABLE,
259+
job_state=job_state,
276260
raw_output=self.output,
277-
base_url=self.status_info["base_url"],
278-
pending_reason=self.status_info["pending_reason"],
279-
failed_reason=self.status_info["failed_reason"],
261+
base_url="UNAVAILABLE",
262+
pending_reason=None,
263+
failed_reason=None,
280264
)
281265

282-
def check_model_health(self) -> None:
266+
def _check_model_health(self) -> None:
283267
"""Check model health and update status accordingly."""
284268
status, status_code = utils.model_health_check(
285269
cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
@@ -290,40 +274,49 @@ def check_model_health(self) -> None:
290274
self.slurm_job_id,
291275
self.log_dir,
292276
)
293-
self.status_info["status"] = status
277+
self.status_info["server_status"] = status
294278
else:
295-
self.status_info["status"], self.status_info["failed_reason"] = (
279+
self.status_info["server_status"], self.status_info["failed_reason"] = (
296280
status,
297281
cast(str, status_code),
298282
)
299283

300-
def process_running_state(self) -> None:
284+
def _process_running_state(self) -> None:
301285
"""Process RUNNING job state and check server status."""
302286
server_status = utils.is_server_running(
303287
cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
304288
)
305289

306290
if isinstance(server_status, tuple):
307-
self.status_info["status"], self.status_info["failed_reason"] = (
291+
self.status_info["server_status"], self.status_info["failed_reason"] = (
308292
server_status
309293
)
310294
return
311295

312296
if server_status == "RUNNING":
313-
self.check_model_health()
297+
self._check_model_health()
314298
else:
315-
self.status_info["status"] = server_status
299+
self.status_info["server_status"] = server_status
316300

317-
def process_pending_state(self) -> None:
301+
def _process_pending_state(self) -> None:
318302
"""Process PENDING job state."""
319303
try:
320304
self.status_info["pending_reason"] = self.output.split(" ")[10].split("=")[
321305
1
322306
]
323-
self.status_info["status"] = ModelStatus.PENDING
307+
self.status_info["server_status"] = ModelStatus.PENDING
324308
except IndexError:
325309
self.status_info["pending_reason"] = "Unknown pending reason"
326310

311+
def process_model_status(self) -> StatusResponse:
312+
"""Process different job states and update status information."""
313+
if self.status_info["job_state"] == ModelStatus.PENDING:
314+
self._process_pending_state()
315+
elif self.status_info["job_state"] == "RUNNING":
316+
self._process_running_state()
317+
318+
return self.status_info
319+
327320

328321
class PerformanceMetricsCollector:
329322
"""Class for handling metrics collection and processing."""
@@ -340,18 +333,18 @@ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
340333
self._last_updated: Optional[float] = None
341334
self._last_throughputs = {"prompt": 0.0, "generation": 0.0}
342335

343-
def _get_status_info(self) -> dict[str, Union[str, None]]:
336+
def _get_status_info(self) -> StatusResponse:
344337
"""Retrieve status info using existing StatusHelper."""
345338
status_helper = ModelStatusMonitor(self.slurm_job_id, self.log_dir)
346-
return status_helper.status_info
339+
return status_helper.process_model_status()
347340

348341
def _build_metrics_url(self) -> str:
349342
"""Construct metrics endpoint URL from base URL with version stripping."""
350-
if self.status_info.get("state") == "PENDING":
343+
if self.status_info.job_state == ModelStatus.PENDING:
351344
return "Pending resources for server initialization"
352345

353346
base_url = utils.get_base_url(
354-
cast(str, self.status_info["model_name"]),
347+
self.status_info.model_name,
355348
self.slurm_job_id,
356349
self.log_dir,
357350
)

vec_inf/client/_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dataclasses import dataclass, field
99
from enum import Enum
10-
from typing import Any, Optional, TypedDict
10+
from typing import Any, Optional, TypedDict, Union
1111

1212
from typing_extensions import NotRequired
1313

@@ -46,9 +46,9 @@ class LaunchResponse:
4646
class StatusResponse:
4747
"""Response from checking a model's status."""
4848

49-
slurm_job_id: int
5049
model_name: str
51-
status: ModelStatus
50+
server_status: ModelStatus
51+
job_state: Union[str, ModelStatus]
5252
raw_output: str = field(repr=False)
5353
base_url: Optional[str] = None
5454
pending_reason: Optional[str] = None
@@ -59,9 +59,8 @@ class StatusResponse:
5959
class MetricsResponse:
6060
"""Response from retrieving model metrics."""
6161

62-
slurm_job_id: int
6362
model_name: str
64-
metrics: dict[str, float]
63+
metrics: Union[dict[str, float], str]
6564
timestamp: float
6665

6766

vec_inf/client/api.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
"""
66

77
import time
8-
from typing import Any, Optional, cast
9-
10-
import requests
8+
from typing import Any, Optional
119

1210
from vec_inf.client._config import ModelConfig
1311
from vec_inf.client._exceptions import (
@@ -138,7 +136,6 @@ def launch_model(
138136

139137
# Create and use the API Launch Helper
140138
model_launcher = ModelLauncher(model_name, options_dict)
141-
142139
return model_launcher.launch()
143140

144141
except ValueError as e:
@@ -211,20 +208,12 @@ def get_metrics(
211208
)
212209

213210
if not performance_metrics_collector.metrics_url.startswith("http"):
214-
raise ServerError(
215-
f"Metrics endpoint unavailable or server not ready - {performance_metrics_collector.metrics_url}"
216-
)
217-
218-
metrics = performance_metrics_collector.fetch_metrics()
219-
220-
if isinstance(metrics, str):
221-
raise requests.RequestException(metrics)
211+
metrics = performance_metrics_collector.metrics_url
212+
else:
213+
metrics = performance_metrics_collector.fetch_metrics()
222214

223215
return MetricsResponse(
224-
slurm_job_id=slurm_job_id,
225-
model_name=cast(
226-
str, performance_metrics_collector.status_info["model_name"]
227-
),
216+
model_name=performance_metrics_collector.status_info.model_name,
228217
metrics=metrics,
229218
timestamp=time.time(),
230219
)

0 commit comments

Comments
 (0)