Skip to content

Commit e9bebab

Browse files
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 11da920 commit e9bebab

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

vec_inf/client/_slurm_script_generator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def __init__(self, params: dict[str, Any]):
3434
self.params = params
3535
self.is_multinode = int(self.params["num_nodes"]) > 1
3636
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
37-
self.additional_binds = f",{self.params['bind']}" if self.params.get("bind") else ""
37+
self.additional_binds = (
38+
f",{self.params['bind']}" if self.params.get("bind") else ""
39+
)
3840
self.model_weights_path = str(
3941
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
4042
)
@@ -105,7 +107,12 @@ def _generate_server_setup(self) -> str:
105107
server_script = ["\n"]
106108
if self.use_container:
107109
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
108-
server_script.append(SLURM_SCRIPT_TEMPLATE["bind_path"].format(model_weights_path=self.model_weights_path, additional_binds=self.additional_binds))
110+
server_script.append(
111+
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
112+
model_weights_path=self.model_weights_path,
113+
additional_binds=self.additional_binds,
114+
)
115+
)
109116
else:
110117
server_script.append(
111118
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
@@ -211,7 +218,9 @@ def __init__(self, params: dict[str, Any]):
211218
self.script_paths: list[Path] = []
212219
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
213220
for model_name in self.params["models"]:
214-
self.params["models"][model_name]["additional_binds"] = f",{self.params['bind']}" if self.params.get("bind") else ""
221+
self.params["models"][model_name]["additional_binds"] = (
222+
f",{self.params['bind']}" if self.params.get("bind") else ""
223+
)
215224
self.params["models"][model_name]["model_weights_path"] = str(
216225
Path(
217226
self.params["models"][model_name]["model_weights_parent_dir"],
@@ -251,10 +260,12 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
251260
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
252261
if self.use_container:
253262
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
254-
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
255-
model_weights_path=model_params["model_weights_path"],
256-
additional_binds=model_params["additional_binds"],
257-
))
263+
script_content.append(
264+
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
265+
model_weights_path=model_params["model_weights_path"],
266+
additional_binds=model_params["additional_binds"],
267+
)
268+
)
258269
script_content.append(
259270
"\n".join(
260271
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]

0 commit comments

Comments
 (0)