Skip to content

Commit e41b9b3

Browse files
authored
Merge branch 'main' into add_multiversion_docs_support
2 parents 777f91b + 14a7776 commit e41b9b3

File tree

1 file changed

+132
-51
lines changed

1 file changed

+132
-51
lines changed

vec_inf/client/_slurm_script_generator.py

Lines changed: 132 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Class for generating SLURM scripts to run vLLM servers."""
2+
13
from datetime import datetime
24
from pathlib import Path
35
from typing import Any
@@ -6,7 +8,24 @@
68

79

810
class SlurmScriptGenerator:
11+
"""A class to generate SLURM scripts for running vLLM servers.
12+
13+
This class handles the generation of SLURM scripts for both single-node and
14+
multi-node configurations, supporting different virtualization environments
15+
(venv or singularity).
16+
17+
Args:
18+
params (dict[str, Any]): Configuration parameters for the SLURM script
19+
src_dir (str): Source directory path containing necessary scripts
20+
"""
21+
922
def __init__(self, params: dict[str, Any], src_dir: str):
23+
"""Initialize the SlurmScriptGenerator with configuration parameters.
24+
25+
Args:
26+
params (dict[str, Any]): Configuration parameters for the SLURM script
27+
src_dir (str): Source directory path containing necessary scripts
28+
"""
1029
self.params = params
1130
self.src_dir = src_dir
1231
self.is_multinode = int(self.params["num_nodes"]) > 1
@@ -16,13 +35,25 @@ def __init__(self, params: dict[str, Any], src_dir: str):
1635
self.task = VLLM_TASK_MAP[self.params["model_type"]]
1736

1837
def _generate_script_content(self) -> str:
38+
"""Generate the complete SLURM script content.
39+
40+
Returns
41+
-------
42+
str: The complete SLURM script as a string
43+
"""
1944
preamble = self._generate_preamble()
2045
server = self._generate_server_script()
2146
launcher = self._generate_launcher()
2247
args = self._generate_shared_args()
2348
return preamble + server + launcher + args
2449

2550
def _generate_preamble(self) -> str:
51+
"""Generate the SLURM script preamble with job specifications.
52+
53+
Returns
54+
-------
55+
str: SLURM preamble containing resource requests and job parameters
56+
"""
2657
base = [
2758
"#!/bin/bash",
2859
"#SBATCH --cpus-per-task=16",
@@ -37,6 +68,15 @@ def _generate_preamble(self) -> str:
3768
return "\n".join(base)
3869

3970
def _generate_shared_args(self) -> str:
71+
"""Generate the command-line arguments for the vLLM server.
72+
73+
Handles both single-node and multi-node configurations, setting appropriate
74+
parallel processing parameters based on the configuration.
75+
76+
Returns
77+
-------
78+
str: Command-line arguments for the vLLM server
79+
"""
4080
if self.is_multinode and not self.params["pipeline_parallelism"]:
4181
tensor_parallel_size = (
4282
self.params["num_nodes"] * self.params["gpus_per_node"]
@@ -77,88 +117,121 @@ def _generate_shared_args(self) -> str:
77117
return "\n".join(args)
78118

79119
def _generate_server_script(self) -> str:
120+
"""Generate the server initialization script.
121+
122+
Creates the script section that handles server setup, including Ray
123+
initialization for multi-node setups and port configuration.
124+
125+
Returns
126+
-------
127+
str: Server initialization script content
128+
"""
80129
server_script = [""]
81130
if self.params["venv"] == "singularity":
82-
server_script.append("""module load singularity-ce/3.8.2
83-
singularity exec $SINGULARITY_IMAGE ray stop
84-
""")
131+
server_script.append(
132+
"module load singularity-ce/3.8.2\n"
133+
"singularity exec $SINGULARITY_IMAGE ray stop\n"
134+
)
85135
server_script.append(f"source {self.src_dir}/find_port.sh\n")
86136
server_script.append(
87137
self._generate_multinode_server_script()
88138
if self.is_multinode
89139
else self._generate_single_node_server_script()
90140
)
91-
server_script.append(f"""json_path="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
92-
jq --arg server_addr "$server_address" \\
93-
'. + {{"server_address": $server_addr}}' \\
94-
"$json_path" > temp.json \\
95-
&& mv temp.json "$json_path"
96-
97-
""")
141+
server_script.append(
142+
f'json_path="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"\n'
143+
'jq --arg server_addr "$server_address" \\\n'
144+
" '. + {{\"server_address\": $server_addr}}' \\\n"
145+
' "$json_path" > temp.json \\\n'
146+
' && mv temp.json "$json_path"\n\n'
147+
)
98148
return "\n".join(server_script)
99149

100150
def _generate_single_node_server_script(self) -> str:
101-
return """hostname=${SLURMD_NODENAME}
102-
vllm_port_number=$(find_available_port ${hostname} 8080 65535)
151+
"""Generate the server script for single-node deployment.
103152
104-
server_address="http://${hostname}:${vllm_port_number}/v1"
105-
echo "Server address: $server_address"
106-
"""
153+
Returns
154+
-------
155+
str: Script content for single-node server setup
156+
"""
157+
return (
158+
"hostname=${SLURMD_NODENAME}\n"
159+
"vllm_port_number=$(find_available_port ${hostname} 8080 65535)\n\n"
160+
'server_address="http://${hostname}:${vllm_port_number}/v1"\n'
161+
'echo "Server address: $server_address"\n'
162+
)
107163

108164
def _generate_multinode_server_script(self) -> str:
109-
server_script = []
110-
server_script.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
111-
nodes_array=($nodes)
165+
"""Generate the server script for multi-node deployment.
112166
113-
head_node=${nodes_array[0]}
114-
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
167+
Creates a script that initializes Ray cluster with head and worker nodes,
168+
configuring networking and GPU resources appropriately.
115169
116-
head_node_port=$(find_available_port $head_node_ip 8080 65535)
117-
118-
ip_head=$head_node_ip:$head_node_port
119-
export ip_head
120-
echo "IP Head: $ip_head"
121-
122-
echo "Starting HEAD at $head_node"
123-
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")
170+
Returns
171+
-------
172+
str: Script content for multi-node server setup
173+
"""
174+
server_script = []
175+
server_script.append(
176+
'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")\n'
177+
"nodes_array=($nodes)\n\n"
178+
"head_node=${nodes_array[0]}\n"
179+
'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)\n\n'
180+
"head_node_port=$(find_available_port $head_node_ip 8080 65535)\n\n"
181+
"ip_head=$head_node_ip:$head_node_port\n"
182+
"export ip_head\n"
183+
'echo "IP Head: $ip_head"\n\n'
184+
'echo "Starting HEAD at $head_node"\n'
185+
'srun --nodes=1 --ntasks=1 -w "$head_node" \\'
186+
)
124187

125188
if self.params["venv"] == "singularity":
126189
server_script.append(
127-
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"
190+
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} "
191+
"--containall $SINGULARITY_IMAGE \\"
128192
)
129193

130-
server_script.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
131-
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
132-
133-
sleep 10
134-
worker_num=$((SLURM_JOB_NUM_NODES - 1))
135-
136-
for ((i = 1; i <= worker_num; i++)); do
137-
node_i=${nodes_array[$i]}
138-
echo "Starting WORKER $i at $node_i"
139-
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")
194+
server_script.append(
195+
' ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\\n'
196+
' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &\n\n'
197+
"sleep 10\n"
198+
"worker_num=$((SLURM_JOB_NUM_NODES - 1))\n\n"
199+
"for ((i = 1; i <= worker_num; i++)); do\n"
200+
" node_i=${nodes_array[$i]}\n"
201+
' echo "Starting WORKER $i at $node_i"\n'
202+
' srun --nodes=1 --ntasks=1 -w "$node_i" \\'
203+
)
140204

141205
if self.params["venv"] == "singularity":
142206
server_script.append(
143-
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
207+
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} "
208+
"--containall $SINGULARITY_IMAGE \\"
144209
)
145-
server_script.append(""" ray start --address "$ip_head" \\
146-
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
147-
sleep 5
148-
done
149210

150-
vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
151-
152-
server_address="http://${head_node_ip}:${vllm_port_number}/v1"
153-
echo "Server address: $server_address"
154-
155-
""")
211+
server_script.append(
212+
' ray start --address "$ip_head" \\\n'
213+
' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &\n'
214+
" sleep 5\n"
215+
"done\n\n"
216+
"vllm_port_number=$(find_available_port $head_node_ip 8080 65535)\n\n"
217+
'server_address="http://${head_node_ip}:${vllm_port_number}/v1"\n'
218+
'echo "Server address: $server_address"\n\n'
219+
)
156220
return "\n".join(server_script)
157221

158222
def _generate_launcher(self) -> str:
223+
"""Generate the vLLM server launch command.
224+
225+
Creates the command to launch the vLLM server, handling different virtualization
226+
environments (venv or singularity).
227+
228+
Returns
229+
-------
230+
str: Server launch command
231+
"""
159232
if self.params["venv"] == "singularity":
160233
launcher_script = [
161-
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
234+
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} --containall $SINGULARITY_IMAGE \\"""
162235
]
163236
else:
164237
launcher_script = [f"""source {self.params["venv"]}/bin/activate"""]
@@ -168,6 +241,14 @@ def _generate_launcher(self) -> str:
168241
return "\n".join(launcher_script)
169242

170243
def write_to_log_dir(self) -> Path:
244+
"""Write the generated SLURM script to the log directory.
245+
246+
Creates a timestamped script file in the configured log directory.
247+
248+
Returns
249+
-------
250+
Path: Path to the generated SLURM script file
251+
"""
171252
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
172253
script_path: Path = (
173254
Path(self.params["log_dir"])

0 commit comments

Comments
 (0)