Skip to content

Commit 1ab7e1a

Browse files
committed
Fix check required field for batch launch
1 parent 92e0da9 commit 1ab7e1a

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

vec_inf/client/_helper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,17 @@ def _get_launch_params(
554554
raise ValueError(
555555
f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
556556
)
557-
# Check for required fields, will raise an error if missing any
558-
utils.check_required_fields(
557+
# Check for required fields and return environment variable overrides
558+
env_overrides = utils.check_required_fields(
559559
{**params["models"][model_name], **common_params}
560560
)
561561

562+
for arg, value in env_overrides.items():
563+
if arg in common_params:
564+
params[arg] = value
565+
else:
566+
params["models"][model_name][arg] = value
567+
562568
return params
563569

564570
def _build_launch_command(self) -> str:

vec_inf/client/_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,20 +436,23 @@ def find_matching_dirs(
436436
return matched
437437

438438

439-
def check_required_fields(params: dict[str, Any]) -> None:
439+
def check_required_fields(params: dict[str, Any]) -> dict[str, Any]:
440440
"""Check for required fields without default vals and their corresponding env vars.
441441
442442
Parameters
443443
----------
444444
params : dict[str, Any]
445445
Dictionary of parameters to check.
446446
"""
447+
env_overrides = {}
447448
for arg in REQUIRED_ARGS:
448449
if not params.get(arg):
449450
default_value = os.getenv(REQUIRED_ARGS[arg])
450451
if default_value:
451452
params[arg] = default_value
453+
env_overrides[arg] = default_value
452454
else:
453455
raise MissingRequiredFieldsError(
454456
f"{arg} is required, please set it in the command arguments or environment variables"
455457
)
458+
return env_overrides

0 commit comments

Comments
 (0)