diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 9a9f93e1ff..8369d76913 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -14,7 +14,7 @@ from . import messages as _messages from ._instrumentation import InstrumentationNames from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UnknownToolNameRetry from .messages import ToolCallPart from .tools import ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool @@ -35,6 +35,8 @@ class ToolManager(Generic[AgentDepsT]): """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) """Names of tools that failed in this run step.""" + max_unknown_tool_retries: int = 1 + """Maximum number of times to retry after an unknown tool is proposed""" @classmethod @contextmanager @@ -146,7 +148,7 @@ async def _call_tool( msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}' else: msg = 'No tools available.' - raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + raise UnknownToolNameRetry(name, msg) if tool.tool_def.kind == 'external': raise RuntimeError('External tools cannot be called') @@ -176,7 +178,10 @@ async def _call_tool( return result except (ValidationError, ModelRetry) as e: - max_retries = tool.max_retries if tool is not None else 1 + if isinstance(e, UnknownToolNameRetry): + max_retries = self.max_unknown_tool_retries + else: + max_retries = tool.max_retries if tool is not None else 1 current_retry = self.ctx.retries.get(name, 0) if current_retry == max_retries: diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index c8208ac9e6..9234976355 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) + _max_unknown_tool_retries: int = dataclasses.field(repr=False) _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) @@ -318,6 +319,7 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._max_tool_retries = retries + self._max_unknown_tool_retries = retries self._validation_context = validation_context @@ -569,7 +571,7 @@ async def main(): output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - tool_manager = ToolManager[AgentDepsT](toolset) + tool_manager = ToolManager[AgentDepsT](toolset, max_unknown_tool_retries=self._max_unknown_tool_retries) # Build the graph graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 0b4500502c..7d3cd2f478 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -67,6 +67,13 @@ def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema ) +class UnknownToolNameRetry(ModelRetry): + """Exception to raise when a tool name is not recognized.""" + + def __init__(self, name: str, msg: str): + super().__init__(f'Unknown tool name: {name!r}. {msg}') + + class CallDeferred(Exception): """Exception to raise when a tool call should be deferred.