Skip to content

Commit 1d2a1ae

Browse files
committed
refactored slurm generation.
1 parent 8d88cb1 commit 1d2a1ae

File tree

4 files changed

+72
-331
lines changed

4 files changed

+72
-331
lines changed

vec_inf/cli/_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def build_launch_command(self) -> str:
155155
)
156156

157157
slurm_script_path = SlurmScriptGenerator(
158-
self.params, src_dir=SRC_DIR, is_multinode=int(self.params["num_nodes"]) > 1
158+
self.params, src_dir=SRC_DIR
159159
).write_to_log_dir()
160160

161161
command_list.append(str(slurm_script_path))
Lines changed: 71 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from datetime import datetime
12
from pathlib import Path
3+
from typing import Any
24

35

46
VLLM_TASK_MAP = {
@@ -10,49 +12,51 @@
1012

1113

1214
class SlurmScriptGenerator:
13-
def __init__(self, params: dict, src_dir: str, is_multinode: bool = False):
15+
def __init__(self, params: dict[str, Any], src_dir: str):
1416
self.params = params
1517
self.src_dir = src_dir
16-
self.is_multinode = is_multinode
17-
self.model_weights_path = Path(
18-
params["model_weights_parent_dir"], params["model_name"]
18+
self.is_multinode = int(self.params["num_nodes"]) > 1
19+
self.model_weights_path = str(
20+
Path(params["model_weights_parent_dir"], params["model_name"])
1921
)
2022
self.task = VLLM_TASK_MAP[self.params["model_type"]]
2123

2224
def _generate_script_content(self) -> str:
23-
return (
24-
self._generate_multinode_script()
25-
if self.is_multinode
26-
else self._generate_single_node_script()
27-
)
25+
preamble = self._generate_preamble()
26+
server = self._generate_server_script()
27+
env_exports = self._export_parallel_vars()
28+
launcher = self._generate_launcher()
29+
args = self._generate_shared_args()
30+
return preamble + server + env_exports + launcher + args
2831

29-
def _generate_preamble(self, is_multinode: bool = False) -> str:
32+
def _generate_preamble(self) -> str:
3033
base = [
3134
"#!/bin/bash",
3235
"#SBATCH --cpus-per-task=16",
3336
"#SBATCH --mem=64G",
3437
]
35-
if is_multinode:
38+
if self.is_multinode:
3639
base += [
3740
"#SBATCH --exclusive",
3841
"#SBATCH --tasks-per-node=1",
3942
]
40-
base += [f"source {self.src_dir}/find_port.sh", ""]
43+
base += [""]
4144
return "\n".join(base)
4245

4346
def _export_parallel_vars(self) -> str:
4447
if self.is_multinode:
4548
return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
46-
export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
47-
export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
49+
export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
50+
export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
4851
else
49-
export PIPELINE_PARALLEL_SIZE=1
50-
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
52+
export PIPELINE_PARALLEL_SIZE=1
53+
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
5154
fi
55+
5256
"""
53-
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n"
57+
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n\n"
5458

55-
def _generate_shared_args(self) -> list[str]:
59+
def _generate_shared_args(self) -> str:
5660
args = [
5761
f"--model {self.model_weights_path} \\",
5862
f"--served-model-name {self.params['model_name']} \\",
@@ -81,56 +85,44 @@ def _generate_shared_args(self) -> list[str]:
8185
if self.params.get("enforce_eager") == "True":
8286
args.append("--enforce-eager")
8387

84-
return args
85-
86-
def _generate_single_node_script(self) -> str:
87-
preamble = self._generate_preamble(is_multinode=False)
88-
89-
server = f"""hostname=${{SLURMD_NODENAME}}
90-
vllm_port_number=$(find_available_port ${{hostname}} 8080 65535)
91-
92-
SERVER_ADDR="http://${{hostname}}:${{vllm_port_number}}/v1"
93-
echo "Server address: $SERVER_ADDR"
88+
return "\n".join(args)
9489

90+
def _generate_server_script(self) -> str:
91+
server_script = [""]
92+
if self.params["venv"] == "singularity":
93+
server_script.append("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
94+
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
95+
module load singularity-ce/3.8.2
96+
singularity exec $SINGULARITY_IMAGE ray stop
97+
""")
98+
server_script.append(f"source {self.src_dir}/find_port.sh\n")
99+
server_script.append(
100+
self._generate_multinode_server_script()
101+
if self.is_multinode
102+
else self._generate_single_node_server_script()
103+
)
104+
server_script.append(f"""echo "Updating server address in $JSON_PATH"
95105
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
96-
echo "Updating server address in $JSON_PATH"
97106
jq --arg server_addr "$SERVER_ADDR" \\
98107
'. + {{"server_address": $server_addr}}' \\
99108
"$JSON_PATH" > temp.json \\
100109
&& mv temp.json "$JSON_PATH" \\
101110
&& rm -f temp.json
102-
"""
103111
104-
env_exports = self._export_parallel_vars()
105-
106-
if self.params["venv"] == "singularity":
107-
launcher = f"""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
108-
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
109-
module load singularity-ce/3.8.2
110-
singularity exec $SINGULARITY_IMAGE ray stop
111-
singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\
112-
python3.10 -m vllm.entrypoints.openai.api_server \\
113-
"""
114-
else:
115-
launcher = f"""source {self.params["venv"]}/bin/activate
116-
python3 -m vllm.entrypoints.openai.api_server \\
117-
"""
118-
119-
args = "\n".join(self._generate_shared_args())
120-
return preamble + server + env_exports + launcher + args
112+
""")
113+
return "\n".join(server_script)
121114

122-
def _generate_multinode_script(self) -> str:
123-
preamble = self._generate_preamble(is_multinode=True)
115+
def _generate_single_node_server_script(self) -> str:
116+
return """hostname=${SLURMD_NODENAME}
117+
vllm_port_number=$(find_available_port ${hostname} 8080 65535)
124118
125-
cluster_setup = []
126-
if self.params["venv"] == "singularity":
127-
cluster_setup.append("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
128-
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
129-
module load singularity-ce/3.8.2
130-
singularity exec $SINGULARITY_IMAGE ray stop
131-
""")
119+
SERVER_ADDR="http://${hostname}:${vllm_port_number}/v1"
120+
echo "Server address: $SERVER_ADDR"
121+
"""
132122

133-
cluster_setup.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
123+
def _generate_multinode_server_script(self) -> str:
124+
server_script = []
125+
server_script.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
134126
nodes_array=($nodes)
135127
136128
head_node=${nodes_array[0]}
@@ -146,11 +138,11 @@ def _generate_multinode_script(self) -> str:
146138
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")
147139

148140
if self.params["venv"] == "singularity":
149-
cluster_setup.append(
150-
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
141+
server_script.append(
142+
f" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"
151143
)
152144

153-
cluster_setup.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
145+
server_script.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
154146
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
155147
156148
sleep 10
@@ -162,48 +154,41 @@ def _generate_multinode_script(self) -> str:
162154
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")
163155

164156
if self.params["venv"] == "singularity":
165-
cluster_setup.append(
157+
server_script.append(
166158
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
167159
)
168-
cluster_setup.append(f""" ray start --address "$ip_head" \\
169-
--num-cpus "${{SLURM_CPUS_PER_TASK}}" --num-gpus "${{SLURM_GPUS_PER_NODE}}" --block &
160+
server_script.append(""" ray start --address "$ip_head" \\
161+
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
170162
sleep 5
171163
done
172164
173-
174165
vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
175166
176-
177-
SERVER_ADDR="http://${{head_node_ip}}:${{vllm_port_number}}/v1"
167+
SERVER_ADDR="http://${head_node_ip}:${vllm_port_number}/v1"
178168
echo "Server address: $SERVER_ADDR"
179169
180-
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
181-
echo "Updating server address in $JSON_PATH"
182-
jq --arg server_addr "$SERVER_ADDR" \\
183-
'. + {{"server_address": $server_addr}}' \\
184-
"$JSON_PATH" > temp.json \\
185-
&& mv temp.json "$JSON_PATH" \\
186-
&& rm -f temp.json
187170
""")
188-
cluster_setup = "\n".join(cluster_setup)
189-
env_exports = self._export_parallel_vars()
171+
return "\n".join(server_script)
190172

173+
def _generate_launcher(self) -> str:
191174
if self.params["venv"] == "singularity":
192-
launcher = f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\
193-
python3.10 -m vllm.entrypoints.openai.api_server \\
194-
"""
175+
launcher_script = [
176+
f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
177+
]
195178
else:
196-
launcher = f"""source {self.params["venv"]}/bin/activate
197-
python3 -m vllm.entrypoints.openai.api_server \\
198-
"""
199-
200-
args = "\n".join(self._generate_shared_args())
201-
return preamble + cluster_setup + env_exports + launcher + args
179+
launcher_script = [f"""source {self.params["venv"]}/bin/activate"""]
180+
launcher_script.append(
181+
"""python3.10 -m vllm.entrypoints.openai.api_server \\\n"""
182+
)
183+
return "\n".join(launcher_script)
202184

203185
def write_to_log_dir(self) -> Path:
204-
log_subdir = Path(self.params["log_dir"]) / self.params["model_name"]
186+
log_subdir: Path = Path(self.params["log_dir"]) / self.params["model_name"]
205187
log_subdir.mkdir(parents=True, exist_ok=True)
206-
script_path = log_subdir / "launch.slurm"
188+
189+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
190+
script_path: Path = log_subdir / f"launch_{timestamp}.slurm"
191+
207192
content = self._generate_script_content()
208193
script_path.write_text(content)
209194
return script_path

0 commit comments

Comments
 (0)