Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jupyter_ai_jupyternaut/models/parameter_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
},
"api_base": {
"type": "string",
"description": "Base URL where LLM requests are sent, used for enterprise proxy gateways."
"description": "Base URL where LLM requests are sent, used for local models, enterprise proxies, or other hosting providers (e.g. vLLM)."
}
}

Expand Down
8 changes: 6 additions & 2 deletions jupyter_ai_jupyternaut/models/parameters_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def get(self):

# Temporary common parameters that work across most models
common_params = ["temperature", "max_tokens", "top_p", "stop"]
# parameters that are actually inputs to the client, not the model
# always include these (get_supported_openai_params doesn't include them)
client_params = ["api_base"]
# Params controlling tool availability & usage require a unique UX
# if they are to be made configurable from the frontend. Therefore
# they are disabled for now.
Expand All @@ -49,7 +52,8 @@ def get(self):
parameter_names = common_params
else:
parameter_names = common_params

# always include client parameters
parameter_names.extend(client_params)
# Filter out excluded params
parameter_names = [n for n in parameter_names if n not in EXCLUDED_PARAMS]

Expand Down Expand Up @@ -109,7 +113,7 @@ def put(self):
except ValueError as e:
raise HTTPError(400, f"Invalid value for parameter '{param_name}': {str(e)}")

config_manager = self.settings.get("jai_config_manager")
config_manager = self.settings.get("jupyternaut.config_manager")
if not config_manager:
raise HTTPError(500, "Config manager not available")

Expand Down
59 changes: 56 additions & 3 deletions jupyter_ai_jupyternaut/tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
import pytest

from jupyter_ai_jupyternaut.models import parameter_schemas


async def test_get_example(jp_fetch):
Expand All @@ -8,6 +11,56 @@ async def test_get_example(jp_fetch):
# Then
assert response.code == 200
payload = json.loads(response.body)
assert payload == {
"data": "This is /api/jupyternaut/get-example endpoint!"
}
assert payload == {"data": "This is /api/jupyternaut/get-example endpoint!"}


@pytest.mark.parametrize(
"model",
[
None,
"openai/gpt-oss-120b",
"hosted_vllm/doesntmatter",
"anthropic/claude-3-5-haiku-latest",
],
)
async def test_get_parameters(jp_fetch, model):
params = {}
if model:
params["model"] = model
response = await jp_fetch("api/jupyternaut/model-parameters", params=params)
assert response.code == 200
payload = json.loads(response.body)
expected_params = [
"api_base",
"max_tokens",
"stop",
"temperature",
"top_p",
]
if model:
expected_params.extend(["max_completion_tokens"])
if not model.startswith("anthropic/"):
expected_params.extend(["frequency_penalty"])

for param in expected_params:
assert param in payload["parameter_names"]
assert param in payload["parameters"]
assert "description" in payload["parameters"][param]


async def test_put_params(jp_fetch):
# TODO: validate all types, error handling
response = await jp_fetch(
"api/jupyternaut/model-parameters",
body=json.dumps({
"model_id": "hosted_vllm/mlx-community/gpt-oss-20b-MXFP4-Q8",
"parameters": {
"api_base": {
"value": "http://127.0.0.1:8080",
"type": "string",
},
},
}),
method="PUT",
)
assert response.code == 200