Skip to content

Commit 85b7565

Browse files
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent fdd02d5 commit 85b7565

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

vec_inf/cli/_helper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,6 @@ def build_launch_command(self) -> str:
165165

166166
command_list.append(str(slurm_script_path))
167167
return " ".join(command_list)
168-
169-
170168

171169
def format_table_output(self, job_id: str) -> Table:
172170
"""Format output as rich Table."""
@@ -199,7 +197,10 @@ def format_table_output(self, job_id: str) -> Table:
199197
)
200198
if self.params.get("enforce_eager"):
201199
table.add_row("Enforce Eager", self.params["enforce_eager"])
202-
table.add_row("Model Weights Directory", str(Path(self.params["model_weights_parent_dir"], self.model_name)))
200+
table.add_row(
201+
"Model Weights Directory",
202+
str(Path(self.params["model_weights_parent_dir"], self.model_name)),
203+
)
203204
table.add_row("Log Directory", self.params["log_dir"])
204205

205206
return table

vec_inf/cli/_slurm_script_generator.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
from pathlib import Path
22

3+
34
VLLM_TASK_MAP = {
45
"LLM": "generate",
56
"VLM": "generate",
67
"Text_Embedding": "embed",
78
"Reward_Modeling": "reward",
89
}
910

11+
1012
class SlurmScriptGenerator:
1113
def __init__(self, params: dict, src_dir: str, is_multinode: bool = False):
1214
self.params = params
1315
self.src_dir = src_dir
1416
self.is_multinode = is_multinode
15-
self.model_weights_path = Path(params["model_weights_parent_dir"], params["model_name"])
17+
self.model_weights_path = Path(
18+
params["model_weights_parent_dir"], params["model_name"]
19+
)
1620
self.task = VLLM_TASK_MAP[self.params["model_type"]]
1721

1822
def _generate_script_content(self) -> str:
19-
return self._generate_multinode_script() if self.is_multinode else self._generate_single_node_script()
23+
return (
24+
self._generate_multinode_script()
25+
if self.is_multinode
26+
else self._generate_single_node_script()
27+
)
2028

2129
def _generate_preamble(self, is_multinode: bool = False) -> str:
2230
base = [
@@ -42,14 +50,13 @@ def _export_parallel_vars(self) -> str:
4250
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
4351
fi
4452
"""
45-
else:
46-
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n"
53+
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n"
4754

4855
def _generate_shared_args(self) -> list[str]:
4956
args = [
5057
f"--model {self.model_weights_path} \\",
5158
f"--served-model-name {self.params['model_name']} \\",
52-
"--host \"0.0.0.0\" \\",
59+
'--host "0.0.0.0" \\',
5360
"--port $vllm_port_number \\",
5461
"--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \\",
5562
f"--dtype {self.params['data_type']} \\",
@@ -64,7 +71,9 @@ def _generate_shared_args(self) -> list[str]:
6471
if self.is_multinode:
6572
args.insert(4, "--pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE} \\")
6673
if self.params.get("max_num_batched_tokens"):
67-
args.append(f"--max-num-batched-tokens={self.params['max_num_batched_tokens']} \\")
74+
args.append(
75+
f"--max-num-batched-tokens={self.params['max_num_batched_tokens']} \\"
76+
)
6877
if self.params.get("enable_prefix_caching") == "True":
6978
args.append("--enable-prefix-caching \\")
7079
if self.params.get("enable_chunked_prefill") == "True":
@@ -83,7 +92,7 @@ def _generate_single_node_script(self) -> str:
8392
SERVER_ADDR="http://${{hostname}}:${{vllm_port_number}}/v1"
8493
echo "Server address: $SERVER_ADDR"
8594
86-
JSON_PATH="{self.params['log_dir']}/{self.params['model_name']}.$SLURM_JOB_ID/{self.params['model_name']}.$SLURM_JOB_ID.json"
95+
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
8796
echo "Updating server address in $JSON_PATH"
8897
jq --arg server_addr "$SERVER_ADDR" \\
8998
'. + {{"server_address": $server_addr}}' \\
@@ -103,86 +112,88 @@ def _generate_single_node_script(self) -> str:
103112
python3.10 -m vllm.entrypoints.openai.api_server \\
104113
"""
105114
else:
106-
launcher = f"""source {self.params['venv']}/bin/activate
115+
launcher = f"""source {self.params["venv"]}/bin/activate
107116
python3 -m vllm.entrypoints.openai.api_server \\
108117
"""
109118

110119
args = "\n".join(self._generate_shared_args())
111120
return preamble + server + env_exports + launcher + args
112-
113-
121+
114122
def _generate_multinode_script(self) -> str:
115123
preamble = self._generate_preamble(is_multinode=True)
116124

117125
cluster_setup = []
118126
if self.params["venv"] == "singularity":
119-
cluster_setup.append(f"""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
127+
cluster_setup.append("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
120128
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
121129
module load singularity-ce/3.8.2
122130
singularity exec $SINGULARITY_IMAGE ray stop
123131
""")
124132

125-
cluster_setup.append(f"""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
133+
cluster_setup.append("""nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
126134
nodes_array=($nodes)
127135
128-
head_node=${{nodes_array[0]}}
136+
head_node=${nodes_array[0]}
129137
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
130138
131139
head_node_port=$(find_available_port $head_node_ip 8080 65535)
132140
133141
ip_head=$head_node_ip:$head_node_port
134142
export ip_head
135143
echo "IP Head: $ip_head"
136-
144+
137145
echo "Starting HEAD at $head_node"
138146
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")
139147

140148
if self.params["venv"] == "singularity":
141-
cluster_setup.append(f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\""")
149+
cluster_setup.append(
150+
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
151+
)
142152

143-
cluster_setup.append(f""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
144-
--num-cpus "${{SLURM_CPUS_PER_TASK}}" --num-gpus "${{SLURM_GPUS_PER_NODE}}" --block &
153+
cluster_setup.append(""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
154+
--num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
145155
146156
sleep 10
147157
worker_num=$((SLURM_JOB_NUM_NODES - 1))
148158
149159
for ((i = 1; i <= worker_num; i++)); do
150-
node_i=${{nodes_array[$i]}}
160+
node_i=${nodes_array[$i]}
151161
echo "Starting WORKER $i at $node_i"
152162
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")
153163

154164
if self.params["venv"] == "singularity":
155-
cluster_setup.append(f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\""")
165+
cluster_setup.append(
166+
f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\"""
167+
)
156168
cluster_setup.append(f""" ray start --address "$ip_head" \\
157169
--num-cpus "${{SLURM_CPUS_PER_TASK}}" --num-gpus "${{SLURM_GPUS_PER_NODE}}" --block &
158170
sleep 5
159171
done
160172
161-
173+
162174
vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
163175
164-
176+
165177
SERVER_ADDR="http://${{head_node_ip}}:${{vllm_port_number}}/v1"
166178
echo "Server address: $SERVER_ADDR"
167179
168-
JSON_PATH="{self.params['log_dir']}/{self.params['model_name']}.$SLURM_JOB_ID/{self.params['model_name']}.$SLURM_JOB_ID.json"
180+
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
169181
echo "Updating server address in $JSON_PATH"
170182
jq --arg server_addr "$SERVER_ADDR" \\
171183
'. + {{"server_address": $server_addr}}' \\
172184
"$JSON_PATH" > temp.json \\
173185
&& mv temp.json "$JSON_PATH" \\
174-
&& rm -f temp.json
186+
&& rm -f temp.json
175187
""")
176188
cluster_setup = "\n".join(cluster_setup)
177189
env_exports = self._export_parallel_vars()
178190

179-
180191
if self.params["venv"] == "singularity":
181192
launcher = f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\
182193
python3.10 -m vllm.entrypoints.openai.api_server \\
183194
"""
184195
else:
185-
launcher = f"""source {self.params['venv']}/bin/activate
196+
launcher = f"""source {self.params["venv"]}/bin/activate
186197
python3 -m vllm.entrypoints.openai.api_server \\
187198
"""
188199

0 commit comments

Comments
 (0)