@@ -24,10 +24,9 @@ def __init__(self, params: dict[str, Any], src_dir: str):
2424 def _generate_script_content (self ) -> str :
2525 preamble = self ._generate_preamble ()
2626 server = self ._generate_server_script ()
27- env_exports = self ._export_parallel_vars ()
2827 launcher = self ._generate_launcher ()
2928 args = self ._generate_shared_args ()
30- return preamble + server + env_exports + launcher + args
29+ return preamble + server + launcher + args
3130
3231 def _generate_preamble (self ) -> str :
3332 base = [
@@ -43,26 +42,20 @@ def _generate_preamble(self) -> str:
4342 base += ["" ]
4443 return "\n " .join (base )
4544
46- def _export_parallel_vars (self ) -> str :
47- if self .is_multinode :
48- return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
49- export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
50- export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
51- else
52- export PIPELINE_PARALLEL_SIZE=1
53- export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
54- fi
55-
56- """
57- return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n \n "
58-
5945 def _generate_shared_args (self ) -> str :
46+ if self .is_multinode and not self .params ["pipeline_parallelism" ]:
47+ tensor_parallel_size = "$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))"
48+ pipeline_parallel_size = "1"
49+ else :
50+ tensor_parallel_size = "$SLURM_GPUS_PER_NODE"
51+ pipeline_parallel_size = "$SLURM_JOB_NUM_NODES"
52+
6053 args = [
6154 f"--model { self .model_weights_path } \\ " ,
6255 f"--served-model-name { self .params ['model_name' ]} \\ " ,
6356 '--host "0.0.0.0" \\ ' ,
6457 "--port $vllm_port_number \\ " ,
65- "--tensor-parallel-size ${TENSOR_PARALLEL_SIZE } \\ " ,
58+ f "--tensor-parallel-size { tensor_parallel_size } \\ " ,
6659 f"--dtype { self .params ['data_type' ]} \\ " ,
6760 "--trust-remote-code \\ " ,
6861 f"--max-logprobs { self .params ['vocab_size' ]} \\ " ,
@@ -73,7 +66,7 @@ def _generate_shared_args(self) -> str:
7366 f"--task { self .task } \\ " ,
7467 ]
7568 if self .is_multinode :
76- args .insert (4 , "--pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE } \\ " )
69+ args .insert (4 , f "--pipeline-parallel-size { pipeline_parallel_size } \\ " )
7770 if self .params .get ("max_num_batched_tokens" ):
7871 args .append (
7972 f"--max-num-batched-tokens={ self .params ['max_num_batched_tokens' ]} \\ "
0 commit comments