@@ -79,7 +79,10 @@ def test_init_with_missing_config_and_existing_weights(
7979 mock_path_exists .return_value = True
8080
8181 with pytest .warns (UserWarning ):
82- launcher = ModelLauncher ("unknown-model" , {"account" : "test-account" , "work_dir" : "/tmp/test-work" })
82+ launcher = ModelLauncher (
83+ "unknown-model" ,
84+ {"account" : "test-account" , "work_dir" : "/tmp/test-work" },
85+ )
8386
8487 assert launcher .model_name == "unknown-model"
8588 assert launcher .model_config .model_name == "unknown-model"
@@ -276,7 +279,11 @@ def batch_model_configs(self) -> list[ModelConfig]:
276279 def test_init_with_valid_configs (self , mock_load_config , batch_model_configs ):
277280 """Test launcher initializes correctly with valid model configurations."""
278281 mock_load_config .return_value = batch_model_configs
279- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
282+ launcher = BatchModelLauncher (
283+ ["family1-variant1" , "family2-variant1" ],
284+ account = "test-account" ,
285+ work_dir = "/tmp/test-work" ,
286+ )
280287
281288 assert launcher .model_names == ["family1-variant1" , "family2-variant1" ]
282289 assert launcher .slurm_job_name == "BATCH-family1-variant1-family2-variant1"
@@ -302,9 +309,9 @@ def test_get_slurm_job_name(self, mock_load_config, batch_model_configs):
302309 """Test SLURM job name is constructed correctly from model names."""
303310 mock_load_config .return_value = batch_model_configs
304311 launcher = BatchModelLauncher (
305- ["family1-variant1" , "family2-variant1" , "family1-variant2" ],
306- account = "test-account" ,
307- work_dir = "/tmp/test-work"
312+ ["family1-variant1" , "family2-variant1" , "family1-variant2" ],
313+ account = "test-account" ,
314+ work_dir = "/tmp/test-work" ,
308315 )
309316
310317 assert (
@@ -321,9 +328,9 @@ def test_get_launch_params_creates_log_dirs(
321328 mock_load_config .return_value = batch_model_configs
322329
323330 launcher = BatchModelLauncher (
324- ["family1-variant1" , "family2-variant1" , "family1-variant2" ],
325- account = "test-account" ,
326- work_dir = "/tmp/test-work"
331+ ["family1-variant1" , "family2-variant1" , "family1-variant2" ],
332+ account = "test-account" ,
333+ work_dir = "/tmp/test-work" ,
327334 )
328335 params = launcher .params
329336
@@ -352,7 +359,11 @@ def test_get_launch_params_with_multi_gpu_no_tp(
352359 mock_load_config .return_value = updated_configs
353360
354361 with pytest .raises (MissingRequiredFieldsError ) as excinfo :
355- BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
362+ BatchModelLauncher (
363+ ["family1-variant1" , "family2-variant1" ],
364+ account = "test-account" ,
365+ work_dir = "/tmp/test-work" ,
366+ )
356367
357368 assert "--tensor-parallel-size" in str (excinfo .value )
358369 assert "family1-variant1" in str (excinfo .value )
@@ -375,7 +386,11 @@ def test_get_launch_params_with_non_power_of_two_gpus(
375386 mock_load_config .return_value = updated_configs
376387
377388 with pytest .raises (ValueError ) as excinfo :
378- BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
389+ BatchModelLauncher (
390+ ["family1-variant1" , "family2-variant1" ],
391+ account = "test-account" ,
392+ work_dir = "/tmp/test-work" ,
393+ )
379394
380395 assert "power of two" in str (excinfo .value )
381396 assert "family1-variant1" in str (excinfo .value )
@@ -401,7 +416,11 @@ def test_get_launch_params_with_mismatched_batch_args(
401416 mock_load_config .return_value = updated_configs
402417
403418 with pytest .raises (ValueError ) as excinfo :
404- BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
419+ BatchModelLauncher (
420+ ["family1-variant1" , "family2-variant1" ],
421+ account = "test-account" ,
422+ work_dir = "/tmp/test-work" ,
423+ )
405424
406425 assert "Mismatch between total number of GPUs requested" in str (excinfo .value )
407426
@@ -451,7 +470,11 @@ def test_launch_success(
451470 # Mock copy2 to do nothing (avoid file operations)
452471 mock_copy2 .return_value = None
453472
454- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
473+ launcher = BatchModelLauncher (
474+ ["family1-variant1" , "family2-variant1" ],
475+ account = "test-account" ,
476+ work_dir = "/tmp/test-work" ,
477+ )
455478 response = launcher .launch ()
456479
457480 assert response .slurm_job_id == "12345"
@@ -475,7 +498,11 @@ def test_launch_with_slurm_error(
475498 mock_load_config .return_value = batch_model_configs
476499 mock_run_bash .return_value = ("" , "sbatch: error: Invalid partition specified" )
477500
478- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
501+ launcher = BatchModelLauncher (
502+ ["family1-variant1" , "family2-variant1" ],
503+ account = "test-account" ,
504+ work_dir = "/tmp/test-work" ,
505+ )
479506 with pytest .raises (SlurmJobError ):
480507 launcher .launch ()
481508
@@ -484,7 +511,11 @@ def test_launch_params_het_group_ids(self, mock_load_config, batch_model_configs
484511 """Test that heterogeneous group IDs are assigned correctly."""
485512 mock_load_config .return_value = batch_model_configs
486513
487- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
514+ launcher = BatchModelLauncher (
515+ ["family1-variant1" , "family2-variant1" ],
516+ account = "test-account" ,
517+ work_dir = "/tmp/test-work" ,
518+ )
488519 params = launcher .params
489520
490521 assert params ["models" ]["family1-variant1" ]["het_group_id" ] == 0
@@ -495,7 +526,11 @@ def test_launch_params_log_file_paths(self, mock_load_config, batch_model_config
495526 """Test that log file paths are constructed correctly."""
496527 mock_load_config .return_value = batch_model_configs
497528
498- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], account = "test-account" , work_dir = "/tmp/test-work" )
529+ launcher = BatchModelLauncher (
530+ ["family1-variant1" , "family2-variant1" ],
531+ account = "test-account" ,
532+ work_dir = "/tmp/test-work" ,
533+ )
499534 params = launcher .params
500535
501536 # Check individual model log files
@@ -535,10 +570,10 @@ def test_init_with_batch_config(self, mock_load_config, batch_model_configs):
535570 mock_load_config .return_value = batch_model_configs
536571
537572 launcher = BatchModelLauncher (
538- ["family1-variant1" , "family2-variant1" ],
573+ ["family1-variant1" , "family2-variant1" ],
539574 batch_config = "custom_config.yaml" ,
540- account = "test-account" ,
541- work_dir = "/tmp/test-work"
575+ account = "test-account" ,
576+ work_dir = "/tmp/test-work" ,
542577 )
543578
544579 assert launcher .batch_config == "custom_config.yaml"
0 commit comments