Skip to content

Commit 645223b

Browse files
authored
Merge pull request #114 from VectorInstitute/test_slurm_generator
Unit tests for slurm script generator component in the vec_inf.client.
2 parents 7f382ba + 2f7bb50 commit 645223b

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
"""Unit tests for slurm script generator component in the vec_inf.client module."""
2+
3+
import tempfile
4+
from pathlib import Path
5+
from unittest.mock import mock_open, patch
6+
7+
import pytest
8+
9+
from vec_inf.client._slurm_script_generator import (
10+
SlurmScriptGenerator,
11+
)
12+
13+
14+
class TestSlurmScriptGenerator:
15+
"""Tests for SlurmScriptGenerator class."""
16+
17+
@pytest.fixture
18+
def basic_params(self):
19+
"""Generate basic SLURM configuration parameters."""
20+
return {
21+
"model_name": "test-model",
22+
"model_weights_parent_dir": "/path/to/model_weights",
23+
"src_dir": "/path/to/src",
24+
"log_dir": "/path/to/logs",
25+
"num_nodes": "1",
26+
"venv": "/path/to/venv",
27+
"gpus_per_node": "4",
28+
"partition": "gpu",
29+
"account": "test-account",
30+
"time": "01:00:00",
31+
"vllm_args": {
32+
"--tensor-parallel-size": "4",
33+
"--max-model-len": "8192",
34+
"--enforce-eager": True,
35+
},
36+
}
37+
38+
@pytest.fixture
39+
def multinode_params(self, basic_params):
40+
"""Generate multi-node SLURM configuration parameters."""
41+
multinode = basic_params.copy()
42+
multinode.update(
43+
{
44+
"num_nodes": "2",
45+
}
46+
)
47+
return multinode
48+
49+
@pytest.fixture
50+
def singularity_params(self, basic_params):
51+
"""Generate singularity-based SLURM configuration parameters."""
52+
singularity = basic_params.copy()
53+
singularity.update(
54+
{
55+
"venv": "singularity",
56+
"bind": "/scratch:/scratch,/data:/data",
57+
}
58+
)
59+
return singularity
60+
61+
@pytest.fixture
62+
def temp_log_dir(self):
63+
"""Generate temporary directory for log files."""
64+
with tempfile.TemporaryDirectory() as tmpdir:
65+
yield Path(tmpdir)
66+
67+
def test_init_single_node(self, basic_params):
68+
"""Test initialization with single-node configuration."""
69+
generator = SlurmScriptGenerator(basic_params)
70+
71+
assert generator.params == basic_params
72+
assert not generator.is_multinode
73+
assert not generator.use_singularity
74+
assert generator.additional_binds == ""
75+
assert generator.model_weights_path == "/path/to/model_weights/test-model"
76+
77+
def test_init_multinode(self, multinode_params):
78+
"""Test initialization with multi-node configuration."""
79+
generator = SlurmScriptGenerator(multinode_params)
80+
81+
assert generator.params == multinode_params
82+
assert generator.is_multinode
83+
assert not generator.use_singularity
84+
assert generator.additional_binds == ""
85+
assert generator.model_weights_path == "/path/to/model_weights/test-model"
86+
87+
def test_init_singularity(self, singularity_params):
88+
"""Test initialization with Singularity configuration."""
89+
generator = SlurmScriptGenerator(singularity_params)
90+
91+
assert generator.params == singularity_params
92+
assert generator.use_singularity
93+
assert not generator.is_multinode
94+
assert generator.additional_binds == " --bind /scratch:/scratch,/data:/data"
95+
assert generator.model_weights_path == "/path/to/model_weights/test-model"
96+
97+
def test_init_singularity_no_bind(self, basic_params):
98+
"""Test Singularity initialization without additional binds."""
99+
params = basic_params.copy()
100+
params["venv"] = "singularity"
101+
generator = SlurmScriptGenerator(params)
102+
103+
assert generator.params == params
104+
assert generator.use_singularity
105+
assert not generator.is_multinode
106+
assert generator.additional_binds == ""
107+
assert generator.model_weights_path == "/path/to/model_weights/test-model"
108+
109+
def test_generate_shebang_single_node(self, basic_params):
110+
"""Test shebang generation for single-node setup."""
111+
generator = SlurmScriptGenerator(basic_params)
112+
shebang = generator._generate_shebang()
113+
114+
assert shebang.startswith("#!/bin/bash")
115+
assert "#SBATCH --job-name=test-model" in shebang
116+
assert "#SBATCH --partition=gpu" in shebang
117+
assert "#SBATCH --nodes=1" in shebang
118+
assert "#SBATCH --exclusive" not in shebang
119+
120+
def test_generate_shebang_multinode(self, multinode_params):
121+
"""Test shebang generation for multi-node setup."""
122+
generator = SlurmScriptGenerator(multinode_params)
123+
shebang = generator._generate_shebang()
124+
125+
assert "#SBATCH --nodes=2" in shebang
126+
assert "#SBATCH --exclusive" in shebang
127+
assert "#SBATCH --tasks-per-node=1" in shebang
128+
129+
def test_generate_server_setup_single_node(self, basic_params):
130+
"""Test server setup generation for single-node."""
131+
generator = SlurmScriptGenerator(basic_params)
132+
setup = generator._generate_server_setup()
133+
134+
assert "head_node_ip=${SLURMD_NODENAME}" in setup
135+
assert "source /path/to/src/find_port.sh" in setup
136+
assert "export LD_LIBRARY_PATH=" in setup
137+
assert "ray start --head" not in setup
138+
139+
def test_generate_server_setup_multinode(self, multinode_params):
140+
"""Test server setup generation for multi-node."""
141+
generator = SlurmScriptGenerator(multinode_params)
142+
setup = generator._generate_server_setup()
143+
144+
assert "ray start --head" in setup
145+
assert "ray start --address" in setup
146+
assert "scontrol show hostnames" in setup
147+
assert "worker_num=$((SLURM_JOB_NUM_NODES - 1))" in setup
148+
149+
def test_generate_server_setup_singularity(self, singularity_params):
150+
"""Test server setup with Singularity container."""
151+
generator = SlurmScriptGenerator(singularity_params)
152+
setup = generator._generate_server_setup()
153+
154+
assert "singularity exec" in setup
155+
assert "ray stop" in setup
156+
assert "module load singularity" in setup
157+
158+
def test_generate_launch_cmd_venv(self, basic_params):
159+
"""Test launch command generation with virtual environment."""
160+
generator = SlurmScriptGenerator(basic_params)
161+
launch_cmd = generator._generate_launch_cmd()
162+
163+
assert "source /path/to/venv/bin/activate" in launch_cmd
164+
assert "vllm serve /path/to/model_weights/test-model" in launch_cmd
165+
assert "--served-model-name test-model" in launch_cmd
166+
assert "--tensor-parallel-size 4" in launch_cmd
167+
assert "--max-model-len 8192" in launch_cmd
168+
assert "--enforce-eager" in launch_cmd
169+
170+
def test_generate_launch_cmd_singularity(self, singularity_params):
171+
"""Test launch command generation with Singularity."""
172+
generator = SlurmScriptGenerator(singularity_params)
173+
launch_cmd = generator._generate_launch_cmd()
174+
175+
assert "singularity exec --nv" in launch_cmd
176+
assert "--bind /path/to/model_weights/test-model" in launch_cmd
177+
assert "--bind /scratch:/scratch,/data:/data" in launch_cmd
178+
assert "source" not in launch_cmd
179+
180+
def test_generate_launch_cmd_boolean_args(self, basic_params):
181+
"""Test launch command with boolean vLLM arguments."""
182+
params = basic_params.copy()
183+
params["vllm_args"] = {
184+
"--trust-remote-code": True,
185+
"--disable-log-stats": True,
186+
"--tensor-parallel-size": "2",
187+
}
188+
189+
generator = SlurmScriptGenerator(params)
190+
launch_cmd = generator._generate_launch_cmd()
191+
192+
assert "--trust-remote-code" in launch_cmd
193+
assert "--disable-log-stats" in launch_cmd
194+
assert "--tensor-parallel-size 2" in launch_cmd
195+
196+
@patch("builtins.open", new_callable=mock_open)
197+
@patch("vec_inf.client._slurm_script_generator.datetime")
198+
def test_write_to_log_dir(
199+
self, mock_datetime, mock_file, basic_params, temp_log_dir
200+
):
201+
"""Test writing SLURM script to log directory."""
202+
mock_datetime.now.return_value.strftime.return_value = "20240101_120000"
203+
204+
params = basic_params.copy()
205+
params["log_dir"] = str(temp_log_dir)
206+
207+
generator = SlurmScriptGenerator(params)
208+
209+
with patch.object(Path, "write_text") as mock_write:
210+
script_path = generator.write_to_log_dir()
211+
212+
expected_path = temp_log_dir / "launch_test-model_20240101_120000.slurm"
213+
assert script_path == expected_path
214+
mock_write.assert_called_once()
215+
216+
def test_generate_script_content_integration(self, basic_params):
217+
"""Test complete script generation integration."""
218+
generator = SlurmScriptGenerator(basic_params)
219+
content = generator._generate_script_content()
220+
221+
assert content.startswith("#!/bin/bash")
222+
assert "vllm serve" in content
223+
assert "find_available_port" in content
224+
assert "source /path/to/venv/bin/activate" in content

0 commit comments

Comments
 (0)