Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/mcp/server/fastmcp/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ResourceTemplate(BaseModel):
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
_uri_pattern: re.Pattern[str] | None = None

@classmethod
def from_function(
Expand Down Expand Up @@ -66,7 +67,10 @@ def from_function(
# ensure the arguments are properly cast
fn = validate_call(fn)

return cls(
pattern_str = uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
compiled_pattern = re.compile(f"^{pattern_str}$")

instance = cls(
uri_template=uri_template,
name=func_name,
title=title,
Expand All @@ -78,15 +82,15 @@ def from_function(
parameters=parameters,
context_kwarg=context_kwarg,
)
instance._uri_pattern = compiled_pattern
return instance

def matches(self, uri: str) -> dict[str, Any] | None:
"""Check if URI matches template and extract parameters."""
# Convert template to regex pattern
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
match = re.match(f"^{pattern}$", uri)
if match:
return match.groupdict()
return None
if self._uri_pattern is None:
pattern_str = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
self._uri_pattern = re.compile(f"^{pattern_str}$")
match = self._uri_pattern.match(uri)
return match.groupdict() if match else None

async def create_resource(
self,
Expand Down
31 changes: 16 additions & 15 deletions src/mcp/server/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class FuncMetadata(BaseModel):
output_schema: dict[str, Any] | None = None
output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
wrap_output: bool = False
_key_to_field_info: dict[str, FieldInfo] | None = None

@property
def key_to_field_info(self) -> dict[str, FieldInfo]:
if self._key_to_field_info is None:
mapping: dict[str, FieldInfo] = {}
for field_name, field_info in self.arg_model.model_fields.items():
mapping[field_name] = field_info
if field_info.alias:
mapping[field_info.alias] = field_info
self._key_to_field_info = mapping
return self._key_to_field_info

async def call_fn_with_arg_validation(
self,
Expand Down Expand Up @@ -141,30 +153,19 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
it seems incapable of NOT doing this. For sub-models, it tends to pass
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy

# Build a mapping from input keys (including aliases) to field info
key_to_field_info: dict[str, FieldInfo] = {}
for field_name, field_info in self.arg_model.model_fields.items():
# Map both the field name and its alias (if any) to the field info
key_to_field_info[field_name] = field_info
if field_info.alias:
key_to_field_info[field_info.alias] = field_info
new_data = data.copy()

for data_key, data_value in data.items():
if data_key not in key_to_field_info: # pragma: no cover
if data_key not in self.key_to_field_info: # pragma: no cover
continue

field_info = key_to_field_info[data_key]
field_info = self.key_to_field_info[data_key]
if isinstance(data_value, str) and field_info.annotation is not str:
try:
pre_parsed = json.loads(data_value)
except json.JSONDecodeError:
continue # Not JSON - skip
continue
if isinstance(pre_parsed, str | int | float):
# This is likely that the raw value is e.g. `"hello"` which we
# Should really be parsed as '"hello"' in Python - but if we parse
# it as JSON it'll turn into just 'hello'. So we skip it.
continue
new_data[data_key] = pre_parsed
assert new_data.keys() == data.keys()
Expand Down
64 changes: 14 additions & 50 deletions src/mcp/shared/experimental/tasks/in_memory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import Result, Task, TaskMetadata, TaskStatus

CLEANUP_INTERVAL_SECONDS = 1.0


@dataclass
class StoredTask:
"""Internal storage representation of a task."""

task: Task
result: Result | None = None
# Time when this task should be removed (None = never)
expires_at: datetime | None = field(default=None)


Expand All @@ -49,21 +48,26 @@ def __init__(self, page_size: int = 10) -> None:
self._tasks: dict[str, StoredTask] = {}
self._page_size = page_size
self._update_events: dict[str, anyio.Event] = {}
self._last_cleanup: datetime | None = None

def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
"""Calculate expiry time from TTL in milliseconds."""
if ttl_ms is None:
return None
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)

def _is_expired(self, stored: StoredTask) -> bool:
"""Check if a task has expired."""
if stored.expires_at is None:
return False
return datetime.now(timezone.utc) >= stored.expires_at

def _cleanup_expired(self) -> None:
"""Remove all expired tasks. Called lazily during access operations."""
now = datetime.now(timezone.utc)
if self._last_cleanup is not None:
elapsed = (now - self._last_cleanup).total_seconds()
if elapsed < CLEANUP_INTERVAL_SECONDS:
return

self._last_cleanup = now
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
for task_id in expired_ids:
del self._tasks[task_id]
Expand All @@ -73,34 +77,21 @@ async def create_task(
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""Create a new task with the given metadata."""
# Cleanup expired tasks on access
self._cleanup_expired()

task = create_task_state(metadata, task_id)

if task.taskId in self._tasks:
raise ValueError(f"Task with ID {task.taskId} already exists")

stored = StoredTask(
task=task,
expires_at=self._calculate_expiry(metadata.ttl),
)
stored = StoredTask(task=task, expires_at=self._calculate_expiry(metadata.ttl))
self._tasks[task.taskId] = stored

# Return a copy to prevent external modification
return Task(**task.model_dump())

async def get_task(self, task_id: str) -> Task | None:
"""Get a task by ID."""
# Cleanup expired tasks on access
self._cleanup_expired()

stored = self._tasks.get(task_id)
if stored is None:
return None

# Return a copy to prevent external modification
return Task(**stored.task.model_dump())

async def update_task(
Expand All @@ -109,12 +100,10 @@ async def update_task(
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""Update a task's status and/or message."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")

# Per spec: Terminal states MUST NOT transition to any other status
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")

Expand All @@ -126,94 +115,69 @@ async def update_task(
if status_message is not None:
stored.task.statusMessage = status_message

# Update lastUpdatedAt on any change
stored.task.lastUpdatedAt = datetime.now(timezone.utc)

# If task is now terminal and has TTL, reset expiry timer
if status is not None and is_terminal(status) and stored.task.ttl is not None:
stored.expires_at = self._calculate_expiry(stored.task.ttl)

# Notify waiters if status changed
if status_changed:
await self.notify_update(task_id)

return Task(**stored.task.model_dump())

async def store_result(self, task_id: str, result: Result) -> None:
"""Store the result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")

stored.result = result

async def get_result(self, task_id: str) -> Result | None:
"""Get the stored result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None

return stored.result
return stored.result if stored else None

async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""List tasks with pagination."""
# Cleanup expired tasks on access
self._cleanup_expired()

all_task_ids = list(self._tasks.keys())

start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
start_index = all_task_ids.index(cursor) + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")

page_task_ids = all_task_ids[start_index : start_index + self._page_size]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]

# Determine next cursor
next_cursor = None
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
next_cursor = page_task_ids[-1]

return tasks, next_cursor

async def delete_task(self, task_id: str) -> bool:
"""Delete a task."""
if task_id not in self._tasks:
return False

del self._tasks[task_id]
return True

async def wait_for_update(self, task_id: str) -> None:
"""Wait until the task status changes."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")

# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
await self._update_events[task_id].wait()

async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()

# --- Testing/debugging helpers ---

def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()

def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
self._cleanup_expired()
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
29 changes: 7 additions & 22 deletions src/mcp/shared/experimental/tasks/message_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
Expand Down Expand Up @@ -162,67 +163,51 @@ class InMemoryTaskMessageQueue(TaskMessageQueue):
"""

def __init__(self) -> None:
self._queues: dict[str, list[QueuedMessage]] = {}
self._queues: dict[str, deque[QueuedMessage]] = {}
self._events: dict[str, anyio.Event] = {}

def _get_queue(self, task_id: str) -> list[QueuedMessage]:
"""Get or create the queue for a task."""
def _get_queue(self, task_id: str) -> deque[QueuedMessage]:
if task_id not in self._queues:
self._queues[task_id] = []
self._queues[task_id] = deque()
return self._queues[task_id]

async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
"""Add a message to the queue."""
queue = self._get_queue(task_id)
queue.append(message)
# Signal that a message is available
await self.notify_message_available(task_id)

async def dequeue(self, task_id: str) -> QueuedMessage | None:
"""Remove and return the next message."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue.pop(0)
return queue.popleft()

async def peek(self, task_id: str) -> QueuedMessage | None:
"""Return the next message without removing it."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue[0]
return queue[0] if queue else None

async def is_empty(self, task_id: str) -> bool:
"""Check if the queue is empty."""
queue = self._get_queue(task_id)
return len(queue) == 0
return len(self._get_queue(task_id)) == 0

async def clear(self, task_id: str) -> list[QueuedMessage]:
"""Remove and return all messages."""
queue = self._get_queue(task_id)
messages = list(queue)
queue.clear()
return messages

async def wait_for_message(self, task_id: str) -> None:
"""Wait until a message is available."""
# Check if there are already messages
if not await self.is_empty(task_id):
return

# Create a fresh event for waiting (anyio.Event can't be cleared)
self._events[task_id] = anyio.Event()
event = self._events[task_id]

# Double-check after creating event (avoid race condition)
if not await self.is_empty(task_id):
return

# Wait for a new message
await event.wait()

async def notify_message_available(self, task_id: str) -> None:
"""Signal that a message is available."""
if task_id in self._events:
self._events[task_id].set()

Expand Down
Loading