Skip to content

Commit f69c810

Browse files
committed
Fix required fields checking for batch launch
1 parent e9bebab commit f69c810

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

vec_inf/client/_helper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,15 @@ def _get_launch_params(
469469
If required fields are missing or tensor parallel size is not specified
470470
when using multiple GPUs
471471
"""
472-
params: dict[str, Any] = {
473-
"models": {},
472+
common_params: dict[str, Any] = {
474473
"slurm_job_name": self.slurm_job_name,
475474
"src_dir": str(SRC_DIR),
476475
"account": account,
477476
"work_dir": work_dir,
478477
}
479478

480-
# Check for required fields without default vals, will raise an error if missing
481-
utils.check_required_fields(params)
479+
params: dict[str, Any] = common_params.copy()
480+
params["models"] = {}
482481

483482
for i, (model_name, config) in enumerate(self.model_configs.items()):
484483
params["models"][model_name] = config.model_dump(exclude_none=True)
@@ -555,6 +554,8 @@ def _get_launch_params(
555554
raise ValueError(
556555
f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
557556
)
557+
# Check for required fields without default vals, will raise an error if missing
558+
utils.check_required_fields({**params["models"][model_name], **common_params})
558559

559560
return params
560561

vec_inf/client/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def check_required_fields(params: dict[str, Any]) -> None:
444444
params : dict[str, Any]
445445
Dictionary of parameters to check.
446446
"""
447+
447448
for arg in REQUIRED_ARGS:
448449
if not params.get(arg):
449450
default_value = os.getenv(REQUIRED_ARGS[arg])

0 commit comments

Comments
 (0)