From c15b04f6f55811a845574c05fe1eaa2005cf1649 Mon Sep 17 00:00:00 2001 From: herniqeu Date: Tue, 23 Dec 2025 00:06:10 -0300 Subject: [PATCH] perf: optimize hot paths with caching and O(1) operations - Replace list.pop(0) with deque.popleft() for O(1) queue dequeue - Cache compiled regex patterns in ResourceTemplate for URI matching - Cache field info mapping in FuncMetadata via lazy property - Throttle expired task cleanup with interval-based execution These optimizations target high-frequency operations in message queuing, resource lookups, tool calls, and task store access. --- src/mcp/server/fastmcp/resources/templates.py | 20 +++--- .../server/fastmcp/utilities/func_metadata.py | 31 ++++----- .../tasks/in_memory_task_store.py | 64 ++++--------------- .../experimental/tasks/message_queue.py | 29 ++------- 4 files changed, 49 insertions(+), 95 deletions(-) diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index a98d37f0ac..9a2441867f 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -33,6 +33,7 @@ class ResourceTemplate(BaseModel): fn: Callable[..., Any] = Field(exclude=True) parameters: dict[str, Any] = Field(description="JSON schema for function parameters") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + _uri_pattern: re.Pattern[str] | None = None @classmethod def from_function( @@ -66,7 +67,10 @@ def from_function( # ensure the arguments are properly cast fn = validate_call(fn) - return cls( + pattern_str = uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") + compiled_pattern = re.compile(f"^{pattern_str}$") + + instance = cls( uri_template=uri_template, name=func_name, title=title, @@ -78,15 +82,15 @@ def from_function( parameters=parameters, context_kwarg=context_kwarg, ) + instance._uri_pattern = compiled_pattern + return instance def matches(self, uri: str) -> dict[str, Any] | None: - """Check if URI matches template and extract parameters.""" - # Convert template to regex pattern - pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") - match = re.match(f"^{pattern}$", uri) - if match: - return match.groupdict() - return None + if self._uri_pattern is None: + pattern_str = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") + self._uri_pattern = re.compile(f"^{pattern_str}$") + match = self._uri_pattern.match(uri) + return match.groupdict() if match else None async def create_resource( self, diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index fa443d2fcb..f57e7b5920 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -70,6 +70,18 @@ class FuncMetadata(BaseModel): output_schema: dict[str, Any] | None = None output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None wrap_output: bool = False + _key_to_field_info: dict[str, FieldInfo] | None = None + + @property + def key_to_field_info(self) -> dict[str, FieldInfo]: + if self._key_to_field_info is None: + mapping: dict[str, FieldInfo] = {} + for field_name, field_info in self.arg_model.model_fields.items(): + mapping[field_name] = field_info + if field_info.alias: + mapping[field_info.alias] = field_info + self._key_to_field_info = mapping + return self._key_to_field_info async def call_fn_with_arg_validation( self, @@ -141,30 +153,19 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: it seems incapable of NOT doing this. For sub-models, it tends to pass dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ - new_data = data.copy() # Shallow copy - - # Build a mapping from input keys (including aliases) to field info - key_to_field_info: dict[str, FieldInfo] = {} - for field_name, field_info in self.arg_model.model_fields.items(): - # Map both the field name and its alias (if any) to the field info - key_to_field_info[field_name] = field_info - if field_info.alias: - key_to_field_info[field_info.alias] = field_info + new_data = data.copy() for data_key, data_value in data.items(): - if data_key not in key_to_field_info: # pragma: no cover + if data_key not in self.key_to_field_info: # pragma: no cover continue - field_info = key_to_field_info[data_key] + field_info = self.key_to_field_info[data_key] if isinstance(data_value, str) and field_info.annotation is not str: try: pre_parsed = json.loads(data_value) except json.JSONDecodeError: - continue # Not JSON - skip + continue if isinstance(pre_parsed, str | int | float): - # This is likely that the raw value is e.g. `"hello"` which we - # Should really be parsed as '"hello"' in Python - but if we parse - # it as JSON it'll turn into just 'hello'. So we skip it. continue new_data[data_key] = pre_parsed assert new_data.keys() == data.keys() diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py index 7b630ce6e2..f96aba975f 100644 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -17,14 +17,13 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import Result, Task, TaskMetadata, TaskStatus +CLEANUP_INTERVAL_SECONDS = 1.0 + @dataclass class StoredTask: - """Internal storage representation of a task.""" - task: Task result: Result | None = None - # Time when this task should be removed (None = never) expires_at: datetime | None = field(default=None) @@ -49,21 +48,26 @@ def __init__(self, page_size: int = 10) -> None: self._tasks: dict[str, StoredTask] = {} self._page_size = page_size self._update_events: dict[str, anyio.Event] = {} + self._last_cleanup: datetime | None = None def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: - """Calculate expiry time from TTL in milliseconds.""" if ttl_ms is None: return None return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) def _is_expired(self, stored: StoredTask) -> bool: - """Check if a task has expired.""" if stored.expires_at is None: return False return datetime.now(timezone.utc) >= stored.expires_at def _cleanup_expired(self) -> None: - """Remove all expired tasks. Called lazily during access operations.""" + now = datetime.now(timezone.utc) + if self._last_cleanup is not None: + elapsed = (now - self._last_cleanup).total_seconds() + if elapsed < CLEANUP_INTERVAL_SECONDS: + return + + self._last_cleanup = now expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] for task_id in expired_ids: del self._tasks[task_id] @@ -73,34 +77,21 @@ async def create_task( metadata: TaskMetadata, task_id: str | None = None, ) -> Task: - """Create a new task with the given metadata.""" - # Cleanup expired tasks on access self._cleanup_expired() - task = create_task_state(metadata, task_id) if task.taskId in self._tasks: raise ValueError(f"Task with ID {task.taskId} already exists") - stored = StoredTask( - task=task, - expires_at=self._calculate_expiry(metadata.ttl), - ) + stored = StoredTask(task=task, expires_at=self._calculate_expiry(metadata.ttl)) self._tasks[task.taskId] = stored - - # Return a copy to prevent external modification return Task(**task.model_dump()) async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID.""" - # Cleanup expired tasks on access self._cleanup_expired() - stored = self._tasks.get(task_id) if stored is None: return None - - # Return a copy to prevent external modification return Task(**stored.task.model_dump()) async def update_task( @@ -109,12 +100,10 @@ async def update_task( status: TaskStatus | None = None, status_message: str | None = None, ) -> Task: - """Update a task's status and/or message.""" stored = self._tasks.get(task_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") - # Per spec: Terminal states MUST NOT transition to any other status if status is not None and status != stored.task.status and is_terminal(stored.task.status): raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") @@ -126,57 +115,43 @@ async def update_task( if status_message is not None: stored.task.statusMessage = status_message - # Update lastUpdatedAt on any change stored.task.lastUpdatedAt = datetime.now(timezone.utc) - # If task is now terminal and has TTL, reset expiry timer if status is not None and is_terminal(status) and stored.task.ttl is not None: stored.expires_at = self._calculate_expiry(stored.task.ttl) - # Notify waiters if status changed if status_changed: await self.notify_update(task_id) return Task(**stored.task.model_dump()) async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task.""" stored = self._tasks.get(task_id) if stored is None: raise ValueError(f"Task with ID {task_id} not found") - stored.result = result async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task.""" stored = self._tasks.get(task_id) - if stored is None: - return None - - return stored.result + return stored.result if stored else None async def list_tasks( self, cursor: str | None = None, ) -> tuple[list[Task], str | None]: - """List tasks with pagination.""" - # Cleanup expired tasks on access self._cleanup_expired() - all_task_ids = list(self._tasks.keys()) start_index = 0 if cursor is not None: try: - cursor_index = all_task_ids.index(cursor) - start_index = cursor_index + 1 + start_index = all_task_ids.index(cursor) + 1 except ValueError: raise ValueError(f"Invalid cursor: {cursor}") page_task_ids = all_task_ids[start_index : start_index + self._page_size] tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] - # Determine next cursor next_cursor = None if start_index + self._page_size < len(all_task_ids) and page_task_ids: next_cursor = page_task_ids[-1] @@ -184,36 +159,25 @@ async def list_tasks( return tasks, next_cursor async def delete_task(self, task_id: str) -> bool: - """Delete a task.""" if task_id not in self._tasks: return False - del self._tasks[task_id] return True async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes.""" if task_id not in self._tasks: raise ValueError(f"Task with ID {task_id} not found") - - # Create a fresh event for waiting (anyio.Event can't be cleared) self._update_events[task_id] = anyio.Event() - event = self._update_events[task_id] - await event.wait() + await self._update_events[task_id].wait() async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated.""" if task_id in self._update_events: self._update_events[task_id].set() - # --- Testing/debugging helpers --- - def cleanup(self) -> None: - """Cleanup all tasks (useful for testing or graceful shutdown).""" self._tasks.clear() self._update_events.clear() def get_all_tasks(self) -> list[Task]: - """Get all tasks (useful for debugging). Returns copies to prevent modification.""" self._cleanup_expired() return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py index 69b6609887..f2f37af7a9 100644 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -13,6 +13,7 @@ """ from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Literal @@ -162,67 +163,51 @@ class InMemoryTaskMessageQueue(TaskMessageQueue): """ def __init__(self) -> None: - self._queues: dict[str, list[QueuedMessage]] = {} + self._queues: dict[str, deque[QueuedMessage]] = {} self._events: dict[str, anyio.Event] = {} - def _get_queue(self, task_id: str) -> list[QueuedMessage]: - """Get or create the queue for a task.""" + def _get_queue(self, task_id: str) -> deque[QueuedMessage]: if task_id not in self._queues: - self._queues[task_id] = [] + self._queues[task_id] = deque() return self._queues[task_id] async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue.""" queue = self._get_queue(task_id) queue.append(message) - # Signal that a message is available await self.notify_message_available(task_id) async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message.""" queue = self._get_queue(task_id) if not queue: return None - return queue.pop(0) + return queue.popleft() async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it.""" queue = self._get_queue(task_id) - if not queue: - return None - return queue[0] + return queue[0] if queue else None async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty.""" - queue = self._get_queue(task_id) - return len(queue) == 0 + return len(self._get_queue(task_id)) == 0 async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages.""" queue = self._get_queue(task_id) messages = list(queue) queue.clear() return messages async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available.""" - # Check if there are already messages if not await self.is_empty(task_id): return - # Create a fresh event for waiting (anyio.Event can't be cleared) self._events[task_id] = anyio.Event() event = self._events[task_id] - # Double-check after creating event (avoid race condition) if not await self.is_empty(task_id): return - # Wait for a new message await event.wait() async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available.""" if task_id in self._events: self._events[task_id].set()