File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -469,16 +469,15 @@ def _get_launch_params(
469469 If required fields are missing or tensor parallel size is not specified
470470 when using multiple GPUs
471471 """
472- params : dict [str , Any ] = {
473- "models" : {},
472+ common_params : dict [str , Any ] = {
474473 "slurm_job_name" : self .slurm_job_name ,
475474 "src_dir" : str (SRC_DIR ),
476475 "account" : account ,
477476 "work_dir" : work_dir ,
478477 }
479478
480- # Check for required fields without default vals, will raise an error if missing
481- utils . check_required_fields ( params )
479+ params : dict [ str , Any ] = common_params . copy ()
480+ params [ "models" ] = {}
482481
483482 for i , (model_name , config ) in enumerate (self .model_configs .items ()):
484483 params ["models" ][model_name ] = config .model_dump (exclude_none = True )
@@ -555,6 +554,8 @@ def _get_launch_params(
555554 raise ValueError (
556555 f"Mismatch found for { arg } : { params [arg ]} != { params ['models' ][model_name ][arg ]} , check your configuration"
557556 )
557+ # Check for required fields without default vals, will raise an error if missing
558+ utils .check_required_fields ({** params ["models" ][model_name ], ** common_params })
558559
559560 return params
560561
Original file line number Diff line number Diff line change @@ -444,6 +444,7 @@ def check_required_fields(params: dict[str, Any]) -> None:
444444 params : dict[str, Any]
445445 Dictionary of parameters to check.
446446 """
447+
447448 for arg in REQUIRED_ARGS :
448449 if not params .get (arg ):
449450 default_value = os .getenv (REQUIRED_ARGS [arg ])
You can’t perform that action at this time.
0 commit comments