@@ -38,11 +38,11 @@ def _generate_preamble(self) -> str:
3838
3939 def _generate_shared_args (self ) -> str :
4040 if self .is_multinode and not self .params ["pipeline_parallelism" ]:
41- tensor_parallel_size = "$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))"
42- pipeline_parallel_size = "1"
41+ tensor_parallel_size = self . params [ "num_nodes" ] * self . params [ "gpus_per_node" ]
42+ pipeline_parallel_size = 1
4343 else :
44- tensor_parallel_size = "$SLURM_GPUS_PER_NODE"
45- pipeline_parallel_size = "$SLURM_JOB_NUM_NODES"
44+ tensor_parallel_size = self . params [ "gpus_per_node" ]
45+ pipeline_parallel_size = self . params [ "num_nodes" ]
4646
4747 args = [
4848 f"--model { self .model_weights_path } \\ " ,
@@ -77,9 +77,7 @@ def _generate_shared_args(self) -> str:
7777 def _generate_server_script (self ) -> str :
7878 server_script = ["" ]
7979 if self .params ["venv" ] == "singularity" :
80- server_script .append ("""export SINGULARITY_IMAGE=/model-weights/vec-inf-shared/vector-inference_latest.sif
81- export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
82- module load singularity-ce/3.8.2
80+ server_script .append ("""module load singularity-ce/3.8.2
8381singularity exec $SINGULARITY_IMAGE ray stop
8482""" )
8583 server_script .append (f"source { self .src_dir } /find_port.sh\n " )
@@ -88,13 +86,11 @@ def _generate_server_script(self) -> str:
8886 if self .is_multinode
8987 else self ._generate_single_node_server_script ()
9088 )
91- server_script .append (f"""echo "Updating server address in $JSON_PATH"
92- JSON_PATH="{ self .params ["log_dir" ]} /{ self .params ["model_name" ]} .$SLURM_JOB_ID/{ self .params ["model_name" ]} .$SLURM_JOB_ID.json"
93- jq --arg server_addr "$SERVER_ADDR" \\
89+ server_script .append (f"""json_path="{ self .params ["log_dir" ]} /{ self .params ["model_name" ]} .$SLURM_JOB_ID/{ self .params ["model_name" ]} .$SLURM_JOB_ID.json"
90+ jq --arg server_addr "$server_address" \\
9491 '. + {{"server_address": $server_addr}}' \\
95- "$JSON_PATH" > temp.json \\
96- && mv temp.json "$JSON_PATH" \\
97- && rm -f temp.json
92+ "$json_path" > temp.json \\
93+ && mv temp.json "$json_path"
9894
9995""" )
10096 return "\n " .join (server_script )
@@ -103,8 +99,8 @@ def _generate_single_node_server_script(self) -> str:
10399 return """hostname=${SLURMD_NODENAME}
104100vllm_port_number=$(find_available_port ${hostname} 8080 65535)
105101
106- SERVER_ADDR ="http://${hostname}:${vllm_port_number}/v1"
107- echo "Server address: $SERVER_ADDR "
102+ server_address ="http://${hostname}:${vllm_port_number}/v1"
103+ echo "Server address: $server_address "
108104"""
109105
110106 def _generate_multinode_server_script (self ) -> str :
@@ -151,8 +147,8 @@ def _generate_multinode_server_script(self) -> str:
151147
152148vllm_port_number=$(find_available_port $head_node_ip 8080 65535)
153149
154- SERVER_ADDR ="http://${head_node_ip}:${vllm_port_number}/v1"
155- echo "Server address: $SERVER_ADDR "
150+ server_address ="http://${head_node_ip}:${vllm_port_number}/v1"
151+ echo "Server address: $server_address "
156152
157153""" )
158154 return "\n " .join (server_script )
0 commit comments