Skip to content

Commit b713ad3

Browse files
committed
pass container id back
1 parent 3b6dc5e commit b713ad3

File tree

2 files changed

+160
-6
lines changed

2 files changed

+160
-6
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
BetaCodeExecutionToolResultBlockContent,
7575
BetaCodeExecutionToolResultBlockParam,
7676
BetaCodeExecutionToolResultBlockParamContentParam,
77+
BetaContainerParams,
7778
BetaContentBlock,
7879
BetaContentBlockParam,
7980
BetaImageBlockParam,
@@ -200,6 +201,16 @@ class AnthropicModelSettings(ModelSettings, total=False):
200201
See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching for more information.
201202
"""
202203

204+
anthropic_container: BetaContainerParams | Literal[False]
205+
"""Container configuration for multi-turn conversations.
206+
207+
By default, if previous messages contain a container_id (from a prior response),
208+
it will be reused automatically.
209+
210+
Set to `False` to force a fresh container (ignore any container_id from history).
211+
Set to a dict (e.g. `{'id': 'container_xxx'}`) to explicitly specify a container.
212+
"""
213+
203214

204215
@dataclass(init=False)
205216
class AnthropicModel(Model):
@@ -385,6 +396,7 @@ async def _messages_create(
385396
output_format = self._native_output_format(model_request_parameters)
386397
betas, extra_headers = self._get_betas_and_extra_headers(tools, model_request_parameters, model_settings)
387398
betas.update(builtin_tool_betas)
399+
container = self._get_container(messages, model_settings)
388400
try:
389401
return await self.client.beta.messages.create(
390402
max_tokens=model_settings.get('max_tokens', 4096),
@@ -403,6 +415,7 @@ async def _messages_create(
403415
top_p=model_settings.get('top_p', OMIT),
404416
timeout=model_settings.get('timeout', NOT_GIVEN),
405417
metadata=model_settings.get('anthropic_metadata', OMIT),
418+
container=container or OMIT,
406419
extra_headers=extra_headers,
407420
extra_body=model_settings.get('extra_body'),
408421
)
@@ -439,6 +452,18 @@ def _get_betas_and_extra_headers(
439452

440453
return betas, extra_headers
441454

455+
def _get_container(
456+
self, messages: list[ModelMessage], model_settings: AnthropicModelSettings
457+
) -> BetaContainerParams | None:
458+
"""Get container config for the API request."""
459+
if (container := model_settings.get('anthropic_container')) is not None:
460+
return None if container is False else container
461+
for m in reversed(messages):
462+
if isinstance(m, ModelResponse) and m.provider_details:
463+
if cid := m.provider_details.get('container_id'):
464+
return BetaContainerParams(id=cid)
465+
return None
466+
442467
async def _messages_count_tokens(
443468
self,
444469
messages: list[ModelMessage],
@@ -526,6 +551,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
526551
if raw_finish_reason := response.stop_reason: # pragma: no branch
527552
provider_details = {'finish_reason': raw_finish_reason}
528553
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
554+
if response.container:
555+
provider_details = provider_details or {}
556+
provider_details['container_id'] = response.container.id
529557

530558
return ModelResponse(
531559
parts=items,
@@ -1125,6 +1153,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
11251153
if isinstance(event, BetaRawMessageStartEvent):
11261154
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
11271155
self.provider_response_id = event.message.id
1156+
if event.message.container:
1157+
self.provider_details = self.provider_details or {}
1158+
self.provider_details['container_id'] = event.message.container.id
11281159

11291160
elif isinstance(event, BetaRawContentBlockStartEvent):
11301161
current_block = event.content_block
@@ -1239,7 +1270,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
12391270
elif isinstance(event, BetaRawMessageDeltaEvent):
12401271
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
12411272
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
1242-
self.provider_details = {'finish_reason': raw_finish_reason}
1273+
self.provider_details = {**(self.provider_details or {}), 'finish_reason': raw_finish_reason}
12431274
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
12441275

12451276
elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch

tests/models/test_anthropic.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,14 @@
8989
BetaRawMessageStreamEvent,
9090
BetaServerToolUseBlock,
9191
BetaTextBlock,
92+
BetaTextDelta,
9293
BetaToolUseBlock,
9394
BetaUsage,
9495
BetaWebSearchResultBlock,
9596
BetaWebSearchToolResultBlock,
9697
)
98+
from anthropic.types.beta.beta_container import BetaContainer
99+
from anthropic.types.beta.beta_container_params import BetaContainerParams
97100
from anthropic.types.beta.beta_raw_message_delta_event import Delta
98101

99102
from pydantic_ai.models.anthropic import (
@@ -169,9 +172,7 @@ async def messages_create(
169172
if isinstance(self.stream[0], Sequence):
170173
response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream[self.index])))
171174
else:
172-
response = MockAsyncStream( # pragma: no cover
173-
iter(cast(list[MockRawMessageStreamEvent], self.stream))
174-
)
175+
response = MockAsyncStream(iter(cast(list[MockRawMessageStreamEvent], self.stream)))
175176
else:
176177
assert self.messages_ is not None, '`messages` must be provided'
177178
if isinstance(self.messages_, Sequence):
@@ -5512,7 +5513,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop
55125513
model_name='claude-sonnet-4-20250514',
55135514
timestamp=IsDatetime(),
55145515
provider_name='anthropic',
5515-
provider_details={'finish_reason': 'end_turn'},
5516+
provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwceSoRxi8Pf16Fb7Tn'},
55165517
provider_response_id='msg_018bVTPr9khzuds31rFDuqW4',
55175518
finish_reason='stop',
55185519
run_id=IsStr(),
@@ -5579,7 +5580,7 @@ async def test_anthropic_code_execution_tool(allow_model_requests: None, anthrop
55795580
model_name='claude-sonnet-4-20250514',
55805581
timestamp=IsDatetime(),
55815582
provider_name='anthropic',
5582-
provider_details={'finish_reason': 'end_turn'},
5583+
provider_details={'finish_reason': 'end_turn', 'container_id': 'container_011CTCwdXe48NC7LaX3rxQ4d'},
55835584
provider_response_id='msg_01VngRFBcNddwrYQoKUmdePY',
55845585
finish_reason='stop',
55855586
run_id=IsStr(),
@@ -7858,3 +7859,125 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant
78587859
assert usage2.cache_read_tokens > 0
78597860
assert usage2.cache_write_tokens > 0
78607861
assert usage2.output_tokens > 0
7862+
7863+
7864+
async def test_anthropic_container_setting_explicit(allow_model_requests: None):
7865+
"""Test that anthropic_container setting passes explicit container config to API."""
7866+
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
7867+
mock_client = MockAnthropic.create_mock(c)
7868+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
7869+
agent = Agent(m)
7870+
7871+
# Test with explicit container config
7872+
await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_container={'id': 'container_abc123'}))
7873+
7874+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
7875+
assert completion_kwargs['container'] == BetaContainerParams(id='container_abc123')
7876+
7877+
7878+
async def test_anthropic_container_from_message_history(allow_model_requests: None):
7879+
"""Test that container_id from message history is passed to subsequent requests."""
7880+
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
7881+
mock_client = MockAnthropic.create_mock([c, c])
7882+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
7883+
agent = Agent(m)
7884+
7885+
# Create a message history with a container_id in provider_details
7886+
history: list[ModelMessage] = [
7887+
ModelRequest(parts=[UserPromptPart(content='hello')]),
7888+
ModelResponse(
7889+
parts=[TextPart(content='world')],
7890+
provider_details={'container_id': 'container_from_history'},
7891+
),
7892+
]
7893+
7894+
# Run with the message history
7895+
await agent.run('follow up', message_history=history)
7896+
7897+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
7898+
assert completion_kwargs['container'] == BetaContainerParams(id='container_from_history')
7899+
7900+
7901+
async def test_anthropic_container_setting_false_ignores_history(allow_model_requests: None):
7902+
"""Test that anthropic_container=False ignores container_id from history."""
7903+
c = completion_message([BetaTextBlock(text='world', type='text')], BetaUsage(input_tokens=5, output_tokens=10))
7904+
mock_client = MockAnthropic.create_mock(c)
7905+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
7906+
agent = Agent(m)
7907+
7908+
# Create a message history with a container_id
7909+
history: list[ModelMessage] = [
7910+
ModelRequest(parts=[UserPromptPart(content='hello')]),
7911+
ModelResponse(
7912+
parts=[TextPart(content='world')],
7913+
provider_details={'container_id': 'container_should_be_ignored'},
7914+
),
7915+
]
7916+
7917+
# Run with anthropic_container=False to force fresh container
7918+
await agent.run(
7919+
'follow up', message_history=history, model_settings=AnthropicModelSettings(anthropic_container=False)
7920+
)
7921+
7922+
completion_kwargs = get_mock_chat_completion_kwargs(mock_client)[0]
7923+
# When anthropic_container=False, container should be OMIT (filtered out before sending to API)
7924+
from anthropic import omit as OMIT
7925+
7926+
assert completion_kwargs.get('container') is OMIT
7927+
7928+
7929+
async def test_anthropic_container_id_from_stream_response(allow_model_requests: None):
7930+
"""Test that container_id is extracted from streamed response and stored in provider_details."""
7931+
from datetime import datetime
7932+
7933+
stream_events: list[BetaRawMessageStreamEvent] = [
7934+
BetaRawMessageStartEvent(
7935+
type='message_start',
7936+
message=BetaMessage(
7937+
id='msg_123',
7938+
content=[],
7939+
model='claude-3-5-haiku-123',
7940+
role='assistant',
7941+
stop_reason=None,
7942+
type='message',
7943+
usage=BetaUsage(input_tokens=5, output_tokens=0),
7944+
container=BetaContainer(
7945+
id='container_from_stream',
7946+
expires_at=datetime(2025, 1, 1, 0, 0, 0),
7947+
),
7948+
),
7949+
),
7950+
BetaRawContentBlockStartEvent(
7951+
type='content_block_start',
7952+
index=0,
7953+
content_block=BetaTextBlock(text='', type='text'),
7954+
),
7955+
BetaRawContentBlockDeltaEvent(
7956+
type='content_block_delta',
7957+
index=0,
7958+
delta=BetaTextDelta(type='text_delta', text='hello'),
7959+
),
7960+
BetaRawContentBlockStopEvent(type='content_block_stop', index=0),
7961+
BetaRawMessageDeltaEvent(
7962+
type='message_delta',
7963+
delta=Delta(stop_reason='end_turn', stop_sequence=None),
7964+
usage=BetaMessageDeltaUsage(output_tokens=5),
7965+
),
7966+
BetaRawMessageStopEvent(type='message_stop'),
7967+
]
7968+
7969+
mock_client = MockAnthropic.create_stream_mock(stream_events)
7970+
m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client))
7971+
agent = Agent(m)
7972+
7973+
async with agent.run_stream('hello') as result:
7974+
response = await result.get_output()
7975+
assert response == 'hello'
7976+
7977+
# Check that container_id was captured in the response
7978+
messages = result.all_messages()
7979+
model_response = messages[-1]
7980+
assert isinstance(model_response, ModelResponse)
7981+
assert model_response.provider_details is not None
7982+
assert model_response.provider_details.get('container_id') == 'container_from_stream'
7983+
assert model_response.provider_details.get('finish_reason') == 'end_turn'

0 commit comments

Comments
 (0)