Skip to content

Commit 7722344

Browse files
committed
Refactor CLI to use client
1 parent 988040c commit 7722344

File tree

3 files changed

+66
-82
lines changed

3 files changed

+66
-82
lines changed

vec_inf/cli/_cli.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from rich.console import Console
88
from rich.live import Live
99

10-
import vec_inf.shared._utils as utils
10+
import vec_inf.client._utils as utils
1111
from vec_inf.cli._helper import (
12-
CLILaunchHelper,
13-
CLIListHelper,
14-
CLIMetricsHelper,
15-
CLIStatusHelper,
12+
CLIMetricsCollector,
13+
CLIModelLauncher,
14+
CLIModelRegistry,
15+
CLIModelStatusMonitor,
1616
)
1717

1818

@@ -131,14 +131,15 @@ def launch(
131131
) -> None:
132132
"""Launch a model on the cluster."""
133133
try:
134-
launch_helper = CLILaunchHelper(model_name, cli_kwargs)
135-
136-
launch_helper.set_env_vars()
137-
launch_command = launch_helper.build_launch_command()
138-
command_output, stderr = utils.run_bash_command(launch_command)
139-
if stderr:
140-
raise click.ClickException(f"Error: {stderr}")
141-
launch_helper.post_launch_processing(command_output, CONSOLE)
134+
model_launcher = CLIModelLauncher(model_name, cli_kwargs)
135+
# Launch model inference server
136+
model_launcher.launch()
137+
# Display launch information
138+
if cli_kwargs.get("json_mode"):
139+
click.echo(model_launcher.params)
140+
else:
141+
launch_info_table = model_launcher.format_table_output()
142+
CONSOLE.print(launch_info_table)
142143

143144
except click.ClickException as e:
144145
raise e
@@ -163,18 +164,14 @@ def status(
163164
) -> None:
164165
"""Get the status of a running model on the cluster."""
165166
try:
166-
status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
167-
output, stderr = utils.run_bash_command(status_cmd)
168-
if stderr:
169-
raise click.ClickException(f"Error: {stderr}")
170-
171-
status_helper = CLIStatusHelper(slurm_job_id, output, log_dir)
172-
173-
status_helper.process_job_state()
167+
# Get model inference server status
168+
model_status_monitor = CLIModelStatusMonitor(slurm_job_id, log_dir)
169+
model_status_monitor.process_model_status()
170+
# Display status information
174171
if json_mode:
175-
status_helper.output_json()
172+
model_status_monitor.output_json()
176173
else:
177-
status_helper.output_table(CONSOLE)
174+
model_status_monitor.output_table(CONSOLE)
178175

179176
except click.ClickException as e:
180177
raise e
@@ -200,8 +197,8 @@ def shutdown(slurm_job_id: int) -> None:
200197
def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
201198
"""List all available models, or get default setup of a specific model."""
202199
try:
203-
list_helper = CLIListHelper(json_mode)
204-
list_helper.process_list_command(CONSOLE, model_name)
200+
model_registry = CLIModelRegistry(json_mode)
201+
model_registry.process_list_command(CONSOLE, model_name)
205202
except click.ClickException as e:
206203
raise e
207204
except Exception as e:
@@ -216,28 +213,28 @@ def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> No
216213
def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
217214
"""Stream real-time performance metrics from the model endpoint."""
218215
try:
219-
metrics_helper = CLIMetricsHelper(slurm_job_id, log_dir)
216+
metrics_collector = CLIMetricsCollector(slurm_job_id, log_dir)
220217

221218
# Check if metrics URL is ready
222-
if not metrics_helper.metrics_url.startswith("http"):
219+
if not metrics_collector.metrics_url.startswith("http"):
223220
table = utils.create_table("Metric", "Value")
224-
metrics_helper.display_failed_metrics(
221+
metrics_collector.display_failed_metrics(
225222
table,
226-
f"Metrics endpoint unavailable or server not ready - {metrics_helper.metrics_url}",
223+
f"Metrics endpoint unavailable or server not ready - {metrics_collector.metrics_url}",
227224
)
228225
CONSOLE.print(table)
229226
return
230227

231228
with Live(refresh_per_second=1, console=CONSOLE) as live:
232229
while True:
233-
metrics = metrics_helper.fetch_metrics()
230+
metrics = metrics_collector.fetch_metrics()
234231
table = utils.create_table("Metric", "Value")
235232

236233
if isinstance(metrics, str):
237234
# Show status information if metrics aren't available
238-
metrics_helper.display_failed_metrics(table, metrics)
235+
metrics_collector.display_failed_metrics(table, metrics)
239236
else:
240-
metrics_helper.display_metrics(table, metrics)
237+
metrics_collector.display_metrics(table, metrics)
241238

242239
live.update(table)
243240
time.sleep(2)

vec_inf/cli/_helper.py

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
from rich.panel import Panel
1212
from rich.table import Table
1313

14-
import vec_inf.shared._utils as utils
15-
from vec_inf.shared._config import ModelConfig
16-
from vec_inf.shared._helper import LaunchHelper, ListHelper, MetricsHelper, StatusHelper
14+
import vec_inf.client._utils as utils
15+
from vec_inf.cli._models import MODEL_TYPE_COLORS, MODEL_TYPE_PRIORITY
16+
from vec_inf.client._config import ModelConfig
17+
from vec_inf.client._helper import (
18+
ModelLauncher,
19+
ModelRegistry,
20+
ModelStatusMonitor,
21+
PerformanceMetricsCollector,
22+
)
1723

1824

19-
class CLILaunchHelper(LaunchHelper):
20-
"""CLI Helper class for handling launch information."""
25+
class CLIModelLauncher(ModelLauncher):
26+
"""CLI Helper class for handling inference server launch."""
2127

2228
def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
2329
super().__init__(model_name, kwargs)
@@ -26,12 +32,12 @@ def _warn(self, message: str) -> None:
2632
"""Warn the user about a potential issue."""
2733
click.echo(click.style(f"Warning: {message}", fg="yellow"), err=True)
2834

29-
def _format_table_output(self, job_id: str) -> Table:
35+
def format_table_output(self) -> Table:
3036
"""Format output as rich Table."""
3137
table = utils.create_table(key_title="Job Config", value_title="Value")
3238

3339
# Add key information with consistent styling
34-
table.add_row("Slurm Job ID", job_id, style="blue")
40+
table.add_row("Slurm Job ID", self.slurm_job_id, style="blue")
3541
table.add_row("Job Name", self.model_name)
3642

3743
# Add model details
@@ -71,33 +77,12 @@ def _format_table_output(self, job_id: str) -> Table:
7177

7278
return table
7379

74-
def post_launch_processing(self, output: str, console: Console) -> None:
75-
"""Process and display launch output."""
76-
json_mode = bool(self.kwargs.get("json_mode", False))
77-
slurm_job_id = output.split(" ")[-1].strip().strip("\n")
78-
self.params["slurm_job_id"] = slurm_job_id
79-
job_json = Path(
80-
self.params["log_dir"],
81-
f"{self.model_name}.{slurm_job_id}",
82-
f"{self.model_name}.{slurm_job_id}.json",
83-
)
84-
job_json.parent.mkdir(parents=True, exist_ok=True)
85-
job_json.touch(exist_ok=True)
86-
87-
with job_json.open("w") as file:
88-
json.dump(self.params, file, indent=4)
89-
if json_mode:
90-
click.echo(self.params)
91-
else:
92-
table = self._format_table_output(slurm_job_id)
93-
console.print(table)
9480

81+
class CLIModelStatusMonitor(ModelStatusMonitor):
82+
"""CLI Helper class for handling server status information and monitoring."""
9583

96-
class CLIStatusHelper(StatusHelper):
97-
"""CLI Helper class for handling status information."""
98-
99-
def __init__(self, slurm_job_id: int, output: str, log_dir: Optional[str] = None):
100-
super().__init__(slurm_job_id, output, log_dir)
84+
def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
85+
super().__init__(slurm_job_id, log_dir)
10186

10287
def output_json(self) -> None:
10388
"""Format and output JSON data."""
@@ -127,7 +112,7 @@ def output_table(self, console: Console) -> None:
127112
console.print(table)
128113

129114

130-
class CLIMetricsHelper(MetricsHelper):
115+
class CLIMetricsCollector(PerformanceMetricsCollector):
131116
"""CLI Helper class for streaming metrics information."""
132117

133118
def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
@@ -204,8 +189,8 @@ def display_metrics(self, table: Table, metrics: dict[str, float]) -> None:
204189
)
205190

206191

207-
class CLIListHelper(ListHelper):
208-
"""Helper class for handling model listing functionality."""
192+
class CLIModelRegistry(ModelRegistry):
193+
"""CLI Helper class for handling model listing functionality."""
209194

210195
def __init__(self, json_mode: bool = False):
211196
super().__init__()
@@ -237,28 +222,15 @@ def format_all_models_output(self) -> Union[list[str], list[Panel]]:
237222
return [config.model_name for config in self.model_configs]
238223

239224
# Sort by model type priority
240-
type_priority = {
241-
"LLM": 0,
242-
"VLM": 1,
243-
"Text_Embedding": 2,
244-
"Reward_Modeling": 3,
245-
}
246225
sorted_configs = sorted(
247226
self.model_configs,
248-
key=lambda x: type_priority.get(x.model_type, 4),
227+
key=lambda x: MODEL_TYPE_PRIORITY.get(x.model_type, 4),
249228
)
250229

251230
# Create panels with color coding
252-
model_type_colors = {
253-
"LLM": "cyan",
254-
"VLM": "bright_blue",
255-
"Text_Embedding": "purple",
256-
"Reward_Modeling": "bright_magenta",
257-
}
258-
259231
panels = []
260232
for config in sorted_configs:
261-
color = model_type_colors.get(config.model_type, "white")
233+
color = MODEL_TYPE_COLORS.get(config.model_type, "white")
262234
variant = config.model_variant or ""
263235
display_text = f"[magenta]{config.model_family}[/magenta]"
264236
if variant:

vec_inf/cli/_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Data models for CLI rendering."""
2+
3+
MODEL_TYPE_PRIORITY = {
4+
"LLM": 0,
5+
"VLM": 1,
6+
"Text_Embedding": 2,
7+
"Reward_Modeling": 3,
8+
}
9+
10+
MODEL_TYPE_COLORS = {
11+
"LLM": "cyan",
12+
"VLM": "bright_blue",
13+
"Text_Embedding": "purple",
14+
"Reward_Modeling": "bright_magenta",
15+
}

0 commit comments

Comments
 (0)