Skip to content

Commit d1fac58

Browse files
committed
Black formatting
1 parent 487aef8 commit d1fac58

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

vec_inf/cli/_cli.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import os
32
import time
43
from typing import Optional
@@ -39,7 +38,7 @@ def cli():
3938
"--partition",
4039
type=str,
4140
default="a40",
42-
help="Type of compute partition, default to a40"
41+
help="Type of compute partition, default to a40",
4342
)
4443
@click.option(
4544
"--num-nodes",
@@ -68,12 +67,14 @@ def cli():
6867
type=int,
6968
help="Vocabulary size, this option is intended for custom models",
7069
)
71-
@click.option("--data-type", type=str, default="auto", help="Model data type, default to auto")
70+
@click.option(
71+
"--data-type", type=str, default="auto", help="Model data type, default to auto"
72+
)
7273
@click.option(
7374
"--venv",
7475
type=str,
7576
default="singularity",
76-
help="Path to virtual environment, default to preconfigured singularity container"
77+
help="Path to virtual environment, default to preconfigured singularity container",
7778
)
7879
@click.option(
7980
"--log-dir",
@@ -293,11 +294,15 @@ def list_all(models_df: pd.DataFrame, json_mode: bool):
293294
"Text Embedding": "purple",
294295
}
295296
custom_order = ["LLM", "VLM", "Text Embedding"]
296-
models_df["model_type"] = pd.Categorical(models_df["model_type"], categories=custom_order, ordered=True)
297+
models_df["model_type"] = pd.Categorical(
298+
models_df["model_type"], categories=custom_order, ordered=True
299+
)
297300
models_df = models_df.sort_values(by="model_type")
298301
for _, row in models_df.iterrows():
299302
panel_color = model_type_colors.get(row["model_type"], "white")
300-
styled_text = f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
303+
styled_text = (
304+
f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
305+
)
301306
panels.append(Panel(styled_text, expand=True, border_style=panel_color))
302307
CONSOLE.print(Columns(panels, equal=True))
303308

@@ -324,17 +329,18 @@ def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
324329
output = utils.run_bash_command(status_cmd)
325330
slurm_job_name = output.split(" ")[1].split("=")[1]
326331
out_logs = utils.read_slurm_log(slurm_job_name, slurm_job_id, "out", log_dir)
327-
332+
328333
with Live(refresh_per_second=1, console=CONSOLE) as live:
329334
while True:
330335
metrics = utils.get_latest_metric(out_logs)
331336
table = utils.create_table(key_title="Metric", value_title="Value")
332337
for key, value in metrics.items():
333338
table.add_row(key, value)
334-
339+
335340
live.update(table)
336341

337342
time.sleep(10)
338343

344+
339345
if __name__ == "__main__":
340346
cli()

0 commit comments

Comments
 (0)