1+ from datetime import datetime
12from pathlib import Path
3+ from typing import Any
24
35
46VLLM_TASK_MAP = {
1012
1113
1214class SlurmScriptGenerator :
13- def __init__ (self , params : dict , src_dir : str , is_multinode : bool = False ):
15+ def __init__ (self , params : dict [ str , Any ], src_dir : str ):
1416 self .params = params
1517 self .src_dir = src_dir
16- self .is_multinode = is_multinode
17- self .model_weights_path = Path (
18- params ["model_weights_parent_dir" ], params ["model_name" ]
18+ self .is_multinode = int ( self . params [ "num_nodes" ]) > 1
19+ self .model_weights_path = str (
20+ Path ( params ["model_weights_parent_dir" ], params ["model_name" ])
1921 )
2022 self .task = VLLM_TASK_MAP [self .params ["model_type" ]]
2123
2224 def _generate_script_content (self ) -> str :
23- return (
24- self ._generate_multinode_script ()
25- if self .is_multinode
26- else self ._generate_single_node_script ()
27- )
25+ preamble = self ._generate_preamble ()
26+ server = self ._generate_server_script ()
27+ env_exports = self ._export_parallel_vars ()
28+ launcher = self ._generate_launcher ()
29+ args = self ._generate_shared_args ()
30+ return preamble + server + env_exports + launcher + args
2831
29- def _generate_preamble (self , is_multinode : bool = False ) -> str :
32+ def _generate_preamble (self ) -> str :
3033 base = [
3134 "#!/bin/bash" ,
3235 "#SBATCH --cpus-per-task=16" ,
3336 "#SBATCH --mem=64G" ,
3437 ]
35- if is_multinode :
38+ if self . is_multinode :
3639 base += [
3740 "#SBATCH --exclusive" ,
3841 "#SBATCH --tasks-per-node=1" ,
3942 ]
40- base += [f"source { self . src_dir } /find_port.sh" , "" ]
43+ base += ["" ]
4144 return "\n " .join (base )
4245
4346 def _export_parallel_vars (self ) -> str :
4447 if self .is_multinode :
4548 return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
46- export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
47- export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
49+ export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
50+ export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
4851else
49- export PIPELINE_PARALLEL_SIZE=1
50- export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
52+ export PIPELINE_PARALLEL_SIZE=1
53+ export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
5154fi
55+
5256"""
53- return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n "
57+ return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n \n "
5458
55- def _generate_shared_args (self ) -> list [ str ] :
59+ def _generate_shared_args (self ) -> str :
5660 args = [
5761 f"--model { self .model_weights_path } \\ " ,
5862 f"--served-model-name { self .params ['model_name' ]} \\ " ,
@@ -81,56 +85,44 @@ def _generate_shared_args(self) -> list[str]:
8185 if self .params .get ("enforce_eager" ) == "True" :
8286 args .append ("--enforce-eager" )
8387
84- return args
85-
86- def _generate_single_node_script (self ) -> str :
87- preamble = self ._generate_preamble (is_multinode = False )
88-
89- server = f"""hostname=${{SLURMD_NODENAME}}
90- vllm_port_number=$(find_available_port ${{hostname}} 8080 65535)
91-
92- SERVER_ADDR="http://${{hostname}}:${{vllm_port_number}}/v1"
93- echo "Server address: $SERVER_ADDR"
88+ return "\n " .join (args )
9489
90+ def _generate_server_script (self ) -> str :
91+ server_script = ["" ]
92+ if self .params ["venv" ] == "singularity" :
93+ server_script .append ("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
94+ export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
95+ module load singularity-ce/3.8.2
96+ singularity exec $SINGULARITY_IMAGE ray stop
97+ """ )
98+ server_script .append (f"source { self .src_dir } /find_port.sh\n " )
99+ server_script .append (
100+ self ._generate_multinode_server_script ()
101+ if self .is_multinode
102+ else self ._generate_single_node_server_script ()
103+ )
104+ server_script .append (f"""echo "Updating server address in $JSON_PATH"
95105JSON_PATH="{ self .params ["log_dir" ]} /{ self .params ["model_name" ]} .$SLURM_JOB_ID/{ self .params ["model_name" ]} .$SLURM_JOB_ID.json"
96- echo "Updating server address in $JSON_PATH"
97106jq --arg server_addr "$SERVER_ADDR" \\
98107 '. + {{"server_address": $server_addr}}' \\
99108 "$JSON_PATH" > temp.json \\
100109 && mv temp.json "$JSON_PATH" \\
101110 && rm -f temp.json
102- """
103111
104- env_exports = self ._export_parallel_vars ()
105-
106- if self .params ["venv" ] == "singularity" :
107- launcher = f"""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
108- export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
109- module load singularity-ce/3.8.2
110- singularity exec $SINGULARITY_IMAGE ray stop
111- singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\
112- python3.10 -m vllm.entrypoints.openai.api_server \\
113- """
114- else :
115- launcher = f"""source { self .params ["venv" ]} /bin/activate
116- python3 -m vllm.entrypoints.openai.api_server \\
117- """
118-
119- args = "\n " .join (self ._generate_shared_args ())
120- return preamble + server + env_exports + launcher + args
112+ """ )
113+ return "\n " .join (server_script )
121114
122- def _generate_multinode_script (self ) -> str :
123- preamble = self ._generate_preamble (is_multinode = True )
115+ def _generate_single_node_server_script (self ) -> str :
116+ return """hostname=${SLURMD_NODENAME}
117+ vllm_port_number=$(find_available_port ${hostname} 8080 65535)
124118
125- cluster_setup = []
126- if self .params ["venv" ] == "singularity" :
127- cluster_setup .append ("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
128- export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
129- module load singularity-ce/3.8.2
130- singularity exec $SINGULARITY_IMAGE ray stop
131- """ )
119+ SERVER_ADDR="http://${hostname}:${vllm_port_number}/v1"
120+ echo "Server address: $SERVER_ADDR"
121+ """
132122
133- cluster_setup .append ("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
123+ def _generate_multinode_server_script (self ) -> str :
124+ server_script = []
125+ server_script .append ("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
134126nodes_array=($nodes)
135127
136128head_node=${nodes_array[0]}
@@ -146,11 +138,11 @@ def _generate_multinode_script(self) -> str:
146138srun --nodes=1 --ntasks=1 -w "$head_node" \\ """ )
147139
148140 if self .params ["venv" ] == "singularity" :
149- cluster_setup .append (
150- f""" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ "" "
141+ server_script .append (
142+ f" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ "
151143 )
152144
153- cluster_setup .append (""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
145+ server_script .append (""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
154146 --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
155147
156148sleep 10
@@ -162,48 +154,41 @@ def _generate_multinode_script(self) -> str:
162154 srun --nodes=1 --ntasks=1 -w "$node_i" \\ """ )
163155
164156 if self .params ["venv" ] == "singularity" :
165- cluster_setup .append (
157+ server_script .append (
166158 f""" singularity exec --nv --bind { self .model_weights_path } :{ self .model_weights_path } $SINGULARITY_IMAGE \\ """
167159 )
168- cluster_setup .append (f """ ray start --address "$ip_head" \\
169- --num-cpus "${{ SLURM_CPUS_PER_TASK}} " --num-gpus "${{ SLURM_GPUS_PER_NODE} }" --block &
160+ server_script .append (""" ray start --address "$ip_head" \\
161+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
170162 sleep 5
171163done
172164
173-
174165vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
175166
176-
177- SERVER_ADDR="http://${{head_node_ip}}:${{vllm_port_number}}/v1"
167+ SERVER_ADDR="http://${head_node_ip}:${vllm_port_number}/v1"
178168echo "Server address: $SERVER_ADDR"
179169
180- JSON_PATH="{ self .params ["log_dir" ]} /{ self .params ["model_name" ]} .$SLURM_JOB_ID/{ self .params ["model_name" ]} .$SLURM_JOB_ID.json"
181- echo "Updating server address in $JSON_PATH"
182- jq --arg server_addr "$SERVER_ADDR" \\
183- '. + {{"server_address": $server_addr}}' \\
184- "$JSON_PATH" > temp.json \\
185- && mv temp.json "$JSON_PATH" \\
186- && rm -f temp.json
187170""" )
188- cluster_setup = "\n " .join (cluster_setup )
189- env_exports = self ._export_parallel_vars ()
171+ return "\n " .join (server_script )
190172
173+ def _generate_launcher (self ) -> str :
191174 if self .params ["venv" ] == "singularity" :
192- launcher = f"""singularity exec --nv --bind { self . model_weights_path } : { self . model_weights_path } $SINGULARITY_IMAGE \\
193- python3.10 -m vllm.entrypoints.openai.api_server \\
194- """
175+ launcher_script = [
176+ f"""singularity exec --nv --bind { self . model_weights_path } : { self . model_weights_path } $SINGULARITY_IMAGE \\ """
177+ ]
195178 else :
196- launcher = f"""source { self .params ["venv" ]} /bin/activate
197- python3 -m vllm.entrypoints.openai.api_server \\
198- """
199-
200- args = "\n " .join (self ._generate_shared_args ())
201- return preamble + cluster_setup + env_exports + launcher + args
179+ launcher_script = [f"""source { self .params ["venv" ]} /bin/activate""" ]
180+ launcher_script .append (
181+ """python3.10 -m vllm.entrypoints.openai.api_server \\ \n """
182+ )
183+ return "\n " .join (launcher_script )
202184
203185 def write_to_log_dir (self ) -> Path :
204- log_subdir = Path (self .params ["log_dir" ]) / self .params ["model_name" ]
186+ log_subdir : Path = Path (self .params ["log_dir" ]) / self .params ["model_name" ]
205187 log_subdir .mkdir (parents = True , exist_ok = True )
206- script_path = log_subdir / "launch.slurm"
188+
189+ timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
190+ script_path : Path = log_subdir / f"launch_{ timestamp } .slurm"
191+
207192 content = self ._generate_script_content ()
208193 script_path .write_text (content )
209194 return script_path
0 commit comments