|
15 | 15 | import yaml |
16 | 16 |
|
17 | 17 | from vec_inf.client._client_vars import MODEL_READY_SIGNATURE |
| 18 | +from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR |
18 | 19 | from vec_inf.client.config import ModelConfig |
19 | 20 | from vec_inf.client.models import ModelStatus |
20 | | -from vec_inf.client.slurm_vars import CACHED_CONFIG |
21 | 21 |
|
22 | 22 |
|
23 | 23 | def run_bash_command(command: str) -> tuple[str, str]: |
@@ -217,44 +217,61 @@ def load_yaml_config(path: Path) -> dict[str, Any]: |
217 | 217 | except yaml.YAMLError as err: |
218 | 218 | raise ValueError(f"Error parsing YAML config at {path}: {err}") from err |
219 | 219 |
|
220 | | - # 1. If config_path is given, use only that |
221 | | - if config_path: |
222 | | - config = load_yaml_config(Path(config_path)) |
| 220 | + def process_config(config: dict[str, Any]) -> list[ModelConfig]: |
| 221 | + """Process the config based on the config type.""" |
223 | 222 | return [ |
224 | 223 | ModelConfig(model_name=name, **model_data) |
225 | 224 | for name, model_data in config.get("models", {}).items() |
226 | 225 | ] |
227 | 226 |
|
| 227 | + def resolve_config_path_from_env_var() -> Path | None: |
| 228 | + """Resolve the config path from the environment variable.""" |
| 229 | + config_dir = os.getenv("VEC_INF_CONFIG_DIR") |
| 230 | + config_path = os.getenv("VEC_INF_MODEL_CONFIG") |
| 231 | + if config_path: |
| 232 | + return Path(config_path) |
| 233 | + if config_dir: |
| 234 | + return Path(config_dir, "models.yaml") |
| 235 | + return None |
| 236 | + |
| 237 | + def update_config( |
| 238 | + config: dict[str, Any], user_config: dict[str, Any] |
| 239 | + ) -> dict[str, Any]: |
| 240 | + """Update the config with the user config.""" |
| 241 | + for name, data in user_config.get("models", {}).items(): |
| 242 | + if name in config.get("models", {}): |
| 243 | + config["models"][name].update(data) |
| 244 | + else: |
| 245 | + config.setdefault("models", {})[name] = data |
| 246 | + |
| 247 | + return config |
| 248 | + |
| 249 | + # 1. If config_path is given, use only that |
| 250 | + if config_path: |
| 251 | + config = load_yaml_config(Path(config_path)) |
| 252 | + return process_config(config) |
| 253 | + |
228 | 254 | # 2. Otherwise, load default config |
229 | 255 | default_path = ( |
230 | | - CACHED_CONFIG |
231 | | - if CACHED_CONFIG.exists() |
| 256 | + CACHED_CONFIG_DIR / "models_latest.yaml" |
| 257 | + if CACHED_CONFIG_DIR.exists() |
232 | 258 | else Path(__file__).resolve().parent.parent / "config" / "models.yaml" |
233 | 259 | ) |
234 | 260 | config = load_yaml_config(default_path) |
235 | 261 |
|
236 | 262 | # 3. If user config exists, merge it |
237 | | - user_path = os.getenv("VEC_INF_CONFIG") |
238 | | - if user_path: |
239 | | - user_path_obj = Path(user_path) |
240 | | - if user_path_obj.exists(): |
241 | | - user_config = load_yaml_config(user_path_obj) |
242 | | - for name, data in user_config.get("models", {}).items(): |
243 | | - if name in config.get("models", {}): |
244 | | - config["models"][name].update(data) |
245 | | - else: |
246 | | - config.setdefault("models", {})[name] = data |
247 | | - else: |
248 | | - warnings.warn( |
249 | | - f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}", |
250 | | - UserWarning, |
251 | | - stacklevel=2, |
252 | | - ) |
253 | | - |
254 | | - return [ |
255 | | - ModelConfig(model_name=name, **model_data) |
256 | | - for name, model_data in config.get("models", {}).items() |
257 | | - ] |
| 263 | + user_path = resolve_config_path_from_env_var() |
| 264 | + if user_path and user_path.exists(): |
| 265 | + user_config = load_yaml_config(user_path) |
| 266 | + config = update_config(config, user_config) |
| 267 | + elif user_path: |
| 268 | + warnings.warn( |
| 269 | + f"WARNING: Could not find user config: {str(user_path)}, revert to default config located at {default_path}", |
| 270 | + UserWarning, |
| 271 | + stacklevel=2, |
| 272 | + ) |
| 273 | + |
| 274 | + return process_config(config) |
258 | 275 |
|
259 | 276 |
|
260 | 277 | def parse_launch_output(output: str) -> tuple[str, dict[str, str]]: |
|
0 commit comments