Skip to content

Commit 238d70e

Browse files
authored
Merge pull request #137 from VectorInstitute/add_config_flag
Add config flag
2 parents f170696 + 74305a8 commit 238d70e

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

vec_inf/cli/_cli.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# ruff: noqa: D301, D411
2+
# Using \f and \b for click --help formatting, which violates these rules.
13
"""Command line interface for Vector Inference.
24
35
This module provides the command-line interface for interacting with Vector
@@ -124,16 +126,23 @@ def cli() -> None:
124126
type=str,
125127
help="Environment variables to be set. Seperate variables with commas. Can also include path to a file containing environment variables seperated by newlines. e.g. --env 'TRITON_CACHE_DIR=/scratch/.cache/triton,my_custom_vars_file.env'",
126128
)
129+
@click.option(
130+
"--config",
131+
type=str,
132+
help="Path to a model config yaml file to use in place of the default",
133+
)
127134
def launch(
128135
model_name: str,
129136
**cli_kwargs: Optional[Union[str, int, float, bool]],
130137
) -> None:
131138
"""Launch a model on the cluster.
132139
140+
\b
133141
Parameters
134142
----------
135143
model_name : str
136144
Name of the model to launch
145+
\f
137146
**cli_kwargs : dict
138147
Additional launch options including:
139148
- model_family : str, optional
@@ -166,6 +175,10 @@ def launch(
166175
Path to model weights directory
167176
- vllm_args : str, optional
168177
vLLM engine arguments
178+
- env : str, optional
179+
Environment variables
180+
- config : str, optional
181+
Path to custom model config yaml file
169182
- json_mode : bool, optional
170183
Output in JSON format
171184
@@ -220,10 +233,12 @@ def batch_launch(
220233
) -> None:
221234
"""Launch multiple models in a batch.
222235
236+
\b
223237
Parameters
224238
----------
225239
model_names : tuple[str, ...]
226240
Names of the models to launch
241+
\f
227242
batch_config : str
228243
Model configuration for batch launch
229244
json_mode : bool, default=False
@@ -267,10 +282,12 @@ def batch_launch(
267282
def status(slurm_job_id: str, json_mode: bool = False) -> None:
268283
"""Get the status of a running model on the cluster.
269284
285+
\b
270286
Parameters
271287
----------
272288
slurm_job_id : str
273289
ID of the SLURM job to check
290+
\f
274291
json_mode : bool, default=False
275292
Whether to output in JSON format
276293
@@ -302,10 +319,12 @@ def status(slurm_job_id: str, json_mode: bool = False) -> None:
302319
def shutdown(slurm_job_id: str) -> None:
303320
"""Shutdown a running model on the cluster.
304321
322+
\b
305323
Parameters
306324
----------
307325
slurm_job_id : str
308326
ID of the SLURM job to shut down
327+
\f
309328
310329
Raises
311330
------
@@ -330,10 +349,12 @@ def shutdown(slurm_job_id: str) -> None:
330349
def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
331350
"""List all available models, or get default setup of a specific model.
332351
352+
\b
333353
Parameters
334354
----------
335355
model_name : str, optional
336356
Name of specific model to get information for
357+
\f
337358
json_mode : bool, default=False
338359
Whether to output in JSON format
339360
@@ -363,10 +384,12 @@ def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> No
363384
def metrics(slurm_job_id: str) -> None:
364385
"""Stream real-time performance metrics from the model endpoint.
365386
387+
\b
366388
Parameters
367389
----------
368390
slurm_job_id : str
369391
ID of the SLURM job to monitor
392+
\f
370393
371394
Raises
372395
------
@@ -433,6 +456,8 @@ def cleanup_logs_cli(
433456
) -> None:
434457
"""Clean up log files based on optional filters.
435458
459+
\f
460+
436461
Parameters
437462
----------
438463
log_dir : str or Path, optional
@@ -447,7 +472,7 @@ def cleanup_logs_cli(
447472
If provided, only delete logs with job ID less than this value.
448473
dry_run : bool
449474
If True, return matching files without deleting them.
450-
"""
475+
""" # NOQA: D301, the \f prevents click from printing options twice.
451476
try:
452477
client = VecInfClient()
453478
matched = client.cleanup_logs(

vec_inf/client/_helper.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
6161
self.kwargs = kwargs or {}
6262
self.slurm_job_id = ""
6363
self.slurm_script_path = Path("")
64-
self.model_config = self._get_model_configuration()
64+
self.model_config = self._get_model_configuration(self.kwargs.get("config"))
6565
self.params = self._get_launch_params()
6666

6767
def _warn(self, message: str) -> None:
@@ -74,9 +74,14 @@ def _warn(self, message: str) -> None:
7474
"""
7575
warnings.warn(message, UserWarning, stacklevel=2)
7676

77-
def _get_model_configuration(self) -> ModelConfig:
77+
def _get_model_configuration(self, config_path: str | None = None) -> ModelConfig:
7878
"""Load and validate model configuration.
7979
80+
Parameters
81+
----------
82+
config_path : str | None, optional
83+
Path to a yaml file with custom model config to use in place of the default
84+
8085
Returns
8186
-------
8287
ModelConfig
@@ -89,7 +94,7 @@ def _get_model_configuration(self) -> ModelConfig:
8994
ModelConfigurationError
9095
If model configuration is not found and weights don't exist
9196
"""
92-
model_configs = utils.load_config()
97+
model_configs = utils.load_config(config_path=config_path)
9398
config = next(
9499
(m for m in model_configs if m.model_name == self.model_name), None
95100
)

vec_inf/client/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ class LaunchOptions:
218218
Additional arguments for vLLM
219219
env : str, optional
220220
Environment variables to be set
221+
config : str, optional
222+
Path to custom model config yaml
221223
"""
222224

223225
model_family: Optional[str] = None
@@ -238,6 +240,7 @@ class LaunchOptions:
238240
model_weights_parent_dir: Optional[str] = None
239241
vllm_args: Optional[str] = None
240242
env: Optional[str] = None
243+
config: Optional[str] = None
241244

242245

243246
@dataclass

0 commit comments

Comments
 (0)