Skip to content

Commit 348a22e

Browse files
committed
Move all environment var declaration to python, remove unnecessary env var usage in generated slurm script, change non env var name to lower case
1 parent f6273a0 commit 348a22e

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

vec_inf/client/_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
BOOLEAN_FIELDS,
3131
LD_LIBRARY_PATH,
3232
REQUIRED_FIELDS,
33+
SINGULARITY_IMAGE,
3334
SRC_DIR,
35+
VLLM_NCCL_SO_PATH,
3436
)
3537

3638

@@ -139,6 +141,8 @@ def _get_launch_params(self) -> dict[str, Any]:
139141
def _set_env_vars(self) -> None:
140142
"""Set environment variables for the launch command."""
141143
os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
144+
os.environ["VLLM_NCCL_SO_PATH"] = VLLM_NCCL_SO_PATH
145+
os.environ["SINGULARITY_IMAGE"] = SINGULARITY_IMAGE
142146

143147
def _build_launch_command(self) -> str:
144148
"""Construct the full launch command with parameters."""

vec_inf/client/_slurm_script_generator.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def _generate_preamble(self) -> str:
3838

3939
def _generate_shared_args(self) -> str:
4040
if self.is_multinode and not self.params["pipeline_parallelism"]:
41-
tensor_parallel_size = "$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))"
42-
pipeline_parallel_size = "1"
41+
tensor_parallel_size = self.params["num_nodes"] * self.params["gpus_per_node"]
42+
pipeline_parallel_size = 1
4343
else:
44-
tensor_parallel_size = "$SLURM_GPUS_PER_NODE"
45-
pipeline_parallel_size = "$SLURM_JOB_NUM_NODES"
44+
tensor_parallel_size = self.params["gpus_per_node"]
45+
pipeline_parallel_size = self.params["num_nodes"]
4646

4747
args = [
4848
f"--model {self.model_weights_path} \\",
@@ -77,9 +77,7 @@ def _generate_shared_args(self) -> str:
7777
def _generate_server_script(self) -> str:
7878
server_script = [""]
7979
if self.params["venv"] == "singularity":
80-
server_script.append("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
81-
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
82-
module load singularity-ce/3.8.2
80+
server_script.append("""module load singularity-ce/3.8.2
8381
singularity exec $SINGULARITY_IMAGE ray stop
8482
""")
8583
server_script.append(f"source {self.src_dir}/find_port.sh\n")
@@ -88,13 +86,11 @@ def _generate_server_script(self) -> str:
8886
if self.is_multinode
8987
else self._generate_single_node_server_script()
9088
)
91-
server_script.append(f"""echo "Updating server address in $JSON_PATH"
92-
JSON_PATH="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
93-
jq --arg server_addr "$SERVER_ADDR" \\
89+
server_script.append(f"""json_path="{self.params["log_dir"]}/{self.params["model_name"]}.$SLURM_JOB_ID/{self.params["model_name"]}.$SLURM_JOB_ID.json"
90+
jq --arg server_addr "$server_address" \\
9491
'. + {{"server_address": $server_addr}}' \\
95-
"$JSON_PATH" > temp.json \\
96-
&& mv temp.json "$JSON_PATH" \\
97-
&& rm -f temp.json
92+
"$json_path" > temp.json \\
93+
&& mv temp.json "$json_path"
9894
9995
""")
10096
return "\n".join(server_script)
@@ -103,8 +99,8 @@ def _generate_single_node_server_script(self) -> str:
10399
return """hostname=${SLURMD_NODENAME}
104100
vllm_port_number=$(find_available_port ${hostname} 8080 65535)
105101
106-
SERVER_ADDR="http://${hostname}:${vllm_port_number}/v1"
107-
echo "Server address: $SERVER_ADDR"
102+
server_address="http://${hostname}:${vllm_port_number}/v1"
103+
echo "Server address: $server_address"
108104
"""
109105

110106
def _generate_multinode_server_script(self) -> str:
@@ -151,8 +147,8 @@ def _generate_multinode_server_script(self) -> str:
151147
152148
vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
153149
154-
SERVER_ADDR="http://${head_node_ip}:${vllm_port_number}/v1"
155-
echo "Server address: $SERVER_ADDR"
150+
server_address="http://${head_node_ip}:${vllm_port_number}/v1"
151+
echo "Server address: $server_address"
156152
157153
""")
158154
return "\n".join(server_script)

vec_inf/client/_vars.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
CACHED_CONFIG = Path("/", "model-weights", "vec-inf-shared", "models.yaml")
88
SRC_DIR = str(Path(__file__).parent.parent)
99
LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/"
10+
VLLM_NCCL_SO_PATH = "/vec-inf/nccl/libnccl.so.2.18.1"
11+
SINGULARITY_IMAGE = "/model-weights/vec-inf-shared/vector-inference_latest.sif"
1012

1113
# Maps model types to vLLM tasks
1214
VLLM_TASK_MAP = {

0 commit comments

Comments
 (0)