44from dataclasses import dataclass , field
55from datetime import timezone
66from functools import cached_property
7- from typing import Any , cast
7+ from typing import Any , TypeVar , cast
88
99import pytest
1010from inline_snapshot import snapshot
2525from pydantic_ai .settings import ModelSettings
2626
2727from ..conftest import IsNow , try_import
28+ from .mock_async_stream import MockAsyncStream
2829
2930with try_import () as imports_successful :
3031 from anthropic import NOT_GIVEN , AsyncAnthropic
3132 from anthropic .types import (
3233 ContentBlock ,
34+ InputJSONDelta ,
3335 Message as AnthropicMessage ,
36+ MessageDeltaUsage ,
37+ RawContentBlockDeltaEvent ,
38+ RawContentBlockStartEvent ,
39+ RawContentBlockStopEvent ,
40+ RawMessageDeltaEvent ,
41+ RawMessageStartEvent ,
42+ RawMessageStopEvent ,
43+ RawMessageStreamEvent ,
3444 TextBlock ,
3545 ToolUseBlock ,
3646 Usage as AnthropicUsage ,
3747 )
48+ from anthropic .types .raw_message_delta_event import Delta
3849
3950 from pydantic_ai .models .anthropic import AnthropicModel
4051
4354 pytest .mark .anyio ,
4455]
4556
57+ # Type variable for generic AsyncStream
58+ T = TypeVar ('T' )
59+
4660
4761def test_init ():
4862 m = AnthropicModel ('claude-3-5-haiku-latest' , api_key = 'foobar' )
@@ -53,6 +67,7 @@ def test_init():
5367@dataclass
5468class MockAnthropic :
5569 messages_ : AnthropicMessage | list [AnthropicMessage ] | None = None
70+ stream : list [RawMessageStreamEvent ] | list [list [RawMessageStreamEvent ]] | None = None
5671 index = 0
5772 chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
5873
@@ -64,14 +79,31 @@ def messages(self) -> Any:
6479 def create_mock (cls , messages_ : AnthropicMessage | list [AnthropicMessage ]) -> AsyncAnthropic :
6580 return cast (AsyncAnthropic , cls (messages_ = messages_ ))
6681
67- async def messages_create (self , * _args : Any , ** kwargs : Any ) -> AnthropicMessage :
82+ @classmethod
83+ def create_stream_mock (
84+ cls , stream : list [RawMessageStreamEvent ] | list [list [RawMessageStreamEvent ]]
85+ ) -> AsyncAnthropic :
86+ return cast (AsyncAnthropic , cls (stream = stream ))
87+
88+ async def messages_create (
89+ self , * _args : Any , stream : bool = False , ** kwargs : Any
90+ ) -> AnthropicMessage | MockAsyncStream [RawMessageStreamEvent ]:
6891 self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
6992
70- assert self .messages_ is not None , '`messages` must be provided'
71- if isinstance (self .messages_ , list ):
72- response = self .messages_ [self .index ]
93+ if stream :
94+ assert self .stream is not None , 'you can only use `stream=True` if `stream` is provided'
95+ # noinspection PyUnresolvedReferences
96+ if isinstance (self .stream [0 ], list ):
97+ indexed_stream = cast (list [RawMessageStreamEvent ], self .stream [self .index ])
98+ response = MockAsyncStream (iter (indexed_stream ))
99+ else :
100+ response = MockAsyncStream (iter (cast (list [RawMessageStreamEvent ], self .stream )))
73101 else :
74- response = self .messages_
102+ assert self .messages_ is not None , '`messages` must be provided'
103+ if isinstance (self .messages_ , list ):
104+ response = self .messages_ [self .index ]
105+ else :
106+ response = self .messages_
75107 self .index += 1
76108 return response
77109
@@ -298,3 +330,112 @@ async def get_location(loc_name: str) -> str:
298330 assert get_mock_chat_completion_kwargs (mock_client )[0 ]['tool_choice' ]['disable_parallel_tool_use' ] == (
299331 not parallel_tool_calls
300332 )
333+
334+
335+ async def test_stream_structured (allow_model_requests : None ):
336+ """Test streaming structured responses with Anthropic's API.
337+
338+ This test simulates how Anthropic streams tool calls:
339+ 1. Message start
340+ 2. Tool block start with initial data
341+ 3. Tool block delta with additional data
342+ 4. Tool block stop
343+ 5. Update usage
344+ 6. Message stop
345+ """
346+ stream : list [RawMessageStreamEvent ] = [
347+ RawMessageStartEvent (
348+ type = 'message_start' ,
349+ message = AnthropicMessage (
350+ id = 'msg_123' ,
351+ model = 'claude-3-5-haiku-latest' ,
352+ role = 'assistant' ,
353+ type = 'message' ,
354+ content = [],
355+ stop_reason = None ,
356+ usage = AnthropicUsage (input_tokens = 20 , output_tokens = 0 ),
357+ ),
358+ ),
359+ # Start tool block with initial data
360+ RawContentBlockStartEvent (
361+ type = 'content_block_start' ,
362+ index = 0 ,
363+ content_block = ToolUseBlock (type = 'tool_use' , id = 'tool_1' , name = 'my_tool' , input = {'first' : 'One' }),
364+ ),
365+ # Add more data through an incomplete JSON delta
366+ RawContentBlockDeltaEvent (
367+ type = 'content_block_delta' ,
368+ index = 0 ,
369+ delta = InputJSONDelta (type = 'input_json_delta' , partial_json = '{"second":' ),
370+ ),
371+ RawContentBlockDeltaEvent (
372+ type = 'content_block_delta' ,
373+ index = 0 ,
374+ delta = InputJSONDelta (type = 'input_json_delta' , partial_json = '"Two"}' ),
375+ ),
376+ # Mark tool block as complete
377+ RawContentBlockStopEvent (type = 'content_block_stop' , index = 0 ),
378+ # Update the top-level message with usage
379+ RawMessageDeltaEvent (
380+ type = 'message_delta' ,
381+ delta = Delta (
382+ stop_reason = 'end_turn' ,
383+ ),
384+ usage = MessageDeltaUsage (
385+ output_tokens = 5 ,
386+ ),
387+ ),
388+ # Mark message as complete
389+ RawMessageStopEvent (type = 'message_stop' ),
390+ ]
391+
392+ done_stream : list [RawMessageStreamEvent ] = [
393+ RawMessageStartEvent (
394+ type = 'message_start' ,
395+ message = AnthropicMessage (
396+ id = 'msg_123' ,
397+ model = 'claude-3-5-haiku-latest' ,
398+ role = 'assistant' ,
399+ type = 'message' ,
400+ content = [],
401+ stop_reason = None ,
402+ usage = AnthropicUsage (input_tokens = 0 , output_tokens = 0 ),
403+ ),
404+ ),
405+ # Text block with final data
406+ RawContentBlockStartEvent (
407+ type = 'content_block_start' ,
408+ index = 0 ,
409+ content_block = TextBlock (type = 'text' , text = 'FINAL_PAYLOAD' ),
410+ ),
411+ RawContentBlockStopEvent (type = 'content_block_stop' , index = 0 ),
412+ RawMessageStopEvent (type = 'message_stop' ),
413+ ]
414+
415+ mock_client = MockAnthropic .create_stream_mock ([stream , done_stream ])
416+ m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
417+ agent = Agent (m )
418+
419+ tool_called = False
420+
421+ @agent .tool_plain
422+ async def my_tool (first : str , second : str ) -> int :
423+ nonlocal tool_called
424+ tool_called = True
425+ return len (first ) + len (second )
426+
427+ async with agent .run_stream ('' ) as result :
428+ assert not result .is_complete
429+ chunks = [c async for c in result .stream (debounce_by = None )]
430+
431+ # The tool output doesn't echo any content to the stream, so we only get the final payload once when
432+ # the block starts and once when it ends.
433+ assert chunks == snapshot (
434+ [
435+ 'FINAL_PAYLOAD' ,
436+ 'FINAL_PAYLOAD' ,
437+ ]
438+ )
439+ assert result .is_complete
440+ assert result .usage () == snapshot (Usage (requests = 2 , request_tokens = 20 , response_tokens = 5 , total_tokens = 25 ))
441+ assert tool_called
0 commit comments