Skip to content

Commit 960f9a5

Browse files
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 019f153 commit 960f9a5

File tree

1 file changed

+53
-18
lines changed

1 file changed

+53
-18
lines changed

tests/vec_inf/client/test_helper.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def test_init_with_missing_config_and_existing_weights(
7979
mock_path_exists.return_value = True
8080

8181
with pytest.warns(UserWarning):
82-
launcher = ModelLauncher("unknown-model", {"account": "test-account", "work_dir": "/tmp/test-work"})
82+
launcher = ModelLauncher(
83+
"unknown-model",
84+
{"account": "test-account", "work_dir": "/tmp/test-work"},
85+
)
8386

8487
assert launcher.model_name == "unknown-model"
8588
assert launcher.model_config.model_name == "unknown-model"
@@ -276,7 +279,11 @@ def batch_model_configs(self) -> list[ModelConfig]:
276279
def test_init_with_valid_configs(self, mock_load_config, batch_model_configs):
277280
"""Test launcher initializes correctly with valid model configurations."""
278281
mock_load_config.return_value = batch_model_configs
279-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
282+
launcher = BatchModelLauncher(
283+
["family1-variant1", "family2-variant1"],
284+
account="test-account",
285+
work_dir="/tmp/test-work",
286+
)
280287

281288
assert launcher.model_names == ["family1-variant1", "family2-variant1"]
282289
assert launcher.slurm_job_name == "BATCH-family1-variant1-family2-variant1"
@@ -302,9 +309,9 @@ def test_get_slurm_job_name(self, mock_load_config, batch_model_configs):
302309
"""Test SLURM job name is constructed correctly from model names."""
303310
mock_load_config.return_value = batch_model_configs
304311
launcher = BatchModelLauncher(
305-
["family1-variant1", "family2-variant1", "family1-variant2"],
306-
account="test-account",
307-
work_dir="/tmp/test-work"
312+
["family1-variant1", "family2-variant1", "family1-variant2"],
313+
account="test-account",
314+
work_dir="/tmp/test-work",
308315
)
309316

310317
assert (
@@ -321,9 +328,9 @@ def test_get_launch_params_creates_log_dirs(
321328
mock_load_config.return_value = batch_model_configs
322329

323330
launcher = BatchModelLauncher(
324-
["family1-variant1", "family2-variant1", "family1-variant2"],
325-
account="test-account",
326-
work_dir="/tmp/test-work"
331+
["family1-variant1", "family2-variant1", "family1-variant2"],
332+
account="test-account",
333+
work_dir="/tmp/test-work",
327334
)
328335
params = launcher.params
329336

@@ -352,7 +359,11 @@ def test_get_launch_params_with_multi_gpu_no_tp(
352359
mock_load_config.return_value = updated_configs
353360

354361
with pytest.raises(MissingRequiredFieldsError) as excinfo:
355-
BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
362+
BatchModelLauncher(
363+
["family1-variant1", "family2-variant1"],
364+
account="test-account",
365+
work_dir="/tmp/test-work",
366+
)
356367

357368
assert "--tensor-parallel-size" in str(excinfo.value)
358369
assert "family1-variant1" in str(excinfo.value)
@@ -375,7 +386,11 @@ def test_get_launch_params_with_non_power_of_two_gpus(
375386
mock_load_config.return_value = updated_configs
376387

377388
with pytest.raises(ValueError) as excinfo:
378-
BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
389+
BatchModelLauncher(
390+
["family1-variant1", "family2-variant1"],
391+
account="test-account",
392+
work_dir="/tmp/test-work",
393+
)
379394

380395
assert "power of two" in str(excinfo.value)
381396
assert "family1-variant1" in str(excinfo.value)
@@ -401,7 +416,11 @@ def test_get_launch_params_with_mismatched_batch_args(
401416
mock_load_config.return_value = updated_configs
402417

403418
with pytest.raises(ValueError) as excinfo:
404-
BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
419+
BatchModelLauncher(
420+
["family1-variant1", "family2-variant1"],
421+
account="test-account",
422+
work_dir="/tmp/test-work",
423+
)
405424

406425
assert "Mismatch between total number of GPUs requested" in str(excinfo.value)
407426

@@ -451,7 +470,11 @@ def test_launch_success(
451470
# Mock copy2 to do nothing (avoid file operations)
452471
mock_copy2.return_value = None
453472

454-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
473+
launcher = BatchModelLauncher(
474+
["family1-variant1", "family2-variant1"],
475+
account="test-account",
476+
work_dir="/tmp/test-work",
477+
)
455478
response = launcher.launch()
456479

457480
assert response.slurm_job_id == "12345"
@@ -475,7 +498,11 @@ def test_launch_with_slurm_error(
475498
mock_load_config.return_value = batch_model_configs
476499
mock_run_bash.return_value = ("", "sbatch: error: Invalid partition specified")
477500

478-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
501+
launcher = BatchModelLauncher(
502+
["family1-variant1", "family2-variant1"],
503+
account="test-account",
504+
work_dir="/tmp/test-work",
505+
)
479506
with pytest.raises(SlurmJobError):
480507
launcher.launch()
481508

@@ -484,7 +511,11 @@ def test_launch_params_het_group_ids(self, mock_load_config, batch_model_configs
484511
"""Test that heterogeneous group IDs are assigned correctly."""
485512
mock_load_config.return_value = batch_model_configs
486513

487-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
514+
launcher = BatchModelLauncher(
515+
["family1-variant1", "family2-variant1"],
516+
account="test-account",
517+
work_dir="/tmp/test-work",
518+
)
488519
params = launcher.params
489520

490521
assert params["models"]["family1-variant1"]["het_group_id"] == 0
@@ -495,7 +526,11 @@ def test_launch_params_log_file_paths(self, mock_load_config, batch_model_config
495526
"""Test that log file paths are constructed correctly."""
496527
mock_load_config.return_value = batch_model_configs
497528

498-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], account="test-account", work_dir="/tmp/test-work")
529+
launcher = BatchModelLauncher(
530+
["family1-variant1", "family2-variant1"],
531+
account="test-account",
532+
work_dir="/tmp/test-work",
533+
)
499534
params = launcher.params
500535

501536
# Check individual model log files
@@ -535,10 +570,10 @@ def test_init_with_batch_config(self, mock_load_config, batch_model_configs):
535570
mock_load_config.return_value = batch_model_configs
536571

537572
launcher = BatchModelLauncher(
538-
["family1-variant1", "family2-variant1"],
573+
["family1-variant1", "family2-variant1"],
539574
batch_config="custom_config.yaml",
540-
account="test-account",
541-
work_dir="/tmp/test-work"
575+
account="test-account",
576+
work_dir="/tmp/test-work",
542577
)
543578

544579
assert launcher.batch_config == "custom_config.yaml"

0 commit comments

Comments
 (0)