File tree Expand file tree Collapse file tree 2 files changed +12
-3
lines changed
Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -554,11 +554,17 @@ def _get_launch_params(
554554 raise ValueError (
555555 f"Mismatch found for { arg } : { params [arg ]} != { params ['models' ][model_name ][arg ]} , check your configuration"
556556 )
557- # Check for required fields, will raise an error if missing any
558- utils .check_required_fields (
557+ # Check for required fields and return environment variable overrides
558+ env_overrides = utils .check_required_fields (
559559 {** params ["models" ][model_name ], ** common_params }
560560 )
561561
562+ for arg , value in env_overrides .items ():
563+ if arg in common_params :
564+ params [arg ] = value
565+ else :
566+ params ["models" ][model_name ][arg ] = value
567+
562568 return params
563569
564570 def _build_launch_command (self ) -> str :
Original file line number Diff line number Diff line change @@ -436,20 +436,23 @@ def find_matching_dirs(
436436 return matched
437437
438438
439- def check_required_fields (params : dict [str , Any ]) -> None :
439+ def check_required_fields (params : dict [str , Any ]) -> dict [ str , Any ] :
440440 """Check for required fields without default vals and their corresponding env vars.
441441
442442 Parameters
443443 ----------
444444 params : dict[str, Any]
445445 Dictionary of parameters to check.
446446 """
447+ env_overrides = {}
447448 for arg in REQUIRED_ARGS :
448449 if not params .get (arg ):
449450 default_value = os .getenv (REQUIRED_ARGS [arg ])
450451 if default_value :
451452 params [arg ] = default_value
453+ env_overrides [arg ] = default_value
452454 else :
453455 raise MissingRequiredFieldsError (
454456 f"{ arg } is required, please set it in the command arguments or environment variables"
455457 )
458+ return env_overrides
You can’t perform that action at this time.
0 commit comments