Skip to content

Commit 35e9a84

Browse files
committed
Decouple slurm variables from package and move them into a config, default cached config directory path is still set within the package
1 parent b3db973 commit 35e9a84

File tree

6 files changed

+136
-73
lines changed

6 files changed

+136
-73
lines changed

vec_inf/client/_slurm_templates.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from typing import TypedDict
88

9-
from vec_inf.client.slurm_vars import (
9+
from vec_inf.client._slurm_vars import (
1010
LD_LIBRARY_PATH,
1111
SINGULARITY_IMAGE,
1212
SINGULARITY_LOAD_CMD,
13+
SINGULARITY_MODULE_NAME,
1314
VLLM_NCCL_SO_PATH,
1415
)
1516

@@ -93,14 +94,14 @@ class SlurmScriptTemplate(TypedDict):
9394
},
9495
"singularity_setup": [
9596
SINGULARITY_LOAD_CMD,
96-
f"singularity exec {SINGULARITY_IMAGE} ray stop",
97+
f"{SINGULARITY_MODULE_NAME} exec {SINGULARITY_IMAGE} ray stop",
9798
],
9899
"imports": "source {src_dir}/find_port.sh",
99100
"env_vars": [
100101
f"export LD_LIBRARY_PATH={LD_LIBRARY_PATH}",
101102
f"export VLLM_NCCL_SO_PATH={VLLM_NCCL_SO_PATH}",
102103
],
103-
"singularity_command": f"singularity exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {SINGULARITY_IMAGE} \\",
104+
"singularity_command": f"{SINGULARITY_MODULE_NAME} exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {SINGULARITY_IMAGE} \\",
104105
"activate_venv": "source {venv}/bin/activate",
105106
"server_setup": {
106107
"single_node": [
@@ -240,7 +241,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
240241
' "$json_path" > temp_{model_name}.json \\',
241242
' && mv temp_{model_name}.json "$json_path"\n',
242243
],
243-
"singularity_command": f"singularity exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {SINGULARITY_IMAGE} \\",
244+
"singularity_command": f"{SINGULARITY_MODULE_NAME} exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {SINGULARITY_IMAGE} \\",
244245
"launch_cmd": [
245246
"vllm serve {model_weights_path} \\",
246247
" --served-model-name {model_name} \\",

vec_inf/client/_slurm_vars.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Slurm cluster configuration variables."""
2+
3+
import os
4+
import warnings
5+
from pathlib import Path
6+
from typing import Any, TypeAlias
7+
8+
import yaml
9+
from typing_extensions import Literal
10+
11+
12+
CACHED_CONFIG_DIR = Path("/model-weights/vec-inf-shared")
13+
14+
15+
def load_env_config() -> dict[str, Any]:
16+
"""Load the environment configuration."""
17+
18+
def load_yaml_config(path: Path) -> dict[str, Any]:
19+
"""Load YAML config with error handling."""
20+
try:
21+
with path.open() as f:
22+
return yaml.safe_load(f) or {}
23+
except FileNotFoundError as err:
24+
raise FileNotFoundError(f"Could not find config: {path}") from err
25+
except yaml.YAMLError as err:
26+
raise ValueError(f"Error parsing YAML config at {path}: {err}") from err
27+
28+
cached_config_path = CACHED_CONFIG_DIR / "environment.yaml"
29+
default_path = (
30+
cached_config_path
31+
if cached_config_path.exists()
32+
else Path(__file__).resolve().parent.parent / "config" / "environment.yaml"
33+
)
34+
config = load_yaml_config(default_path)
35+
36+
user_path = os.getenv("VEC_INF_CONFIG_DIR")
37+
if user_path:
38+
user_path_obj = Path(user_path, "environment.yaml")
39+
if user_path_obj.exists():
40+
user_config = load_yaml_config(user_path_obj)
41+
config.update(user_config)
42+
else:
43+
warnings.warn(
44+
f"WARNING: Could not find user config directory: {user_path}, revert to default config located at {default_path}",
45+
UserWarning,
46+
stacklevel=2,
47+
)
48+
49+
return config
50+
51+
52+
_config = load_env_config()
53+
54+
# Extract path values
55+
LD_LIBRARY_PATH = _config["paths"]["ld_library_path"]
56+
SINGULARITY_IMAGE = _config["paths"]["image_path"]
57+
VLLM_NCCL_SO_PATH = _config["paths"]["vllm_nccl_so_path"]
58+
59+
# Extract containerization info
60+
SINGULARITY_LOAD_CMD = _config["containerization"]["module_load_cmd"]
61+
SINGULARITY_MODULE_NAME = _config["containerization"]["module_name"]
62+
63+
# Extract limits
64+
MAX_GPUS_PER_NODE = _config["limits"]["max_gpus_per_node"]
65+
MAX_NUM_NODES = _config["limits"]["max_num_nodes"]
66+
MAX_CPUS_PER_TASK = _config["limits"]["max_cpus_per_task"]
67+
68+
# Create dynamic Literal types
69+
QOS: TypeAlias = str # Runtime validation will handle the actual values
70+
PARTITION: TypeAlias = str # Runtime validation will handle the actual values
71+
QOS = Literal[tuple(_config["allowed_values"]["qos"])]
72+
PARTITION = Literal[tuple(_config["allowed_values"]["partition"])]
73+
74+
# Extract default arguments
75+
DEFAULT_ARGS = _config["default_args"]

vec_inf/client/_utils.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import yaml
1616

1717
from vec_inf.client._client_vars import MODEL_READY_SIGNATURE
18+
from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR
1819
from vec_inf.client.config import ModelConfig
1920
from vec_inf.client.models import ModelStatus
20-
from vec_inf.client.slurm_vars import CACHED_CONFIG
2121

2222

2323
def run_bash_command(command: str) -> tuple[str, str]:
@@ -217,44 +217,54 @@ def load_yaml_config(path: Path) -> dict[str, Any]:
217217
except yaml.YAMLError as err:
218218
raise ValueError(f"Error parsing YAML config at {path}: {err}") from err
219219

220-
# 1. If config_path is given, use only that
221-
if config_path:
222-
config = load_yaml_config(Path(config_path))
220+
def process_config(config: dict[str, Any]) -> list[ModelConfig]:
221+
"""Process the config based on the config type."""
223222
return [
224223
ModelConfig(model_name=name, **model_data)
225224
for name, model_data in config.get("models", {}).items()
226225
]
226+
227+
228+
def update_config(
229+
config: dict[str, Any], user_config: dict[str, Any]
230+
) -> dict[str, Any]:
231+
"""Update the config with the user config."""
232+
for name, data in user_config.get("models", {}).items():
233+
if name in config.get("models", {}):
234+
config["models"][name].update(data)
235+
else:
236+
config.setdefault("models", {})[name] = data
237+
238+
return config
239+
240+
# 1. If config_path is given, use only that
241+
if config_path:
242+
config = load_yaml_config(Path(config_path))
243+
return process_config(config)
227244

228245
# 2. Otherwise, load default config
229246
default_path = (
230-
CACHED_CONFIG
231-
if CACHED_CONFIG.exists()
247+
CACHED_CONFIG_DIR / "models_latest.yaml"
248+
if CACHED_CONFIG_DIR.exists()
232249
else Path(__file__).resolve().parent.parent / "config" / "models.yaml"
233250
)
234251
config = load_yaml_config(default_path)
235252

236253
# 3. If user config exists, merge it
237-
user_path = os.getenv("VEC_INF_CONFIG")
254+
user_path = os.getenv("VEC_INF_CONFIG_DIR")
238255
if user_path:
239-
user_path_obj = Path(user_path)
256+
user_path_obj = Path(user_path, "models.yaml")
240257
if user_path_obj.exists():
241258
user_config = load_yaml_config(user_path_obj)
242-
for name, data in user_config.get("models", {}).items():
243-
if name in config.get("models", {}):
244-
config["models"][name].update(data)
245-
else:
246-
config.setdefault("models", {})[name] = data
259+
config = update_config(config, user_config)
247260
else:
248261
warnings.warn(
249-
f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}",
262+
f"WARNING: Could not find user config directory: {user_path}, revert to default config located at {default_path}",
250263
UserWarning,
251264
stacklevel=2,
252265
)
253266

254-
return [
255-
ModelConfig(model_name=name, **model_data)
256-
for name, model_data in config.get("models", {}).items()
257-
]
267+
return process_config(config)
258268

259269

260270
def parse_launch_output(output: str) -> tuple[str, dict[str, str]]:

vec_inf/client/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import BaseModel, ConfigDict, Field
1111
from typing_extensions import Literal
1212

13-
from vec_inf.client.slurm_vars import (
13+
from vec_inf.client._slurm_vars import (
1414
DEFAULT_ARGS,
1515
MAX_CPUS_PER_TASK,
1616
MAX_GPUS_PER_NODE,
@@ -132,7 +132,6 @@ class ModelConfig(BaseModel):
132132
vllm_args: Optional[dict[str, Any]] = Field(
133133
default={}, description="vLLM engine arguments"
134134
)
135-
136135
model_config = ConfigDict(
137136
extra="forbid", str_strip_whitespace=True, validate_default=True, frozen=True
138137
)

vec_inf/client/slurm_vars.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

vec_inf/config/environment.yaml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
paths:
2+
ld_library_path: "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/"
3+
image_path: "/model-weights/vec-inf-shared/vector-inference_latest.sif"
4+
vllm_nccl_so_path: "/vec-inf/nccl/libnccl.so.2.18.1"
5+
6+
containerization:
7+
module_load_cmd: "module load singularity-ce/3.8.2"
8+
module_name: "singularity"
9+
10+
limits:
11+
max_gpus_per_node: 8
12+
max_num_nodes: 16
13+
max_cpus_per_task: 128
14+
15+
allowed_values:
16+
qos: ["normal", "m", "m2", "m3", "m4", "m5", "long", "deadline", "high", "scavenger", "llm", "a100"]
17+
partition: ["a40", "a100", "t4v1", "t4v2", "rtx6000"]
18+
19+
default_args:
20+
cpus_per_task: 16
21+
mem_per_node: "64G"
22+
qos: "m2"
23+
time: "08:00:00"
24+
partition: "a40"
25+
data_type: "auto"
26+
log_dir: "~/.vec-inf-logs"
27+
model_weights_parent_dir: "/model-weights"

0 commit comments

Comments
 (0)