Skip to content

Commit 1458b60

Browse files
committed
Add compilation config to model config, remove max num batched tokens, change bool options to flags
1 parent 8f45869 commit 1458b60

File tree

3 files changed

+44
-46
lines changed

3 files changed

+44
-46
lines changed

vec_inf/cli/_cli.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def cli() -> None:
4141
)
4242
@click.option(
4343
"--enable-prefix-caching",
44-
type=click.Choice(["True", "False"]),
45-
help="Enables automatic prefix caching, accepts 'True' or 'False', default to 'False'",
44+
is_flag=True,
45+
help="Enables automatic prefix caching",
4646
)
4747
@click.option(
4848
"--enable-chunked-prefill",
49-
type=click.Choice(["True", "False"]),
50-
help="Enable chunked prefill, accepts 'True' or 'False', default to 'True' if max-num-seqs > 32k, else 'False'",
49+
is_flag=True,
50+
help="Enable chunked prefill, enabled by default if max number of sequences > 32k",
5151
)
5252
@click.option(
5353
"--max-num-batched-tokens",
@@ -102,18 +102,18 @@ def cli() -> None:
102102
)
103103
@click.option(
104104
"--pipeline-parallelism",
105-
type=str,
106-
help="Enable pipeline parallelism, accepts 'True' or 'False', default to 'True' for supported models",
105+
is_flag=True,
106+
help="Enable pipeline parallelism, enabled by default for supported models",
107107
)
108108
@click.option(
109109
"--compilation-config",
110-
type=click.Choice(["0", "3"]),
111-
help="torch.compile optimization level, accepts '0' or '3', default to '0', which means no optimization is applied",
110+
type=click.Choice(["0", "1", "2", "3"]),
111+
help="torch.compile optimization level, accepts '0', '1', '2', or '3', default to '0', which means no optimization is applied",
112112
)
113113
@click.option(
114114
"--enforce-eager",
115-
type=str,
116-
help="Always use eager-mode PyTorch, accepts 'True' or 'False', default to 'False' for custom models if not set",
115+
is_flag=True,
116+
help="Always use eager-mode PyTorch",
117117
)
118118
@click.option(
119119
"--json-mode",

vec_inf/cli/_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ class ModelConfig(BaseModel):
4747
max_num_seqs: int = Field(
4848
default=256, gt=0, le=1024, description="Maximum concurrent request sequences"
4949
)
50-
max_num_batched_tokens: Optional[int] = Field(
51-
default=None,
52-
gt=0,
53-
le=1_000_000,
54-
description="Maximum batched tokens per iteration",
50+
compilation_config: int = Field(
51+
default=0,
52+
gt=-1,
53+
le=4,
54+
description="torch.compile optimization level",
5555
)
5656
gpu_memory_utilization: float = Field(
5757
default=0.9, gt=0.0, le=1.0, description="GPU memory utilization"

vec_inf/cli/_helper.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
"max_model_len",
3535
}
3636

37+
BOOLEAN_FIELDS = {
38+
"pipeline_parallelism",
39+
"enforce_eager",
40+
"enable_prefix_caching",
41+
"enable_chunked_prefill",
42+
}
43+
3744
LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/"
3845
SRC_DIR = str(Path(__file__).parent.parent)
3946

@@ -90,36 +97,18 @@ def _get_launch_params(self) -> dict[str, Any]:
9097
params = self.model_config.model_dump()
9198

9299
# Process boolean fields
93-
for bool_field in [
94-
"pipeline_parallelism",
95-
"enforce_eager",
96-
"enable_prefix_caching",
97-
"enable_chunked_prefill",
98-
]:
99-
if (value := self.cli_kwargs.get(bool_field)) is not None:
100-
params[bool_field] = utils.convert_boolean_value(value)
100+
for bool_field in BOOLEAN_FIELDS:
101+
if self.cli_kwargs[bool_field]:
102+
params[bool_field] = True
101103

102104
# Merge other overrides
103105
for key, value in self.cli_kwargs.items():
104106
if value is not None and key not in [
105107
"json_mode",
106-
"pipeline_parallelism",
107-
"enforce_eager",
108-
"enable_prefix_caching",
109-
"enable_chunked_prefill",
108+
*BOOLEAN_FIELDS,
110109
]:
111110
params[key] = value
112111

113-
if "compilation_config" not in params:
114-
params["compilation_config"] = "0"
115-
if "enable_prefix_caching" not in params:
116-
params["enable_prefix_caching"] = False
117-
if "enable_chunked_prefill" not in params:
118-
params["enable_chunked_prefill"] = False
119-
120-
if params["max_model_len"] > 32_000: # this is the default behavior of vLLM
121-
params["enable_chunked_prefill"] = True
122-
123112
# Validate required fields
124113
if not REQUIRED_FIELDS.issubset(set(params.keys())):
125114
raise click.ClickException(
@@ -146,11 +135,7 @@ def set_env_vars(self) -> None:
146135
os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
147136
os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
148137
os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
149-
os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
150-
os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
151-
os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
152138
os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
153-
os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
154139
os.environ["SRC_DIR"] = SRC_DIR
155140
os.environ["MODEL_WEIGHTS"] = str(
156141
Path(self.params["model_weights_parent_dir"], self.model_name)
@@ -159,6 +144,15 @@ def set_env_vars(self) -> None:
159144
os.environ["VENV_BASE"] = self.params["venv"]
160145
os.environ["LOG_DIR"] = self.params["log_dir"]
161146

147+
if self.params.get("enable_prefix_caching"):
148+
os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
149+
if self.params.get("enable_chunked_prefill"):
150+
os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
151+
if self.params.get("max_num_batched_tokens"):
152+
os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
153+
if self.params.get("enforce_eager"):
154+
os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
155+
162156
def build_launch_command(self) -> str:
163157
"""Construct the full launch command with parameters."""
164158
# Base command
@@ -206,12 +200,16 @@ def format_table_output(self, job_id: str) -> Table:
206200
table.add_row("Max Model Length", self.params["max_model_len"])
207201
table.add_row("Max Num Seqs", self.params["max_num_seqs"])
208202
table.add_row("GPU Memory Utilization", self.params["gpu_memory_utilization"])
209-
table.add_row("Pipeline Parallelism", self.params["pipeline_parallelism"])
210-
table.add_row("Enable Prefix Caching", self.params["enable_prefix_caching"])
211-
table.add_row("Enable Chunked Prefill", self.params["enable_chunked_prefill"])
212-
table.add_row("Max Num Batched Tokens", self.params["max_num_batched_tokens"])
213203
table.add_row("Compilation Config", self.params["compilation_config"])
214-
table.add_row("Enforce Eager", self.params["enforce_eager"])
204+
table.add_row("Pipeline Parallelism", self.params["pipeline_parallelism"])
205+
if self.params.get("enable_prefix_caching"):
206+
table.add_row("Enable Prefix Caching", self.params["enable_prefix_caching"])
207+
if self.params.get("enable_chunked_prefill"):
208+
table.add_row("Enable Chunked Prefill", self.params["enable_chunked_prefill"])
209+
if self.params.get("max_num_batched_tokens"):
210+
table.add_row("Max Num Batched Tokens", self.params["max_num_batched_tokens"])
211+
if self.params.get("enforce_eager"):
212+
table.add_row("Enforce Eager", self.params["enforce_eager"])
215213
table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
216214
table.add_row("Log Directory", self.params["log_dir"])
217215

0 commit comments

Comments
 (0)