1+ """Class for generating SLURM scripts to run vLLM servers."""
2+
13from datetime import datetime
24from pathlib import Path
35from typing import Any
68
79
810class SlurmScriptGenerator :
11+ """A class to generate SLURM scripts for running vLLM servers.
12+
13+ This class handles the generation of SLURM scripts for both single-node and
14+ multi-node configurations, supporting different virtualization environments
15+ (venv or singularity).
16+
17+ Args:
18+ params (dict[str, Any]): Configuration parameters for the SLURM script
19+ src_dir (str): Source directory path containing necessary scripts
20+ """
21+
922 def __init__ (self , params : dict [str , Any ], src_dir : str ):
23+ """Initialize the SlurmScriptGenerator with configuration parameters.
24+
25+ Args:
26+ params (dict[str, Any]): Configuration parameters for the SLURM script
27+ src_dir (str): Source directory path containing necessary scripts
28+ """
1029 self .params = params
1130 self .src_dir = src_dir
1231 self .is_multinode = int (self .params ["num_nodes" ]) > 1
@@ -16,13 +35,25 @@ def __init__(self, params: dict[str, Any], src_dir: str):
1635 self .task = VLLM_TASK_MAP [self .params ["model_type" ]]
1736
1837 def _generate_script_content (self ) -> str :
38+ """Generate the complete SLURM script content.
39+
40+ Returns
41+ -------
42+ str: The complete SLURM script as a string
43+ """
1944 preamble = self ._generate_preamble ()
2045 server = self ._generate_server_script ()
2146 launcher = self ._generate_launcher ()
2247 args = self ._generate_shared_args ()
2348 return preamble + server + launcher + args
2449
2550 def _generate_preamble (self ) -> str :
51+ """Generate the SLURM script preamble with job specifications.
52+
53+ Returns
54+ -------
55+ str: SLURM preamble containing resource requests and job parameters
56+ """
2657 base = [
2758 "#!/bin/bash" ,
2859 "#SBATCH --cpus-per-task=16" ,
@@ -37,6 +68,15 @@ def _generate_preamble(self) -> str:
3768 return "\n " .join (base )
3869
3970 def _generate_shared_args (self ) -> str :
71+ """Generate the command-line arguments for the vLLM server.
72+
73+ Handles both single-node and multi-node configurations, setting appropriate
74+ parallel processing parameters based on the configuration.
75+
76+ Returns
77+ -------
78+ str: Command-line arguments for the vLLM server
79+ """
4080 if self .is_multinode and not self .params ["pipeline_parallelism" ]:
4181 tensor_parallel_size = (
4282 self .params ["num_nodes" ] * self .params ["gpus_per_node" ]
@@ -77,88 +117,121 @@ def _generate_shared_args(self) -> str:
77117 return "\n " .join (args )
78118
79119 def _generate_server_script (self ) -> str :
120+ """Generate the server initialization script.
121+
122+ Creates the script section that handles server setup, including Ray
123+ initialization for multi-node setups and port configuration.
124+
125+ Returns
126+ -------
127+ str: Server initialization script content
128+ """
80129 server_script = ["" ]
81130 if self .params ["venv" ] == "singularity" :
82- server_script .append ("""module load singularity-ce/3.8.2
83- singularity exec $SINGULARITY_IMAGE ray stop
84- """ )
131+ server_script .append (
132+ "module load singularity-ce/3.8.2\n "
133+ "singularity exec $SINGULARITY_IMAGE ray stop\n "
134+ )
85135 server_script .append (f"source { self .src_dir } /find_port.sh\n " )
86136 server_script .append (
87137 self ._generate_multinode_server_script ()
88138 if self .is_multinode
89139 else self ._generate_single_node_server_script ()
90140 )
91- server_script .append (f"""json_path=" { self . params [ "log_dir" ] } / { self . params [ "model_name" ] } .$SLURM_JOB_ID/ { self . params [ "model_name" ] } .$SLURM_JOB_ID.json"
92- jq --arg server_addr "$server_address" \\
93- '. + {{" server_address": $server_addr}}' \\
94- "$json_path" > temp.json \\
95- && mv temp.json "$json_path"
96-
97- """ )
141+ server_script .append (
142+ f'json_path=" { self . params [ "log_dir" ] } / { self . params [ "model_name" ] } .$SLURM_JOB_ID/ { self . params [ "model_name" ] } .$SLURM_JOB_ID.json" \n '
143+ 'jq --arg server_addr "$ server_address" \\ \n '
144+ " '. + {{ \" server_address \" : $server_addr}}' \\ \n "
145+ ' "$json_path" > temp.json \\ \n '
146+ ' && mv temp.json "$json_path" \n \n '
147+ )
98148 return "\n " .join (server_script )
99149
100150 def _generate_single_node_server_script (self ) -> str :
101- return """hostname=${SLURMD_NODENAME}
102- vllm_port_number=$(find_available_port ${hostname} 8080 65535)
151+ """Generate the server script for single-node deployment.
103152
104- server_address="http://${hostname}:${vllm_port_number}/v1"
105- echo "Server address: $server_address"
106- """
153+ Returns
154+ -------
155+ str: Script content for single-node server setup
156+ """
157+ return (
158+ "hostname=${SLURMD_NODENAME}\n "
159+ "vllm_port_number=$(find_available_port ${hostname} 8080 65535)\n \n "
160+ 'server_address="http://${hostname}:${vllm_port_number}/v1"\n '
161+ 'echo "Server address: $server_address"\n '
162+ )
107163
108164 def _generate_multinode_server_script (self ) -> str :
109- server_script = []
110- server_script .append ("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
111- nodes_array=($nodes)
165+ """Generate the server script for multi-node deployment.
112166
113- head_node=${nodes_array[0]}
114- head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
167+ Creates a script that initializes Ray cluster with head and worker nodes,
168+ configuring networking and GPU resources appropriately.
115169
116- head_node_port=$(find_available_port $head_node_ip 8080 65535)
117-
118- ip_head=$head_node_ip:$head_node_port
119- export ip_head
120- echo "IP Head: $ip_head"
121-
122- echo "Starting HEAD at $head_node"
123- srun --nodes=1 --ntasks=1 -w "$head_node" \\ """ )
170+ Returns
171+ -------
172+ str: Script content for multi-node server setup
173+ """
174+ server_script = []
175+ server_script .append (
176+ 'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")\n '
177+ "nodes_array=($nodes)\n \n "
178+ "head_node=${nodes_array[0]}\n "
179+ 'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)\n \n '
180+ "head_node_port=$(find_available_port $head_node_ip 8080 65535)\n \n "
181+ "ip_head=$head_node_ip:$head_node_port\n "
182+ "export ip_head\n "
183+ 'echo "IP Head: $ip_head"\n \n '
184+ 'echo "Starting HEAD at $head_node"\n '
185+ 'srun --nodes=1 --ntasks=1 -w "$head_node" \\ '
186+ )
124187
125188 if self .params ["venv" ] == "singularity" :
126189 server_script .append (
127- f" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ "
190+ f" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } "
191+ "--containall $SINGULARITY_IMAGE \\ "
128192 )
129193
130- server_script .append (""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
131- --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE} " --block &
132-
133- sleep 10
134- worker_num=$((SLURM_JOB_NUM_NODES - 1))
135-
136- for ((i = 1; i <= worker_num; i++)); do
137- node_i=${nodes_array[$i]}
138- echo "Starting WORKER $i at $node_i"
139- srun --nodes=1 --ntasks=1 -w "$node_i" \\ """ )
194+ server_script .append (
195+ ' ray start --head --node-ip-address="$head_node_ip " --port=$head_node_port \\ \n '
196+ ' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & \n \n '
197+ " sleep 10\n "
198+ " worker_num=$((SLURM_JOB_NUM_NODES - 1))\n \n "
199+ "for ((i = 1; i <= worker_num; i++)); do \n "
200+ " node_i=${nodes_array[$i]} \n "
201+ ' echo "Starting WORKER $i at $node_i" \n '
202+ ' srun --nodes=1 --ntasks=1 -w " $node_i" \\ '
203+ )
140204
141205 if self .params ["venv" ] == "singularity" :
142206 server_script .append (
143- f""" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ """
207+ f" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } "
208+ "--containall $SINGULARITY_IMAGE \\ "
144209 )
145- server_script .append (""" ray start --address "$ip_head" \\
146- --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
147- sleep 5
148- done
149210
150- vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
151-
152- server_address="http://${head_node_ip}:${vllm_port_number}/v1"
153- echo "Server address: $server_address"
154-
155- """ )
211+ server_script .append (
212+ ' ray start --address "$ip_head" \\ \n '
213+ ' --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &\n '
214+ " sleep 5\n "
215+ "done\n \n "
216+ "vllm_port_number=$(find_available_port $head_node_ip 8080 65535)\n \n "
217+ 'server_address="http://${head_node_ip}:${vllm_port_number}/v1"\n '
218+ 'echo "Server address: $server_address"\n \n '
219+ )
156220 return "\n " .join (server_script )
157221
158222 def _generate_launcher (self ) -> str :
223+ """Generate the vLLM server launch command.
224+
225+ Creates the command to launch the vLLM server, handling different virtualization
226+ environments (venv or singularity).
227+
228+ Returns
229+ -------
230+ str: Server launch command
231+ """
159232 if self .params ["venv" ] == "singularity" :
160233 launcher_script = [
161- f"""singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ """
234+ f"""singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } --containall $SINGULARITY_IMAGE \\ """
162235 ]
163236 else :
164237 launcher_script = [f"""source { self .params ["venv" ]} /bin/activate""" ]
@@ -168,6 +241,14 @@ def _generate_launcher(self) -> str:
168241 return "\n " .join (launcher_script )
169242
170243 def write_to_log_dir (self ) -> Path :
244+ """Write the generated SLURM script to the log directory.
245+
246+ Creates a timestamped script file in the configured log directory.
247+
248+ Returns
249+ -------
250+ Path: Path to the generated SLURM script file
251+ """
171252 timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
172253 script_path : Path = (
173254 Path (self .params ["log_dir" ])
0 commit comments