Skip to content

Commit 84fdd7a

Browse files
authored
Support tool and resource caching for MCP servers that support change notifications (#3560)
1 parent 212f935 commit 84fdd7a

File tree

4 files changed

+311
-19
lines changed

4 files changed

+311
-19
lines changed

pydantic_ai_slim/pydantic_ai/_mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def add_msg(
9191
'user',
9292
mcp_types.ImageContent(
9393
type='image',
94-
data=base64.b64decode(chunk.data).decode(),
94+
data=base64.b64encode(chunk.data).decode(),
9595
mimeType=chunk.media_type,
9696
),
9797
)

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 115 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from mcp.shared import exceptions as mcp_exceptions
3737
from mcp.shared.context import RequestContext
3838
from mcp.shared.message import SessionMessage
39+
from mcp.shared.session import RequestResponder
3940
except ImportError as _import_error:
4041
raise ImportError(
4142
'Please install the `mcp` package to use the MCP server, '
@@ -226,12 +227,21 @@ class ServerCapabilities:
226227
prompts: bool = False
227228
"""Whether the server offers any prompt templates."""
228229

230+
prompts_list_changed: bool = False
231+
"""Whether the server will emit notifications when the list of prompts changes."""
232+
229233
resources: bool = False
230234
"""Whether the server offers any resources to read."""
231235

236+
resources_list_changed: bool = False
237+
"""Whether the server will emit notifications when the list of resources changes."""
238+
232239
tools: bool = False
233240
"""Whether the server offers any tools to call."""
234241

242+
tools_list_changed: bool = False
243+
"""Whether the server will emit notifications when the list of tools changes."""
244+
235245
completions: bool = False
236246
"""Whether the server offers autocompletion suggestions for prompts and resources."""
237247

@@ -244,12 +254,18 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC
244254
Args:
245255
mcp_capabilities: The MCP SDK ServerCapabilities object.
246256
"""
257+
prompts_cap = mcp_capabilities.prompts
258+
resources_cap = mcp_capabilities.resources
259+
tools_cap = mcp_capabilities.tools
247260
return cls(
248261
experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None,
249262
logging=mcp_capabilities.logging is not None,
250-
prompts=mcp_capabilities.prompts is not None,
251-
resources=mcp_capabilities.resources is not None,
252-
tools=mcp_capabilities.tools is not None,
263+
prompts=prompts_cap is not None,
264+
prompts_list_changed=bool(prompts_cap.listChanged) if prompts_cap else False,
265+
resources=resources_cap is not None,
266+
resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False,
267+
tools=tools_cap is not None,
268+
tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False,
253269
completions=mcp_capabilities.completions is not None,
254270
)
255271

@@ -319,6 +335,26 @@ class MCPServer(AbstractToolset[Any], ABC):
319335
elicitation_callback: ElicitationFnT | None = None
320336
"""Callback function to handle elicitation requests from the server."""
321337

338+
cache_tools: bool
339+
"""Whether to cache the list of tools.
340+
341+
When enabled (default), tools are fetched once and cached until either:
342+
- The server sends a `notifications/tools/list_changed` notification
343+
- The connection is closed
344+
345+
Set to `False` for servers that change tools dynamically without sending notifications.
346+
"""
347+
348+
cache_resources: bool
349+
"""Whether to cache the list of resources.
350+
351+
When enabled (default), resources are fetched once and cached until either:
352+
- The server sends a `notifications/resources/list_changed` notification
353+
- The connection is closed
354+
355+
Set to `False` for servers that change resources dynamically without sending notifications.
356+
"""
357+
322358
_id: str | None
323359

324360
_enter_lock: Lock = field(compare=False)
@@ -332,6 +368,9 @@ class MCPServer(AbstractToolset[Any], ABC):
332368
_server_capabilities: ServerCapabilities
333369
_instructions: str | None
334370

371+
_cached_tools: list[mcp_types.Tool] | None
372+
_cached_resources: list[Resource] | None
373+
335374
def __init__(
336375
self,
337376
tool_prefix: str | None = None,
@@ -344,6 +383,8 @@ def __init__(
344383
sampling_model: models.Model | None = None,
345384
max_retries: int = 1,
346385
elicitation_callback: ElicitationFnT | None = None,
386+
cache_tools: bool = True,
387+
cache_resources: bool = True,
347388
*,
348389
id: str | None = None,
349390
):
@@ -357,6 +398,8 @@ def __init__(
357398
self.sampling_model = sampling_model
358399
self.max_retries = max_retries
359400
self.elicitation_callback = elicitation_callback
401+
self.cache_tools = cache_tools
402+
self.cache_resources = cache_resources
360403

361404
self._id = id or tool_prefix
362405

@@ -366,6 +409,8 @@ def __post_init__(self):
366409
self._enter_lock = Lock()
367410
self._running_count = 0
368411
self._exit_stack = None
412+
self._cached_tools = None
413+
self._cached_resources = None
369414

370415
@abstractmethod
371416
@asynccontextmanager
@@ -430,13 +475,22 @@ def instructions(self) -> str | None:
430475
async def list_tools(self) -> list[mcp_types.Tool]:
431476
"""Retrieve tools that are currently active on the server.
432477
433-
Note:
434-
- We don't cache tools as they might change.
435-
- We also don't subscribe to the server to avoid complexity.
478+
Tools are cached by default, with cache invalidation on:
479+
- `notifications/tools/list_changed` notifications from the server
480+
- Connection close (cache is cleared in `__aexit__`)
481+
482+
Set `cache_tools=False` for servers that change tools without sending notifications.
436483
"""
437-
async with self: # Ensure server is running
438-
result = await self._client.list_tools()
439-
return result.tools
484+
async with self:
485+
if self.cache_tools:
486+
if self._cached_tools is not None:
487+
return self._cached_tools
488+
result = await self._client.list_tools()
489+
self._cached_tools = result.tools
490+
return result.tools
491+
else:
492+
result = await self._client.list_tools()
493+
return result.tools
440494

441495
async def direct_call_tool(
442496
self,
@@ -542,21 +596,31 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
542596
async def list_resources(self) -> list[Resource]:
543597
"""Retrieve resources that are currently present on the server.
544598
545-
Note:
546-
- We don't cache resources as they might change.
547-
- We also don't subscribe to resource changes to avoid complexity.
599+
Resources are cached by default, with cache invalidation on:
600+
- `notifications/resources/list_changed` notifications from the server
601+
- Connection close (cache is cleared in `__aexit__`)
602+
603+
Set `cache_resources=False` for servers that change resources without sending notifications.
548604
549605
Raises:
550606
MCPError: If the server returns an error.
551607
"""
552-
async with self: # Ensure server is running
608+
async with self:
553609
if not self.capabilities.resources:
554610
return []
555611
try:
556-
result = await self._client.list_resources()
612+
if self.cache_resources:
613+
if self._cached_resources is not None:
614+
return self._cached_resources
615+
result = await self._client.list_resources()
616+
resources = [Resource.from_mcp_sdk(r) for r in result.resources]
617+
self._cached_resources = resources
618+
return resources
619+
else:
620+
result = await self._client.list_resources()
621+
return [Resource.from_mcp_sdk(r) for r in result.resources]
557622
except mcp_exceptions.McpError as e:
558623
raise MCPError.from_mcp_sdk(e) from e
559-
return [Resource.from_mcp_sdk(r) for r in result.resources]
560624

561625
async def list_resource_templates(self) -> list[ResourceTemplate]:
562626
"""Retrieve resource templates that are currently present on the server.
@@ -628,6 +692,7 @@ async def __aenter__(self) -> Self:
628692
elicitation_callback=self.elicitation_callback,
629693
logging_callback=self.log_handler,
630694
read_timeout_seconds=timedelta(seconds=self.read_timeout),
695+
message_handler=self._handle_notification,
631696
)
632697
self._client = await exit_stack.enter_async_context(client)
633698

@@ -651,6 +716,8 @@ async def __aexit__(self, *args: Any) -> bool | None:
651716
if self._running_count == 0 and self._exit_stack is not None:
652717
await self._exit_stack.aclose()
653718
self._exit_stack = None
719+
self._cached_tools = None
720+
self._cached_resources = None
654721

655722
@property
656723
def is_running(self) -> bool:
@@ -680,6 +747,19 @@ async def _sampling_callback(
680747
model=self.sampling_model.model_name,
681748
)
682749

750+
async def _handle_notification(
751+
self,
752+
message: RequestResponder[mcp_types.ServerRequest, mcp_types.ClientResult]
753+
| mcp_types.ServerNotification
754+
| Exception,
755+
) -> None:
756+
"""Handle notifications from the MCP server, invalidating caches as needed."""
757+
if isinstance(message, mcp_types.ServerNotification): # pragma: no branch
758+
if isinstance(message.root, mcp_types.ToolListChangedNotification):
759+
self._cached_tools = None
760+
elif isinstance(message.root, mcp_types.ResourceListChangedNotification):
761+
self._cached_resources = None
762+
683763
async def _map_tool_result_part(
684764
self, part: mcp_types.ContentBlock
685765
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
@@ -776,6 +856,8 @@ class MCPServerStdio(MCPServer):
776856
sampling_model: models.Model | None
777857
max_retries: int
778858
elicitation_callback: ElicitationFnT | None = None
859+
cache_tools: bool
860+
cache_resources: bool
779861

780862
def __init__(
781863
self,
@@ -794,6 +876,8 @@ def __init__(
794876
sampling_model: models.Model | None = None,
795877
max_retries: int = 1,
796878
elicitation_callback: ElicitationFnT | None = None,
879+
cache_tools: bool = True,
880+
cache_resources: bool = True,
797881
id: str | None = None,
798882
):
799883
"""Build a new MCP server.
@@ -813,6 +897,10 @@ def __init__(
813897
sampling_model: The model to use for sampling.
814898
max_retries: The maximum number of times to retry a tool call.
815899
elicitation_callback: Callback function to handle elicitation requests from the server.
900+
cache_tools: Whether to cache the list of tools.
901+
See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools].
902+
cache_resources: Whether to cache the list of resources.
903+
See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources].
816904
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
817905
"""
818906
self.command = command
@@ -831,6 +919,8 @@ def __init__(
831919
sampling_model,
832920
max_retries,
833921
elicitation_callback,
922+
cache_tools,
923+
cache_resources,
834924
id=id,
835925
)
836926

@@ -930,6 +1020,8 @@ class _MCPServerHTTP(MCPServer):
9301020
sampling_model: models.Model | None
9311021
max_retries: int
9321022
elicitation_callback: ElicitationFnT | None = None
1023+
cache_tools: bool
1024+
cache_resources: bool
9331025

9341026
def __init__(
9351027
self,
@@ -948,6 +1040,8 @@ def __init__(
9481040
sampling_model: models.Model | None = None,
9491041
max_retries: int = 1,
9501042
elicitation_callback: ElicitationFnT | None = None,
1043+
cache_tools: bool = True,
1044+
cache_resources: bool = True,
9511045
**_deprecated_kwargs: Any,
9521046
):
9531047
"""Build a new MCP server.
@@ -967,6 +1061,10 @@ def __init__(
9671061
sampling_model: The model to use for sampling.
9681062
max_retries: The maximum number of times to retry a tool call.
9691063
elicitation_callback: Callback function to handle elicitation requests from the server.
1064+
cache_tools: Whether to cache the list of tools.
1065+
See [`MCPServer.cache_tools`][pydantic_ai.mcp.MCPServer.cache_tools].
1066+
cache_resources: Whether to cache the list of resources.
1067+
See [`MCPServer.cache_resources`][pydantic_ai.mcp.MCPServer.cache_resources].
9701068
"""
9711069
if 'sse_read_timeout' in _deprecated_kwargs:
9721070
if read_timeout is not None:
@@ -997,6 +1095,8 @@ def __init__(
9971095
sampling_model,
9981096
max_retries,
9991097
elicitation_callback,
1098+
cache_tools,
1099+
cache_resources,
10001100
id=id,
10011101
)
10021102

tests/mcp_server.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,19 @@ async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> s
235235
return f'User {result.action}ed the elicitation'
236236

237237

238+
async def hidden_tool() -> str:
239+
"""A tool that is hidden by default."""
240+
return 'I was hidden!'
241+
242+
243+
@mcp.tool()
244+
async def enable_hidden_tool(ctx: Context[ServerSession, None]) -> str:
245+
"""Enable the hidden tool, triggering a ToolListChangedNotification."""
246+
mcp._tool_manager.add_tool(hidden_tool) # pyright: ignore[reportPrivateUsage]
247+
await ctx.session.send_tool_list_changed()
248+
return 'Hidden tool enabled'
249+
250+
238251
@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
239252
async def set_logging_level(level: str) -> None:
240253
global log_level

0 commit comments

Comments
 (0)