11from __future__ import annotations as _annotations
22
33import json
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from datetime import timezone
66from functools import cached_property
77from typing import Any , cast
2222 UserPromptPart ,
2323)
2424from pydantic_ai .result import Usage
25+ from pydantic_ai .settings import ModelSettings
2526
2627from ..conftest import IsNow , try_import
2728
2829with try_import () as imports_successful :
29- from anthropic import AsyncAnthropic
30+ from anthropic import NOT_GIVEN , AsyncAnthropic
3031 from anthropic .types import (
3132 ContentBlock ,
3233 Message as AnthropicMessage ,
@@ -53,6 +54,7 @@ def test_init():
5354class MockAnthropic :
5455 messages_ : AnthropicMessage | list [AnthropicMessage ] | None = None
5556 index = 0
57+ chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
5658
5759 @cached_property
5860 def messages (self ) -> Any :
@@ -62,7 +64,9 @@ def messages(self) -> Any:
6264 def create_mock (cls , messages_ : AnthropicMessage | list [AnthropicMessage ]) -> AsyncAnthropic :
6365 return cast (AsyncAnthropic , cls (messages_ = messages_ ))
6466
65- async def messages_create (self , * _args : Any , ** _kwargs : Any ) -> AnthropicMessage :
67+ async def messages_create (self , * _args : Any , ** kwargs : Any ) -> AnthropicMessage :
68+ self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
69+
6670 assert self .messages_ is not None , '`messages` must be provided'
6771 if isinstance (self .messages_ , list ):
6872 response = self .messages_ [self .index ]
@@ -257,3 +261,40 @@ async def get_location(loc_name: str) -> str:
257261 ),
258262 ]
259263 )
264+
265+
266+ def get_mock_chat_completion_kwargs (async_anthropic : AsyncAnthropic ) -> list [dict [str , Any ]]:
267+ if isinstance (async_anthropic , MockAnthropic ):
268+ return async_anthropic .chat_completion_kwargs
269+ else : # pragma: no cover
270+ raise RuntimeError ('Not a MockOpenAI instance' )
271+
272+
273+ @pytest .mark .parametrize ('parallel_tool_calls' , [True , False ])
274+ async def test_parallel_tool_calls (allow_model_requests : None , parallel_tool_calls : bool ) -> None :
275+ responses = [
276+ completion_message (
277+ [ToolUseBlock (id = '1' , input = {'loc_name' : 'San Francisco' }, name = 'get_location' , type = 'tool_use' )],
278+ usage = AnthropicUsage (input_tokens = 2 , output_tokens = 1 ),
279+ ),
280+ completion_message (
281+ [TextBlock (text = 'final response' , type = 'text' )],
282+ usage = AnthropicUsage (input_tokens = 3 , output_tokens = 5 ),
283+ ),
284+ ]
285+
286+ mock_client = MockAnthropic .create_mock (responses )
287+ m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
288+ agent = Agent (m , model_settings = ModelSettings (parallel_tool_calls = parallel_tool_calls ))
289+
290+ @agent .tool_plain
291+ async def get_location (loc_name : str ) -> str :
292+ if loc_name == 'London' :
293+ return json .dumps ({'lat' : 51 , 'lng' : 0 })
294+ else :
295+ raise ModelRetry ('Wrong location, please try again' )
296+
297+ await agent .run ('hello' )
298+ assert get_mock_chat_completion_kwargs (mock_client )[0 ]['tool_choice' ]['disable_parallel_tool_use' ] == (
299+ not parallel_tool_calls
300+ )
0 commit comments