Skip to content

Commit 86c1c3f

Browse files
committed
removed export vars.
1 parent 1d2a1ae commit 86c1c3f

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

vec_inf/cli/_slurm_script_generator.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ def __init__(self, params: dict[str, Any], src_dir: str):
2424
def _generate_script_content(self) -> str:
2525
preamble = self._generate_preamble()
2626
server = self._generate_server_script()
27-
env_exports = self._export_parallel_vars()
2827
launcher = self._generate_launcher()
2928
args = self._generate_shared_args()
30-
return preamble + server + env_exports + launcher + args
29+
return preamble + server + launcher + args
3130

3231
def _generate_preamble(self) -> str:
3332
base = [
@@ -43,26 +42,20 @@ def _generate_preamble(self) -> str:
4342
base += [""]
4443
return "\n".join(base)
4544

46-
def _export_parallel_vars(self) -> str:
47-
if self.is_multinode:
48-
return """if [ "$PIPELINE_PARALLELISM" = "True" ]; then
49-
export PIPELINE_PARALLEL_SIZE=$SLURM_JOB_NUM_NODES
50-
export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE
51-
else
52-
export PIPELINE_PARALLEL_SIZE=1
53-
export TENSOR_PARALLEL_SIZE=$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))
54-
fi
55-
56-
"""
57-
return "export TENSOR_PARALLEL_SIZE=$SLURM_GPUS_PER_NODE\n\n"
58-
5945
def _generate_shared_args(self) -> str:
46+
if self.is_multinode and not self.params["pipeline_parallelism"]:
47+
tensor_parallel_size = "$((SLURM_JOB_NUM_NODES*SLURM_GPUS_PER_NODE))"
48+
pipeline_parallel_size = "1"
49+
else:
50+
tensor_parallel_size = "$SLURM_GPUS_PER_NODE"
51+
pipeline_parallel_size = "$SLURM_JOB_NUM_NODES"
52+
6053
args = [
6154
f"--model {self.model_weights_path} \\",
6255
f"--served-model-name {self.params['model_name']} \\",
6356
'--host "0.0.0.0" \\',
6457
"--port $vllm_port_number \\",
65-
"--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \\",
58+
f"--tensor-parallel-size {tensor_parallel_size} \\",
6659
f"--dtype {self.params['data_type']} \\",
6760
"--trust-remote-code \\",
6861
f"--max-logprobs {self.params['vocab_size']} \\",
@@ -73,7 +66,7 @@ def _generate_shared_args(self) -> str:
7366
f"--task {self.task} \\",
7467
]
7568
if self.is_multinode:
76-
args.insert(4, "--pipeline-parallel-size ${PIPELINE_PARALLEL_SIZE} \\")
69+
args.insert(4, f"--pipeline-parallel-size {pipeline_parallel_size} \\")
7770
if self.params.get("max_num_batched_tokens"):
7871
args.append(
7972
f"--max-num-batched-tokens={self.params['max_num_batched_tokens']} \\"

0 commit comments

Comments
 (0)