Skip to content

Commit 53352d0

Browse files
llama-server: add router multi-model tests (#17704)
Add 4 test cases for model router: - test_router_unload_model: explicit model unloading - test_router_models_max_evicts_lru: LRU eviction with --models-max - test_router_no_models_autoload: --no-models-autoload flag behavior - test_router_api_key_required: API key authentication Tests use async model loading with polling and graceful skip when insufficient models available for eviction testing. utils.py changes: - Add models_max, models_dir, no_models_autoload attributes to ServerProcess - Handle JSONDecodeError for non-JSON error responses (fallback to text)
1 parent b3e3060 commit 53352d0

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

tools/server/tests/unit/test_router.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,152 @@ def test_router_chat_completion_stream(model: str, success: bool):
4848
else:
4949
assert ex is not None
5050
assert content == ""
51+
52+
53+
def _get_model_status(model_id: str) -> str:
54+
res = server.make_request("GET", "/models")
55+
assert res.status_code == 200
56+
for item in res.body.get("data", []):
57+
if item.get("id") == model_id or item.get("model") == model_id:
58+
return item["status"]["value"]
59+
raise AssertionError(f"Model {model_id} not found in /models response")
60+
61+
62+
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
63+
deadline = time.time() + timeout
64+
last_status = None
65+
while time.time() < deadline:
66+
last_status = _get_model_status(model_id)
67+
if last_status in desired:
68+
return last_status
69+
time.sleep(1)
70+
raise AssertionError(
71+
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
72+
)
73+
74+
75+
def _load_model_and_wait(
76+
model_id: str, timeout: int = 60, headers: dict | None = None
77+
) -> None:
78+
load_res = server.make_request(
79+
"POST", "/models/load", data={"model": model_id}, headers=headers
80+
)
81+
assert load_res.status_code == 200
82+
assert isinstance(load_res.body, dict)
83+
assert load_res.body.get("success") is True
84+
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
85+
86+
87+
def test_router_unload_model():
88+
global server
89+
server.start()
90+
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
91+
92+
_load_model_and_wait(model_id)
93+
94+
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
95+
assert unload_res.status_code == 200
96+
assert unload_res.body.get("success") is True
97+
_wait_for_model_status(model_id, {"unloaded"})
98+
99+
100+
def test_router_models_max_evicts_lru():
101+
global server
102+
server.models_max = 2
103+
server.start()
104+
105+
candidate_models = [
106+
"ggml-org/tinygemma3-GGUF:Q8_0",
107+
"ggml-org/models/tinyllamas/stories260K.gguf",
108+
"ggml-org/models/bert-bge-small/ggml-model-f16.gguf",
109+
]
110+
111+
loaded_models: list[str] = []
112+
for model_id in candidate_models:
113+
try:
114+
_load_model_and_wait(model_id, timeout=120)
115+
loaded_models.append(model_id)
116+
except AssertionError:
117+
continue
118+
119+
if len(loaded_models) < 3:
120+
pytest.skip("Not enough models could be loaded to exercise eviction")
121+
122+
first, second, third = loaded_models[:3]
123+
124+
_wait_for_model_status(first, {"loaded"})
125+
_wait_for_model_status(second, {"loaded"})
126+
127+
_load_model_and_wait(third, timeout=120)
128+
129+
assert _get_model_status(third) == "loaded"
130+
assert _get_model_status(first) != "loaded"
131+
132+
133+
def test_router_no_models_autoload():
134+
global server
135+
server.no_models_autoload = True
136+
server.start()
137+
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
138+
139+
res = server.make_request(
140+
"POST",
141+
"/v1/chat/completions",
142+
data={
143+
"model": model_id,
144+
"messages": [{"role": "user", "content": "hello"}],
145+
"max_tokens": 4,
146+
},
147+
)
148+
assert res.status_code == 400
149+
assert "error" in res.body
150+
151+
_load_model_and_wait(model_id)
152+
153+
success_res = server.make_request(
154+
"POST",
155+
"/v1/chat/completions",
156+
data={
157+
"model": model_id,
158+
"messages": [{"role": "user", "content": "hello"}],
159+
"max_tokens": 4,
160+
},
161+
)
162+
assert success_res.status_code == 200
163+
assert "error" not in success_res.body
164+
165+
166+
def test_router_api_key_required():
167+
global server
168+
server.api_key = "sk-router-secret"
169+
server.start()
170+
171+
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
172+
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
173+
174+
res = server.make_request(
175+
"POST",
176+
"/v1/chat/completions",
177+
data={
178+
"model": model_id,
179+
"messages": [{"role": "user", "content": "hello"}],
180+
"max_tokens": 4,
181+
},
182+
)
183+
assert res.status_code == 401
184+
assert res.body.get("error", {}).get("type") == "authentication_error"
185+
186+
_load_model_and_wait(model_id, headers=auth_headers)
187+
188+
authed = server.make_request(
189+
"POST",
190+
"/v1/chat/completions",
191+
headers=auth_headers,
192+
data={
193+
"model": model_id,
194+
"messages": [{"role": "user", "content": "hello"}],
195+
"max_tokens": 4,
196+
},
197+
)
198+
assert authed.status_code == 200
199+
assert "error" not in authed.body

tools/server/tests/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import re
99
import json
10+
from json import JSONDecodeError
1011
import sys
1112
import requests
1213
import time
@@ -83,6 +84,9 @@ class ServerProcess:
8384
pooling: str | None = None
8485
draft: int | None = None
8586
api_key: str | None = None
87+
models_dir: str | None = None
88+
models_max: int | None = None
89+
no_models_autoload: bool | None = None
8690
lora_files: List[str] | None = None
8791
enable_ctx_shift: int | None = False
8892
draft_min: int | None = None
@@ -143,6 +147,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
143147
server_args.extend(["--hf-repo", self.model_hf_repo])
144148
if self.model_hf_file:
145149
server_args.extend(["--hf-file", self.model_hf_file])
150+
if self.models_dir:
151+
server_args.extend(["--models-dir", self.models_dir])
152+
if self.models_max is not None:
153+
server_args.extend(["--models-max", self.models_max])
146154
if self.n_batch:
147155
server_args.extend(["--batch-size", self.n_batch])
148156
if self.n_ubatch:
@@ -204,6 +212,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
204212
server_args.extend(["--draft-min", self.draft_min])
205213
if self.no_webui:
206214
server_args.append("--no-webui")
215+
if self.no_models_autoload:
216+
server_args.append("--no-models-autoload")
207217
if self.jinja:
208218
server_args.append("--jinja")
209219
else:
@@ -295,7 +305,13 @@ def make_request(
295305
result = ServerResponse()
296306
result.headers = dict(response.headers)
297307
result.status_code = response.status_code
298-
result.body = response.json() if parse_body else None
308+
if parse_body:
309+
try:
310+
result.body = response.json()
311+
except JSONDecodeError:
312+
result.body = response.text
313+
else:
314+
result.body = None
299315
print("Response from server", json.dumps(result.body, indent=2))
300316
return result
301317

0 commit comments

Comments
 (0)