Skip to content

Commit 170c61a

Browse files
authored
Merge branch 'main' into slurm_dependency
2 parents 85eb85a + 019ca54 commit 170c61a

File tree

9 files changed

+44
-52
lines changed

9 files changed

+44
-52
lines changed

Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:12.3.1-devel-ubuntu20.04
1+
FROM nvidia/cuda:12.4.1-devel-ubuntu20.04
22

33
# Non-interactive apt-get commands
44
ARG DEBIAN_FRONTEND=noninteractive
@@ -41,8 +41,10 @@ COPY . /vec-inf
4141

4242
# Install project dependencies with build requirements
4343
RUN PIP_INDEX_URL="https://download.pytorch.org/whl/cu121" uv pip install --system -e .[dev]
44-
# Install Flash Attention
44+
# Install FlashAttention
4545
RUN python3.10 -m pip install flash-attn --no-build-isolation
46+
# Install FlashInfer
47+
RUN python3.10 -m pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/
4648

4749
# Final configuration
4850
RUN mkdir -p /vec-inf/nccl && \

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
----------------------------------------------------
44

55
[![PyPI](https://img.shields.io/pypi/v/vec-inf)](https://pypi.org/project/vec-inf)
6-
[![downloads](https://img.shields.io/pypi/dm/vec-inf)]
6+
[![downloads](https://img.shields.io/pypi/dm/vec-inf)](https://pypistats.org/packages/vec-inf)
77
[![code checks](https://github.com/VectorInstitute/vector-inference/actions/workflows/code_checks.yml/badge.svg)](https://github.com/VectorInstitute/vector-inference/actions/workflows/code_checks.yml)
88
[![docs](https://github.com/VectorInstitute/vector-inference/actions/workflows/docs.yml/badge.svg)](https://github.com/VectorInstitute/vector-inference/actions/workflows/docs.yml)
99
[![codecov](https://codecov.io/github/VectorInstitute/vector-inference/branch/main/graph/badge.svg?token=NI88QSIGAC)](https://app.codecov.io/github/VectorInstitute/vector-inference/tree/main)

vec_inf/cli/_cli.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MetricsResponseFormatter,
3333
StatusResponseFormatter,
3434
)
35-
from vec_inf.client import LaunchOptions, LaunchOptionsDict, VecInfClient
35+
from vec_inf.client import LaunchOptions, VecInfClient
3636

3737

3838
CONSOLE = Console()
@@ -63,6 +63,11 @@ def cli() -> None:
6363
type=int,
6464
help="Number of GPUs/node to use, default to suggested resource allocation for model",
6565
)
66+
@click.option(
67+
"--account",
68+
type=str,
69+
help="Charge resources used by this job to specified account.",
70+
)
6671
@click.option(
6772
"--qos",
6873
type=str,
@@ -142,17 +147,18 @@ def launch(
142147
"""
143148
try:
144149
# Convert cli_kwargs to LaunchOptions
145-
kwargs = {k: v for k, v in cli_kwargs.items() if k != "json_mode"}
146-
# Cast the dictionary to LaunchOptionsDict
147-
options_dict: LaunchOptionsDict = kwargs # type: ignore
148-
launch_options = LaunchOptions(**options_dict)
150+
json_mode = cli_kwargs["json_mode"]
151+
del cli_kwargs["json_mode"]
152+
153+
launch_options = LaunchOptions(**cli_kwargs) # type: ignore
149154

150155
# Start the client and launch model inference server
151156
client = VecInfClient()
152157
launch_response = client.launch_model(model_name, launch_options)
153158

154159
# Display launch information
155160
launch_formatter = LaunchResponseFormatter(model_name, launch_response.config)
161+
156162
if cli_kwargs.get("json_mode"):
157163
click.echo(json.dumps(launch_response.config))
158164
else:

vec_inf/client/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vec_inf.client.config import ModelConfig
1010
from vec_inf.client.models import (
1111
LaunchOptions,
12-
LaunchOptionsDict,
1312
LaunchResponse,
1413
MetricsResponse,
1514
ModelInfo,
@@ -28,6 +27,5 @@
2827
"ModelStatus",
2928
"ModelType",
3029
"LaunchOptions",
31-
"LaunchOptionsDict",
3230
"ModelConfig",
3331
]

vec_inf/client/_client_vars.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
SLURM_JOB_CONFIG_ARGS = {
5757
"job-name": "model_name",
5858
"partition": "partition",
59+
"account": "account",
5960
"qos": "qos",
6061
"time": "time",
6162
"nodes": "num_nodes",
@@ -66,6 +67,13 @@
6667
"error": "err_file",
6768
}
6869

70+
# vLLM engine args mapping between short and long names
71+
VLLM_SHORT_TO_LONG_MAP = {
72+
"-tp": "--tensor-parallel-size",
73+
"-pp": "--pipeline-parallel-size",
74+
"-O": "--compilation-config",
75+
}
76+
6977

7078
# Slurm script templates
7179
class ShebangConfig(TypedDict):

vec_inf/client/_helper.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
KEY_METRICS,
2020
REQUIRED_FIELDS,
2121
SRC_DIR,
22+
VLLM_SHORT_TO_LONG_MAP,
2223
)
2324
from vec_inf.client._exceptions import (
2425
MissingRequiredFieldsError,
@@ -156,9 +157,14 @@ def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
156157
for arg in arg_string.split(","):
157158
if "=" in arg:
158159
key, value = arg.split("=")
159-
vllm_args[key] = value
160+
if key.strip() in VLLM_SHORT_TO_LONG_MAP:
161+
key = VLLM_SHORT_TO_LONG_MAP[key.strip()]
162+
vllm_args[key.strip()] = value.strip()
163+
elif "-O" in arg.strip():
164+
key = VLLM_SHORT_TO_LONG_MAP["-O"]
165+
vllm_args[key] = arg.strip()[2:].strip()
160166
else:
161-
vllm_args[arg] = True
167+
vllm_args[arg.strip()] = True
162168
return vllm_args
163169

164170
def _get_launch_params(self) -> dict[str, Any]:
@@ -175,7 +181,7 @@ def _get_launch_params(self) -> dict[str, Any]:
175181
If required fields are missing or tensor parallel size is not specified
176182
when using multiple GPUs
177183
"""
178-
params = self.model_config.model_dump()
184+
params = self.model_config.model_dump(exclude_none=True)
179185

180186
# Override config defaults with CLI arguments
181187
if self.kwargs.get("vllm_args"):

vec_inf/client/_slurm_script_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def _generate_shebang(self) -> str:
6868
"""
6969
shebang = [SLURM_SCRIPT_TEMPLATE["shebang"]["base"]]
7070
for arg, value in SLURM_JOB_CONFIG_ARGS.items():
71-
shebang.append(f"#SBATCH --{arg}={self.params[value]}")
71+
if self.params.get(value):
72+
shebang.append(f"#SBATCH --{arg}={self.params[value]}")
7273
if self.is_multinode:
7374
shebang += SLURM_SCRIPT_TEMPLATE["shebang"]["multinode"]
7475
return "\n".join(shebang)

vec_inf/client/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class ModelConfig(BaseModel):
4747
Memory allocation per node in GB format (e.g., '32G')
4848
vocab_size : int
4949
Size of the model's vocabulary (1-1,000,000)
50+
account : Optional[str], optional
51+
Charge resources used by this job to specified account.
5052
qos : Union[QOS, str], optional
5153
Quality of Service tier for job scheduling
5254
time : str, optional
@@ -92,6 +94,9 @@ class ModelConfig(BaseModel):
9294
description="Memory per node",
9395
)
9496
vocab_size: int = Field(..., gt=0, le=1_000_000)
97+
account: Optional[str] = Field(
98+
default=None, description="Account name for job scheduling"
99+
)
95100
qos: Union[QOS, str] = Field(
96101
default=cast(str, DEFAULT_ARGS["qos"]), description="Quality of Service tier"
97102
)

vec_inf/client/models.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from dataclasses import dataclass, field
2727
from enum import Enum
28-
from typing import Any, Optional, TypedDict, Union
28+
from typing import Any, Optional, Union
2929

3030

3131
class ModelStatus(str, Enum):
@@ -164,6 +164,8 @@ class LaunchOptions:
164164
Number of nodes to allocate
165165
gpus_per_node : int, optional
166166
Number of GPUs per node
167+
account : str, optional
168+
Account name for job scheduling
167169
qos : str, optional
168170
Quality of Service level
169171
time : str, optional
@@ -187,6 +189,7 @@ class LaunchOptions:
187189
partition: Optional[str] = None
188190
num_nodes: Optional[int] = None
189191
gpus_per_node: Optional[int] = None
192+
account: Optional[str] = None
190193
qos: Optional[str] = None
191194
time: Optional[str] = None
192195
vocab_size: Optional[int] = None
@@ -197,43 +200,6 @@ class LaunchOptions:
197200
vllm_args: Optional[str] = None
198201

199202

200-
class LaunchOptionsDict(TypedDict):
201-
"""TypedDict for LaunchOptions.
202-
203-
A TypedDict representation of LaunchOptions for type checking and
204-
serialization purposes. All fields are optional and may be None.
205-
206-
Attributes
207-
----------
208-
model_family : str, optional
209-
Family/architecture of the model
210-
model_variant : str, optional
211-
Specific variant/version of the model
212-
partition : str, optional
213-
SLURM partition to use
214-
num_nodes : int, optional
215-
Number of nodes to allocate
216-
gpus_per_node : int, optional
217-
Number of GPUs per node
218-
qos : str, optional
219-
Quality of Service level
220-
time : str, optional
221-
Time limit for the job
222-
vocab_size : int, optional
223-
Size of model vocabulary
224-
data_type : str, optional
225-
Data type for model weights
226-
venv : str, optional
227-
Virtual environment to use
228-
log_dir : str, optional
229-
Directory for logs
230-
model_weights_parent_dir : str, optional
231-
Parent directory containing model weights
232-
vllm_args : str, optional
233-
Additional arguments for vLLM
234-
"""
235-
236-
237203
@dataclass
238204
class ModelInfo:
239205
"""Information about an available model.

0 commit comments

Comments
 (0)