Skip to content

Commit e7b6871

Browse files
committed
Added enforce eager option, added support for reward modeling models, refactors based on mypy
1 parent e0b194c commit e7b6871

File tree

2 files changed

+75
-37
lines changed

2 files changed

+75
-37
lines changed

vec_inf/cli/_cli.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import time
3-
from typing import Optional
3+
from typing import Optional, cast
44

55
import click
6-
import pandas as pd
6+
7+
import polars as pl
78
from rich.columns import Columns
89
from rich.console import Console
910
from rich.live import Live
@@ -91,6 +92,11 @@ def cli():
9192
type=str,
9293
help="Enable pipeline parallelism, accepts 'True' or 'False', default to 'True' for supported models",
9394
)
95+
@click.option(
96+
"--enforce-eager",
97+
type=str,
98+
help="Always use eager-mode PyTorch, accepts 'True' or 'False', default to 'False' for custom models if not set",
99+
)
94100
@click.option(
95101
"--json-mode",
96102
is_flag=True,
@@ -113,14 +119,17 @@ def launch(
113119
log_dir: Optional[str] = None,
114120
model_weights_parent_dir: Optional[str] = None,
115121
pipeline_parallelism: Optional[str] = None,
122+
enforce_eager: Optional[str] = None,
116123
json_mode: bool = False,
117124
) -> None:
118125
"""
119126
Launch a model on the cluster
120127
"""
121128

122129
if isinstance(pipeline_parallelism, str):
123-
pipeline_parallelism = pipeline_parallelism.lower() == "true"
130+
pipeline_parallelism = (
131+
"True" if pipeline_parallelism.lower() == "true" else "False"
132+
)
124133

125134
launch_script_path = os.path.join(
126135
os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "launch_server.sh"
@@ -129,15 +138,15 @@ def launch(
129138

130139
models_df = utils.load_models_df()
131140

132-
if model_name in models_df["model_name"].values:
141+
if model_name in models_df["model_name"].to_list():
133142
default_args = utils.load_default_args(models_df, model_name)
134143
for arg in default_args:
135144
if arg in locals() and locals()[arg] is not None:
136145
default_args[arg] = locals()[arg]
137146
renamed_arg = arg.replace("_", "-")
138147
launch_cmd += f" --{renamed_arg} {default_args[arg]}"
139148
else:
140-
model_args = models_df.columns.tolist()
149+
model_args = models_df.columns
141150
model_args.remove("model_name")
142151
model_args.remove("model_type")
143152
for arg in model_args:
@@ -265,45 +274,58 @@ def shutdown(slurm_job_id: int) -> None:
265274
is_flag=True,
266275
help="Output in JSON string",
267276
)
268-
def list(model_name: Optional[str] = None, json_mode: bool = False) -> None:
277+
def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
269278
"""
270279
List all available models, or get default setup of a specific model
271280
"""
272281

273-
def list_model(model_name: str, models_df: pd.DataFrame, json_mode: bool):
274-
if model_name not in models_df["model_name"].values:
282+
def list_model(model_name: str, models_df: pl.DataFrame, json_mode: bool):
283+
if model_name not in models_df["model_name"].to_list():
275284
raise ValueError(f"Model name {model_name} not found in available models")
276285

277286
excluded_keys = {"venv", "log_dir"}
278-
model_row = models_df.loc[models_df["model_name"] == model_name]
287+
model_row = models_df.filter(models_df["model_name"] == model_name)
279288

280289
if json_mode:
281-
filtered_model_row = model_row.drop(columns=excluded_keys, errors="ignore")
282-
click.echo(filtered_model_row.to_json(orient="records"))
290+
filtered_model_row = model_row.drop(excluded_keys, strict=False)
291+
click.echo(filtered_model_row.to_dicts()[0])
283292
return
284293
table = utils.create_table(key_title="Model Config", value_title="Value")
285-
for _, row in model_row.iterrows():
294+
for row in model_row.to_dicts():
286295
for key, value in row.items():
287296
if key not in excluded_keys:
288297
table.add_row(key, str(value))
289298
CONSOLE.print(table)
290299

291-
def list_all(models_df: pd.DataFrame, json_mode: bool):
300+
def list_all(models_df: pl.DataFrame, json_mode: bool):
292301
if json_mode:
293-
click.echo(models_df["model_name"].to_json(orient="records"))
302+
click.echo(models_df["model_name"].to_list())
294303
return
295304
panels = []
296305
model_type_colors = {
297306
"LLM": "cyan",
298307
"VLM": "bright_blue",
299308
"Text Embedding": "purple",
309+
"Reward Modeling": "bright_magenta",
300310
}
301-
custom_order = ["LLM", "VLM", "Text Embedding"]
302-
models_df["model_type"] = pd.Categorical(
303-
models_df["model_type"], categories=custom_order, ordered=True
311+
312+
models_df = models_df.with_columns(
313+
pl.when(pl.col("model_type") == "LLM")
314+
.then(0)
315+
.when(pl.col("model_type") == "VLM")
316+
.then(1)
317+
.when(pl.col("model_type") == "Text Embedding")
318+
.then(2)
319+
.when(pl.col("model_type") == "Reward Modeling")
320+
.then(3)
321+
.otherwise(-1)
322+
.alias("model_type_order")
304323
)
305-
models_df = models_df.sort_values(by="model_type")
306-
for _, row in models_df.iterrows():
324+
325+
models_df = models_df.sort("model_type_order")
326+
models_df = models_df.drop("model_type_order")
327+
328+
for row in models_df.to_dicts():
307329
panel_color = model_type_colors.get(row["model_type"], "white")
308330
styled_text = (
309331
f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
@@ -336,10 +358,22 @@ def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
336358

337359
with Live(refresh_per_second=1, console=CONSOLE) as live:
338360
while True:
339-
out_logs = utils.read_slurm_log(slurm_job_name, slurm_job_id, "out", log_dir)
340-
metrics = utils.get_latest_metric(out_logs)
361+
out_logs = utils.read_slurm_log(
362+
slurm_job_name, slurm_job_id, "out", log_dir
363+
)
364+
# if out_logs is a string, then it is an error message
365+
if isinstance(out_logs, str):
366+
live.update(out_logs)
367+
break
368+
out_logs = cast(list, out_logs)
369+
latest_metrics = utils.get_latest_metric(out_logs)
370+
# if latest_metrics is a string, then it is an error message
371+
if isinstance(latest_metrics, str):
372+
live.update(latest_metrics)
373+
break
374+
latest_metrics = cast(dict, latest_metrics)
341375
table = utils.create_table(key_title="Metric", value_title="Value")
342-
for key, value in metrics.items():
376+
for key, value in latest_metrics.items():
343377
table.add_row(key, value)
344378

345379
live.update(table)

vec_inf/cli/_utils.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import subprocess
3-
from typing import Optional, Union
3+
from typing import Optional, Union, cast
44

5-
import pandas as pd
5+
import polars as pl
66
import requests
77
from rich.table import Table
88

@@ -35,9 +35,11 @@ def read_slurm_log(
3535
log_dir = os.path.join(models_dir, dir)
3636
break
3737

38+
log_dir = cast(str, log_dir)
39+
3840
try:
3941
file_path = os.path.join(
40-
log_dir,
42+
log_dir,
4143
f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}",
4244
)
4345
with open(file_path, "r") as file:
@@ -58,13 +60,15 @@ def is_server_running(
5860
if isinstance(log_content, str):
5961
return log_content
6062

61-
status = None
63+
status: Union[str, tuple[str, str]] = "LAUNCHING"
64+
6265
for line in log_content:
6366
if "error" in line.lower():
6467
status = ("FAILED", line.strip("\n"))
6568
if MODEL_READY_SIGNATURE in line:
6669
status = "RUNNING"
67-
return "LAUNCHING" if not status else status
70+
71+
return status
6872

6973

7074
def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]) -> str:
@@ -115,11 +119,11 @@ def create_table(
115119
return table
116120

117121

118-
def load_models_df() -> pd.DataFrame:
122+
def load_models_df() -> pl.DataFrame:
119123
"""
120124
Load the models dataframe
121125
"""
122-
models_df = pd.read_csv(
126+
models_df = pl.read_csv(
123127
os.path.join(
124128
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
125129
"models/models.csv",
@@ -128,14 +132,14 @@ def load_models_df() -> pd.DataFrame:
128132
return models_df
129133

130134

131-
def load_default_args(models_df: pd.DataFrame, model_name: str) -> dict:
135+
def load_default_args(models_df: pl.DataFrame, model_name: str) -> dict:
132136
"""
133137
Load the default arguments for a model
134138
"""
135-
row_data = models_df.loc[models_df["model_name"] == model_name]
136-
default_args = row_data.iloc[0].to_dict()
137-
default_args.pop("model_name")
138-
default_args.pop("model_type")
139+
row_data = models_df.filter(models_df["model_name"] == model_name)
140+
default_args = row_data.to_dicts()[0]
141+
default_args.pop("model_name", None)
142+
default_args.pop("model_type", None)
139143
return default_args
140144

141145

@@ -147,9 +151,9 @@ def get_latest_metric(log_lines: list[str]) -> dict | str:
147151
for line in reversed(log_lines):
148152
if "Avg prompt throughput" in line:
149153
# Parse the metric values from the line
150-
metrics = line.split("] ")[1].strip().strip(".")
151-
metrics = metrics.split(", ")
152-
for metric in metrics:
154+
metrics_str = line.split("] ")[1].strip().strip(".")
155+
metrics_list = metrics_str.split(", ")
156+
for metric in metrics_list:
153157
key, value = metric.split(": ")
154158
latest_metric[key] = value
155159
break

0 commit comments

Comments
 (0)