Skip to content

Commit 5ef4e1f

Browse files
committed
mypy fixes
1 parent 35b96dc commit 5ef4e1f

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

examples/api/basic_usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# Get metrics
3434
print("\nRetrieving metrics...")
3535
metrics = client.get_metrics(job_id)
36-
if metrics.metrics:
36+
if isinstance(metrics.metrics, dict):
3737
for key, value in metrics.metrics.items():
3838
print(f"- {key}: {value}")
3939

vec_inf/cli/_cli.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from rich.console import Console
88
from rich.live import Live
99

10-
import vec_inf.client._utils as utils
1110
from vec_inf.cli._helper import (
1211
LaunchResponseFormatter,
1312
ListCmdDisplay,
1413
MetricsResponseFormatter,
1514
StatusResponseFormatter,
1615
)
17-
from vec_inf.client._models import LaunchOptions
16+
from vec_inf.client._models import LaunchOptions, LaunchOptionsDict
1817
from vec_inf.client.api import VecInfClient
1918

2019

@@ -129,14 +128,15 @@ def cli() -> None:
129128
)
130129
def launch(
131130
model_name: str,
132-
**cli_kwargs: Optional[Union[str, int, bool]],
131+
**cli_kwargs: Optional[Union[str, int, float, bool]],
133132
) -> None:
134133
"""Launch a model on the cluster."""
135134
try:
136135
# Convert cli_kwargs to LaunchOptions
137-
launch_options = LaunchOptions(
138-
**{k: v for k, v in cli_kwargs.items() if k != "json_mode"}
139-
)
136+
kwargs = {k: v for k, v in cli_kwargs.items() if k != "json_mode"}
137+
# Cast the dictionary to LaunchOptionsDict
138+
options_dict: LaunchOptionsDict = kwargs # type: ignore
139+
launch_options = LaunchOptions(**options_dict)
140140

141141
# Start the client and launch model inference server
142142
client = VecInfClient()
@@ -194,8 +194,12 @@ def status(
194194
@click.argument("slurm_job_id", type=int, nargs=1)
195195
def shutdown(slurm_job_id: int) -> None:
196196
"""Shutdown a running model on the cluster."""
197-
utils.shutdown_model(slurm_job_id)
198-
click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}")
197+
try:
198+
client = VecInfClient()
199+
client.shutdown_model(slurm_job_id)
200+
click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}")
201+
except Exception as e:
202+
raise click.ClickException(f"Shutdown failed: {str(e)}") from e
199203

200204

201205
@cli.command("list")

vec_inf/cli/_helper.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,18 @@ def output_table(self) -> Table:
105105
class MetricsResponseFormatter:
106106
"""CLI Helper class for formatting MetricsResponse."""
107107

108-
def __init__(self, metrics: dict[str, float]):
109-
self.metrics = metrics
108+
def __init__(self, metrics: Union[dict[str, float], str]):
109+
self.metrics = self._set_metrics(metrics)
110110
self.table = utils.create_table("Metric", "Value")
111111
self.enabled_prefix_caching = self._check_prefix_caching()
112112

113+
def _set_metrics(self, metrics: Union[dict[str, float], str]) -> dict[str, float]:
114+
"""Set the metrics attribute."""
115+
return metrics if isinstance(metrics, dict) else {}
116+
113117
def _check_prefix_caching(self) -> bool:
114118
"""Check if prefix caching is enabled by looking for prefix cache metrics."""
115-
if isinstance(self.metrics, str):
116-
# If metrics is a string, it's an error message
117-
return False
118-
119-
cache_rate = self.metrics.get("gpu_prefix_cache_hit_rate")
120-
return cache_rate is not None
119+
return self.metrics.get("gpu_prefix_cache_hit_rate") is not None
121120

122121
def format_failed_metrics(self, message: str) -> None:
123122
self.table.add_row("ERROR", message)

vec_inf/client/_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _process_running_state(self) -> None:
294294
if server_status == "RUNNING":
295295
self._check_model_health()
296296
else:
297-
self.status_info.server_status = server_status
297+
self.status_info.server_status = cast(ModelStatus, server_status)
298298

299299
def _process_pending_state(self) -> None:
300300
"""Process PENDING job state."""

vec_inf/client/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
import time
8-
from typing import Any, Optional
8+
from typing import Any, Optional, Union
99

1010
from vec_inf.client._config import ModelConfig
1111
from vec_inf.client._exceptions import (
@@ -147,6 +147,7 @@ def get_metrics(
147147
slurm_job_id, log_dir
148148
)
149149

150+
metrics: Union[dict[str, float], str]
150151
if not performance_metrics_collector.metrics_url.startswith("http"):
151152
metrics = performance_metrics_collector.metrics_url
152153
else:

0 commit comments

Comments
 (0)