Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
BetaCodeExecutionToolResultBlockContent,
BetaCodeExecutionToolResultBlockParam,
BetaCodeExecutionToolResultBlockParamContentParam,
BetaContainerParams,
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
Expand Down Expand Up @@ -200,6 +201,16 @@ class AnthropicModelSettings(ModelSettings, total=False):
See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information.
"""

anthropic_container: BetaContainerParams | Literal[False]
"""Container configuration for multi-turn conversations.

By default, if previous messages contain a container_id (from a prior response),
it will be reused automatically.

Set to `False` to force a fresh container (ignore any container_id from history).
Set to a dict (e.g. `{'id': 'container_xxx'}`) to explicitly specify a container.
"""


@dataclass(init=False)
class AnthropicModel(Model):
Expand Down Expand Up @@ -385,6 +396,7 @@ async def _messages_create(
output_format = self._native_output_format(model_request_parameters)
betas, extra_headers = self._get_betas_and_extra_headers(tools, model_request_parameters, model_settings)
betas.update(builtin_tool_betas)
container = self._get_container(messages, model_settings)
try:
return await self.client.beta.messages.create(
max_tokens=model_settings.get('max_tokens', 4096),
Expand All @@ -403,6 +415,7 @@ async def _messages_create(
top_p=model_settings.get('top_p', OMIT),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', OMIT),
container=container or OMIT,
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
)
Expand Down Expand Up @@ -439,6 +452,18 @@ def _get_betas_and_extra_headers(

return betas, extra_headers

def _get_container(
self, messages: list[ModelMessage], model_settings: AnthropicModelSettings
) -> BetaContainerParams | None:
"""Get container config for the API request."""
if (container := model_settings.get('anthropic_container')) is not None:
return None if container is False else container
for m in reversed(messages):
if isinstance(m, ModelResponse) and m.provider_details:
if cid := m.provider_details.get('container_id'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify that the provider_name matches self.system as well

return BetaContainerParams(id=cid)
return None

async def _messages_count_tokens(
self,
messages: list[ModelMessage],
Expand Down Expand Up @@ -526,6 +551,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
if raw_finish_reason := response.stop_reason: # pragma: no branch
provider_details = {'finish_reason': raw_finish_reason}
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if response.container:
provider_details = provider_details or {}
provider_details['container_id'] = response.container.id

return ModelResponse(
parts=items,
Expand Down Expand Up @@ -1125,6 +1153,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if isinstance(event, BetaRawMessageStartEvent):
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
self.provider_response_id = event.message.id
if event.message.container:
self.provider_details = self.provider_details or {}
self.provider_details['container_id'] = event.message.container.id

elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
Expand Down Expand Up @@ -1239,7 +1270,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawMessageDeltaEvent):
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.provider_details = {**(self.provider_details or {}), 'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch
Expand Down
133 changes: 128 additions & 5 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,14 @@
BetaRawMessageStreamEvent,
BetaServerToolUseBlock,
BetaTextBlock,
BetaTextDelta,
BetaToolUseBlock,
BetaUsage,
BetaWebSearchResultBlock,
BetaWebSearchToolResultBlock,
)
from anthropic.types.beta.beta_container import BetaContainer
from anthropic.types.beta.beta_container_params import BetaContainerParams
from anthropic.types.beta.beta_raw_message_delta_event import Delta

from pydantic_ai.models.anthropic import (
Expand Down Expand Up @@ -169,9 +172,7 @@ async def messages_create(
if isinstance(self.stream[0], Sequence):
response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream[self.index])))
else:
response = MockAsyncStream( # pragma: no cover
iter(cast(list[MockRawMessageStreamEvent], self.stream))
)
response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream)))
else:
assert self.messages_ is not None, '`messages` must be provided'
if isinstance(self.messages_, Sequence):
Expand Down Expand Up @@ -5512,7 +5513,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop
model_name='claude-sonnet-4-20250514',
timestamp=IsDatetime(),
provider_name='anthropic',
provider_details={'finish_reason': 'end_turn'},
provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwceSoRxi8Pf16Fb7Tn'},
provider_response_id='msg_018bVTPr9khzuds31rFDuqW4',
finish_reason='stop',
run_id=IsStr(),
Expand Down Expand Up @@ -5579,7 +5580,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop
model_name='claude-sonnet-4-20250514',
timestamp=IsDatetime(),
provider_name='anthropic',
provider_details={'finish_reason': 'end_turn'},
provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwdXe48NC7LaX3rxQ4d'},
provider_response_id='msg_01VngRFBcNddwrYQoKUmdePY',
finish_reason='stop',
run_id=IsStr(),
Expand Down Expand Up @@ -7858,3 +7859,125 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant
assert usage2.cache_read_tokens > 0
assert usage2.cache_write_tokens > 0
assert usage2.output_tokens > 0


async def test_anthropic_container_setting_explicit(allow_model_requests: None):
"""Test that anthropic_container setting passes explicit container config to API."""
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

# Test with explicit container config
await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_container={'id': 'container_abc123'}))

completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
assert completion_kwargs['container'] == BetaContainerParams(id='container_abc123')


async def test_anthropic_container_from_message_history(allow_model_requests: None):
"""Test that container_id from message history is passed to subsequent requests."""
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock([c, c])
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

# Create a message history with a container_id in provider_details
history: list[ModelMessage] = [
ModelRequest(parts=[UserPromptPart(content='hello')]),
ModelResponse(
parts=[TextPart(content='world')],
provider_details={'container_id': 'container_from_history'},
),
]

# Run with the message history
await agent.run('follow up', message_history=history)

completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
assert completion_kwargs['container'] == BetaContainerParams(id='container_from_history')


async def test_anthropic_container_setting_false_ignores_history(allow_model_requests: None):
"""Test that anthropic_container=False ignores container_id from history."""
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

# Create a message history with a container_id
history: list[ModelMessage] = [
ModelRequest(parts=[UserPromptPart(content='hello')]),
ModelResponse(
parts=[TextPart(content='world')],
provider_details={'container_id': 'container_should_be_ignored'},
),
]

# Run with anthropic_container=False to force fresh container
await agent.run(
'follow up', message_history=history, model_settings=AnthropicModelSettings(anthropic_container=False)
)

completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
# When anthropic_container=False, container should be OMIT (filtered out before sending to API)
from anthropic import omit as OMIT

assert completion_kwargs.get('container') is OMIT


async def test_anthropic_container_id_from_stream_response(allow_model_requests: None):
"""Test that container_id is extracted from streamed response and stored in provider_details."""
from datetime import datetime

stream_events: list[BetaRawMessageStreamEvent] = [
BetaRawMessageStartEvent(
type='message_start',
message=BetaMessage(
id='msg_123',
content=[],
model='claude-3-5-haiku-123',
role='assistant',
stop_reason=None,
type='message',
usage=BetaUsage(input_tokens=5, output_tokens=0),
container=BetaContainer(
id='container_from_stream',
expires_at=datetime(2025, 1, 1, 0, 0, 0),
),
),
),
BetaRawContentBlockStartEvent(
type='content_block_start',
index=0,
content_block=BetaTextBlock(text='', type='text'),
),
BetaRawContentBlockDeltaEvent(
type='content_block_delta',
index=0,
delta=BetaTextDelta(type='text_delta', text='hello'),
),
BetaRawContentBlockStopEvent(type='content_block_stop', index=0),
BetaRawMessageDeltaEvent(
type='message_delta',
delta=Delta(stop_reason='end_turn', stop_sequence=None),
usage=BetaMessageDeltaUsage(output_tokens=5),
),
BetaRawMessageStopEvent(type='message_stop'),
]

mock_client = MockAnthropic.create_stream_mock(stream_events)
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

async with agent.run_stream('hello') as result:
response = await result.get_output()
assert response == 'hello'

# Check that container_id was captured in the response
messages = result.all_messages()
model_response = messages[-1]
assert isinstance(model_response, ModelResponse)
assert model_response.provider_details is not None
assert model_response.provider_details.get('container_id') == 'container_from_stream'
assert model_response.provider_details.get('finish_reason') == 'end_turn'