Skip to content

Commit b21f844

Browse files
committed
Rewrite MCP schema generation to be simpler and fix some bugs
1 parent b8ffd77 commit b21f844

File tree

2 files changed

+86
-108
lines changed

2 files changed

+86
-108
lines changed

examples/mcp_example.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,6 @@
77

88
mcp = McpServer("example")
99

10-
class SystemInfo(TypedDict):
11-
platform: Annotated[str, "Operating system platform"]
12-
python_version: Annotated[str, "Python version"]
13-
machine: Annotated[str, "Machine architecture"]
14-
timestamp: Annotated[float, "Current timestamp"]
15-
16-
class GreetingResponse(TypedDict):
17-
message: Annotated[str, "Greeting message"]
18-
name: Annotated[str, "Name that was greeted"]
19-
age: Annotated[NotRequired[int], "Age if provided"]
20-
2110
@mcp.tool
2211
def divide(
2312
numerator: Annotated[float, "Numerator"],
@@ -26,6 +15,11 @@ def divide(
2615
"""Divide two numbers (no zero check - tests natural exceptions)"""
2716
return numerator / denominator
2817

18+
class GreetingResponse(TypedDict):
19+
message: Annotated[str, "Greeting message"]
20+
name: Annotated[str, "Name that was greeted"]
21+
age: Annotated[NotRequired[int], "Age if provided"]
22+
2923
@mcp.tool
3024
def greet(
3125
name: Annotated[str, "Name to greet"],
@@ -43,6 +37,12 @@ def greet(
4337
"name": name
4438
}
4539

40+
class SystemInfo(TypedDict):
41+
platform: Annotated[str, "Operating system platform"]
42+
python_version: Annotated[str, "Python version"]
43+
machine: Annotated[str, "Machine architecture"]
44+
timestamp: Annotated[float, "Current timestamp"]
45+
4646
@mcp.tool
4747
def get_system_info() -> SystemInfo:
4848
"""Get system information"""
@@ -79,6 +79,16 @@ def struct_get(
7979
for name in (names if isinstance(names, list) else [names])
8080
]
8181

82+
@mcp.tool
83+
def random_dict(param: dict[str, int] | None) -> dict:
84+
"""Return a random dictionary for testing serialization"""
85+
return {
86+
**(param or {}),
87+
"x": 42,
88+
"y": 7,
89+
"z": 99,
90+
}
91+
8292
@mcp.resource("example://system_info")
8393
def system_info_resource() -> SystemInfo:
8494
"""Resource providing system information"""

src/zeromcp/mcp.py

Lines changed: 65 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import threading
88
import traceback
99
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
10-
from typing import Any, Callable, Union, Annotated, BinaryIO, NotRequired, get_origin, get_args, get_type_hints
10+
from typing import Any, Callable, Union, Annotated, BinaryIO, NotRequired, get_origin, get_args, get_type_hints, is_typeddict
11+
from types import UnionType
1112
from urllib.parse import urlparse, parse_qs
1213
from io import BufferedIOBase
1314

@@ -161,9 +162,7 @@ def _handle_sse_post(self, body: bytes):
161162
sse_conn = self.mcp_server._sse_connections.get(session_id)
162163
if sse_conn is None or not sse_conn.alive:
163164
# No SSE connection found
164-
error_msg = f"No active SSE connection found for session {session_id}"
165-
print(f"[MCP SSE ERROR] {error_msg}")
166-
self.send_error(400, error_msg)
165+
self.send_error(400, f"No active SSE connection found for session {session_id}")
167166
return
168167

169168
# Send response via SSE event stream
@@ -362,7 +361,7 @@ def _mcp_tools_call(self, name: str, arguments: dict | None = None, _meta: dict
362361
result = tool_response.get("result") if tool_response else None
363362
return {
364363
"content": [{"type": "text", "text": json.dumps(result, indent=2)}],
365-
"structuredContent": result if isinstance(result, dict) else {"value": result},
364+
"structuredContent": result if isinstance(result, dict) else {"result": result},
366365
"isError": False,
367366
}
368367

@@ -469,88 +468,68 @@ def _mcp_resources_read(self, uri: str, _meta: dict | None = None) -> dict:
469468

470469
def _type_to_json_schema(self, py_type: Any) -> dict:
471470
"""Convert Python type hint to JSON schema object"""
472-
# Handle Annotated[Type, "description"]
473-
if get_origin(py_type) is Annotated:
471+
origin = get_origin(py_type)
472+
# Annotated[T, "description"]
473+
if origin is Annotated:
474474
args = get_args(py_type)
475-
actual_type = args[0]
476-
description = args[1] if len(args) > 1 else None
477-
schema = self._type_to_json_schema(actual_type)
478-
if description:
479-
schema["description"] = description
480-
return schema
481-
482-
# Handle Union/Optional types
483-
if get_origin(py_type) is Union:
484-
union_args = get_args(py_type)
485-
non_none = [t for t in union_args if t is not type(None)]
486-
if len(non_none) == 1:
487-
return self._type_to_json_schema(non_none[0])
488-
# Multiple types -> anyOf
489-
return {"anyOf": [self._type_to_json_schema(t) for t in non_none]}
475+
return {
476+
**self._type_to_json_schema(args[0]),
477+
"description": str(args[-1]),
478+
}
490479

491-
# Primitives
492-
if py_type == int:
493-
return {"type": "integer"}
494-
if py_type == float:
495-
return {"type": "number"}
496-
if py_type == str:
497-
return {"type": "string"}
498-
if py_type == bool:
499-
return {"type": "boolean"}
500-
501-
# Handle list types
502-
if py_type == list or get_origin(py_type) is list:
503-
args = get_args(py_type)
504-
schema: dict[str, Any] = {"type": "array"}
505-
if args:
506-
schema["items"] = self._type_to_json_schema(args[0])
507-
return schema
480+
# NotRequired[T]
481+
if origin is NotRequired:
482+
return self._type_to_json_schema(get_args(py_type)[0])
483+
484+
# Union[Ts..], Optional[T] and T1 | T2
485+
if origin in (Union, UnionType):
486+
return {"anyOf": [self._type_to_json_schema(t) for t in get_args(py_type)]}
487+
488+
# list[T]
489+
if origin is list:
490+
return {
491+
"type": "array",
492+
"items": self._type_to_json_schema(get_args(py_type)[0]),
493+
}
508494

509-
# Handle dict types
510-
if py_type == dict or get_origin(py_type) is dict:
511-
return {"type": "object"}
495+
# dict[str, T]
496+
if origin is dict:
497+
return {
498+
"type": "object",
499+
"additionalProperties": self._type_to_json_schema(get_args(py_type)[1]),
500+
}
512501

513-
# TypedDict detection
514-
if hasattr(py_type, "__annotations__"):
515-
if hasattr(py_type, "__required_keys__") or hasattr(py_type, "__optional_keys__"):
516-
return self._typed_dict_to_schema(py_type)
502+
# TypedDict
503+
if is_typeddict(py_type):
504+
return self._typed_dict_to_schema(py_type)
517505

518-
# Fallback
519-
return {"type": "object"}
506+
# Primitives
507+
return {
508+
"type": {
509+
int: "integer",
510+
float: "number",
511+
str: "string",
512+
bool: "boolean",
513+
list: "array",
514+
dict: "object",
515+
type(None): "null",
516+
}.get(py_type, "object"),
517+
}
520518

521519
def _typed_dict_to_schema(self, typed_dict_class) -> dict:
522520
"""Convert TypedDict to JSON schema"""
523521
hints = get_type_hints(typed_dict_class, include_extras=True)
524-
properties = {}
525-
required = []
522+
required_keys = getattr(typed_dict_class, '__required_keys__', set(hints.keys()))
526523

527-
for field_name, field_type in hints.items():
528-
# Check if field is NotRequired
529-
is_not_required = get_origin(field_type) is NotRequired
530-
if is_not_required:
531-
field_type = get_args(field_type)[0]
532-
533-
properties[field_name] = self._type_to_json_schema(field_type)
534-
535-
# Add to required if not NotRequired
536-
if not is_not_required:
537-
# Also check __required_keys__ if available
538-
if hasattr(typed_dict_class, "__required_keys__"):
539-
if field_name in typed_dict_class.__required_keys__:
540-
if field_name not in required:
541-
required.append(field_name)
542-
else:
543-
# Default to required if no __required_keys__
544-
required.append(field_name)
545-
546-
schema = {
524+
return {
547525
"type": "object",
548-
"properties": properties,
526+
"properties": {
527+
field_name: self._type_to_json_schema(field_type)
528+
for field_name, field_type in hints.items()
529+
},
530+
"required": [key for key in hints.keys() if key in required_keys],
531+
"additionalProperties": False
549532
}
550-
if required:
551-
schema["required"] = required
552-
553-
return schema
554533

555534
def _generate_tool_schema(self, func_name: str, func: Callable) -> dict:
556535
"""Generate MCP tool schema from a function"""
@@ -563,25 +542,16 @@ def _generate_tool_schema(self, func_name: str, func: Callable) -> dict:
563542
required = []
564543

565544
for param_name, param_type in hints.items():
566-
# Check if parameter has default value
567-
param = sig.parameters.get(param_name)
568-
has_default = param and param.default is not inspect.Parameter.empty
569-
570-
# Use _type_to_json_schema to handle all type conversions including Union
571545
properties[param_name] = self._type_to_json_schema(param_type)
572546

573-
# Only add to required if no default value
574-
if not has_default:
547+
# Add to required if no default value
548+
param = sig.parameters.get(param_name)
549+
if not param or param.default is inspect.Parameter.empty:
575550
required.append(param_name)
576551

577-
# Get docstring as description
578-
description = func.__doc__ or f"Call {func_name}"
579-
if description:
580-
description = description.strip()
581-
582552
schema: dict[str, Any] = {
583553
"name": func_name,
584-
"description": description,
554+
"description": (func.__doc__ or f"Call {func_name}").strip(),
585555
"inputSchema": {
586556
"type": "object",
587557
"properties": properties,
@@ -592,17 +562,15 @@ def _generate_tool_schema(self, func_name: str, func: Callable) -> dict:
592562
# Add outputSchema if return type exists and is not None
593563
if return_type and return_type is not type(None):
594564
return_schema = self._type_to_json_schema(return_type)
595-
# MCP spec requires outputSchema to always be type: object
596-
# Wrap primitives in an object with a "value" property
565+
566+
# Wrap non-object returns in a "result" property
597567
if return_schema.get("type") != "object":
598-
schema["outputSchema"] = {
568+
return_schema = {
599569
"type": "object",
600-
"properties": {
601-
"value": return_schema,
602-
},
603-
"required": ["value"],
570+
"properties": {"result": return_schema},
571+
"required": ["result"],
604572
}
605-
else:
606-
schema["outputSchema"] = return_schema
573+
574+
schema["outputSchema"] = return_schema
607575

608576
return schema

0 commit comments

Comments
 (0)