Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,15 +549,25 @@ async def _map_messages(
elif isinstance(part, UserPromptPart):
message_parts.extend(await self._map_user_prompt(part))
elif isinstance(part, ToolReturnPart):
message_parts.append(
{
'function_response': {
'name': part.tool_name,
'response': part.model_response_object(),
'id': part.tool_call_id,
if self.profile.supports_tools:
message_parts.append(
{
'function_response': {
'name': part.tool_name,
'response': part.model_response_object(),
'id': part.tool_call_id,
}
}
}
)
)
else:
text = '\n'.join(
[
f'-----BEGIN TOOL RETURN name="{part.tool_name}" id="{part.tool_call_id}"-----',
f'response: {part.model_response_object()}',
f'-----END TOOL RETURN id="{part.tool_call_id}"-----',
]
)
message_parts.append({'text': text})
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
message_parts.append({'text': part.model_response()})
Expand All @@ -577,7 +587,7 @@ async def _map_messages(
if message_parts:
contents.append({'role': 'user', 'parts': message_parts})
elif isinstance(m, ModelResponse):
maybe_content = _content_model_response(m, self.system)
maybe_content = _content_model_response(m, self.system, self.profile.supports_tools)
if maybe_content:
contents.append(maybe_content)
else:
Expand Down Expand Up @@ -786,7 +796,7 @@ def timestamp(self) -> datetime:
return self._timestamp


def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict | None: # noqa: C901
def _content_model_response(m: ModelResponse, provider_name: str, supports_tools: bool) -> ContentDict | None: # noqa: C901
parts: list[PartDict] = []
thinking_part_signature: str | None = None
function_call_requires_signature: bool = True
Expand All @@ -803,17 +813,27 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
thinking_part_signature = None

if isinstance(item, ToolCallPart):
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
part['function_call'] = function_call
if function_call_requires_signature and not part.get('thought_signature'):
# Per https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#migrating_from_other_models:
# > If you are transferring a conversation trace from another model (e.g., Gemini 2.5) or injecting
# > a custom function call that was not generated by Gemini 3, you will not have a valid signature.
# > To bypass strict validation in these specific scenarios, populate the field with this specific
# > dummy string: "thoughtSignature": "context_engineering_is_the_way_to_go"
part['thought_signature'] = b'context_engineering_is_the_way_to_go'
# Only the first function call requires a signature
function_call_requires_signature = False
if supports_tools:
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
part['function_call'] = function_call
if function_call_requires_signature and not part.get('thought_signature'):
# Per https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#migrating_from_other_models:
# > If you are transferring a conversation trace from another model (e.g., Gemini 2.5) or injecting
# > a custom function call that was not generated by Gemini 3, you will not have a valid signature.
# > To bypass strict validation in these specific scenarios, populate the field with this specific
# > dummy string: "thoughtSignature": "context_engineering_is_the_way_to_go"
part['thought_signature'] = b'context_engineering_is_the_way_to_go'
# Only the first function call requires a signature
function_call_requires_signature = False
else:
text = '\n'.join(
[
f'-----BEGIN TOOL CALL name="{item.tool_name} "id="{item.tool_call_id}""-----',
f'args: {item.args_as_json_str()}',
f'-----END TOOL CALL id="{item.tool_call_id}"-----',
]
)
part['text'] = text
elif isinstance(item, TextPart):
part['text'] = item.content
elif isinstance(item, ThinkingPart):
Expand Down
111 changes: 111 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -4307,6 +4307,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
new_google_response = _content_model_response(
ModelResponse(
Expand All @@ -4318,6 +4319,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
assert old_google_response == snapshot(
{
Expand All @@ -4342,6 +4344,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
new_google_response = _content_model_response(
ModelResponse(
Expand All @@ -4352,6 +4355,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
assert old_google_response == snapshot(
{
Expand All @@ -4376,6 +4380,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
new_google_response = _content_model_response(
ModelResponse(
Expand All @@ -4386,6 +4391,7 @@ def test_google_thought_signature_on_thinking_part():
provider_name='google-gla',
),
'google-gla',
True,
)
assert old_google_response == snapshot(
{
Expand All @@ -4412,6 +4418,7 @@ def test_google_missing_tool_call_thought_signature():
provider_name='openai',
),
'google-gla',
True,
)
assert google_response == snapshot(
{
Expand All @@ -4425,3 +4432,107 @@ def test_google_missing_tool_call_thought_signature():
],
}
)


async def test_google_mapping_messages_no_tool_support(google_provider: GoogleProvider):
old_messages = [
ModelRequest(
parts=[
UserPromptPart(
content='What is the largest city in the user country?',
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())]),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_user_country',
content='Mexico',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
args={'city': 'Mexico City', 'country': 'Mexico'},
tool_call_id=IsStr(),
)
],
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
]
model = GoogleModel('gemini-2.5-flash-image-preview', provider=google_provider)
new_messages = await model._map_messages(old_messages, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage]
assert new_messages == snapshot(
(
None,
[
{'role': 'user', 'parts': [{'text': 'What is the largest city in the user country?'}]},
{
'role': 'model',
'parts': [
{
'text': """\
-----BEGIN TOOL CALL name="get_user_country "id="IsStr()""-----
args: {}
-----END TOOL CALL id="IsStr()"-----\
"""
}
],
},
{
'role': 'user',
'parts': [
{
'text': """\
-----BEGIN TOOL RETURN name="get_user_country" id="IsStr()"-----
response: {'return_value': 'Mexico'}
-----END TOOL RETURN id="IsStr()"-----\
"""
}
],
},
{
'role': 'model',
'parts': [
{
'text': """\
-----BEGIN TOOL CALL name="final_result "id="IsStr()""-----
args: {"city":"Mexico City","country":"Mexico"}
-----END TOOL CALL id="IsStr()"-----\
"""
}
],
},
{
'role': 'user',
'parts': [
{
'text': """\
-----BEGIN TOOL RETURN name="final_result" id="IsStr()"-----
response: {'return_value': 'Final result processed.'}
-----END TOOL RETURN id="IsStr()"-----\
"""
}
],
},
],
)
)