diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index cadaa91849f..3405be3e25d 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -65,6 +65,7 @@ def test_server_slots(): def test_load_split_model(): global server + server.offline = False server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" server.model_alias = "tinyllama-split" diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index e6f3c6485c0..e85f2c33829 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -17,7 +17,6 @@ def create_server(): ] ) def test_router_chat_completion_stream(model: str, success: bool): - # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server global server server.start() content = "" @@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool): else: assert ex is not None assert content == "" + + +def _get_model_status(model_id: str) -> str: + res = server.make_request("GET", "/models") + assert res.status_code == 200 + for item in res.body.get("data", []): + if item.get("id") == model_id or item.get("model") == model_id: + return item["status"]["value"] + raise AssertionError(f"Model {model_id} not found in /models response") + + +def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str: + deadline = time.time() + timeout + last_status = None + while time.time() < deadline: + last_status = _get_model_status(model_id) + if last_status in desired: + return last_status + time.sleep(1) + raise AssertionError( + f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}" + ) + + +def _load_model_and_wait( + model_id: str, timeout: int = 60, headers: dict | None = None +) -> None: + load_res = server.make_request( + "POST", "/models/load", data={"model": model_id}, headers=headers + ) + assert load_res.status_code == 200 + assert isinstance(load_res.body, dict) + assert load_res.body.get("success") is True + _wait_for_model_status(model_id, {"loaded"}, timeout=timeout) + + +def test_router_unload_model(): + global server + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + _load_model_and_wait(model_id) + + unload_res = server.make_request("POST", "/models/unload", data={"model": model_id}) + assert unload_res.status_code == 200 + assert unload_res.body.get("success") is True + _wait_for_model_status(model_id, {"unloaded"}) + + +def test_router_models_max_evicts_lru(): + global server + server.models_max = 2 + server.start() + + candidate_models = [ + "ggml-org/tinygemma3-GGUF:Q8_0", + "ggml-org/test-model-stories260K", + "ggml-org/test-model-stories260K-infill", + ] + + # Load only the first 2 models to fill the cache + first, second, third = candidate_models[:3] + + _load_model_and_wait(first, timeout=120) + _load_model_and_wait(second, timeout=120) + + # Verify both models are loaded + assert _get_model_status(first) == "loaded" + assert _get_model_status(second) == "loaded" + + # Load the third model - this should trigger LRU eviction of the first model + _load_model_and_wait(third, timeout=120) + + # Verify eviction: third is loaded, first was evicted + assert _get_model_status(third) == "loaded" + assert _get_model_status(first) == "unloaded" + + +def test_router_no_models_autoload(): + global server + server.no_models_autoload = True + server.start() + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 400 + assert "error" in res.body + + _load_model_and_wait(model_id) + + success_res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert success_res.status_code == 200 + assert "error" not in success_res.body + + +def test_router_api_key_required(): + global server + server.api_key = "sk-router-secret" + server.start() + + model_id = "ggml-org/tinygemma3-GGUF:Q8_0" + auth_headers = {"Authorization": f"Bearer {server.api_key}"} + + res = server.make_request( + "POST", + "/v1/chat/completions", + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert res.status_code == 401 + assert res.body.get("error", {}).get("type") == "authentication_error" + + _load_model_and_wait(model_id, headers=auth_headers) + + authed = server.make_request( + "POST", + "/v1/chat/completions", + headers=auth_headers, + data={ + "model": model_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 4, + }, + ) + assert authed.status_code == 200 + assert "error" not in authed.body diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index dfd2c8a260a..48e7403602f 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -7,6 +7,7 @@ import os import re import json +from json import JSONDecodeError import sys import requests import time @@ -83,6 +84,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None + models_dir: str | None = None + models_max: int | None = None + no_models_autoload: bool | None = None lora_files: List[str] | None = None enable_ctx_shift: int | None = False draft_min: int | None = None @@ -143,6 +147,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--hf-repo", self.model_hf_repo]) if self.model_hf_file: server_args.extend(["--hf-file", self.model_hf_file]) + if self.models_dir: + server_args.extend(["--models-dir", self.models_dir]) + if self.models_max is not None: + server_args.extend(["--models-max", self.models_max]) if self.n_batch: server_args.extend(["--batch-size", self.n_batch]) if self.n_ubatch: @@ -204,6 +212,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.no_models_autoload: + server_args.append("--no-models-autoload") if self.jinja: server_args.append("--jinja") else: @@ -295,7 +305,13 @@ def make_request( result = ServerResponse() result.headers = dict(response.headers) result.status_code = response.status_code - result.body = response.json() if parse_body else None + if parse_body: + try: + result.body = response.json() + except JSONDecodeError: + result.body = response.text + else: + result.body = None print("Response from server", json.dumps(result.body, indent=2)) return result @@ -434,8 +450,9 @@ def load_all() -> None: @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K.gguf" + server.offline = True # will be downloaded by load_all() + server.model_hf_repo = "ggml-org/test-model-stories260K" + server.model_hf_file = None server.model_alias = "tinyllama-2" server.n_ctx = 512 server.n_batch = 32 @@ -479,8 +496,8 @@ def bert_bge_small_with_fa() -> ServerProcess: def tinyllama_infill() -> ServerProcess: server = ServerProcess() server.offline = True # will be downloaded by load_all() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_hf_repo = "ggml-org/test-model-stories260K-infill" + server.model_hf_file = None server.model_alias = "tinyllama-infill" server.n_ctx = 2048 server.n_batch = 1024 @@ -537,6 +554,7 @@ def tinygemma3() -> ServerProcess: @staticmethod def router() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() # router server has no models server.model_file = None server.model_alias = None