Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import messages as _messages
from ._instrumentation import InstrumentationNames
from ._run_context import AgentDepsT, RunContext
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UnknownToolNameRetry
from .messages import ToolCallPart
from .tools import ToolDefinition
from .toolsets.abstract import AbstractToolset, ToolsetTool
Expand All @@ -35,6 +35,8 @@ class ToolManager(Generic[AgentDepsT]):
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
"""Names of tools that failed in this run step."""
max_unknown_tool_retries: int = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of calling it default_max_retries?

"""Maximum number of times to retry after an unknown tool is proposed"""

@classmethod
@contextmanager
Expand Down Expand Up @@ -146,7 +148,7 @@ async def _call_tool(
msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}'
else:
msg = 'No tools available.'
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
raise UnknownToolNameRetry(name, msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to keep this as it was for now.


if tool.tool_def.kind == 'external':
raise RuntimeError('External tools cannot be called')
Expand Down Expand Up @@ -176,7 +178,10 @@ async def _call_tool(

return result
except (ValidationError, ModelRetry) as e:
max_retries = tool.max_retries if tool is not None else 1
if isinstance(e, UnknownToolNameRetry):
max_retries = self.max_unknown_tool_retries
else:
max_retries = tool.max_retries if tool is not None else 1
current_retry = self.ctx.retries.get(name, 0)

if current_retry == max_retries:
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_max_tool_retries: int = dataclasses.field(repr=False)
_max_unknown_tool_retries: int = dataclasses.field(repr=False)
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
Expand Down Expand Up @@ -318,6 +319,7 @@ def __init__(

self._max_result_retries = output_retries if output_retries is not None else retries
self._max_tool_retries = retries
self._max_unknown_tool_retries = retries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify this a bit by dropping this variable as well


self._validation_context = validation_context

Expand Down Expand Up @@ -569,7 +571,7 @@ async def main():
output_toolset.max_retries = self._max_result_retries
output_toolset.output_validators = output_validators
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
tool_manager = ToolManager[AgentDepsT](toolset)
tool_manager = ToolManager[AgentDepsT](toolset, max_unknown_tool_retries=self._max_unknown_tool_retries)

# Build the graph
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema
)


class UnknownToolNameRetry(ModelRetry):
"""Exception to raise when a tool name is not recognized."""

def __init__(self, name: str, msg: str):
super().__init__(f'Unknown tool name: {name!r}. {msg}')


class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.

Expand Down