Skip to content

Commit 72087da

Browse files
committed
Refactor for mypy
1 parent dec8065 commit 72087da

File tree

2 files changed

+99
-56
lines changed

2 files changed

+99
-56
lines changed

vec_inf/cli/_helper.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ def __init__(self, model_name: str, params: dict[str, Any]):
3636
self.model_name = model_name
3737
self.params = params
3838

39+
def _add_resource_allocation_details(self, table: Table) -> None:
40+
"""Add resource allocation details to the table."""
41+
optional_fields = [
42+
("account", "Account"),
43+
("work_dir", "Working Directory"),
44+
("resource_type", "Resource Type"),
45+
("partition", "Partition"),
46+
("qos", "QoS"),
47+
]
48+
for key, label in optional_fields:
49+
if self.params.get(key):
50+
table.add_row(label, self.params[key])
51+
52+
def _add_vllm_config(self, table: Table) -> None:
53+
"""Add vLLM configuration details to the table."""
54+
if self.params.get("vllm_args"):
55+
table.add_row("vLLM Arguments:", style="magenta")
56+
for arg, value in self.params["vllm_args"].items():
57+
table.add_row(f" {arg}:", str(value))
58+
59+
def _add_env_vars(self, table: Table) -> None:
60+
"""Add environment variable configuration details to the table."""
61+
if self.params.get("env"):
62+
table.add_row("Environment Variables", style="magenta")
63+
for arg, value in self.params["env"].items():
64+
table.add_row(f" {arg}:", str(value))
65+
66+
def _add_bind_paths(self, table: Table) -> None:
67+
"""Add bind path configuration details to the table."""
68+
if self.params.get("bind"):
69+
table.add_row("Bind Paths", style="magenta")
70+
for path in self.params["bind"].split(","):
71+
host = target = path
72+
if ":" in path:
73+
host, target = path.split(":")
74+
table.add_row(f" {host}:", target)
75+
3976
def format_table_output(self) -> Table:
4077
"""Format output as rich Table.
4178
@@ -59,16 +96,7 @@ def format_table_output(self) -> Table:
5996
table.add_row("Vocabulary Size", self.params["vocab_size"])
6097

6198
# Add resource allocation details
62-
if self.params.get("account"):
63-
table.add_row("Account", self.params["account"])
64-
if self.params.get("work_dir"):
65-
table.add_row("Working Directory", self.params["work_dir"])
66-
if self.params.get("resource_type"):
67-
table.add_row("Resource Type", self.params["resource_type"])
68-
if self.params.get("partition"):
69-
table.add_row("Partition", self.params["partition"])
70-
if self.params.get("qos"):
71-
table.add_row("QoS", self.params["qos"])
99+
self._add_resource_allocation_details(table)
72100
table.add_row("Time Limit", self.params["time"])
73101
table.add_row("Num Nodes", self.params["num_nodes"])
74102
table.add_row("GPUs/Node", self.params["gpus_per_node"])
@@ -84,26 +112,10 @@ def format_table_output(self) -> Table:
84112
)
85113
table.add_row("Log Directory", self.params["log_dir"])
86114

87-
# Add vLLM configuration details
88-
if self.params.get("vllm_args"):
89-
table.add_row("vLLM Arguments:", style="magenta")
90-
for arg, value in self.params["vllm_args"].items():
91-
table.add_row(f" {arg}:", str(value))
92-
93-
# Add environment variable configuration details
94-
if self.params.get("env"):
95-
table.add_row("Environment Variables", style="magenta")
96-
for arg, value in self.params["env"].items():
97-
table.add_row(f" {arg}:", str(value))
98-
99-
# Add bind path configuration details
100-
if self.params.get("bind"):
101-
table.add_row("Bind Paths", style="magenta")
102-
for path in self.params["bind"].split(","):
103-
host = target = path
104-
if ":" in path:
105-
host, target = path.split(":")
106-
table.add_row(f" {host}:", target)
115+
# Add configuration details
116+
self._add_vllm_config(table)
117+
self._add_env_vars(table)
118+
self._add_bind_paths(table)
107119

108120
return table
109121

vec_inf/client/_helper.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -196,23 +196,14 @@ def _process_env_vars(self, env_arg: str) -> dict[str, str]:
196196
print(f"WARNING: Could not parse env var: {line}")
197197
return env_vars
198198

199-
def _get_launch_params(self) -> dict[str, Any]:
200-
"""Prepare launch parameters, set log dir, and validate required fields.
201-
202-
Returns
203-
-------
204-
dict[str, Any]
205-
Dictionary of prepared launch parameters
199+
def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
200+
"""Apply CLI argument overrides to params.
206201
207-
Raises
208-
------
209-
MissingRequiredFieldsError
210-
If required fields are missing or tensor parallel size is not specified
211-
when using multiple GPUs
202+
Parameters
203+
----------
204+
params : dict[str, Any]
205+
Dictionary of launch parameters to override
212206
"""
213-
params = self.model_config.model_dump(exclude_none=True)
214-
215-
# Override config defaults with CLI arguments
216207
if self.kwargs.get("vllm_args"):
217208
vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
218209
for key, value in vllm_args.items():
@@ -232,10 +223,22 @@ def _get_launch_params(self) -> dict[str, Any]:
232223
for key, value in self.kwargs.items():
233224
params[key] = value
234225

235-
# Check for required fields without default vals, will raise an error if missing
236-
utils.check_required_fields(params)
226+
def _validate_resource_allocation(self, params: dict[str, Any]) -> None:
227+
"""Validate resource allocation and parallelization settings.
237228
238-
# Validate resource allocation and parallelization settings
229+
Parameters
230+
----------
231+
params : dict[str, Any]
232+
Dictionary of launch parameters to validate
233+
234+
Raises
235+
------
236+
MissingRequiredFieldsError
237+
If tensor parallel size is not specified when using multiple GPUs
238+
ValueError
239+
If total # of GPUs requested is not a power of two
240+
If mismatch between total # of GPUs requested and parallelization settings
241+
"""
239242
if (
240243
int(params["gpus_per_node"]) > 1
241244
and params["vllm_args"].get("--tensor-parallel-size") is None
@@ -256,19 +259,18 @@ def _get_launch_params(self) -> dict[str, Any]:
256259
"Mismatch between total number of GPUs requested and parallelization settings"
257260
)
258261

259-
# Convert gpus_per_node and resource_type to gres
260-
resource_type = params.get("resource_type")
261-
if resource_type:
262-
params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
263-
else:
264-
params["gres"] = f"gpu:{params['gpus_per_node']}"
262+
def _setup_log_files(self, params: dict[str, Any]) -> None:
263+
"""Set up log directory and file paths.
265264
266-
# Create log directory
265+
Parameters
266+
----------
267+
params : dict[str, Any]
268+
Dictionary of launch parameters to set up log files
269+
"""
267270
params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
268271
params["log_dir"].mkdir(parents=True, exist_ok=True)
269272
params["src_dir"] = SRC_DIR
270273

271-
# Construct slurm log file paths
272274
params["out_file"] = (
273275
f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
274276
)
@@ -279,6 +281,35 @@ def _get_launch_params(self) -> dict[str, Any]:
279281
f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
280282
)
281283

284+
def _get_launch_params(self) -> dict[str, Any]:
285+
"""Prepare launch parameters, set log dir, and validate required fields.
286+
287+
Returns
288+
-------
289+
dict[str, Any]
290+
Dictionary of prepared launch parameters
291+
"""
292+
params = self.model_config.model_dump(exclude_none=True)
293+
294+
# Override config defaults with CLI arguments
295+
self._apply_cli_overrides(params)
296+
297+
# Check for required fields without default vals, will raise an error if missing
298+
utils.check_required_fields(params)
299+
300+
# Validate resource allocation and parallelization settings
301+
self._validate_resource_allocation(params)
302+
303+
# Convert gpus_per_node and resource_type to gres
304+
resource_type = params.get("resource_type")
305+
if resource_type:
306+
params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
307+
else:
308+
params["gres"] = f"gpu:{params['gpus_per_node']}"
309+
310+
# Setup log files
311+
self._setup_log_files(params)
312+
282313
# Convert path to string for JSON serialization
283314
for field in params:
284315
if field in ["vllm_args", "env"]:

0 commit comments

Comments
 (0)