Skip to content

Commit ed0a5dd

Browse files
committed
Refactor CLI logic to use client instead of inheriting client helper classes
1 parent 3e5e5ad commit ed0a5dd

File tree

3 files changed

+188
-162
lines changed

3 files changed

+188
-162
lines changed

tests/vec_inf/cli/test_cli.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,25 @@ def test_launch_command_with_json_output(
322322
assert str(test_log_dir) in output.get("log_dir", "")
323323

324324

325+
def test_launch_command_no_model_weights_parent_dir(runner, debug_helper, base_patches):
326+
"""Test handling when model weights parent dir is not set."""
327+
with ExitStack() as stack:
328+
# Apply all base patches
329+
for patch_obj in base_patches:
330+
stack.enter_context(patch_obj)
331+
332+
# Mock load_config to return empty list
333+
stack.enter_context(
334+
patch("vec_inf.client._utils.load_config", return_value=[])
335+
)
336+
337+
result = runner.invoke(cli, ["launch", "test-model"])
338+
debug_helper.print_debug_info(result)
339+
340+
assert result.exit_code == 1
341+
assert "Could not determine model weights parent directory" in result.output
342+
343+
325344
def test_launch_command_model_not_in_config_with_weights(
326345
runner, mock_launch_output, path_exists, debug_helper, test_paths, base_patches
327346
):
@@ -346,18 +365,19 @@ def test_launch_command_model_not_in_config_with_weights(
346365
expected_job_id = "14933051"
347366
mock_run.return_value = mock_launch_output(expected_job_id)
348367

349-
result = runner.invoke(cli, ["launch", "unknown-model"])
350-
debug_helper.print_debug_info(result)
368+
with pytest.warns(UserWarning) as record:
369+
result = runner.invoke(cli, ["launch", "unknown-model"])
370+
debug_helper.print_debug_info(result)
351371

352-
assert result.exit_code == 1
353-
assert (
354-
"Could not determine model_weights_parent_dir and 'unknown-model' not found in configuration"
355-
in result.output
372+
assert result.exit_code == 0
373+
assert len(record) == 1
374+
assert str(record[0].message) == (
375+
"Warning: 'unknown-model' configuration not found in config, please ensure model configuration are properly set in command arguments"
356376
)
357377

358378

359379
def test_launch_command_model_not_found(
360-
runner, path_exists, debug_helper, test_paths, base_patches
380+
runner, debug_helper, test_paths, base_patches
361381
):
362382
"""Test handling of a model that's neither in config nor has weights."""
363383

@@ -389,7 +409,8 @@ def custom_path_exists(p):
389409

390410
assert result.exit_code == 1
391411
assert (
392-
"Could not determine model_weights_parent_dir and 'unknown-model' not found in configuration"
412+
"'unknown-model' not found in configuration and model weights "
413+
"not found at expected path '/model-weights/unknown-model'"
393414
in result.output
394415
)
395416

@@ -428,10 +449,9 @@ def test_metrics_command_pending_server(
428449
debug_helper.print_debug_info(result)
429450

430451
assert result.exit_code == 0
431-
assert "Server State" in result.output
432-
assert "PENDING" in result.output
452+
assert "ERROR" in result.output
433453
assert (
434-
"Metrics endpoint unavailable or server not ready - Pending"
454+
"Pending resources for server initialization"
435455
in result.output
436456
)
437457

@@ -452,10 +472,9 @@ def test_metrics_command_server_not_ready(
452472
debug_helper.print_debug_info(result)
453473

454474
assert result.exit_code == 0
455-
assert "Server State" in result.output
456-
assert "RUNNING" in result.output
475+
assert "ERROR" in result.output
457476
assert (
458-
"Metrics endpoint unavailable or server not ready - Server not"
477+
"Server not ready"
459478
in result.output
460479
)
461480

@@ -519,8 +538,7 @@ def test_metrics_command_request_failed(
519538
debug_helper.print_debug_info(result)
520539

521540
# KeyboardInterrupt is expected and ok
522-
assert "Server State" in result.output
523-
assert "RUNNING" in result.output
541+
assert "ERROR" in result.output
524542
assert (
525543
"Metrics request failed, `metrics` endpoint might not be ready"
526544
in result.output

vec_inf/cli/_cli.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
import vec_inf.client._utils as utils
1111
from vec_inf.cli._helper import (
12-
CLIMetricsCollector,
13-
CLIModelLauncher,
14-
CLIModelRegistry,
15-
CLIModelStatusMonitor,
12+
LaunchResponseFormatter,
13+
ListCmdDisplay,
14+
MetricsResponseFormatter,
15+
StatusResponseFormatter,
1616
)
17+
from vec_inf.client._models import LaunchOptions
18+
from vec_inf.client.api import VecInfClient
1719

1820

1921
CONSOLE = Console()
@@ -131,14 +133,19 @@ def launch(
131133
) -> None:
132134
"""Launch a model on the cluster."""
133135
try:
134-
model_launcher = CLIModelLauncher(model_name, cli_kwargs)
135-
# Launch model inference server
136-
model_launcher.launch()
136+
# Convert cli_kwargs to LaunchOptions
137+
launch_options = LaunchOptions(**{k: v for k, v in cli_kwargs.items() if k != "json_mode"})
138+
139+
# Start the client and launch model inference server
140+
client = VecInfClient()
141+
launch_response = client.launch_model(model_name, launch_options)
142+
137143
# Display launch information
144+
launch_formatter = LaunchResponseFormatter(model_name, launch_response.config)
138145
if cli_kwargs.get("json_mode"):
139-
click.echo(model_launcher.params)
146+
click.echo(launch_response.config)
140147
else:
141-
launch_info_table = model_launcher.format_table_output()
148+
launch_info_table = launch_formatter.format_table_output()
142149
CONSOLE.print(launch_info_table)
143150

144151
except click.ClickException as e:
@@ -164,14 +171,16 @@ def status(
164171
) -> None:
165172
"""Get the status of a running model on the cluster."""
166173
try:
167-
# Get model inference server status
168-
model_status_monitor = CLIModelStatusMonitor(slurm_job_id, log_dir)
169-
model_status_monitor.process_model_status()
174+
# Start the client and get model inference server status
175+
client = VecInfClient()
176+
status_response = client.get_status(slurm_job_id, log_dir)
170177
# Display status information
178+
status_formatter = StatusResponseFormatter(status_response)
171179
if json_mode:
172-
model_status_monitor.output_json()
180+
status_formatter.output_json()
173181
else:
174-
model_status_monitor.output_table(CONSOLE)
182+
status_info_table = status_formatter.output_table()
183+
CONSOLE.print(status_info_table)
175184

176185
except click.ClickException as e:
177186
raise e
@@ -197,8 +206,15 @@ def shutdown(slurm_job_id: int) -> None:
197206
def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
198207
"""List all available models, or get default setup of a specific model."""
199208
try:
200-
model_registry = CLIModelRegistry(json_mode)
201-
model_registry.process_list_command(CONSOLE, model_name)
209+
# Start the client
210+
client = VecInfClient()
211+
list_display = ListCmdDisplay(CONSOLE, json_mode)
212+
if model_name:
213+
model_config = client.get_model_config(model_name)
214+
list_display.display_single_model_output(model_config)
215+
else:
216+
model_infos = client.list_models()
217+
list_display.display_all_models_output(model_infos)
202218
except click.ClickException as e:
203219
raise e
204220
except Exception as e:
@@ -213,30 +229,29 @@ def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> No
213229
def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
214230
"""Stream real-time performance metrics from the model endpoint."""
215231
try:
216-
metrics_collector = CLIMetricsCollector(slurm_job_id, log_dir)
217-
218-
# Check if metrics URL is ready
219-
if not metrics_collector.metrics_url.startswith("http"):
220-
table = utils.create_table("Metric", "Value")
221-
metrics_collector.display_failed_metrics(
222-
table,
223-
f"Metrics endpoint unavailable or server not ready - {metrics_collector.metrics_url}",
224-
)
225-
CONSOLE.print(table)
232+
# Start the client and get inference server metrics
233+
client = VecInfClient()
234+
metrics_response = client.get_metrics(slurm_job_id, log_dir)
235+
metrics_formatter = MetricsResponseFormatter(metrics_response.metrics)
236+
237+
# Check if metrics response is ready
238+
if isinstance(metrics_response.metrics, str):
239+
metrics_formatter.format_failed_metrics(metrics_response.metrics)
240+
CONSOLE.print(metrics_formatter.table)
226241
return
227242

228243
with Live(refresh_per_second=1, console=CONSOLE) as live:
229244
while True:
230-
metrics = metrics_collector.fetch_metrics()
231-
table = utils.create_table("Metric", "Value")
245+
metrics_response = client.get_metrics(slurm_job_id, log_dir)
246+
metrics_formatter = MetricsResponseFormatter(metrics_response.metrics)
232247

233-
if isinstance(metrics, str):
248+
if isinstance(metrics_response.metrics, str):
234249
# Show status information if metrics aren't available
235-
metrics_collector.display_failed_metrics(table, metrics)
250+
metrics_formatter.format_failed_metrics(metrics_response.metrics)
236251
else:
237-
metrics_collector.display_metrics(table, metrics)
252+
metrics_formatter.format_metrics()
238253

239-
live.update(table)
254+
live.update(metrics_formatter.table)
240255
time.sleep(2)
241256
except click.ClickException as e:
242257
raise e

0 commit comments

Comments
 (0)