Skip to content

Commit 9ffac32

Browse files
authored
Merge pull request #84 from VectorInstitute/dynamic_slurm
Generate slurm scripts on the fly, retire generic single node and multi-node slurm scripts
2 parents acbc33b + a6a090c commit 9ffac32

File tree

7 files changed

+244
-320
lines changed

7 files changed

+244
-320
lines changed

tests/vec_inf/cli/test_cli.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,12 @@ def base_patches(test_paths, mock_truediv, debug_helper):
226226
"pathlib.Path.parent", return_value=debug_helper.config_file.parent.parent
227227
),
228228
patch("pathlib.Path.__truediv__", side_effect=mock_truediv),
229-
patch("pathlib.Path.iterdir", return_value=[]), # Mock empty directory listing
229+
patch("pathlib.Path.iterdir", return_value=[]),
230230
patch("json.dump"),
231231
patch("pathlib.Path.touch"),
232232
patch("vec_inf.client._utils.Path", return_value=test_paths["weights_dir"]),
233-
patch(
234-
"pathlib.Path.home", return_value=Path("/home/user")
235-
), # Mock home directory
233+
patch("pathlib.Path.home", return_value=Path("/home/user")),
234+
patch("pathlib.Path.rename"),
236235
]
237236

238237

@@ -246,25 +245,25 @@ def apply_base_patches(base_patches):
246245
yield
247246

248247

249-
def test_launch_command_success(runner, mock_launch_output, path_exists, debug_helper):
248+
def test_launch_command_success(
249+
runner,
250+
mock_launch_output,
251+
path_exists,
252+
debug_helper,
253+
mock_truediv,
254+
test_paths,
255+
base_patches,
256+
):
250257
"""Test successful model launch with minimal required arguments."""
251-
test_log_dir = Path("/tmp/test_vec_inf_logs")
258+
with ExitStack() as stack:
259+
# Apply all base patches
260+
for patch_obj in base_patches:
261+
stack.enter_context(patch_obj)
262+
263+
# Apply specific patches for this test
264+
mock_run = stack.enter_context(patch("vec_inf.client._utils.run_bash_command"))
265+
stack.enter_context(patch("pathlib.Path.exists", new=path_exists))
252266

253-
with (
254-
patch("vec_inf.client._utils.run_bash_command") as mock_run,
255-
patch("pathlib.Path.mkdir"),
256-
patch("builtins.open", debug_helper.tracked_mock_open),
257-
patch("pathlib.Path.open", debug_helper.tracked_mock_open),
258-
patch("pathlib.Path.exists", new=path_exists),
259-
patch("pathlib.Path.expanduser", return_value=test_log_dir),
260-
patch("pathlib.Path.resolve", return_value=debug_helper.config_file.parent),
261-
patch(
262-
"pathlib.Path.parent", return_value=debug_helper.config_file.parent.parent
263-
),
264-
patch("json.dump"),
265-
patch("pathlib.Path.touch"),
266-
patch("pathlib.Path.__truediv__", return_value=test_log_dir),
267-
):
268267
expected_job_id = "14933053"
269268
mock_run.return_value = mock_launch_output(expected_job_id)
270269

@@ -277,25 +276,24 @@ def test_launch_command_success(runner, mock_launch_output, path_exists, debug_h
277276

278277

279278
def test_launch_command_with_json_output(
280-
runner, mock_launch_output, path_exists, debug_helper
279+
runner,
280+
mock_launch_output,
281+
path_exists,
282+
debug_helper,
283+
mock_truediv,
284+
test_paths,
285+
base_patches,
281286
):
282287
"""Test JSON output format for launch command."""
283-
test_log_dir = Path("/tmp/test_vec_inf_logs")
284-
with (
285-
patch("vec_inf.client._utils.run_bash_command") as mock_run,
286-
patch("pathlib.Path.mkdir"),
287-
patch("builtins.open", debug_helper.tracked_mock_open),
288-
patch("pathlib.Path.open", debug_helper.tracked_mock_open),
289-
patch("pathlib.Path.exists", new=path_exists),
290-
patch("pathlib.Path.expanduser", return_value=test_log_dir),
291-
patch("pathlib.Path.resolve", return_value=debug_helper.config_file.parent),
292-
patch(
293-
"pathlib.Path.parent", return_value=debug_helper.config_file.parent.parent
294-
),
295-
patch("json.dump"),
296-
patch("pathlib.Path.touch"),
297-
patch("pathlib.Path.__truediv__", return_value=test_log_dir),
298-
):
288+
with ExitStack() as stack:
289+
# Apply all base patches
290+
for patch_obj in base_patches:
291+
stack.enter_context(patch_obj)
292+
293+
# Apply specific patches for this test
294+
mock_run = stack.enter_context(patch("vec_inf.client._utils.run_bash_command"))
295+
stack.enter_context(patch("pathlib.Path.exists", new=path_exists))
296+
299297
expected_job_id = "14933051"
300298
mock_run.return_value = mock_launch_output(expected_job_id)
301299

@@ -319,7 +317,7 @@ def test_launch_command_with_json_output(
319317
assert output.get("slurm_job_id") == expected_job_id
320318
assert output.get("model_name") == "Meta-Llama-3.1-8B"
321319
assert output.get("model_type") == "LLM"
322-
assert str(test_log_dir) in output.get("log_dir", "")
320+
assert str(test_paths["log_dir"]) in output.get("log_dir", "")
323321

324322

325323
def test_launch_command_no_model_weights_parent_dir(runner, debug_helper, base_patches):

vec_inf/cli/_helper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Helper classes for the CLI."""
22

3-
import os
3+
from pathlib import Path
44
from typing import Any, Union
55

66
import click
@@ -59,9 +59,10 @@ def format_table_output(self) -> Table:
5959
)
6060
if self.params.get("enforce_eager"):
6161
table.add_row("Enforce Eager", self.params["enforce_eager"])
62-
63-
# Add path details
64-
table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
62+
table.add_row(
63+
"Model Weights Directory",
64+
str(Path(self.params["model_weights_parent_dir"], self.model_name)),
65+
)
6566
table.add_row("Log Directory", self.params["log_dir"])
6667

6768
return table

vec_inf/client/_helper.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
ModelType,
2626
StatusResponse,
2727
)
28+
from vec_inf.client._slurm_script_generator import SlurmScriptGenerator
2829
from vec_inf.client._vars import (
2930
BOOLEAN_FIELDS,
3031
LD_LIBRARY_PATH,
3132
REQUIRED_FIELDS,
33+
SINGULARITY_IMAGE,
3234
SRC_DIR,
33-
VLLM_TASK_MAP,
35+
VLLM_NCCL_SO_PATH,
3436
)
3537

3638

@@ -50,6 +52,7 @@ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
5052
self.model_name = model_name
5153
self.kwargs = kwargs or {}
5254
self.slurm_job_id = ""
55+
self.slurm_script_path = Path("")
5356
self.model_config = self._get_model_configuration()
5457
self.params = self._get_launch_params()
5558

@@ -137,31 +140,9 @@ def _get_launch_params(self) -> dict[str, Any]:
137140

138141
def _set_env_vars(self) -> None:
139142
"""Set environment variables for the launch command."""
140-
os.environ["MODEL_NAME"] = self.model_name
141-
os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"]
142-
os.environ["MAX_LOGPROBS"] = self.params["vocab_size"]
143-
os.environ["DATA_TYPE"] = self.params["data_type"]
144-
os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"]
145-
os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
146-
os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
147-
os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
148-
os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
149-
os.environ["SRC_DIR"] = SRC_DIR
150-
os.environ["MODEL_WEIGHTS"] = str(
151-
Path(self.params["model_weights_parent_dir"], self.model_name)
152-
)
153143
os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
154-
os.environ["VENV_BASE"] = self.params["venv"]
155-
os.environ["LOG_DIR"] = self.params["log_dir"]
156-
157-
if self.params.get("enable_prefix_caching"):
158-
os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
159-
if self.params.get("enable_chunked_prefill"):
160-
os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
161-
if self.params.get("max_num_batched_tokens"):
162-
os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
163-
if self.params.get("enforce_eager"):
164-
os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
144+
os.environ["VLLM_NCCL_SO_PATH"] = VLLM_NCCL_SO_PATH
145+
os.environ["SINGULARITY_IMAGE"] = SINGULARITY_IMAGE
165146

166147
def _build_launch_command(self) -> str:
167148
"""Construct the full launch command with parameters."""
@@ -187,10 +168,10 @@ def _build_launch_command(self) -> str:
187168
]
188169
)
189170
# Add slurm script
190-
slurm_script = "vllm.slurm"
191-
if int(self.params["num_nodes"]) > 1:
192-
slurm_script = "multinode_vllm.slurm"
193-
command_list.append(f"{SRC_DIR}/{slurm_script}")
171+
self.slurm_script_path = SlurmScriptGenerator(
172+
self.params, SRC_DIR
173+
).write_to_log_dir()
174+
command_list.append(str(self.slurm_script_path))
194175
return " ".join(command_list)
195176

196177
def launch(self) -> LaunchResponse:
@@ -207,15 +188,22 @@ def launch(self) -> LaunchResponse:
207188
self.slurm_job_id = command_output.split(" ")[-1].strip().strip("\n")
208189
self.params["slurm_job_id"] = self.slurm_job_id
209190

210-
# Create log directory and job json file
191+
# Create log directory and job json file, move slurm script to job log directory
192+
job_log_dir = Path(
193+
self.params["log_dir"], f"{self.model_name}.{self.slurm_job_id}"
194+
)
195+
job_log_dir.mkdir(parents=True, exist_ok=True)
196+
211197
job_json = Path(
212-
self.params["log_dir"],
213-
f"{self.model_name}.{self.slurm_job_id}",
198+
job_log_dir,
214199
f"{self.model_name}.{self.slurm_job_id}.json",
215200
)
216-
job_json.parent.mkdir(parents=True, exist_ok=True)
217201
job_json.touch(exist_ok=True)
218202

203+
self.slurm_script_path.rename(
204+
job_log_dir / f"{self.model_name}.{self.slurm_job_id}.slurm"
205+
)
206+
219207
with job_json.open("w") as file:
220208
json.dump(self.params, file, indent=4)
221209

0 commit comments

Comments
 (0)