diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 89290ea3ce..73c5d12953 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -549,15 +549,25 @@ async def _map_messages( elif isinstance(part, UserPromptPart): message_parts.extend(await self._map_user_prompt(part)) elif isinstance(part, ToolReturnPart): - message_parts.append( - { - 'function_response': { - 'name': part.tool_name, - 'response': part.model_response_object(), - 'id': part.tool_call_id, + if self.profile.supports_tools: + message_parts.append( + { + 'function_response': { + 'name': part.tool_name, + 'response': part.model_response_object(), + 'id': part.tool_call_id, + } } - } - ) + ) + else: + text = '\n'.join( + [ + f'-----BEGIN TOOL RETURN name="{part.tool_name}" id="{part.tool_call_id}"-----', + f'response: {part.model_response_object()}', + f'-----END TOOL RETURN id="{part.tool_call_id}"-----', + ] + ) + message_parts.append({'text': text}) elif isinstance(part, RetryPromptPart): if part.tool_name is None: message_parts.append({'text': part.model_response()}) @@ -577,7 +587,7 @@ async def _map_messages( if message_parts: contents.append({'role': 'user', 'parts': message_parts}) elif isinstance(m, ModelResponse): - maybe_content = _content_model_response(m, self.system) + maybe_content = _content_model_response(m, self.system, self.profile.supports_tools) if maybe_content: contents.append(maybe_content) else: @@ -786,7 +796,7 @@ def timestamp(self) -> datetime: return self._timestamp -def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict | None: # noqa: C901 +def _content_model_response(m: ModelResponse, provider_name: str, supports_tools: bool) -> ContentDict | None: # noqa: C901 parts: list[PartDict] = [] thinking_part_signature: str | None = None function_call_requires_signature: bool = True @@ -803,17 +813,27 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict thinking_part_signature = None if isinstance(item, ToolCallPart): - function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id) - part['function_call'] = function_call - if function_call_requires_signature and not part.get('thought_signature'): - # Per https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#migrating_from_other_models: - # > If you are transferring a conversation trace from another model (e.g., Gemini 2.5) or injecting - # > a custom function call that was not generated by Gemini 3, you will not have a valid signature. - # > To bypass strict validation in these specific scenarios, populate the field with this specific - # > dummy string: "thoughtSignature": "context_engineering_is_the_way_to_go" - part['thought_signature'] = b'context_engineering_is_the_way_to_go' - # Only the first function call requires a signature - function_call_requires_signature = False + if supports_tools: + function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id) + part['function_call'] = function_call + if function_call_requires_signature and not part.get('thought_signature'): + # Per https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#migrating_from_other_models: + # > If you are transferring a conversation trace from another model (e.g., Gemini 2.5) or injecting + # > a custom function call that was not generated by Gemini 3, you will not have a valid signature. + # > To bypass strict validation in these specific scenarios, populate the field with this specific + # > dummy string: "thoughtSignature": "context_engineering_is_the_way_to_go" + part['thought_signature'] = b'context_engineering_is_the_way_to_go' + # Only the first function call requires a signature + function_call_requires_signature = False + else: + text = '\n'.join( + [ + f'-----BEGIN TOOL CALL name="{item.tool_name} "id="{item.tool_call_id}""-----', + f'args: {item.args_as_json_str()}', + f'-----END TOOL CALL id="{item.tool_call_id}"-----', + ] + ) + part['text'] = text elif isinstance(item, TextPart): part['text'] = item.content elif isinstance(item, ThinkingPart): diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 3ef8cd5dda..18158526a3 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -4307,6 +4307,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) new_google_response = _content_model_response( ModelResponse( @@ -4318,6 +4319,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) assert old_google_response == snapshot( { @@ -4342,6 +4344,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) new_google_response = _content_model_response( ModelResponse( @@ -4352,6 +4355,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) assert old_google_response == snapshot( { @@ -4376,6 +4380,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) new_google_response = _content_model_response( ModelResponse( @@ -4386,6 +4391,7 @@ def test_google_thought_signature_on_thinking_part(): provider_name='google-gla', ), 'google-gla', + True, ) assert old_google_response == snapshot( { @@ -4412,6 +4418,7 @@ def test_google_missing_tool_call_thought_signature(): provider_name='openai', ), 'google-gla', + True, ) assert google_response == snapshot( { @@ -4425,3 +4432,107 @@ def test_google_missing_tool_call_thought_signature(): ], } ) + + +async def test_google_mapping_messages_no_tool_support(google_provider: GoogleProvider): + old_messages = [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse(parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())]), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'city': 'Mexico City', 'country': 'Mexico'}, + tool_call_id=IsStr(), + ) + ], + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ] + model = GoogleModel('gemini-2.5-flash-image-preview', provider=google_provider) + new_messages = await model._map_messages(old_messages, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage] + assert new_messages == snapshot( + ( + None, + [ + {'role': 'user', 'parts': [{'text': 'What is the largest city in the user country?'}]}, + { + 'role': 'model', + 'parts': [ + { + 'text': """\ +-----BEGIN TOOL CALL name="get_user_country "id="IsStr()""----- +args: {} +-----END TOOL CALL id="IsStr()"-----\ +""" + } + ], + }, + { + 'role': 'user', + 'parts': [ + { + 'text': """\ +-----BEGIN TOOL RETURN name="get_user_country" id="IsStr()"----- +response: {'return_value': 'Mexico'} +-----END TOOL RETURN id="IsStr()"-----\ +""" + } + ], + }, + { + 'role': 'model', + 'parts': [ + { + 'text': """\ +-----BEGIN TOOL CALL name="final_result "id="IsStr()""----- +args: {"city":"Mexico City","country":"Mexico"} +-----END TOOL CALL id="IsStr()"-----\ +""" + } + ], + }, + { + 'role': 'user', + 'parts': [ + { + 'text': """\ +-----BEGIN TOOL RETURN name="final_result" id="IsStr()"----- +response: {'return_value': 'Final result processed.'} +-----END TOOL RETURN id="IsStr()"-----\ +""" + } + ], + }, + ], + ) + )