Skip to content

Commit 49d6b40

Browse files
committed
feat: add a global timeout
1 parent 25b62b7 commit 49d6b40

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
147147
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
148148
_max_result_retries: int = dataclasses.field(repr=False)
149149
_max_tool_retries: int = dataclasses.field(repr=False)
150+
_tool_timeout: float | None = dataclasses.field(repr=False)
150151
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)
151152

152153
_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
@@ -179,6 +180,7 @@ def __init__(
179180
instrument: InstrumentationSettings | bool | None = None,
180181
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
181182
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
183+
tool_timeout: float | None = None,
182184
) -> None: ...
183185

184186
@overload
@@ -206,6 +208,7 @@ def __init__(
206208
instrument: InstrumentationSettings | bool | None = None,
207209
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
208210
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
211+
tool_timeout: float | None = None,
209212
) -> None: ...
210213

211214
def __init__(
@@ -231,6 +234,7 @@ def __init__(
231234
instrument: InstrumentationSettings | bool | None = None,
232235
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
233236
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
237+
tool_timeout: float | None = None,
234238
**_deprecated_kwargs: Any,
235239
):
236240
"""Create an agent.
@@ -285,6 +289,9 @@ def __init__(
285289
Each processor takes a list of messages and returns a modified list of messages.
286290
Processors can be sync or async and are applied in sequence.
287291
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools.
292+
tool_timeout: Default timeout in seconds for tool execution. If a tool takes longer than this,
293+
a retry prompt is returned to the model. Individual tools can override this with their own timeout.
294+
Defaults to None (no timeout).
288295
"""
289296
if model is None or defer_model_check:
290297
self._model = model
@@ -318,6 +325,7 @@ def __init__(
318325

319326
self._max_result_retries = output_retries if output_retries is not None else retries
320327
self._max_tool_retries = retries
328+
self._tool_timeout = tool_timeout
321329

322330
self._validation_context = validation_context
323331

@@ -331,7 +339,10 @@ def __init__(
331339
self._output_toolset.max_retries = self._max_result_retries
332340

333341
self._function_toolset = _AgentFunctionToolset(
334-
tools, max_retries=self._max_tool_retries, output_schema=self._output_schema
342+
tools,
343+
max_retries=self._max_tool_retries,
344+
default_timeout=self._tool_timeout,
345+
output_schema=self._output_schema,
335346
)
336347
self._dynamic_toolsets = [
337348
DynamicToolset[AgentDepsT](toolset_func=toolset)
@@ -1101,7 +1112,7 @@ async def spam(ctx: RunContext[str], y: float) -> float:
11011112
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
11021113
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
11031114
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
1104-
Defaults to None (no timeout).
1115+
Overrides the agent-level `tool_timeout` if set. Defaults to None (no timeout).
11051116
"""
11061117

11071118
def tool_decorator(
@@ -1217,7 +1228,7 @@ async def spam(ctx: RunContext[str]) -> float:
12171228
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
12181229
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
12191230
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
1220-
Defaults to None (no timeout).
1231+
Overrides the agent-level `tool_timeout` if set. Defaults to None (no timeout).
12211232
"""
12221233

12231234
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
@@ -1414,7 +1425,10 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
14141425

14151426
if some_tools := self._override_tools.get():
14161427
function_toolset = _AgentFunctionToolset(
1417-
some_tools.value, max_retries=self._max_tool_retries, output_schema=self._output_schema
1428+
some_tools.value,
1429+
max_retries=self._max_tool_retries,
1430+
default_timeout=self._tool_timeout,
1431+
output_schema=self._output_schema,
14181432
)
14191433
else:
14201434
function_toolset = self._function_toolset
@@ -1521,11 +1535,12 @@ def __init__(
15211535
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
15221536
*,
15231537
max_retries: int = 1,
1538+
default_timeout: float | None = None,
15241539
id: str | None = None,
15251540
output_schema: _output.OutputSchema[Any],
15261541
):
15271542
self.output_schema = output_schema
1528-
super().__init__(tools, max_retries=max_retries, id=id)
1543+
super().__init__(tools, max_retries=max_retries, default_timeout=default_timeout, id=id)
15291544

15301545
@property
15311546
def id(self) -> str:

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
3737

3838
tools: dict[str, Tool[Any]]
3939
max_retries: int
40+
default_timeout: float | None
4041
_id: str | None
4142
docstring_format: DocstringFormat
4243
require_parameter_descriptions: bool
@@ -47,6 +48,7 @@ def __init__(
4748
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
4849
*,
4950
max_retries: int = 1,
51+
default_timeout: float | None = None,
5052
docstring_format: DocstringFormat = 'auto',
5153
require_parameter_descriptions: bool = False,
5254
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
@@ -62,6 +64,9 @@ def __init__(
6264
tools: The tools to add to the toolset.
6365
max_retries: The maximum number of retries for each tool during a run.
6466
Applies to all tools, unless overridden when adding a tool.
67+
default_timeout: Default timeout in seconds for tool execution. If a tool takes longer than this,
68+
a retry prompt is returned to the model. Individual tools can override this with their own timeout.
69+
Defaults to None (no timeout).
6570
docstring_format: Format of tool docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
6671
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
6772
Applies to all tools, unless overridden when adding a tool.
@@ -82,6 +87,7 @@ def __init__(
8287
in which case the ID will be used to identify the toolset's activities within the workflow.
8388
"""
8489
self.max_retries = max_retries
90+
self.default_timeout = default_timeout
8591
self._id = id
8692
self.docstring_format = docstring_format
8793
self.require_parameter_descriptions = require_parameter_descriptions
@@ -360,7 +366,8 @@ async def call_tool(
360366
) -> Any:
361367
assert isinstance(tool, FunctionToolsetTool)
362368

363-
timeout = tool.timeout
369+
# Per-tool timeout takes precedence over default timeout
370+
timeout = tool.timeout if tool.timeout is not None else self.default_timeout
364371
if timeout is not None:
365372
try:
366373
with anyio.fail_after(timeout):

tests/test_tools.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,3 +2677,83 @@ async def always_slow_tool() -> str:
26772677

26782678
with pytest.raises(UnexpectedModelBehavior, match='exceeded max retries'):
26792679
await agent.run('call always_slow_tool')
2680+
2681+
2682+
@pytest.mark.anyio
2683+
async def test_agent_level_tool_timeout():
2684+
"""Test that agent-level tool_timeout applies to all tools."""
2685+
import asyncio
2686+
2687+
call_count = 0
2688+
2689+
async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
2690+
nonlocal call_count
2691+
call_count += 1
2692+
if call_count == 1:
2693+
return ModelResponse(parts=[ToolCallPart(tool_name='slow_tool', args={}, tool_call_id='call-1')])
2694+
return ModelResponse(parts=[TextPart(content='done')])
2695+
2696+
# Set global tool_timeout on Agent
2697+
agent = Agent(FunctionModel(model_logic), tool_timeout=0.1)
2698+
2699+
@agent.tool_plain
2700+
async def slow_tool() -> str:
2701+
await asyncio.sleep(1.0) # 1 second, but agent timeout is 0.1s
2702+
return 'done' # pragma: no cover
2703+
2704+
result = await agent.run('call slow_tool')
2705+
2706+
# Check that retry prompt was sent
2707+
retry_parts = [
2708+
part
2709+
for msg in result.all_messages()
2710+
if isinstance(msg, ModelRequest)
2711+
for part in msg.parts
2712+
if isinstance(part, RetryPromptPart) and 'Timed out' in str(part.content)
2713+
]
2714+
assert len(retry_parts) == 1
2715+
assert 'Timed out after 0.1 seconds' in retry_parts[0].content
2716+
2717+
2718+
@pytest.mark.anyio
2719+
async def test_per_tool_timeout_overrides_agent_timeout():
2720+
"""Test that per-tool timeout overrides agent-level timeout."""
2721+
import asyncio
2722+
2723+
call_count = 0
2724+
2725+
async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
2726+
nonlocal call_count
2727+
call_count += 1
2728+
if call_count == 1:
2729+
return ModelResponse(parts=[ToolCallPart(tool_name='fast_timeout_tool', args={}, tool_call_id='call-1')])
2730+
return ModelResponse(parts=[TextPart(content='done')])
2731+
2732+
# Agent has generous 10s timeout, but per-tool timeout is only 0.1s
2733+
agent = Agent(FunctionModel(model_logic), tool_timeout=10.0)
2734+
2735+
@agent.tool_plain(timeout=0.1) # Per-tool timeout overrides agent timeout
2736+
async def fast_timeout_tool() -> str:
2737+
await asyncio.sleep(1.0) # 1 second, per-tool timeout is 0.1s
2738+
return 'done' # pragma: no cover
2739+
2740+
result = await agent.run('call fast_timeout_tool')
2741+
2742+
# Should timeout because per-tool timeout (0.1s) is applied, not agent timeout (10s)
2743+
retry_parts = [
2744+
part
2745+
for msg in result.all_messages()
2746+
if isinstance(msg, ModelRequest)
2747+
for part in msg.parts
2748+
if isinstance(part, RetryPromptPart) and 'Timed out' in str(part.content)
2749+
]
2750+
assert len(retry_parts) == 1
2751+
assert 'Timed out after 0.1 seconds' in retry_parts[0].content
2752+
2753+
2754+
def test_agent_tool_timeout_passed_to_toolset():
2755+
"""Test that agent-level tool_timeout is passed to FunctionToolset as default_timeout."""
2756+
agent = Agent(TestModel(), tool_timeout=30.0)
2757+
2758+
# The agent's tool_timeout should be passed to the toolset as default_timeout
2759+
assert agent._function_toolset.default_timeout == 30.0

0 commit comments

Comments
 (0)