Skip to content

Commit 7eda3e2

Browse files
committed
Update slurm template for RDMA setup and binding to resolve Ray issues
1 parent 814bb7b commit 7eda3e2

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

vec_inf/client/_slurm_script_generator.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,29 @@ def __init__(self, params: dict[str, Any]):
4040
self.model_weights_path = str(
4141
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
4242
)
43+
self.env_str = self._generate_env_str()
44+
45+
def _generate_env_str(self) -> str:
46+
"""Generate the environment variables string for the Slurm script.
47+
48+
Returns
49+
-------
50+
str
51+
Formatted environment variables string for container or shell export commands.
52+
"""
4353
env_dict: dict[str, str] = self.params.get("env", {})
44-
# Create string of environment variables
45-
self.env_str = ""
46-
for key, val in env_dict.items():
47-
if len(self.env_str) == 0:
48-
self.env_str = "--env "
49-
else:
50-
self.env_str += ","
51-
self.env_str += key + "=" + val
54+
55+
if not env_dict:
56+
return ""
57+
58+
if self.use_container:
59+
# Format for container: --env KEY1=VAL1,KEY2=VAL2
60+
env_pairs = [f"{key}={val}" for key, val in env_dict.items()]
61+
return f"--env {','.join(env_pairs)}"
62+
else:
63+
# Format for shell: export KEY1=VAL1\nexport KEY2=VAL2
64+
export_lines = [f"export {key}={val}" for key, val in env_dict.items()]
65+
return "\n".join(export_lines)
5266

5367
def _generate_script_content(self) -> str:
5468
"""Generate the complete Slurm script content.
@@ -94,7 +108,12 @@ def _generate_server_setup(self) -> str:
94108
server_script = ["\n"]
95109
if self.use_container:
96110
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
97-
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["env_vars"]))
111+
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_env_vars"]))
112+
else:
113+
server_script.append(
114+
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
115+
)
116+
server_script.append(self.env_str)
98117
server_script.append(
99118
SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
100119
)
@@ -111,6 +130,11 @@ def _generate_server_setup(self) -> str:
111130
env_str=self.env_str,
112131
),
113132
)
133+
else:
134+
server_setup_str = server_setup_str.replace(
135+
"CONTAINER_PLACEHOLDER",
136+
"\\",
137+
)
114138
else:
115139
server_setup_str = "\n".join(
116140
SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
@@ -144,10 +168,7 @@ def _generate_launch_cmd(self) -> str:
144168
env_str=self.env_str,
145169
)
146170
)
147-
else:
148-
launcher_script.append(
149-
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
150-
)
171+
151172
launcher_script.append(
152173
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
153174
model_weights_path=self.model_weights_path,

vec_inf/client/_slurm_templates.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class SlurmScriptTemplate(TypedDict):
9696
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
9797
],
9898
"imports": "source {src_dir}/find_port.sh",
99-
"env_vars": [
100-
f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g')"
99+
"container_env_vars": [
100+
f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp"
101101
],
102102
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
103103
"activate_venv": "source {venv}/bin/activate",
@@ -112,6 +112,23 @@ class SlurmScriptTemplate(TypedDict):
112112
"nodes_array=($nodes)",
113113
"head_node=${{nodes_array[0]}}",
114114
'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
115+
"\n# Check for RDMA devices and set environment variable accordingly",
116+
"if ! command -v ibv_devices >/dev/null 2>&1; then",
117+
" echo \"ibv_devices not found; forcing TCP. (No RDMA userland on host?)\"",
118+
" export NCCL_IB_DISABLE=1",
119+
" export NCCL_ENV_ARG=\"--env NCCL_IB_DISABLE=1\"",
120+
"else",
121+
" # Pick GID index based on link layer (IB vs RoCE)",
122+
" if ibv_devinfo 2>/dev/null | grep -q \"link_layer:.*Ethernet\"; then",
123+
" # RoCEv2 typically needs a nonzero GID index; 3 is common, try 2 if your fabric uses it",
124+
" export NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-3}}",
125+
" export NCCL_ENV_ARG=\"--env NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-3}}\"",
126+
" else",
127+
" # Native InfiniBand => GID 0",
128+
" export NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-0}}",
129+
" export NCCL_ENV_ARG=\"--env NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-0}}\"",
130+
" fi",
131+
"fi",
115132
"\n# Start Ray head node",
116133
"head_node_port=$(find_available_port $head_node_ip 8080 65535)",
117134
"ray_head=$head_node_ip:$head_node_port",

0 commit comments

Comments
 (0)