Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

for part in parts:
provider_details: dict[str, Any] | None = None
thought_signature: str | None = None
if part.thought_signature:
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#thought-signatures:
# - Always send the thought_signature back to the model inside its original Part.
Expand All @@ -718,12 +719,28 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
thought_signature = base64.b64encode(part.thought_signature).decode('utf-8')
provider_details = {'thought_signature': thought_signature}

# Google returns thought_signature on the part FOLLOWING the thinking part.
# Apply it to the previous ThinkingPart if this is a non-thinking part with a signature.
if thought_signature and not part.thought:
# Only apply signature if the latest part is a ThinkingPart
parts = self._parts_manager.get_parts()
if parts and isinstance(parts[-1], ThinkingPart):
for event in self._parts_manager.handle_thinking_delta(
vendor_part_id=None,
signature=thought_signature,
):
yield event

if part.text is not None:
if len(part.text) == 0 and not provider_details:
continue
if part.thought:
for event in self._parts_manager.handle_thinking_delta(
vendor_part_id=None, content=part.text, provider_details=provider_details
vendor_part_id=None,
content=part.text,
signature=thought_signature,
provider_name=self._provider_name,
provider_details=provider_details,
):
yield event
else:
Expand Down Expand Up @@ -878,8 +895,10 @@ def _process_response_from_parts(

item: ModelResponsePart | None = None
code_execution_tool_call_id: str | None = None
last_thinking_part: ThinkingPart | None = None
for part in parts:
provider_details: dict[str, Any] | None = None
thought_signature: str | None = None
if part.thought_signature:
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#thought-signatures:
# - Always send the thought_signature back to the model inside its original Part.
Expand All @@ -899,7 +918,8 @@ def _process_response_from_parts(
if len(part.text) == 0 and not provider_details:
continue
if part.thought:
item = ThinkingPart(content=part.text)
item = ThinkingPart(content=part.text, signature=thought_signature, provider_name=provider_name)
last_thinking_part = item
else:
item = TextPart(content=part.text)
elif part.function_call:
Expand All @@ -916,6 +936,12 @@ def _process_response_from_parts(
else: # pragma: no cover
raise UnexpectedModelBehavior(f'Unsupported response from Gemini: {part!r}')

# Google returns thought_signature on the part FOLLOWING the thinking part.
# Apply it to the previous ThinkingPart if this is a non-thinking part with a signature.
if thought_signature and last_thinking_part and not part.thought:
last_thinking_part.signature = thought_signature
last_thinking_part = None # Only apply once

if provider_details:
item.provider_details = {**(item.provider_details or {}), **provider_details}

Expand Down
29 changes: 20 additions & 9 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,7 +1931,11 @@ def dummy() -> None: ... # pragma: no cover
),
ModelResponse(
parts=[
ThinkingPart(content=IsStr()),
ThinkingPart(
content=IsStr(),
signature=IsStr(),
provider_name='google-gla',
),
TextPart(
content=IsStr(),
provider_details={'thought_signature': IsStr()},
Expand Down Expand Up @@ -1968,7 +1972,11 @@ def dummy() -> None: ... # pragma: no cover
),
ModelResponse(
parts=[
ThinkingPart(content=IsStr()),
ThinkingPart(
content=IsStr(),
signature=IsStr(),
provider_name='google-gla',
),
TextPart(
content=IsStr(),
provider_details={'thought_signature': IsStr()},
Expand Down Expand Up @@ -2077,7 +2085,7 @@ def dummy() -> None: ... # pragma: no cover
),
ModelResponse(
parts=[
ThinkingPart(content=IsStr()),
ThinkingPart(content=IsStr(), signature=IsStr(), provider_name='google-gla'),
TextPart(
content=IsStr(),
provider_details={'thought_signature': IsStr()},
Expand Down Expand Up @@ -2133,7 +2141,7 @@ def dummy() -> None: ... # pragma: no cover
),
ModelResponse(
parts=[
ThinkingPart(content=IsStr()),
ThinkingPart(content=IsStr(), signature=IsStr(), provider_name='google-gla'),
TextPart(
content=IsStr(),
provider_details={'thought_signature': IsStr()},
Expand All @@ -2155,10 +2163,11 @@ def dummy() -> None: ... # pragma: no cover

assert event_parts == snapshot(
[
PartStartEvent(index=0, part=ThinkingPart(content=IsStr())),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr())),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr())),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr())),
PartStartEvent(index=0, part=ThinkingPart(content=IsStr(), provider_name='google-gla')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr(), provider_name='google-gla')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr(), provider_name='google-gla')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=IsStr(), provider_name='google-gla')),
PartDeltaEvent(index=0, delta=ThinkingPartDelta(signature_delta=IsStr())),
PartEndEvent(
index=0,
part=ThinkingPart(
Expand All @@ -2183,7 +2192,9 @@ def dummy() -> None: ... # pragma: no cover
I've identified the core user intent: to learn safe street-crossing. Now, I'm focusing on crafting universally applicable steps. Finding safe crossing locations and looking-listening for traffic remain paramount. I'm prioritizing direct, clear language, addressing my limitations as an AI. I'm crafting advice that works generally, regardless of specific circumstances or locations.


"""
""",
signature=IsStr(),
provider_name='google-gla',
),
next_part_kind='text',
),
Expand Down
Loading