From 23109b8f9ce7077362e74438925a0077a9e3bc0d Mon Sep 17 00:00:00 2001 From: "Swathi Murali(sm3223)" Date: Tue, 2 Dec 2025 08:52:56 -0500 Subject: [PATCH 1/5] Initial plumbing for sesison management --- src/mcp/server/checkpoint.py | 142 +++++++++ src/mcp/server/session.py | 594 ++++++++--------------------------- src/mcp/types.py | 90 ++++-- 3 files changed, 333 insertions(+), 493 deletions(-) create mode 100644 src/mcp/server/checkpoint.py diff --git a/src/mcp/server/checkpoint.py b/src/mcp/server/checkpoint.py new file mode 100644 index 000000000..a7951e9f0 --- /dev/null +++ b/src/mcp/server/checkpoint.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import abc +import time +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +from mcp.server.session import ServerSession +from mcp.types import ( + CheckpointCreateParams, + CheckpointCreateResult, + CheckpointValidateParams, + CheckpointValidateResult, + CheckpointResumeParams, + CheckpointResumeResult, + CheckpointDeleteParams, + CheckpointDeleteResult, +) + + +@runtime_checkable +class CheckpointBackend(Protocol): + """Backend that actually stores and restores state behind handles.""" + + async def create_checkpoint( + self, + session: ServerSession, + params: CheckpointCreateParams, + ) -> CheckpointCreateResult: ... + + async def validate_checkpoint( + self, + session: ServerSession, + params: CheckpointValidateParams, + ) -> CheckpointValidateResult: ... + + async def resume_checkpoint( + self, + session: ServerSession, + params: CheckpointResumeParams, + ) -> CheckpointResumeResult: ... + + async def delete_checkpoint( + self, + session: ServerSession, + params: CheckpointDeleteParams, + ) -> CheckpointDeleteResult: ... + + +@dataclass +class InMemoryHandleEntry: + value: object + digest: str + expires_at: float + + +class InMemoryCheckpointBackend(CheckpointBackend): + """Simple in-memory backend you can use for tests/POC. + + This is intentionally generic; concrete servers (data, browser, etc.) + decide *what* `value` is and how to interpret it. + """ + + def __init__(self, ttl_seconds: int = 1800) -> None: + self._ttl = ttl_seconds + self._handles: dict[str, InMemoryHandleEntry] = {} + + def _now(self) -> float: + return time.time() + + async def create_checkpoint( + self, + session: ServerSession, + params: CheckpointCreateParams, + ) -> CheckpointCreateResult: + # session.fastmcp or session.server can expose some "current state" + # For now you can override this backend in your server and implement + # your own snapshot logic. + raise NotImplementedError( + "Subclass InMemoryCheckpointBackend and override create_checkpoint " + "to capture concrete state (e.g. data tables, browser session)." + ) + + async def validate_checkpoint( + self, + session: ServerSession, + params: CheckpointValidateParams, + ) -> CheckpointValidateResult: + entry = self._handles.get(params.handle) + if not entry: + return CheckpointValidateResult( + valid=False, + remainingTtlSeconds=0, + digestMatch=False, + ) + + now = self._now() + if now >= entry.expires_at: + return CheckpointValidateResult( + valid=False, + remainingTtlSeconds=0, + digestMatch=params.expectedDigest == entry.digest, + ) + + remaining = int(entry.expires_at - now) + return CheckpointValidateResult( + valid=True, + remainingTtlSeconds=remaining, + digestMatch=( + params.expectedDigest is None + or params.expectedDigest == entry.digest + ), + ) + + async def resume_checkpoint( + self, + session: ServerSession, + params: CheckpointResumeParams, + ) -> CheckpointResumeResult: + entry = self._handles.get(params.handle) + if not entry: + # You’ll map this to HANDLE_NOT_FOUND at JSON-RPC level + return CheckpointResumeResult(resumed=False, handle=params.handle) + + if self._now() >= entry.expires_at: + # Map to EXPIRED + return CheckpointResumeResult(resumed=False, handle=params.handle) + + # Subclasses should take `entry.value` and rehydrate into session state. + raise NotImplementedError( + "Subclass InMemoryCheckpointBackend.resume_checkpoint to rehydrate " + "concrete session state from stored value." + ) + + async def delete_checkpoint( + self, + session: ServerSession, + params: CheckpointDeleteParams, + ) -> CheckpointDeleteResult: + deleted = params.handle in self._handles + self._handles.pop(params.handle, None) + return CheckpointDeleteResult(deleted=deleted) \ No newline at end of file diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 8f0baa3e9..05eadc05a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,12 +1,11 @@ -""" -ServerSession Module +"""ServerSession Module -This module provides the ServerSession class, which manages communication between the -server and client in the MCP (Model Context Protocol) framework. It is most commonly -used in MCP servers to interact with the client. +This module provides the ServerSession class, which manages communication +between the server and client in the MCP (Model Context Protocol) framework. +It is most commonly used in MCP servers to interact with the client. Common usage pattern: -``` + server = Server(name) @server.call_tool() @@ -20,7 +19,6 @@ async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> An else: # Fall back to basic tool operations result = await perform_basic_tool_operation(arguments) - return result @server.list_prompts() @@ -31,14 +29,13 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: return generate_custom_prompts(ctx.session.client_params) else: return default_prompts -``` -The ServerSession class is typically used internally by the Server class and should not -be instantiated directly by users of the MCP framework. +The ServerSession class is typically used internally by the Server class +and should not be instantiated directly by users of the MCP framework. """ from enum import Enum -from typing import Any, TypeVar, overload +from typing import Any, TypeVar, TYPE_CHECKING import anyio import anyio.lowlevel @@ -46,17 +43,17 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from pydantic import AnyUrl import mcp.types as types -from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.experimental.tasks.capabilities import check_tasks_capability -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, ) -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS + +# Only import the checkpoint backend for type checking to avoid runtime +# dependencies or circular imports. The actual implementation lives in +# mcp.server.checkpoint. +if TYPE_CHECKING: # pragma: no cover + from mcp.server.checkpoint import CheckpointBackend class InitializationState(Enum): @@ -66,9 +63,10 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") - ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception + RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception ) @@ -83,47 +81,64 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None - _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, - stateless: bool = False, + # NEW: optional checkpoint backend for servers that support + # checkpoint/create, checkpoint/validate, checkpoint/resume, + # checkpoint/delete at a higher level (e.g., FastMCP). + checkpoint_backend: "CheckpointBackend | None" = None, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) - self._initialization_state = ( - InitializationState.Initialized if stateless else InitializationState.NotInitialized + super().__init__( + read_stream, + write_stream, + types.ClientRequest, + types.ClientNotification, ) - + self._initialization_state = InitializationState.NotInitialized self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_reader.aclose() + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_writer.aclose() + ) + + # Store optional checkpoint backend. This does not itself handle any + # checkpoint RPCs; it just makes the backend available to higher-level + # server code via ctx.session.checkpoint_backend. + self._checkpoint_backend = checkpoint_backend @property def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params # pragma: no cover + return self._client_params @property - def experimental(self) -> ExperimentalServerSessionFeatures: - """Experimental APIs for server→client task operations. + def checkpoint_backend(self) -> "CheckpointBackend | None": + """Optional checkpoint backend attached to this session. - WARNING: These APIs are experimental and may change without notice. + FastMCP or custom servers can use this to implement the + checkpoint/create, checkpoint/validate, checkpoint/resume, + checkpoint/delete methods by delegating into the backend and using + this session as the context. """ - if self._experimental_features is None: - self._experimental_features = ExperimentalServerSessionFeatures(self) - return self._experimental_features + return self._checkpoint_backend - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" if self._client_params is None: return False + # Get client capabilities from initialization params client_caps = self._client_params.capabilities + # Check each specified capability in the passed in capability object if capability.roots is not None: if client_caps.roots is None: return False @@ -133,107 +148,93 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if capability.sampling is not None: if client_caps.sampling is None: return False - if capability.sampling.context is not None and client_caps.sampling.context is None: - return False - if capability.sampling.tools is not None and client_caps.sampling.tools is None: - return False - - if capability.elicitation is not None and client_caps.elicitation is None: - return False if capability.experimental is not None: if client_caps.experimental is None: return False + # Check each experimental capability for exp_key, exp_value in capability.experimental.items(): - if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: + if ( + exp_key not in client_caps.experimental + or client_caps.experimental[exp_key] != exp_value + ): return False - if capability.tasks is not None: - if client_caps.tasks is None: - return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): - return False - return True - async def _receive_loop(self) -> None: - async with self._incoming_message_stream_writer: - await super()._receive_loop() - - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): + async def _received_request( + self, responder: RequestResponder[types.ClientRequest, types.ServerResult] + ): match responder.request.root: case types.InitializeRequest(params=params): - requested_version = params.protocolVersion self._initialization_state = InitializationState.Initializing self._client_params = params with responder: await responder.respond( types.ServerResult( types.InitializeResult( - protocolVersion=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, + protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, serverInfo=types.Implementation( name=self._init_options.server_name, version=self._init_options.server_version, - websiteUrl=self._init_options.website_url, - icons=self._init_options.icons, ), instructions=self._init_options.instructions, ) ) ) - self._initialization_state = InitializationState.Initialized - case types.PingRequest(): - # Ping requests are allowed at any time - pass case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError("Received request before initialization was complete") + raise RuntimeError( + "Received request before initialization was complete" + ) - async def _received_notification(self, notification: types.ClientNotification) -> None: + async def _received_notification( + self, notification: types.ClientNotification + ) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() + match notification.root: case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: - if self._initialization_state != InitializationState.Initialized: # pragma: no cover - raise RuntimeError("Received notification before initialization was complete") + if self._initialization_state != InitializationState.Initialized: + raise RuntimeError( + "Received notification before initialization was complete" + ) async def send_log_message( self, level: types.LoggingLevel, data: Any, logger: str | None = None, - related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( types.ServerNotification( types.LoggingMessageNotification( + method="notifications/message", params=types.LoggingMessageNotificationParams( level=level, data=data, logger=logger, ), ) - ), - related_request_id, + ) ) - async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ServerNotification( types.ResourceUpdatedNotification( + method="notifications/resources/updated", params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) ) - @overload async def create_message( self, messages: list[types.SamplingMessage], @@ -245,205 +246,46 @@ async def create_message( stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, model_preferences: types.ModelPreferences | None = None, - tools: None = None, - tool_choice: types.ToolChoice | None = None, - related_request_id: types.RequestId | None = None, ) -> types.CreateMessageResult: - """Overload: Without tools, returns single content.""" - ... - - @overload - async def create_message( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool], - tool_choice: types.ToolChoice | None = None, - related_request_id: types.RequestId | None = None, - ) -> types.CreateMessageResultWithTools: - """Overload: With tools, returns array-capable content.""" - ... - - async def create_message( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_request_id: types.RequestId | None = None, - ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: - """Send a sampling/create_message request. - - Args: - messages: The conversation messages to send. - max_tokens: Maximum number of tokens to generate. - system_prompt: Optional system prompt. - include_context: Optional context inclusion setting. - Should only be set to "thisServer" or "allServers" - if the client has sampling.context capability. - temperature: Optional sampling temperature. - stop_sequences: Optional stop sequences. - metadata: Optional metadata to pass through to the LLM provider. - model_preferences: Optional model selection preferences. - tools: Optional list of tools the LLM can use during sampling. - Requires client to have sampling.tools capability. - tool_choice: Optional control over tool usage behavior. - Requires client to have sampling.tools capability. - related_request_id: Optional ID of a related request. - - Returns: - The sampling result from the client. - - Raises: - McpError: If tools are provided but client doesn't support them. - ValueError: If tool_use or tool_result message structure is invalid. - """ - client_caps = self._client_params.capabilities if self._client_params else None - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - request = types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - systemPrompt=system_prompt, - includeContext=include_context, - temperature=temperature, - maxTokens=max_tokens, - stopSequences=stop_sequences, - metadata=metadata, - modelPreferences=model_preferences, - tools=tools, - toolChoice=tool_choice, - ), - ) - ) - metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) - - # Use different result types based on whether tools are provided - if tools is not None: - return await self.send_request( - request=request, - result_type=types.CreateMessageResultWithTools, - metadata=metadata_obj, - ) - return await self.send_request( - request=request, - result_type=types.CreateMessageResult, - metadata=metadata_obj, - ) - - async def list_roots(self) -> types.ListRootsResult: - """Send a roots/list request.""" - return await self.send_request( - types.ServerRequest(types.ListRootsRequest()), - types.ListRootsResult, - ) - - async def elicit( - self, - message: str, - requestedSchema: types.ElicitRequestedSchema, - related_request_id: types.RequestId | None = None, - ) -> types.ElicitResult: - """Send a form mode elicitation/create request. - - Args: - message: The message to present to the user - requestedSchema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation - - Returns: - The client's response - - Note: - This method is deprecated in favor of elicit_form(). It remains for - backward compatibility but new code should use elicit_form(). - """ - return await self.elicit_form(message, requestedSchema, related_request_id) - - async def elicit_form( - self, - message: str, - requestedSchema: types.ElicitRequestedSchema, - related_request_id: types.RequestId | None = None, - ) -> types.ElicitResult: - """Send a form mode elicitation/create request. - - Args: - message: The message to present to the user - requestedSchema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation - - Returns: - The client's response with form data - """ + """Send a sampling/createMessage request.""" return await self.send_request( types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requestedSchema=requestedSchema, + types.CreateMessageRequest( + method="sampling/createMessage", + params=types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, ), ) ), - types.ElicitResult, - metadata=ServerMessageMetadata(related_request_id=related_request_id), + types.CreateMessageResult, ) - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - related_request_id: types.RequestId | None = None, - ) -> types.ElicitResult: - """Send a URL mode elicitation/create request. - - This directs the user to an external URL for out-of-band interactions - like OAuth flows, credential collection, or payment processing. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_request_id: Optional ID of the request that triggered this elicitation - - Returns: - The client's response indicating acceptance, decline, or cancellation - """ + + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" return await self.send_request( types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestURLParams( - message=message, - url=url, - elicitationId=elicitation_id, - ), + types.ListRootsRequest( + method="roots/list", ) ), - types.ElicitResult, - metadata=ServerMessageMetadata(related_request_id=related_request_id), + types.ListRootsResult, ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( - types.ServerRequest(types.PingRequest()), + types.ServerRequest( + types.PingRequest( + method="ping", + ) + ), types.EmptyResult, ) @@ -452,235 +294,51 @@ async def send_progress_notification( progress_token: str | int, progress: float, total: float | None = None, - message: str | None = None, - related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( types.ServerNotification( types.ProgressNotification( + method="notifications/progress", params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, - message=message, ), ) - ), - related_request_id, + ) ) - async def send_resource_list_changed(self) -> None: # pragma: no cover + async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" - await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) - - async def send_tool_list_changed(self) -> None: # pragma: no cover - """Send a tool list changed notification.""" - await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) - - async def send_prompt_list_changed(self) -> None: # pragma: no cover - """Send a prompt list changed notification.""" - await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) - - async def send_elicit_complete( - self, - elicitation_id: str, - related_request_id: types.RequestId | None = None, - ) -> None: - """Send an elicitation completion notification. - - This should be sent when a URL mode elicitation has been completed - out-of-band to inform the client that it may retry any requests - that were waiting for this elicitation. - - Args: - elicitation_id: The unique identifier of the completed elicitation - related_request_id: Optional ID of the request that triggered this - """ await self.send_notification( types.ServerNotification( - types.ElicitCompleteNotification( - params=types.ElicitCompleteNotificationParams(elicitationId=elicitation_id) + types.ResourceListChangedNotification( + method="notifications/resources/list_changed", ) - ), - related_request_id, - ) - - def _build_elicit_form_request( - self, - message: str, - requestedSchema: types.ElicitRequestedSchema, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a form mode elicitation request without sending it. - - Args: - message: The message to present to the user - requestedSchema: Schema defining the expected response structure - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestFormParams( - message=message, - requestedSchema=requestedSchema, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no cover - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - taskId=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, + ) ) - def _build_elicit_url_request( - self, - message: str, - url: str, - elicitation_id: str, - related_task_id: str | None = None, - ) -> types.JSONRPCRequest: - """Build a URL mode elicitation request without sending it. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestURLParams( - message=message, - url=url, - elicitationId=elicitation_id, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no cover - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - taskId=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + await self.send_notification( + types.ServerNotification( + types.ToolListChangedNotification( + method="notifications/tools/list_changed", + ) + ) ) - def _build_create_message_request( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a sampling/createMessage request without sending it. - - Args: - messages: The conversation messages to send - max_tokens: Maximum number of tokens to generate - system_prompt: Optional system prompt - include_context: Optional context inclusion setting - temperature: Optional sampling temperature - stop_sequences: Optional stop sequences - metadata: Optional metadata to pass through to the LLM provider - model_preferences: Optional model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.CreateMessageRequestParams( - messages=messages, - systemPrompt=system_prompt, - includeContext=include_context, - temperature=temperature, - maxTokens=max_tokens, - stopSequences=stop_sequences, - metadata=metadata, - modelPreferences=model_preferences, - tools=tools, - toolChoice=tool_choice, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no cover - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - taskId=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="sampling/createMessage", - params=params_data, + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + await self.send_notification( + types.ServerNotification( + types.PromptListChangedNotification( + method="notifications/prompts/list_changed", + ) + ) ) - async def send_message(self, message: SessionMessage) -> None: - """Send a raw session message. - - This is primarily used by TaskResultHandler to deliver queued messages - (elicitation/sampling requests) to the client during task execution. - - WARNING: This is a low-level experimental method that may change without - notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. - - Args: - message: The session message to send - """ - await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) @@ -688,4 +346,4 @@ async def _handle_incoming(self, req: ServerRequestResponder) -> None: def incoming_messages( self, ) -> MemoryObjectReceiveStream[ServerRequestResponder]: - return self._incoming_message_stream_reader + return self._incoming_message_stream_reader \ No newline at end of file diff --git a/src/mcp/types.py b/src/mcp/types.py index 7a46ad620..5784edbe2 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1146,10 +1146,6 @@ class ToolResultContent(BaseModel): SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent """Content block types allowed in sampling messages.""" -SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent -"""Basic content types for sampling responses (without tool use). -Used for backwards-compatible CreateMessageResult when tools are not used.""" - class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" @@ -1547,27 +1543,7 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server. - - This is the backwards-compatible version that returns single content (no arrays). - Used when the request does not include tools. - """ - - role: Role - """The role of the message sender (typically 'assistant' for LLM responses).""" - content: SamplingContent - """Response content. Single content block (text, image, or audio).""" - model: str - """The name of the model that generated the message.""" - stopReason: StopReason | None = None - """The reason why sampling stopped, if known.""" - - -class CreateMessageResultWithTools(Result): - """The client's response to a sampling/create_message request when tools were provided. - - This version supports array content for tool use flows. - """ + """The client's response to a sampling/create_message request from the server.""" role: Role """The role of the message sender (typically 'assistant' for LLM responses).""" @@ -1996,3 +1972,67 @@ class ServerNotification(RootModel[ServerNotificationType]): class ServerResult(RootModel[ServerResultType]): pass + + +# --- Checkpoint protocol extensions ----------------------------------------- + +class CheckpointHandle(BaseModel): + """Opaque checkpoint handle returned by servers.""" + handle: str + digest: str + ttlSeconds: int + + +class CheckpointCreateParams(BaseModel): + """Params for checkpoint/create. + + For v1 you can keep this empty – the server infers the session + from transport/session context – but we define it for forward compat. + """ + # Optional: allow tools to tag a logical name + label: str | None = None + + +class CheckpointCreateResult(BaseModel): + """Result of checkpoint/create.""" + handle: str + digest: str + ttlSeconds: int + + +class CheckpointValidateParams(BaseModel): + """Params for checkpoint/validate.""" + handle: str + expectedDigest: str | None = None + + +class CheckpointValidateResult(BaseModel): + """Result of checkpoint/validate.""" + valid: bool + remainingTtlSeconds: int + digestMatch: bool + + +class CheckpointResumeParams(BaseModel): + """Params for checkpoint/resume.""" + handle: str + + +class CheckpointResumeResult(BaseModel): + """Result of checkpoint/resume. + + You can expand this later if you want to + surface metadata to the client. + """ + resumed: bool + handle: str + + +class CheckpointDeleteParams(BaseModel): + """Params for checkpoint/delete.""" + handle: str + + +class CheckpointDeleteResult(BaseModel): + """Result of checkpoint/delete.""" + deleted: bool \ No newline at end of file From 1741b1dc558017919f216baa8062f001f3dd1fff Mon Sep 17 00:00:00 2001 From: "Swathi Murali(sm3223)" Date: Thu, 4 Dec 2025 19:38:20 -0500 Subject: [PATCH 2/5] draft --- src/mcp/__init__.py | 4 ++-- src/mcp/server/fastmcp/server.py | 8 ++++++++ src/mcp/server/lowlevel/server.py | 7 +++++++ src/mcp/server/session.py | 5 +++++ src/mcp/types.py | 1 + 5 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index fbec40d0a..bf0b78c54 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -13,7 +13,7 @@ CompleteRequest, CreateMessageRequest, CreateMessageResult, - CreateMessageResultWithTools, + #CreateMessageResultWithTools, ErrorData, GetPromptRequest, GetPromptResult, @@ -43,7 +43,7 @@ ResourceUpdatedNotification, RootsCapability, SamplingCapability, - SamplingContent, + #SamplingContent, SamplingContextCapability, SamplingMessage, SamplingMessageContentBlock, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f74b65557..93147ee3d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -64,6 +64,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.checkpoint import CheckpointBackend from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt @@ -173,6 +174,7 @@ def __init__( # noqa: PLR0913 lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None, auth: AuthSettings | None = None, transport_security: TransportSecuritySettings | None = None, + checkpoint_backend: CheckpointBackend | None = None, ): # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): @@ -230,6 +232,7 @@ def __init__( # noqa: PLR0913 if auth_server_provider and not token_verifier: # pragma: no cover self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._event_store = event_store + self._checkpoint_backend = checkpoint_backend self._retry_interval = retry_interval self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies @@ -275,6 +278,11 @@ def session_manager(self) -> StreamableHTTPSessionManager: "to avoid unnecessary initialization." ) return self._session_manager # pragma: no cover + + @property + def checkpoint_backend(self) -> CheckpointBackend | None: + """Return the checkpoint backend (if any) attached to this server.""" + return self._checkpoint_backend def run( self, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index e29c021b7..6f79ca9cf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -94,6 +94,7 @@ async def main(): from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.tool_name_validation import validate_and_warn_tool_name +from mcp.server.checkpoint import CheckpointBackend logger = logging.getLogger(__name__) @@ -146,6 +147,9 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + *, + stateless: bool = False, + checkpoint_backend: CheckpointBackend | None = None, ): self.name = name self.version = version @@ -159,6 +163,8 @@ def __init__( self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} self._experimental_handlers: ExperimentalHandlers | None = None + self._stateless = stateless + self._checkpoint_backend = checkpoint_backend logger.debug("Initializing server %r", name) def create_initialization_options( @@ -650,6 +656,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + checkpoint_backend=self._checkpoint_backend, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 05eadc05a..63ec51c1b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -90,6 +90,8 @@ def __init__( # NEW: optional checkpoint backend for servers that support # checkpoint/create, checkpoint/validate, checkpoint/resume, # checkpoint/delete at a higher level (e.g., FastMCP). + *, + stateless: bool = False, checkpoint_backend: "CheckpointBackend | None" = None, ) -> None: super().__init__( @@ -110,6 +112,9 @@ def __init__( lambda: self._incoming_message_stream_writer.aclose() ) + # Preserve original stateless behaviour (if anything uses it) + self._stateless = stateless + # Store optional checkpoint backend. This does not itself handle any # checkpoint RPCs; it just makes the backend available to higher-level # server code via ctx.session.checkpoint_backend. diff --git a/src/mcp/types.py b/src/mcp/types.py index 5784edbe2..3ed3421bb 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import deprecated +from typing import Any """ Model Context Protocol bindings for Python From 373ead42b8be3dc3dd08436f411f158113682455 Mon Sep 17 00:00:00 2001 From: "Swathi Murali(sm3223)" Date: Thu, 4 Dec 2025 22:45:40 -0500 Subject: [PATCH 3/5] edit session.py --- src/mcp/server/session.py | 597 ++++++++++++++++++++++++++++++-------- src/mcp/types.py | 27 +- 2 files changed, 496 insertions(+), 128 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 63ec51c1b..ae0153801 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,11 +1,12 @@ -"""ServerSession Module +""" +ServerSession Module -This module provides the ServerSession class, which manages communication -between the server and client in the MCP (Model Context Protocol) framework. -It is most commonly used in MCP servers to interact with the client. +This module provides the ServerSession class, which manages communication between the +server and client in the MCP (Model Context Protocol) framework. It is most commonly +used in MCP servers to interact with the client. Common usage pattern: - +``` server = Server(name) @server.call_tool() @@ -19,6 +20,7 @@ async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> An else: # Fall back to basic tool operations result = await perform_basic_tool_operation(arguments) + return result @server.list_prompts() @@ -29,13 +31,14 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: return generate_custom_prompts(ctx.session.client_params) else: return default_prompts +``` -The ServerSession class is typically used internally by the Server class -and should not be instantiated directly by users of the MCP framework. +The ServerSession class is typically used internally by the Server class and should not +be instantiated directly by users of the MCP framework. """ from enum import Enum -from typing import Any, TypeVar, TYPE_CHECKING +from typing import Any, TypeVar, overload, TYPE_CHECKING import anyio import anyio.lowlevel @@ -43,19 +46,20 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from pydantic import AnyUrl import mcp.types as types +from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions +from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.experimental.tasks.capabilities import check_tasks_capability +from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, ) - -# Only import the checkpoint backend for type checking to avoid runtime -# dependencies or circular imports. The actual implementation lives in -# mcp.server.checkpoint. -if TYPE_CHECKING: # pragma: no cover +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +if TYPE_CHECKING: from mcp.server.checkpoint import CheckpointBackend - class InitializationState(Enum): NotInitialized = 1 Initializing = 2 @@ -63,10 +67,9 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") + ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception + RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception ) @@ -81,69 +84,54 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None + _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, - # NEW: optional checkpoint backend for servers that support - # checkpoint/create, checkpoint/validate, checkpoint/resume, - # checkpoint/delete at a higher level (e.g., FastMCP). - *, stateless: bool = False, checkpoint_backend: "CheckpointBackend | None" = None, ) -> None: - super().__init__( - read_stream, - write_stream, - types.ClientRequest, - types.ClientNotification, - ) - self._initialization_state = InitializationState.NotInitialized - self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_reader.aclose() - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + self._initialization_state = ( + InitializationState.Initialized if stateless else InitializationState.NotInitialized ) - # Preserve original stateless behaviour (if anything uses it) - self._stateless = stateless - - # Store optional checkpoint backend. This does not itself handle any - # checkpoint RPCs; it just makes the backend available to higher-level - # server code via ctx.session.checkpoint_backend. + self._init_options = init_options + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) self._checkpoint_backend = checkpoint_backend @property def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params + return self._client_params # pragma: no cover @property - def checkpoint_backend(self) -> "CheckpointBackend | None": - """Optional checkpoint backend attached to this session. + def experimental(self) -> ExperimentalServerSessionFeatures: + """Experimental APIs for server→client task operations. - FastMCP or custom servers can use this to implement the - checkpoint/create, checkpoint/validate, checkpoint/resume, - checkpoint/delete methods by delegating into the backend and using - this session as the context. + WARNING: These APIs are experimental and may change without notice. """ + if self._experimental_features is None: + self._experimental_features = ExperimentalServerSessionFeatures(self) + return self._experimental_features + + @property + def checkpoint_backend(self) -> "CheckpointBackend | None": + """Optional checkpoint backend attached to this session.""" return self._checkpoint_backend - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover """Check if the client supports a specific capability.""" if self._client_params is None: return False - # Get client capabilities from initialization params client_caps = self._client_params.capabilities - # Check each specified capability in the passed in capability object if capability.roots is not None: if client_caps.roots is None: return False @@ -153,93 +141,107 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if capability.sampling is not None: if client_caps.sampling is None: return False + if capability.sampling.context is not None and client_caps.sampling.context is None: + return False + if capability.sampling.tools is not None and client_caps.sampling.tools is None: + return False + + if capability.elicitation is not None and client_caps.elicitation is None: + return False if capability.experimental is not None: if client_caps.experimental is None: return False - # Check each experimental capability for exp_key, exp_value in capability.experimental.items(): - if ( - exp_key not in client_caps.experimental - or client_caps.experimental[exp_key] != exp_value - ): + if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False + if capability.tasks is not None: + if client_caps.tasks is None: + return False + if not check_tasks_capability(capability.tasks, client_caps.tasks): + return False + return True - async def _received_request( - self, responder: RequestResponder[types.ClientRequest, types.ServerResult] - ): + async def _receive_loop(self) -> None: + async with self._incoming_message_stream_writer: + await super()._receive_loop() + + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): match responder.request.root: case types.InitializeRequest(params=params): + requested_version = params.protocolVersion self._initialization_state = InitializationState.Initializing self._client_params = params with responder: await responder.respond( types.ServerResult( types.InitializeResult( - protocolVersion=types.LATEST_PROTOCOL_VERSION, + protocolVersion=requested_version + if requested_version in SUPPORTED_PROTOCOL_VERSIONS + else types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, serverInfo=types.Implementation( name=self._init_options.server_name, version=self._init_options.server_version, + websiteUrl=self._init_options.website_url, + icons=self._init_options.icons, ), instructions=self._init_options.instructions, ) ) ) + self._initialization_state = InitializationState.Initialized + case types.PingRequest(): + # Ping requests are allowed at any time + pass case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received request before initialization was complete" - ) + raise RuntimeError("Received request before initialization was complete") - async def _received_notification( - self, notification: types.ClientNotification - ) -> None: + async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() - match notification.root: case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: - if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received notification before initialization was complete" - ) + if self._initialization_state != InitializationState.Initialized: # pragma: no cover + raise RuntimeError("Received notification before initialization was complete") async def send_log_message( self, level: types.LoggingLevel, data: Any, logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( types.ServerNotification( types.LoggingMessageNotification( - method="notifications/message", params=types.LoggingMessageNotificationParams( level=level, data=data, logger=logger, ), ) - ) + ), + related_request_id, ) - async def send_resource_updated(self, uri: AnyUrl) -> None: + async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( types.ServerNotification( types.ResourceUpdatedNotification( - method="notifications/resources/updated", params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) ) + @overload async def create_message( self, messages: list[types.SamplingMessage], @@ -251,46 +253,205 @@ async def create_message( stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, model_preferences: types.ModelPreferences | None = None, + tools: None = None, + tool_choice: types.ToolChoice | None = None, + related_request_id: types.RequestId | None = None, ) -> types.CreateMessageResult: - """Send a sampling/createMessage request.""" + """Overload: Without tools, returns single content.""" + ... + + @overload + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool], + tool_choice: types.ToolChoice | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResultWithTools: + """Overload: With tools, returns array-capable content.""" + ... + + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: + """Send a sampling/create_message request. + + Args: + messages: The conversation messages to send. + max_tokens: Maximum number of tokens to generate. + system_prompt: Optional system prompt. + include_context: Optional context inclusion setting. + Should only be set to "thisServer" or "allServers" + if the client has sampling.context capability. + temperature: Optional sampling temperature. + stop_sequences: Optional stop sequences. + metadata: Optional metadata to pass through to the LLM provider. + model_preferences: Optional model selection preferences. + tools: Optional list of tools the LLM can use during sampling. + Requires client to have sampling.tools capability. + tool_choice: Optional control over tool usage behavior. + Requires client to have sampling.tools capability. + related_request_id: Optional ID of a related request. + + Returns: + The sampling result from the client. + + Raises: + McpError: If tools are provided but client doesn't support them. + ValueError: If tool_use or tool_result message structure is invalid. + """ + client_caps = self._client_params.capabilities if self._client_params else None + validate_sampling_tools(client_caps, tools, tool_choice) + validate_tool_use_result_messages(messages) + + request = types.ServerRequest( + types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + tools=tools, + toolChoice=tool_choice, + ), + ) + ) + metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) + + # Use different result types based on whether tools are provided + if tools is not None: + return await self.send_request( + request=request, + result_type=types.CreateMessageResultWithTools, + metadata=metadata_obj, + ) return await self.send_request( - types.ServerRequest( - types.CreateMessageRequest( - method="sampling/createMessage", - params=types.CreateMessageRequestParams( - messages=messages, - systemPrompt=system_prompt, - includeContext=include_context, - temperature=temperature, - maxTokens=max_tokens, - stopSequences=stop_sequences, - metadata=metadata, - modelPreferences=model_preferences, - ), - ) - ), - types.CreateMessageResult, + request=request, + result_type=types.CreateMessageResult, + metadata=metadata_obj, ) - + async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" + return await self.send_request( + types.ServerRequest(types.ListRootsRequest()), + types.ListRootsResult, + ) + + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response + + Note: + This method is deprecated in favor of elicit_form(). It remains for + backward compatibility but new code should use elicit_form(). + """ + return await self.elicit_form(message, requestedSchema, related_request_id) + + async def elicit_form( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response with form data + """ return await self.send_request( types.ServerRequest( - types.ListRootsRequest( - method="roots/list", + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requestedSchema=requestedSchema, + ), ) ), - types.ListRootsResult, + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: - """Send a ping request.""" + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request. + + This directs the user to an external URL for out-of-band interactions + like OAuth flows, credential collection, or payment processing. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response indicating acceptance, decline, or cancellation + """ return await self.send_request( types.ServerRequest( - types.PingRequest( - method="ping", + types.ElicitRequest( + params=types.ElicitRequestURLParams( + message=message, + url=url, + elicitationId=elicitation_id, + ), ) ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + + async def send_ping(self) -> types.EmptyResult: # pragma: no cover + """Send a ping request.""" + return await self.send_request( + types.ServerRequest(types.PingRequest()), types.EmptyResult, ) @@ -299,51 +460,235 @@ async def send_progress_notification( progress_token: str | int, progress: float, total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( types.ServerNotification( types.ProgressNotification( - method="notifications/progress", params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, + message=message, ), ) - ) + ), + related_request_id, ) - async def send_resource_list_changed(self) -> None: + async def send_resource_list_changed(self) -> None: # pragma: no cover """Send a resource list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ResourceListChangedNotification( - method="notifications/resources/list_changed", - ) - ) - ) + await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) - async def send_tool_list_changed(self) -> None: + async def send_tool_list_changed(self) -> None: # pragma: no cover """Send a tool list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ToolListChangedNotification( - method="notifications/tools/list_changed", - ) - ) - ) + await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) - async def send_prompt_list_changed(self) -> None: + async def send_prompt_list_changed(self) -> None: # pragma: no cover """Send a prompt list changed notification.""" + await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + + async def send_elicit_complete( + self, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send an elicitation completion notification. + + This should be sent when a URL mode elicitation has been completed + out-of-band to inform the client that it may retry any requests + that were waiting for this elicitation. + + Args: + elicitation_id: The unique identifier of the completed elicitation + related_request_id: Optional ID of the request that triggered this + """ await self.send_notification( types.ServerNotification( - types.PromptListChangedNotification( - method="notifications/prompts/list_changed", + types.ElicitCompleteNotification( + params=types.ElicitCompleteNotificationParams(elicitationId=elicitation_id) ) - ) + ), + related_request_id, + ) + + def _build_elicit_form_request( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_task_id: str | None = None, + task: types.TaskMetadata | None = None, + ) -> types.JSONRPCRequest: + """Build a form mode elicitation request without sending it. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + task: If provided, makes this a task-augmented request + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.ElicitRequestFormParams( + message=message, + requestedSchema=requestedSchema, + task=task, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="elicitation/create", + params=params_data, ) + def _build_elicit_url_request( + self, + message: str, + url: str, + elicitation_id: str, + related_task_id: str | None = None, + ) -> types.JSONRPCRequest: + """Build a URL mode elicitation request without sending it. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.ElicitRequestURLParams( + message=message, + url=url, + elicitationId=elicitation_id, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="elicitation/create", + params=params_data, + ) + + def _build_create_message_request( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice | None = None, + related_task_id: str | None = None, + task: types.TaskMetadata | None = None, + ) -> types.JSONRPCRequest: + """Build a sampling/createMessage request without sending it. + + Args: + messages: The conversation messages to send + max_tokens: Maximum number of tokens to generate + system_prompt: Optional system prompt + include_context: Optional context inclusion setting + temperature: Optional sampling temperature + stop_sequences: Optional stop sequences + metadata: Optional metadata to pass through to the LLM provider + model_preferences: Optional model selection preferences + tools: Optional list of tools the LLM can use during sampling + tool_choice: Optional control over tool usage behavior + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + task: If provided, makes this a task-augmented request + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + tools=tools, + toolChoice=tool_choice, + task=task, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="sampling/createMessage", + params=params_data, + ) + + async def send_message(self, message: SessionMessage) -> None: + """Send a raw session message. + + This is primarily used by TaskResultHandler to deliver queued messages + (elicitation/sampling requests) to the client during task execution. + + WARNING: This is a low-level experimental method that may change without + notice. Prefer using higher-level methods like send_notification() or + send_request() for normal operations. + + Args: + message: The session message to send + """ + await self._write_stream.send(message) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/types.py b/src/mcp/types.py index 7af39fdfd..3be6b0e26 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import deprecated -from typing import Any """ Model Context Protocol bindings for Python @@ -1147,6 +1146,10 @@ class ToolResultContent(BaseModel): SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent """Content block types allowed in sampling messages.""" +SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent +"""Basic content types for sampling responses (without tool use). +Used for backwards-compatible CreateMessageResult when tools are not used.""" + class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" @@ -1544,7 +1547,27 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server.""" + """The client's response to a sampling/create_message request from the server. + + This is the backwards-compatible version that returns single content (no arrays). + Used when the request does not include tools. + """ + + role: Role + """The role of the message sender (typically 'assistant' for LLM responses).""" + content: SamplingContent + """Response content. Single content block (text, image, or audio).""" + model: str + """The name of the model that generated the message.""" + stopReason: StopReason | None = None + """The reason why sampling stopped, if known.""" + + +class CreateMessageResultWithTools(Result): + """The client's response to a sampling/create_message request when tools were provided. + + This version supports array content for tool use flows. + """ role: Role """The role of the message sender (typically 'assistant' for LLM responses).""" From dc6a24267bfe1f6e76530e00f325014519c93549 Mon Sep 17 00:00:00 2001 From: "Swathi Murali(sm3223)" Date: Thu, 4 Dec 2025 22:48:42 -0500 Subject: [PATCH 4/5] uncomment test changes --- src/mcp/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index bf0b78c54..fbec40d0a 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -13,7 +13,7 @@ CompleteRequest, CreateMessageRequest, CreateMessageResult, - #CreateMessageResultWithTools, + CreateMessageResultWithTools, ErrorData, GetPromptRequest, GetPromptResult, @@ -43,7 +43,7 @@ ResourceUpdatedNotification, RootsCapability, SamplingCapability, - #SamplingContent, + SamplingContent, SamplingContextCapability, SamplingMessage, SamplingMessageContentBlock, From b6a5ffe0def7825c57eb8fe485225b1d2f93126a Mon Sep 17 00:00:00 2001 From: "Swathi Murali(sm3223)" Date: Fri, 5 Dec 2025 09:14:01 -0500 Subject: [PATCH 5/5] fix fastmcp --- src/mcp/server/fastmcp/server.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 93147ee3d..b269542ec 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -204,6 +204,7 @@ def __init__( # noqa: PLR0913 transport_security=transport_security, ) + self._checkpoint_backend = checkpoint_backend self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, @@ -212,6 +213,7 @@ def __init__( # noqa: PLR0913 # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore + checkpoint_backend=self._checkpoint_backend, ) self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) @@ -232,7 +234,6 @@ def __init__( # noqa: PLR0913 if auth_server_provider and not token_verifier: # pragma: no cover self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._event_store = event_store - self._checkpoint_backend = checkpoint_backend self._retry_interval = retry_interval self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies @@ -278,11 +279,6 @@ def session_manager(self) -> StreamableHTTPSessionManager: "to avoid unnecessary initialization." ) return self._session_manager # pragma: no cover - - @property - def checkpoint_backend(self) -> CheckpointBackend | None: - """Return the checkpoint backend (if any) attached to this server.""" - return self._checkpoint_backend def run( self,