Skip to content

Commit 037f9d0

Browse files
committed
Generate Slurm files dynamically and fix issues in venv.sh
1 parent 45ef380 commit 037f9d0

File tree

4 files changed

+216
-33
lines changed

4 files changed

+216
-33
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "Efficient LLM inference on Slurm clusters using vLLM."
55
readme = "README.md"
66
authors = [{name = "Marshall Wang", email = "marshall.wang@vectorinstitute.ai"}]
77
license = "MIT"
8-
requires-python = ">=3.10"
8+
requires-python = ">=3.10,<4.0"
99
dependencies = [
1010
"requests>=2.31.0",
1111
"click>=8.1.0",

vec_inf/cli/_helper.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import vec_inf.cli._utils as utils
1818
from vec_inf.cli._config import ModelConfig
19+
from vec_inf.cli._slurm_script_generator import SlurmScriptGenerator
1920

2021

2122
VLLM_TASK_MAP = {
@@ -127,31 +128,7 @@ def _get_launch_params(self) -> dict[str, Any]:
127128

128129
def set_env_vars(self) -> None:
129130
"""Set environment variables for the launch command."""
130-
os.environ["MODEL_NAME"] = self.model_name
131-
os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"]
132-
os.environ["MAX_LOGPROBS"] = self.params["vocab_size"]
133-
os.environ["DATA_TYPE"] = self.params["data_type"]
134-
os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"]
135-
os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
136-
os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
137-
os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
138-
os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
139-
os.environ["SRC_DIR"] = SRC_DIR
140-
os.environ["MODEL_WEIGHTS"] = str(
141-
Path(self.params["model_weights_parent_dir"], self.model_name)
142-
)
143131
os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
144-
os.environ["VENV_BASE"] = self.params["venv"]
145-
os.environ["LOG_DIR"] = self.params["log_dir"]
146-
147-
if self.params.get("enable_prefix_caching"):
148-
os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
149-
if self.params.get("enable_chunked_prefill"):
150-
os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
151-
if self.params.get("max_num_batched_tokens"):
152-
os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
153-
if self.params.get("enforce_eager"):
154-
os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
155132

156133
def build_launch_command(self) -> str:
157134
"""Construct the full launch command with parameters."""
@@ -177,11 +154,19 @@ def build_launch_command(self) -> str:
177154
]
178155
)
179156
# Add slurm script
180-
slurm_script = "vllm.slurm"
181-
if int(self.params["num_nodes"]) > 1:
182-
slurm_script = "multinode_vllm.slurm"
183-
command_list.append(f"{SRC_DIR}/{slurm_script}")
157+
# slurm_script = "vllm.slurm"
158+
# if int(self.params["num_nodes"]) > 1:
159+
# slurm_script = "multinode_vllm.slurm"
160+
# command_list.append(f"{SRC_DIR}/{slurm_script}")
161+
162+
slurm_script_path = SlurmScriptGenerator(
163+
self.params, src_dir=SRC_DIR, is_multinode=int(self.params["num_nodes"]) > 1
164+
).write_to_log_dir()
165+
166+
command_list.append(str(slurm_script_path))
184167
return " ".join(command_list)
168+
169+
185170

186171
def format_table_output(self, job_id: str) -> Table:
187172
"""Format output as rich Table."""
@@ -214,7 +199,7 @@ def format_table_output(self, job_id: str) -> Table:
214199
)
215200
if self.params.get("enforce_eager"):
216201
table.add_row("Enforce Eager", self.params["enforce_eager"])
217-
table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
202+
table.add_row("Model Weights Directory", str(Path(self.params["model_weights_parent_dir"], self.model_name)))
218203
table.add_row("Log Directory", self.params["log_dir"])
219204

220205
return table
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from pathlib import Path
2+
3+
VLLM_TASK_MAP = {
4+
"LLM": "generate",
5+
"VLM": "generate",
6+
"Text_Embedding": "embed",
7+
"Reward_Modeling": "reward",
8+
}
9+
10+
class SlurmScriptGenerator:
11+
def __init__(self, params: dict, src_dir: str, is_multinode: bool = False):
12+
self.params = params
13+
self.src_dir = src_dir
14+
self.is_multinode = is_multinode
15+
self.model_weights_path = Path(params["model_weights_parent_dir"], params["model_name"])
16+
self.task = VLLM_TASK_MAP[self.params["model_type"]]
17+
18+
def _generate_script_content(self) -> str:
19+
return self._generate_multinode_script() if self.is_multinode else self._generate_single_node_script()
20+
21+
def _generate_preamble(self, is_multinode: bool = False) -> str:
22+
base = [
23+
"#!/bin/bash",
24+
"#SBATCH --cpus-per-task=16",
25+
"#SBATCH --mem=64G",
26+
]
27+
if is_multinode:
28+
base += [
29+
"#SBATCH --exclusive",
30+
"#SBATCH --tasks-per-node=1",
31+
]
32+
base += [f"source {self.src_dir}/find_port.sh", ""]
33+
return "\n".join(base)
34+
35+
def _export_parallel_vars(self) -> str:
36+
if self.is_multinode:
37+
return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
38+
export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
39+
export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
40+
else
41+
export PIPELINE_PARALLEL_SIZE=1
42+
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
43+
fi
44+
"""
45+
else:
46+
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n"
47+
48+
def _generate_shared_args(self) -> list[str]:
49+
args = [
50+
f"--model {self.model_weights_path} \\",
51+
f"--served-model-name {self.params['model_name']} \\",
52+
"--host \"0.0.0.0\" \\",
53+
"--port $vllm_port_number \\",
54+
"--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \\",
55+
f"--dtype {self.params['data_type']} \\",
56+
"--trust-remote-code \\",
57+
f"--max-logprobs {self.params['vocab_size']} \\",
58+
f"--max-model-len {self.params['max_model_len']} \\",
59+
f"--max-num-seqs {self.params['max_num_seqs']} \\",
60+
f"--gpu-memory-utilization {self.params['gpu_memory_utilization']} \\",
61+
f"--compilation-config {self.params['compilation_config']} \\",
62+
f"--task {self.task} \\",
63+
]
64+
if self.is_multinode:
65+
args.insert(4, "--pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE} \\")
66+
if self.params.get("max_num_batched_tokens"):
67+
args.append(f"--max-num-batched-tokens={self.params['max_num_batched_tokens']} \\")
68+
if self.params.get("enable_prefix_caching") == "True":
69+
args.append("--enable-prefix-caching \\")
70+
if self.params.get("enable_chunked_prefill") == "True":
71+
args.append("--enable-chunked-prefill \\")
72+
if self.params.get("enforce_eager") == "True":
73+
args.append("--enforce-eager")
74+
75+
return args
76+
77+
def _generate_single_node_script(self) -> str:
78+
preamble = self._generate_preamble(is_multinode=False)
79+
80+
server = f"""hostname=${{SLURMD_NODENAME}}
81+
vllm_port_number=$(find_available_port ${{hostname}} 8080 65535)
82+
83+
SERVER_ADDR="http://${{hostname}}:${{vllm_port_number}}/v1"
84+
echo "Server address: $SERVER_ADDR"
85+
86+
JSON_PATH="{self.params['log_dir']}/{self.params['model_name']}.$SLURM_JOB_ID/{self.params['model_name']}.$SLURM_JOB_ID.json"
87+
echo "Updating server address in $JSON_PATH"
88+
jq --arg server_addr "$SERVER_ADDR" \\
89+
'. + {{"server_address": $server_addr}}' \\
90+
"$JSON_PATH" > temp.json \\
91+
&& mv temp.json "$JSON_PATH" \\
92+
&& rm -f temp.json
93+
"""
94+
95+
env_exports = self._export_parallel_vars()
96+
97+
if self.params["venv"] == "singularity":
98+
launcher = f"""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
99+
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
100+
module load singularity-ce/3.8.2
101+
singularity exec $SINGULARITY_IMAGE ray stop
102+
singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\
103+
python3.10 -m vllm.entrypoints.openai.api_server \\
104+
"""
105+
else:
106+
launcher = f"""source {self.params['venv']}/bin/activate
107+
python3 -m vllm.entrypoints.openai.api_server \\
108+
"""
109+
110+
args = "\n".join(self._generate_shared_args())
111+
return preamble + server + env_exports + launcher + args
112+
113+
114+
def _generate_multinode_script(self) -> str:
115+
preamble = self._generate_preamble(is_multinode=True)
116+
117+
cluster_setup = []
118+
if self.params["venv"] == "singularity":
119+
cluster_setup.append(f"""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
120+
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
121+
module load singularity-ce/3.8.2
122+
singularity exec $SINGULARITY_IMAGE ray stop
123+
""")
124+
125+
cluster_setup.append(f"""nodes=$(scontrol show hostnames "${{SLURM_JOB_NODELIST}}")
126+
nodes_array=(${{nodes}})
127+
128+
head_node=${{nodes_array[0]}}
129+
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
130+
131+
head_node_port=$(find_available_port $head_node_ip 8080 65535)
132+
vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
133+
134+
ip_head=$head_node_ip:$head_node_port
135+
export ip_head
136+
echo "IP Head: $ip_head"
137+
138+
echo "Starting HEAD at $head_node"
139+
srun --nodes=1 --ntasks=1 -w "$head_node" \\""")
140+
141+
if self.params["venv"] == "singularity":
142+
cluster_setup.append(f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\""")
143+
144+
cluster_setup.append(f""" ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\
145+
--num-cpus "${{SLURM_CPUS_PER_TASK}}" --num-gpus "${{SLURM_GPUS_PER_NODE}}" --block &
146+
147+
sleep 10
148+
worker_num=$((SLURM_JOB_NUM_NODES - 1))
149+
150+
for ((i = 1; i <= worker_num; i++)); do
151+
node_i=${{nodes_array[$i]}}
152+
echo "Starting WORKER $i at $node_i"
153+
srun --nodes=1 --ntasks=1 -w "$node_i" \\""")
154+
155+
if self.params["venv"] == "singularity":
156+
cluster_setup.append(f""" singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\""")
157+
cluster_setup.append(f""" ray start --address "$ip_head" \\
158+
--num-cpus "${{SLURM_CPUS_PER_TASK}}" --num-gpus "${{SLURM_GPUS_PER_NODE}}" --block &
159+
sleep 5
160+
done
161+
162+
SERVER_ADDR="http://$head_node_ip:$vllm_port_number/v1"
163+
echo "Server address: $SERVER_ADDR"
164+
165+
JSON_PATH="{self.params['log_dir']}/{self.params['model_name']}.$SLURM_JOB_ID/{self.params['model_name']}.$SLURM_JOB_ID.json"
166+
echo "Updating server address in $JSON_PATH"
167+
jq --arg server_addr "$SERVER_ADDR" \\
168+
'. + {{"server_address": $server_addr}}' \\
169+
"$JSON_PATH" > temp.json \\
170+
&& mv temp.json "$JSON_PATH" \\
171+
&& rm -f temp.json
172+
""")
173+
cluster_setup = "\n".join(cluster_setup)
174+
env_exports = self._export_parallel_vars()
175+
176+
177+
if self.params["venv"] == "singularity":
178+
launcher = f"""singularity exec --nv --bind {self.model_weights_path}:{self.model_weights_path} $SINGULARITY_IMAGE \\
179+
python3.10 -m vllm.entrypoints.openai.api_server \\
180+
"""
181+
else:
182+
launcher = f"""source {self.params['venv']}/bin/activate
183+
python3 -m vllm.entrypoints.openai.api_server \\
184+
"""
185+
186+
args = "\n".join(self._generate_shared_args())
187+
return preamble + cluster_setup + env_exports + launcher + args
188+
189+
def write_to_log_dir(self) -> Path:
190+
log_subdir = Path(self.params["log_dir"]) / self.params["model_name"]
191+
log_subdir.mkdir(parents=True, exist_ok=True)
192+
script_path = log_subdir / "launch.slurm"
193+
content = self._generate_script_content()
194+
script_path.write_text(content)
195+
return script_path

venv.sh

100644100755
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#!bin/bash
1+
#!/bin/bash
22

33
# Load python module if you are on Vector cluster and install poetry
44
module load python/3.10.12
5-
pip install poetry
5+
pip3 install poetry
66

77
# Optional: it's recommended to change the cache directory to somewhere in the scratch space to avoid
88
# running out of space in your home directory, below is an example for the Vector cluster
@@ -13,11 +13,14 @@ export POETRY_CACHE_DIR=/scratch/ssd004/scratch/$(whoami)/poetry_cache
1313
# poetry config cache-dir
1414
echo "Cache directory set to: $(poetry config cache-dir)"
1515

16+
echo "📜 Telling Poetry to use Python 3.10..."
17+
poetry env use python3.10
18+
1619
# Install dependencies via poetry
1720
poetry install
1821

1922
# Activate the virtual environment
20-
poetry shell
23+
# poetry shell
2124

2225
# Deactivate the virtual environment
2326
# deactivate

0 commit comments

Comments
 (0)