diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index b23da276e2..532b7d7dcd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -74,6 +74,7 @@ BetaCodeExecutionToolResultBlockContent, BetaCodeExecutionToolResultBlockParam, BetaCodeExecutionToolResultBlockParamContentParam, + BetaContainerParams, BetaContentBlock, BetaContentBlockParam, BetaImageBlockParam, @@ -200,6 +201,16 @@ class AnthropicModelSettings(ModelSettings, total=False): See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information. """ + anthropic_container: BetaContainerParams | Literal[False] + """Container configuration for multi-turn conversations. + + By default, if previous messages contain a container_id (from a prior response), + it will be reused automatically. + + Set to `False` to force a fresh container (ignore any container_id from history). + Set to a dict (e.g. `{'id': 'container_xxx'}`) to explicitly specify a container. + """ + @dataclass(init=False) class AnthropicModel(Model): @@ -385,6 +396,7 @@ async def _messages_create( output_format = self._native_output_format(model_request_parameters) betas, extra_headers = self._get_betas_and_extra_headers(tools, model_request_parameters, model_settings) betas.update(builtin_tool_betas) + container = self._get_container(messages, model_settings) try: return await self.client.beta.messages.create( max_tokens=model_settings.get('max_tokens', 4096), @@ -403,6 +415,7 @@ async def _messages_create( top_p=model_settings.get('top_p', OMIT), timeout=model_settings.get('timeout', NOT_GIVEN), metadata=model_settings.get('anthropic_metadata', OMIT), + container=container or OMIT, extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) @@ -439,6 +452,18 @@ def _get_betas_and_extra_headers( return betas, extra_headers + def _get_container( + self, messages: list[ModelMessage], model_settings: AnthropicModelSettings + ) -> BetaContainerParams | None: + """Get container config for the API request.""" + if (container := model_settings.get('anthropic_container')) is not None: + return None if container is False else container + for m in reversed(messages): + if isinstance(m, ModelResponse) and m.provider_name == self.system and m.provider_details: + if cid := m.provider_details.get('container_id'): + return BetaContainerParams(id=cid) + return None + async def _messages_count_tokens( self, messages: list[ModelMessage], @@ -526,6 +551,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: if raw_finish_reason := response.stop_reason: # pragma: no branch provider_details = {'finish_reason': raw_finish_reason} finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) + if response.container: + provider_details = provider_details or {} + provider_details['container_id'] = response.container.id return ModelResponse( parts=items, @@ -1125,6 +1153,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(event, BetaRawMessageStartEvent): self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name) self.provider_response_id = event.message.id + if event.message.container: + self.provider_details = self.provider_details or {} + self.provider_details['container_id'] = event.message.container.id elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block @@ -1239,7 +1270,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawMessageDeltaEvent): self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage) if raw_finish_reason := event.delta.stop_reason: # pragma: no branch - self.provider_details = {'finish_reason': raw_finish_reason} + self.provider_details = {**(self.provider_details or {}), 'finish_reason': raw_finish_reason} self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index ad65735d38..2212587ff8 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -89,11 +89,14 @@ BetaRawMessageStreamEvent, BetaServerToolUseBlock, BetaTextBlock, + BetaTextDelta, BetaToolUseBlock, BetaUsage, BetaWebSearchResultBlock, BetaWebSearchToolResultBlock, ) + from anthropic.types.beta.beta_container import BetaContainer + from anthropic.types.beta.beta_container_params import BetaContainerParams from anthropic.types.beta.beta_raw_message_delta_event import Delta from pydantic_ai.models.anthropic import ( @@ -169,9 +172,7 @@ async def messages_create( if isinstance(self.stream[0], Sequence): response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream[self.index]))) else: - response = MockAsyncStream( # pragma: no cover - iter(cast(list[MockRawMessageStreamEvent], self.stream)) - ) + response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream))) else: assert self.messages_ is not None, '`messages` must be provided' if isinstance(self.messages_, Sequence): @@ -5512,7 +5513,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop model_name='claude-sonnet-4-20250514', timestamp=IsDatetime(), provider_name='anthropic', - provider_details={'finish_reason': 'end_turn'}, + provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwceSoRxi8Pf16Fb7Tn'}, provider_response_id='msg_018bVTPr9khzuds31rFDuqW4', finish_reason='stop', run_id=IsStr(), @@ -5579,7 +5580,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop model_name='claude-sonnet-4-20250514', timestamp=IsDatetime(), provider_name='anthropic', - provider_details={'finish_reason': 'end_turn'}, + provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwdXe48NC7LaX3rxQ4d'}, provider_response_id='msg_01VngRFBcNddwrYQoKUmdePY', finish_reason='stop', run_id=IsStr(), @@ -7858,3 +7859,127 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant assert usage2.cache_read_tokens > 0 assert usage2.cache_write_tokens > 0 assert usage2.output_tokens > 0 + + +async def test_anthropic_container_setting_explicit(allow_model_requests: None): + """Test that anthropic_container setting passes explicit container config to API.""" + c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + # Test with explicit container config + await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_container={'id': 'container_abc123'})) + + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert completion_kwargs['container'] == BetaContainerParams(id='container_abc123') + + +async def test_anthropic_container_from_message_history(allow_model_requests: None): + """Test that container_id from message history is passed to subsequent requests.""" + c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock([c, c]) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + # Create a message history with a container_id in provider_details + history: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse( + parts=[TextPart(content='world')], + provider_name='anthropic', + provider_details={'container_id': 'container_from_history'}, + ), + ] + + # Run with the message history + await agent.run('follow up', message_history=history) + + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert completion_kwargs['container'] == BetaContainerParams(id='container_from_history') + + +async def test_anthropic_container_setting_false_ignores_history(allow_model_requests: None): + """Test that anthropic_container=False ignores container_id from history.""" + c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + # Create a message history with a container_id + history: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse( + parts=[TextPart(content='world')], + provider_name='anthropic', + provider_details={'container_id': 'container_should_be_ignored'}, + ), + ] + + # Run with anthropic_container=False to force fresh container + await agent.run( + 'follow up', message_history=history, model_settings=AnthropicModelSettings(anthropic_container=False) + ) + + completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + # When anthropic_container=False, container should be OMIT (filtered out before sending to API) + from anthropic import omit as OMIT + + assert completion_kwargs.get('container') is OMIT + + +async def test_anthropic_container_id_from_stream_response(allow_model_requests: None): + """Test that container_id is extracted from streamed response and stored in provider_details.""" + from datetime import datetime + + stream_events: list[BetaRawMessageStreamEvent] = [ + BetaRawMessageStartEvent( + type='message_start', + message=BetaMessage( + id='msg_123', + content=[], + model='claude-3-5-haiku-123', + role='assistant', + stop_reason=None, + type='message', + usage=BetaUsage(input_tokens=5, output_tokens=0), + container=BetaContainer( + id='container_from_stream', + expires_at=datetime(2025, 1, 1, 0, 0, 0), + ), + ), + ), + BetaRawContentBlockStartEvent( + type='content_block_start', + index=0, + content_block=BetaTextBlock(text='', type='text'), + ), + BetaRawContentBlockDeltaEvent( + type='content_block_delta', + index=0, + delta=BetaTextDelta(type='text_delta', text='hello'), + ), + BetaRawContentBlockStopEvent(type='content_block_stop', index=0), + BetaRawMessageDeltaEvent( + type='message_delta', + delta=Delta(stop_reason='end_turn', stop_sequence=None), + usage=BetaMessageDeltaUsage(output_tokens=5), + ), + BetaRawMessageStopEvent(type='message_stop'), + ] + + mock_client = MockAnthropic.create_stream_mock(stream_events) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('hello') as result: + response = await result.get_output() + assert response == 'hello' + + # Check that container_id was captured in the response + messages = result.all_messages() + model_response = messages[-1] + assert isinstance(model_response, ModelResponse) + assert model_response.provider_details is not None + assert model_response.provider_details.get('container_id') == 'container_from_stream' + assert model_response.provider_details.get('finish_reason') == 'end_turn'