diff --git a/jupyter_ai_jupyternaut/models/parameter_schemas.py b/jupyter_ai_jupyternaut/models/parameter_schemas.py index 6efd7b0..89c64c0 100644 --- a/jupyter_ai_jupyternaut/models/parameter_schemas.py +++ b/jupyter_ai_jupyternaut/models/parameter_schemas.py @@ -107,7 +107,7 @@ }, "api_base": { "type": "string", - "description": "Base URL where LLM requests are sent, used for enterprise proxy gateways." + "description": "Base URL where LLM requests are sent, used for local models, enterprise proxies, or other hosting providers (e.g. vLLM)." } } diff --git a/jupyter_ai_jupyternaut/models/parameters_rest_api.py b/jupyter_ai_jupyternaut/models/parameters_rest_api.py index b7d4971..921d8d5 100644 --- a/jupyter_ai_jupyternaut/models/parameters_rest_api.py +++ b/jupyter_ai_jupyternaut/models/parameters_rest_api.py @@ -32,6 +32,9 @@ def get(self): # Temporary common parameters that work across most models common_params = ["temperature", "max_tokens", "top_p", "stop"] + # parameters that are actually inputs to the client, not the model + # always include these (get_supported_openai_params doesn't include them) + client_params = ["api_base"] # Params controlling tool availability & usage require a unique UX # if they are to be made configurable from the frontend. Therefore # they are disabled for now. @@ -49,7 +52,8 @@ def get(self): parameter_names = common_params else: parameter_names = common_params - + # always include client parameters + parameter_names.extend(client_params) # Filter out excluded params parameter_names = [n for n in parameter_names if n not in EXCLUDED_PARAMS] @@ -109,7 +113,7 @@ def put(self): except ValueError as e: raise HTTPError(400, f"Invalid value for parameter '{param_name}': {str(e)}") - config_manager = self.settings.get("jai_config_manager") + config_manager = self.settings.get("jupyternaut.config_manager") if not config_manager: raise HTTPError(500, "Config manager not available") diff --git a/jupyter_ai_jupyternaut/tests/test_handlers.py b/jupyter_ai_jupyternaut/tests/test_handlers.py index 9f22fa5..d999742 100644 --- a/jupyter_ai_jupyternaut/tests/test_handlers.py +++ b/jupyter_ai_jupyternaut/tests/test_handlers.py @@ -1,4 +1,7 @@ import json +import pytest + +from jupyter_ai_jupyternaut.models import parameter_schemas async def test_get_example(jp_fetch): @@ -8,6 +11,56 @@ async def test_get_example(jp_fetch): # Then assert response.code == 200 payload = json.loads(response.body) - assert payload == { - "data": "This is /api/jupyternaut/get-example endpoint!" - } \ No newline at end of file + assert payload == {"data": "This is /api/jupyternaut/get-example endpoint!"} + + +@pytest.mark.parametrize( + "model", + [ + None, + "openai/gpt-oss-120b", + "hosted_vllm/doesntmatter", + "anthropic/claude-3-5-haiku-latest", + ], +) +async def test_get_parameters(jp_fetch, model): + params = {} + if model: + params["model"] = model + response = await jp_fetch("api/jupyternaut/model-parameters", params=params) + assert response.code == 200 + payload = json.loads(response.body) + expected_params = [ + "api_base", + "max_tokens", + "stop", + "temperature", + "top_p", + ] + if model: + expected_params.extend(["max_completion_tokens"]) + if not model.startswith("anthropic/"): + expected_params.extend(["frequency_penalty"]) + + for param in expected_params: + assert param in payload["parameter_names"] + assert param in payload["parameters"] + assert "description" in payload["parameters"][param] + + +async def test_put_params(jp_fetch): + # TODO: validate all types, error handling + response = await jp_fetch( + "api/jupyternaut/model-parameters", + body=json.dumps({ + "model_id": "hosted_vllm/mlx-community/gpt-oss-20b-MXFP4-Q8", + "parameters": { + "api_base": { + "value": "http://127.0.0.1:8080", + "type": "string", + }, + }, + }), + method="PUT", + ) + assert response.code == 200