From 8d52d651154b9f021f3038f6b449a940e8debeb4 Mon Sep 17 00:00:00 2001 From: David <64162682+dsfaccini@users.noreply.github.com> Date: Fri, 28 Nov 2025 12:42:02 -0500 Subject: [PATCH 1/9] Clarify usage of agent factories Clarify that agents can be produced by a factory function if preferred. --- docs/agents.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/agents.md b/docs/agents.md index 0633fb88ba..a0dc22e5b1 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -57,7 +57,7 @@ print(result.output) 4. `result.output` will be a boolean indicating if the square is a winner. Pydantic performs the output validation, and it'll be typed as a `bool` since its type is derived from the `output_type` generic parameter of the agent. !!! tip "Agents are designed for reuse, like FastAPI Apps" - Agents are intended to be instantiated once (frequently as module globals) and reused throughout your application, similar to a small [FastAPI][fastapi.FastAPI] app or an [APIRouter][fastapi.APIRouter]. + Agents can be instantiated once as a module global and reused throughout your application, similar to a small [FastAPI][fastapi.FastAPI] app or an [APIRouter][fastapi.APIRouter], or be created dynamically by a factory function like `get_agent('agent-type')`, whichever you prefer. ## Running Agents From 6d942f5f3b31fcf44e57dc0077700d155d3e2933 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:40:21 -0500 Subject: [PATCH 2/9] implement tool choice resolution per model --- .../pydantic_ai/models/anthropic.py | 89 ++++++++- .../pydantic_ai/models/bedrock.py | 59 +++++- pydantic_ai_slim/pydantic_ai/models/google.py | 84 +++++++- pydantic_ai_slim/pydantic_ai/models/groq.py | 77 +++++++- .../pydantic_ai/models/huggingface.py | 77 +++++++- .../pydantic_ai/models/mistral.py | 59 +++++- pydantic_ai_slim/pydantic_ai/models/openai.py | 183 ++++++++++++++++-- pydantic_ai_slim/pydantic_ai/settings.py | 29 +++ tests/models/test_anthropic.py | 138 +++++++++++++ tests/models/test_groq.py | 75 ++++++- tests/models/test_huggingface.py | 2 +- tests/models/test_openai.py | 131 +++++++++++++ 12 files changed, 946 insertions(+), 57 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index ecdb9fe61f..9929feef1b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import io +import warnings from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace @@ -640,18 +641,96 @@ def _infer_tool_choice( ) -> BetaToolChoiceParam | None: if not tools: return None - else: - tool_choice: BetaToolChoiceParam + user_tool_choice = model_settings.get('tool_choice') + thinking_enabled = model_settings.get('anthropic_thinking') is not None + tool_choice: BetaToolChoiceParam + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools (Anthropic only supports one tool at a time) + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + tool_choice = {'type': 'tool', 'name': output_tool_names[0]} + else: + # Multiple output tools - fall back to 'auto' and warn + warnings.warn( + 'Anthropic only supports forcing a single tool. ' + "Falling back to 'auto' for multiple output tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + else: + tool_choice = {'type': 'none'} + + elif user_tool_choice == 'auto': + tool_choice = {'type': 'auto'} + + elif user_tool_choice == 'required': + if thinking_enabled: + warnings.warn( + "tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + else: + tool_choice = {'type': 'any'} + + elif isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if thinking_enabled: + warnings.warn( + "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + elif len(user_tool_choice) == 1: + tool_choice = {'type': 'tool', 'name': user_tool_choice[0]} + else: + # Anthropic only supports one tool at a time + warnings.warn( + 'Anthropic only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'any'} + else: + tool_choice = {'type': 'auto'} + + else: + # Default behavior: infer from allow_text_output if not model_request_parameters.allow_text_output: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} - if 'parallel_tool_calls' in model_settings: - tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] + if 'parallel_tool_calls' in model_settings and tool_choice.get('type') != 'none': + # only `BetaToolChoiceNoneParam` doesn't have this field + tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] # pyright: ignore[reportGeneralTypeIssues] - return tool_choice + return tool_choice async def _map_message( # noqa: C901 self, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ff03460904..8927d4be43 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -2,6 +2,7 @@ import functools import typing +import warnings from collections.abc import AsyncIterator, Iterable, Iterator, Mapping from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -429,7 +430,7 @@ async def _messages_create( 'inferenceConfig': inference_config, } - tool_config = self._map_tool_config(model_request_parameters) + tool_config = self._map_tool_config(model_request_parameters, model_settings) if tool_config: params['toolConfig'] = tool_config @@ -485,16 +486,64 @@ def _map_inference_config( return inference_config - def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: + def _map_tool_config( + self, + model_request_parameters: ModelRequestParameters, + model_settings: BedrockModelSettings | None, + ) -> ToolConfigurationTypeDef | None: tools = self._get_tools(model_request_parameters) if not tools: return None + user_tool_choice = model_settings.get('tool_choice') if model_settings else None tool_choice: ToolChoiceTypeDef - if not model_request_parameters.allow_text_output: - tool_choice = {'any': {}} + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # Bedrock doesn't support 'none', fall back to 'auto' with warning + warnings.warn( + "Bedrock does not support tool_choice='none'. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'auto': {}} + + elif user_tool_choice == 'auto': + tool_choice = {'auto': {}} + + elif user_tool_choice == 'required': + tool_choice = {'any': {}} + + elif isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if len(user_tool_choice) == 1: + tool_choice = {'tool': {'name': user_tool_choice[0]}} + else: + # Bedrock only supports single tool choice, fall back to any + warnings.warn( + 'Bedrock only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'any': {}} + else: + tool_choice = {'auto': {}} else: - tool_choice = {'auto': {}} + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + tool_choice = {'any': {}} + else: + tool_choice = {'auto': {}} tool_config: ToolConfigurationTypeDef = {'tools': tools} if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 89290ea3ce..5558dd37b9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import base64 +import warnings from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace @@ -363,18 +364,85 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T ) return tools or None - def _get_tool_config( - self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None + def _get_tool_config( # noqa: C901 + self, + model_request_parameters: ModelRequestParameters, + tools: list[ToolDict] | None, + model_settings: GoogleModelSettings, ) -> ToolConfigDict | None: - if not model_request_parameters.allow_text_output and tools: - names: list[str] = [] + if not tools: + return None + + user_tool_choice = model_settings.get('tool_choice') + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=output_tool_names, + ) + ) + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE) + ) + + if user_tool_choice == 'auto': + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO) + ) + + if user_tool_choice == 'required': + # Get all tool names + names: list[str] = [] + for tool in tools: + for function_declaration in tool.get('function_declarations') or []: + if name := function_declaration.get('name'): + names.append(name) + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=names, + ) + ) + + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=list(user_tool_choice), + ) + ) + + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + names = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: - if name := function_declaration.get('name'): # pragma: no branch + if name := function_declaration.get('name'): names.append(name) return _tool_config(names) - else: - return None + return None @overload async def _generate_content( @@ -440,7 +508,7 @@ async def _build_content_and_config( raise UserError('JSON output is not supported by this model.') response_mime_type = 'application/json' - tool_config = self._get_tool_config(model_request_parameters, tools) + tool_config = self._get_tool_config(model_request_parameters, tools, model_settings) system_instruction, contents = await self._map_messages(messages, model_request_parameters) modalities = [Modality.TEXT.value] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 64f7ddcf85..9138c8f789 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -56,6 +57,8 @@ from groq.types import chat from groq.types.chat.chat_completion_content_part_image_param import ImageURL from groq.types.chat.chat_completion_message import ExecutedTool + from groq.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam + from groq.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam except ImportError as _import_error: raise ImportError( 'Please install `groq` to use the Groq model, ' @@ -265,12 +268,7 @@ async def _completions_create( ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) tools += self._get_builtin_tools(model_request_parameters) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) groq_messages = self._map_messages(messages, model_request_parameters) @@ -376,6 +374,73 @@ async def _process_streamed_response( def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[chat.ChatCompletionToolParam], + model_settings: GroqModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionToolChoiceOptionParam | None: + user_tool_choice = model_settings.get('tool_choice') + + if not tools: + return None + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools (force first one since Groq only supports single tool) + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': output_tool_names[0]}, + ) + return 'none' + + if user_tool_choice == 'auto': + return 'auto' + + if user_tool_choice == 'required': + return 'required' + + # Handle list of specific tool names + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if len(user_tool_choice) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': user_tool_choice[0]}, + ) + else: + # Groq only supports single tool choice, fall back to required + warnings.warn( + "Groq only supports forcing a single tool. Falling back to 'required' for multiple tools.", + UserWarning, + stacklevel=6, + ) + return 'required' + + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + return 'required' + return 'auto' + def _get_builtin_tools( self, model_request_parameters: ModelRequestParameters ) -> list[chat.ChatCompletionToolParam]: diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 790b30bec3..375c3e95d5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -52,10 +53,12 @@ import aiohttp from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInputFunctionName, ChatCompletionInputMessage, ChatCompletionInputMessageChunk, ChatCompletionInputTool, ChatCompletionInputToolCall, + ChatCompletionInputToolChoiceClass, ChatCompletionInputURL, ChatCompletionOutput, ChatCompletionOutputMessage, @@ -221,13 +224,7 @@ async def _completions_create( model_request_parameters: ModelRequestParameters, ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: tools = self._get_tools(model_request_parameters) - - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) if model_request_parameters.builtin_tools: raise UserError('HuggingFace does not support built-in tools') @@ -322,6 +319,72 @@ async def _process_streamed_response( def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[ChatCompletionInputTool], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> Literal['none', 'required', 'auto'] | ChatCompletionInputToolChoiceClass | None: + user_tool_choice = model_settings.get('tool_choice') + + if not tools: + return None + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools (force first one) + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=output_tool_names[0]) # pyright: ignore[reportCallIssue] + ) + return 'none' + + if user_tool_choice == 'auto': + return 'auto' + + if user_tool_choice == 'required': + return 'required' + + # Handle list of specific tool names + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if len(user_tool_choice) == 1: + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=user_tool_choice[0]) # pyright: ignore[reportCallIssue] + ) + else: + # HuggingFace only supports single tool choice, fall back to required + warnings.warn( + 'HuggingFace only supports forcing a single tool. ' + "Falling back to 'required' for multiple tools.", + UserWarning, + stacklevel=6, + ) + return 'required' + + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + return 'required' + return 'auto' + async def _map_messages( self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index cefa28e9dc..ae87c8fc78 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -233,7 +234,7 @@ async def _completions_create( messages=self._map_messages(messages, model_request_parameters), n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), + tool_choice=self._get_tool_choice(model_request_parameters, model_settings), stream=False, max_tokens=model_settings.get('max_tokens', UNSET), temperature=model_settings.get('temperature', UNSET), @@ -273,7 +274,7 @@ async def _stream_completions_create( messages=mistral_messages, n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), + tool_choice=self._get_tool_choice(model_request_parameters, model_settings), temperature=model_settings.get('temperature', UNSET), top_p=model_settings.get('top_p', 1), max_tokens=model_settings.get('max_tokens', UNSET), @@ -312,7 +313,11 @@ async def _stream_completions_create( assert response, 'A unexpected empty response from Mistral.' return response - def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None: + def _get_tool_choice( + self, + model_request_parameters: ModelRequestParameters, + model_settings: MistralModelSettings, + ) -> MistralToolChoiceEnum | None: """Get tool choice for the model. - "auto": Default mode. Model decides if it uses the tool or not. @@ -322,10 +327,52 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif not model_request_parameters.allow_text_output: + + user_tool_choice = model_settings.get('tool_choice') + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + return 'required' + return 'none' + + if user_tool_choice == 'auto': + return 'auto' + + if user_tool_choice == 'required': + return 'required' + + # Handle list of specific tool names + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + # Mistral doesn't support forcing specific tools, fall back to required + warnings.warn( + "Mistral does not support forcing specific tools. Falling back to 'required'.", + UserWarning, + stacklevel=6, + ) + return 'required' + + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: return 'required' - else: - return 'auto' + return 'auto' def _map_function_and_output_tools_definition( self, model_request_parameters: ModelRequestParameters diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 10af284ee8..c4a8da4db3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -52,7 +52,16 @@ from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition -from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent +from . import ( + Model, + ModelRequestParameters, + ResolvedToolChoice, + StreamedResponse, + check_allow_model_requests, + download_item, + get_user_agent, + resolve_tool_choice, +) try: from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream @@ -67,6 +76,8 @@ chat_completion_chunk, chat_completion_token_logprob, ) + from openai.types.chat.chat_completion_allowed_tool_choice_param import ChatCompletionAllowedToolChoiceParam + from openai.types.chat.chat_completion_allowed_tools_param import ChatCompletionAllowedToolsParam from openai.types.chat.chat_completion_content_part_image_param import ImageURL from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.chat.chat_completion_content_part_param import File, FileFile @@ -75,16 +86,21 @@ from openai.types.chat.chat_completion_message_function_tool_call_param import ( ChatCompletionMessageFunctionToolCallParam, ) + from openai.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam + from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam from openai.types.chat.completion_create_params import ( WebSearchOptions, WebSearchOptionsUserLocation, WebSearchOptionsUserLocationApproximate, ) from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam + from openai.types.responses.response_create_params import ToolChoice as ResponsesToolChoice from openai.types.responses.response_input_param import FunctionCallOutput, Message from openai.types.responses.response_reasoning_item_param import Summary from openai.types.responses.response_status import ResponseStatus + from openai.types.responses.tool_choice_allowed_param import ToolChoiceAllowedParam + from openai.types.responses.tool_choice_function_param import ToolChoiceFunctionParam from openai.types.shared import ReasoningEffort from openai.types.shared_params import Reasoning except ImportError as _import_error: @@ -493,15 +509,7 @@ async def _completions_create( tools = self._get_tools(model_request_parameters) web_search_options = self._get_web_search_options(model_request_parameters) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) openai_messages = await self._map_messages(messages, model_request_parameters) @@ -691,6 +699,85 @@ def _map_usage(self, response: chat.ChatCompletion) -> usage.RequestUsage: def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[chat.ChatCompletionToolParam], + model_settings: OpenAIChatModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionToolChoiceOptionParam | None: + user_tool_choice = model_settings.get('tool_choice') + + if not tools: + return None + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': output_tool_names[0]}, + ) + else: + return ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'function': {'name': n}} for n in output_tool_names], + ), + ) + return 'none' + + if user_tool_choice == 'auto': + return 'auto' + + if user_tool_choice == 'required': + return 'required' + + # Handle list of specific tool names + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if len(user_tool_choice) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': user_tool_choice[0]}, + ) + else: + return ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'function': {'name': n}} for n in user_tool_choice], + ), + ) + + # Default behavior: infer from allow_text_output + if ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + ): + return 'required' + return 'auto' + def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None: for tool in model_request_parameters.builtin_tools: if isinstance(tool, WebSearchTool): # pragma: no branch @@ -1272,7 +1359,7 @@ async def _responses_create( model_request_parameters: ModelRequestParameters, ) -> AsyncStream[responses.ResponseStreamEvent]: ... - async def _responses_create( # noqa: C901 + async def _responses_create( self, messages: list[ModelRequest | ModelResponse], stream: bool, @@ -1285,12 +1372,7 @@ async def _responses_create( # noqa: C901 + self._get_tools(model_request_parameters) ) profile = OpenAIModelProfile.from_profile(self.profile) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_responses_tool_choice(tools, model_settings, model_request_parameters) previous_response_id = model_settings.get('openai_previous_response_id') if previous_response_id == 'auto': @@ -1401,6 +1483,73 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_responses_tool_choice( + self, + tools: list[responses.ToolParam], + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ResponsesToolChoice | None: + user_tool_choice = model_settings.get('tool_choice') + profile = OpenAIModelProfile.from_profile(self.profile) + + if not tools: + return None + + # Handle explicit user-provided tool_choice + if user_tool_choice is not None: + if user_tool_choice == 'none': + # If output tools exist, we can't truly disable all tools + if model_request_parameters.output_tools: + warnings.warn( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + UserWarning, + stacklevel=6, + ) + # Allow only output tools + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + return ToolChoiceFunctionParam(type='function', name=output_tool_names[0]) + else: + return ToolChoiceAllowedParam( + type='allowed_tools', + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'name': n} for n in output_tool_names], + ) + return 'none' + + if user_tool_choice == 'auto': + return 'auto' + + if user_tool_choice == 'required': + return 'required' + + # Handle list of specific tool names + if isinstance(user_tool_choice, list): + # Validate tool names exist in function_tools + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + + if len(user_tool_choice) == 1: + return ToolChoiceFunctionParam(type='function', name=user_tool_choice[0]) + else: + return ToolChoiceAllowedParam( + type='allowed_tools', + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'name': n} for n in user_tool_choice], + ) + + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required: + return 'required' + return 'auto' + def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]: tools: list[responses.ToolParam] = [] has_image_generating_tool = False diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index 6941eb1ab3..5f0fde9312 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + from httpx import Timeout from typing_extensions import TypedDict @@ -88,6 +90,33 @@ class ModelSettings(TypedDict, total=False): * Anthropic """ + tool_choice: Literal['none', 'required', 'auto'] | list[str] | None + """Control which function tools the model can use. + + This setting only affects function tools registered on the agent, not output tools + used for structured output. + + * `None` (default): Automatically determined based on output configuration + * `'auto'`: Model decides whether to use function tools + * `'required'`: Model must use one of the available function tools + * `'none'`: Model cannot use function tools (output tools remain available if needed) + * `list[str]`: Model must use one of the specified function tools (validated against registered tools) + + If the agent has a structured output type that requires an output tool and `tool_choice='none'` + is set, the output tool will still be available and a warning will be logged. Consider using + native or prompted output modes if you need `tool_choice='none'` with structured output. + + Supported by: + + * OpenAI + * Anthropic (note: `'required'` and specific tools not supported with thinking/extended thinking) + * Gemini + * Groq + * Mistral + * HuggingFace + * Bedrock (note: `'none'` not supported, will fall back to `'auto'` with a warning) + """ + seed: int """The random seed to use for the model, theoretically allowing for deterministic results. diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index f770b157af..ae96155baa 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -7742,3 +7742,141 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant assert usage2.cache_read_tokens > 0 assert usage2.cache_write_tokens > 0 assert usage2.output_tokens > 0 + + +# Tests for tool_choice ModelSettings + + +@pytest.mark.parametrize( + 'tool_choice,expected_type', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'any', id='required-maps-to-any'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected_type: str) -> None: + """Test that tool_choice string values are correctly mapped to Anthropic's format.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == expected_type + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Test tool_choice with a single specific tool name maps to Anthropic's 'tool' type.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice'] == {'type': 'tool', 'name': 'tool_a'} + + +async def test_tool_choice_multiple_tools_falls_back_to_any(allow_model_requests: None) -> None: + """Test tool_choice with multiple tools falls back to 'any' with warning (Anthropic limitation).""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Anthropic only supports forcing a single tool'): + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'any' + + +async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: + """Test that invalid tool names in tool_choice raise UserError.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) + + +async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" + + class Location(BaseModel): + city: str + country: str + + c = completion_message( + [BetaToolUseBlock(id='1', type='tool_use', name='final_result', input={'city': 'Paris', 'country': 'France'})], + BetaUsage(input_tokens=5, output_tokens=10), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == Location(city='Paris', country='France') + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + # Output tool should be preserved (single output tool -> 'tool' type) + assert kwargs['tool_choice'] == {'type': 'tool', 'name': 'final_result'} + + +async def test_tool_choice_required_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: + """Test that tool_choice='required' with thinking mode falls back to 'auto' with warning.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported with Anthropic thinking mode"): + await agent.run( + 'hello', + model_settings={ + 'tool_choice': 'required', + 'anthropic_thinking': {'type': 'enabled', 'budget_tokens': 1000}, + }, # type: ignore + ) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index dd3395750e..9e44bd29e0 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -3,7 +3,7 @@ import json import os from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timezone from functools import cached_property from typing import Any, Literal, cast @@ -45,6 +45,7 @@ BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] ) +from pydantic_ai.exceptions import UserError from pydantic_ai.output import NativeOutput, PromptedOutput from pydantic_ai.usage import RequestUsage, RunUsage @@ -97,6 +98,7 @@ class MockGroq: completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None index: int = 0 + chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list) @cached_property def chat(self) -> Any: @@ -115,8 +117,9 @@ def create_mock_stream( return cast(AsyncGroq, cls(stream=stream)) async def chat_completions_create( - self, *_args: Any, stream: bool = False, **_kwargs: Any + self, *_args: Any, stream: bool = False, **kwargs: Any ) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]: + self.chat_completion_kwargs.append(kwargs) if stream: assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' if isinstance(self.stream[0], Sequence): @@ -137,6 +140,13 @@ async def chat_completions_create( return response +def get_mock_chat_completion_kwargs(groq_client: AsyncGroq) -> list[dict[str, Any]]: + if isinstance(groq_client, MockGroq): + return groq_client.chat_completion_kwargs + else: # pragma: no cover + raise RuntimeError('Not a MockGroq instance') + + def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion: return chat.ChatCompletion( id='123', @@ -5623,3 +5633,64 @@ class CityLocation(BaseModel): ), ] ) + + +# Tests for tool_choice ModelSettings + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Test that tool_choice string values are correctly passed to the API.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == expected + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Test tool_choice with a single specific tool name.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'tool_a'}} + + +async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: + """Test that invalid tool names in tool_choice raise UserError.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 56d74ed619..7a66bfee43 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -49,7 +49,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.exceptions import ModelHTTPError, UserError from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index b6c16b0f3e..fd8e9bc097 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3201,3 +3201,134 @@ async def test_cache_point_filtering_responses_model(): assert len(msg['content']) == 2 assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] + + +# Tests for tool_choice ModelSettings + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Test that tool_choice string values are correctly passed to the API.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == expected + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Test tool_choice with a single specific tool name.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'tool_a'}} + + +async def test_tool_choice_specific_tools_multiple(allow_model_requests: None) -> None: + """Test tool_choice with multiple specific tool names.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_c(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client) + tool_choice = kwargs[0]['tool_choice'] + assert tool_choice['type'] == 'allowed_tools' + assert tool_choice['allowed_tools']['mode'] == 'auto' + assert len(tool_choice['allowed_tools']['tools']) == 2 + tool_names = {t['function']['name'] for t in tool_choice['allowed_tools']['tools']} + assert tool_names == {'tool_a', 'tool_b'} + + +async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: + """Test that invalid tool names in tool_choice raise UserError.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) + + +async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" + + class Location(BaseModel): + city: str + country: str + + mock_client = MockOpenAI.create_mock( + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageFunctionToolCall( + id='1', + type='function', + function=Function( + name='final_result', + arguments='{"city": "Paris", "country": "France"}', + ), + ), + ], + ) + ) + ) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == Location(city='Paris', country='France') + kwargs = get_mock_chat_completion_kwargs(mock_client) + # Output tool should be preserved (single output tool -> named tool choice) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} From 05853479845c8bcb0e69e93354ba5fdb5b6571a2 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Mon, 1 Dec 2025 19:56:06 -0500 Subject: [PATCH 3/9] - centralize logic in utility and add tests for all providers - pending: centralize tests? --- .../pydantic_ai/models/__init__.py | 81 ++++++++ .../pydantic_ai/models/anthropic.py | 50 ++--- .../pydantic_ai/models/bedrock.py | 30 +-- pydantic_ai_slim/pydantic_ai/models/google.py | 41 ++-- pydantic_ai_slim/pydantic_ai/models/groq.py | 42 ++--- .../pydantic_ai/models/huggingface.py | 42 ++--- .../pydantic_ai/models/mistral.py | 36 +--- pydantic_ai_slim/pydantic_ai/models/openai.py | 43 ++--- tests/models/test_anthropic.py | 60 ++++++ tests/models/test_bedrock.py | 178 +++++++++++++++++- tests/models/test_google.py | 136 +++++++++++++ tests/models/test_groq.py | 62 +++++- tests/models/test_huggingface.py | 159 +++++++++++++++- tests/models/test_mistral.py | 96 +++++++++- 14 files changed, 841 insertions(+), 215 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 3bc676559a..ee53778cab 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -349,6 +349,87 @@ def prompted_output_instructions(self) -> str | None: __repr__ = _utils.dataclasses_no_defaults_repr +@dataclass +class ResolvedToolChoice: + """Provider-agnostic resolved tool choice. + + This is the result of validating and resolving the user's `tool_choice` setting. + Providers should map this to their API-specific format. + """ + + mode: Literal['none', 'auto', 'required', 'specific'] + """The resolved tool choice mode.""" + + tool_names: list[str] | None = None + """For 'specific' mode, the list of tool names to force.""" + + output_tools_fallback: bool = False + """True if we need to fall back to output tools only (when 'none' was requested but output tools exist).""" + + +# Warning message used when tool_choice='none' conflicts with output tools +_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING = ( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output." +) + + +def resolve_tool_choice( + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + *, + stacklevel: int = 6, +) -> ResolvedToolChoice | None: + """Resolve and validate tool_choice from model settings. + + This centralizes the common logic for handling tool_choice across all providers: + - Validates tool names in list[str] against available function_tools + - Issues warnings for conflicting settings (tool_choice='none' with output tools) + - Returns a provider-agnostic ResolvedToolChoice for the provider to map to their API format + + Args: + model_settings: The model settings containing tool_choice. + model_request_parameters: The request parameters containing tool definitions. + stacklevel: The stack level for warnings (default 6 works for most provider call stacks). + + Returns: + ResolvedToolChoice if an explicit tool_choice was provided and validated, + None if tool_choice was not set (provider should use default behavior based on allow_text_output). + + Raises: + UserError: If tool names in list[str] are invalid. + """ + user_tool_choice = (model_settings or {}).get('tool_choice') + + if user_tool_choice is None: + return None + + if user_tool_choice == 'none': + if model_request_parameters.output_tools: + warnings.warn(_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING, UserWarning, stacklevel=stacklevel) + return ResolvedToolChoice(mode='none', output_tools_fallback=True) + return ResolvedToolChoice(mode='none') + + if user_tool_choice == 'auto': + return ResolvedToolChoice(mode='auto') + + if user_tool_choice == 'required': + return ResolvedToolChoice(mode='required') + + if isinstance(user_tool_choice, list): + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + return ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice)) + + return None + + class Model(ABC): """Abstract class for a model.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 9929feef1b..9ffcff8fcc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -43,7 +43,15 @@ from ..providers.anthropic import AsyncAnthropicClient from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition -from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent +from . import ( + Model, + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, + download_item, + get_user_agent, + resolve_tool_choice, +) _FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = { 'end_turn': 'stop', @@ -642,28 +650,18 @@ def _infer_tool_choice( if not tools: return None - user_tool_choice = model_settings.get('tool_choice') thinking_enabled = model_settings.get('anthropic_thinking') is not None tool_choice: BetaToolChoiceParam - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools (Anthropic only supports one tool at a time) + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] if len(output_tool_names) == 1: tool_choice = {'type': 'tool', 'name': output_tool_names[0]} else: - # Multiple output tools - fall back to 'auto' and warn warnings.warn( 'Anthropic only supports forcing a single tool. ' "Falling back to 'auto' for multiple output tools.", @@ -674,10 +672,10 @@ def _infer_tool_choice( else: tool_choice = {'type': 'none'} - elif user_tool_choice == 'auto': + elif resolved.mode == 'auto': tool_choice = {'type': 'auto'} - elif user_tool_choice == 'required': + elif resolved.mode == 'required': if thinking_enabled: warnings.warn( "tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.", @@ -688,16 +686,7 @@ def _infer_tool_choice( else: tool_choice = {'type': 'any'} - elif isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - + elif resolved.mode == 'specific' and resolved.tool_names: if thinking_enabled: warnings.warn( "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", @@ -705,10 +694,9 @@ def _infer_tool_choice( stacklevel=6, ) tool_choice = {'type': 'auto'} - elif len(user_tool_choice) == 1: - tool_choice = {'type': 'tool', 'name': user_tool_choice[0]} + elif len(resolved.tool_names) == 1: + tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} else: - # Anthropic only supports one tool at a time warnings.warn( 'Anthropic only supports forcing a single tool. ' "Falling back to 'any' (required) for multiple tools.", diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 8927d4be43..9ae625c700 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -42,7 +42,7 @@ ) from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError -from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item, resolve_tool_choice from pydantic_ai.providers import Provider, infer_provider from pydantic_ai.providers.bedrock import BedrockModelProfile from pydantic_ai.settings import ModelSettings @@ -495,13 +495,11 @@ def _map_tool_config( if not tools: return None - user_tool_choice = model_settings.get('tool_choice') if model_settings else None + resolved = resolve_tool_choice(model_settings, model_request_parameters) tool_choice: ToolChoiceTypeDef - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # Bedrock doesn't support 'none', fall back to 'auto' with warning + if resolved is not None: + if resolved.mode == 'none': warnings.warn( "Bedrock does not support tool_choice='none'. Falling back to 'auto'.", UserWarning, @@ -509,26 +507,16 @@ def _map_tool_config( ) tool_choice = {'auto': {}} - elif user_tool_choice == 'auto': + elif resolved.mode == 'auto': tool_choice = {'auto': {}} - elif user_tool_choice == 'required': + elif resolved.mode == 'required': tool_choice = {'any': {}} - elif isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - - if len(user_tool_choice) == 1: - tool_choice = {'tool': {'name': user_tool_choice[0]}} + elif resolved.mode == 'specific' and resolved.tool_names: + if len(resolved.tool_names) == 1: + tool_choice = {'tool': {'name': resolved.tool_names[0]}} else: - # Bedrock only supports single tool choice, fall back to any warnings.warn( 'Bedrock only supports forcing a single tool. ' "Falling back to 'any' (required) for multiple tools.", diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 5558dd37b9..4903b8ed98 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations import base64 -import warnings from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace @@ -50,6 +49,7 @@ check_allow_model_requests, download_item, get_user_agent, + resolve_tool_choice, ) try: @@ -364,7 +364,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T ) return tools or None - def _get_tool_config( # noqa: C901 + def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None, @@ -373,21 +373,11 @@ def _get_tool_config( # noqa: C901 if not tools: return None - user_tool_choice = model_settings.get('tool_choice') - - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] return ToolConfigDict( function_calling_config=FunctionCallingConfigDict( @@ -399,13 +389,12 @@ def _get_tool_config( # noqa: C901 function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE) ) - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return ToolConfigDict( function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO) ) - if user_tool_choice == 'required': - # Get all tool names + if resolved.mode == 'required': names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: @@ -418,19 +407,11 @@ def _get_tool_config( # noqa: C901 ) ) - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) + if resolved.mode == 'specific' and resolved.tool_names: return ToolConfigDict( function_calling_config=FunctionCallingConfigDict( mode=FunctionCallingConfigMode.ANY, - allowed_function_names=list(user_tool_choice), + allowed_function_names=resolved.tool_names, ) ) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 9138c8f789..9fc3d69d9d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -50,6 +50,7 @@ StreamedResponse, check_allow_model_requests, get_user_agent, + resolve_tool_choice, ) try: @@ -380,24 +381,14 @@ def _get_tool_choice( model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, ) -> ChatCompletionToolChoiceOptionParam | None: - user_tool_choice = model_settings.get('tool_choice') - if not tools: return None - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools (force first one since Groq only supports single tool) + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] return ChatCompletionNamedToolChoiceParam( type='function', @@ -405,30 +396,19 @@ def _get_tool_choice( ) return 'none' - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return 'auto' - if user_tool_choice == 'required': + if resolved.mode == 'required': return 'required' - # Handle list of specific tool names - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - - if len(user_tool_choice) == 1: + if resolved.mode == 'specific' and resolved.tool_names: + if len(resolved.tool_names) == 1: return ChatCompletionNamedToolChoiceParam( type='function', - function={'name': user_tool_choice[0]}, + function={'name': resolved.tool_names[0]}, ) else: - # Groq only supports single tool choice, fall back to required warnings.warn( "Groq only supports forcing a single tool. Falling back to 'required' for multiple tools.", UserWarning, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 375c3e95d5..7656ff59dc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -47,6 +47,7 @@ ModelRequestParameters, StreamedResponse, check_allow_model_requests, + resolve_tool_choice, ) try: @@ -325,53 +326,32 @@ def _get_tool_choice( model_settings: HuggingFaceModelSettings, model_request_parameters: ModelRequestParameters, ) -> Literal['none', 'required', 'auto'] | ChatCompletionInputToolChoiceClass | None: - user_tool_choice = model_settings.get('tool_choice') - if not tools: return None - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools (force first one) + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] return ChatCompletionInputToolChoiceClass( function=ChatCompletionInputFunctionName(name=output_tool_names[0]) # pyright: ignore[reportCallIssue] ) return 'none' - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return 'auto' - if user_tool_choice == 'required': + if resolved.mode == 'required': return 'required' - # Handle list of specific tool names - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - - if len(user_tool_choice) == 1: + if resolved.mode == 'specific' and resolved.tool_names: + if len(resolved.tool_names) == 1: return ChatCompletionInputToolChoiceClass( - function=ChatCompletionInputFunctionName(name=user_tool_choice[0]) # pyright: ignore[reportCallIssue] + function=ChatCompletionInputFunctionName(name=resolved.tool_names[0]) # pyright: ignore[reportCallIssue] ) else: - # HuggingFace only supports single tool choice, fall back to required warnings.warn( 'HuggingFace only supports forcing a single tool. ' "Falling back to 'required' for multiple tools.", diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ae87c8fc78..dfdd0528c3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -48,6 +48,7 @@ StreamedResponse, check_allow_model_requests, get_user_agent, + resolve_tool_choice, ) try: @@ -328,40 +329,21 @@ def _get_tool_choice( if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - user_tool_choice = model_settings.get('tool_choice') - - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: return 'required' return 'none' - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return 'auto' - if user_tool_choice == 'required': + if resolved.mode == 'required': return 'required' - # Handle list of specific tool names - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - # Mistral doesn't support forcing specific tools, fall back to required + if resolved.mode == 'specific': warnings.warn( "Mistral does not support forcing specific tools. Falling back to 'required'.", UserWarning, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c4a8da4db3..f1428c4f84 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -55,7 +55,6 @@ from . import ( Model, ModelRequestParameters, - ResolvedToolChoice, StreamedResponse, check_allow_model_requests, download_item, @@ -705,24 +704,14 @@ def _get_tool_choice( model_settings: OpenAIChatModelSettings, model_request_parameters: ModelRequestParameters, ) -> ChatCompletionToolChoiceOptionParam | None: - user_tool_choice = model_settings.get('tool_choice') - if not tools: return None - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] if len(output_tool_names) == 1: return ChatCompletionNamedToolChoiceParam( @@ -739,34 +728,24 @@ def _get_tool_choice( ) return 'none' - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return 'auto' - if user_tool_choice == 'required': + if resolved.mode == 'required': return 'required' - # Handle list of specific tool names - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - - if len(user_tool_choice) == 1: + if resolved.mode == 'specific' and resolved.tool_names: + if len(resolved.tool_names) == 1: return ChatCompletionNamedToolChoiceParam( type='function', - function={'name': user_tool_choice[0]}, + function={'name': resolved.tool_names[0]}, ) else: return ChatCompletionAllowedToolChoiceParam( type='allowed_tools', allowed_tools=ChatCompletionAllowedToolsParam( mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'function': {'name': n}} for n in user_tool_choice], + tools=[{'type': 'function', 'function': {'name': n}} for n in resolved.tool_names], ), ) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index ae96155baa..e92bb7694a 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -7880,3 +7880,63 @@ def my_tool(x: int) -> str: kwargs = mock_client.chat_completion_kwargs[0] # type: ignore assert kwargs['tool_choice']['type'] == 'auto' + + +async def test_tool_choice_specific_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: + """Test that specific tool_choice with thinking mode falls back to 'auto' with warning.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Forcing specific tools is not supported with Anthropic thinking mode'): + await agent.run( + 'hello', + model_settings={ + 'tool_choice': ['my_tool'], + 'anthropic_thinking': {'type': 'enabled', 'budget_tokens': 1000}, + }, # type: ignore + ) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' + + +async def test_tool_choice_none_with_multiple_output_tools_falls_back_to_auto(allow_model_requests: None) -> None: + """Test that tool_choice='none' with multiple output tools falls back to 'auto' with warning.""" + import warnings as warn_module + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + c = completion_message( + [BetaToolUseBlock(id='1', type='tool_use', name='final_result_LocationA', input={'city': 'Paris'})], + BetaUsage(input_tokens=5, output_tokens=10), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(m, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + # Expect two warnings: one from resolve_tool_choice about output tools, one from Anthropic about multiple tools + with warn_module.catch_warnings(record=True) as w: + warn_module.simplefilter('always') + await agent.run('hello', model_settings={'tool_choice': 'none'}) + + # Check that we got the Anthropic-specific warning about multiple tools + warning_messages = [str(warning.message) for warning in w] + assert any("tool_choice='none' is set but output tools are required" in msg for msg in warning_messages) + assert any('Anthropic only supports forcing a single tool' in msg for msg in warning_messages) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index f13aaff4fb..ad91381551 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -35,7 +35,7 @@ VideoUrl, ) from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded, UserError from pydantic_ai.messages import AgentStreamEvent from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings @@ -1324,7 +1324,7 @@ async def test_bedrock_group_consecutive_tool_return_parts(bedrock_provider: Bed ] # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage] assert bedrock_messages == snapshot( [ @@ -1445,7 +1445,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Models other than Mistral support toolResult.content with text, not json model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage,reportArgumentType] assert bedrock_messages == snapshot( [ @@ -1461,7 +1461,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Mistral requires toolResult.content to hold json, not text model = BedrockConverseModel('mistral.mistral-7b-instruct-v0:2', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage,reportArgumentType] assert bedrock_messages == snapshot( [ @@ -1485,7 +1485,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Amazon Nova supports tool_choice model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1506,7 +1506,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Anthropic supports tool_choice model = BedrockConverseModel('us.anthropic.claude-3-7-sonnet-20250219-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1527,7 +1527,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Other models don't support tool_choice model = BedrockConverseModel('us.meta.llama4-maverick-17b-instruct-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1624,3 +1624,167 @@ async def test_cache_point_filtering(): # CachePoint should be filtered out, message should still be valid assert len(messages) == 1 assert messages[0]['role'] == 'user' + + +# tool_choice tests + + +@pytest.mark.parametrize( + 'tool_choice,expected_tool_choice', + [ + pytest.param('auto', {'auto': {}}, id='auto'), + pytest.param('required', {'any': {}}, id='required-maps-to-any'), + ], +) +async def test_tool_choice_string_values( + bedrock_provider: BedrockProvider, tool_choice: str, expected_tool_choice: dict[str, Any] +) -> None: + """Test that tool_choice string values are correctly mapped.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config is not None + assert tool_config.get('toolChoice') == expected_tool_choice + + +async def test_tool_choice_none_falls_back_to_auto(bedrock_provider: BedrockProvider) -> None: + """Test that tool_choice='none' falls back to 'auto' with warning since Bedrock doesn't support it.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + + settings: BedrockModelSettings = {'tool_choice': 'none'} + with pytest.warns(UserWarning, match="Bedrock does not support tool_choice='none'. Falling back to 'auto'"): + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'my_tool', + 'description': 'Test tool', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + } + ], + 'toolChoice': {'auto': {}}, + } + ) + + +async def test_tool_choice_specific_tool_single(bedrock_provider: BedrockProvider) -> None: + """Test tool_choice with a single specific tool name.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': ['tool_a']} + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'tool_a', + 'description': 'Test tool A', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + { + 'toolSpec': { + 'name': 'tool_b', + 'description': 'Test tool B', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + ], + 'toolChoice': {'tool': {'name': 'tool_a'}}, + } + ) + + +async def test_tool_choice_multiple_tools_falls_back_to_any(bedrock_provider: BedrockProvider) -> None: + """Test that multiple tools in tool_choice falls back to 'any' with warning.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + + with pytest.warns(UserWarning, match='Bedrock only supports forcing a single tool'): + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'tool_a', + 'description': 'Test tool A', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + { + 'toolSpec': { + 'name': 'tool_b', + 'description': 'Test tool B', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + ], + 'toolChoice': {'any': {}}, + } + ) + + +async def test_tool_choice_invalid_tool_name(bedrock_provider: BedrockProvider) -> None: + """Test that invalid tool names raise UserError.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 3ef8cd5dda..234be077b7 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -59,6 +59,7 @@ from pydantic_ai.models import ModelRequestParameters from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits from ..conftest import IsBytes, IsDatetime, IsInstance, IsStr, try_import @@ -68,12 +69,15 @@ from google.genai import errors from google.genai.types import ( FinishReason as GoogleFinishReason, + FunctionCallingConfigMode, + FunctionDeclarationDict, GenerateContentResponse, GenerateContentResponseUsageMetadata, HarmBlockThreshold, HarmCategory, MediaModality, ModalityTokenCount, + ToolDict, ) from pydantic_ai.models.google import ( @@ -4425,3 +4429,135 @@ def test_google_missing_tool_call_thought_signature(): ], } ) + + +# tool_choice tests + + +@pytest.mark.parametrize( + 'tool_choice,expected_mode', + [ + pytest.param('none', FunctionCallingConfigMode.NONE, id='none'), + pytest.param('auto', FunctionCallingConfigMode.AUTO, id='auto'), + ], +) +def test_tool_choice_string_values(google_provider: GoogleProvider, tool_choice: str, expected_mode: str) -> None: + """Test that tool_choice string values are correctly mapped to Google's FunctionCallingConfigMode.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc is not None + assert fcc.get('mode') == expected_mode + + +def test_tool_choice_required_maps_to_any(google_provider: GoogleProvider) -> None: + """Test that 'required' maps to ANY mode with all tool names.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'required'} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc is not None + assert fcc.get('mode') == FunctionCallingConfigMode.ANY + assert fcc.get('allowed_function_names') == ['my_tool'] + + +def test_tool_choice_specific_tool_single(google_provider: GoogleProvider) -> None: + """Test tool_choice with a single specific tool name.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools = [ + ToolDict(function_declarations=[FunctionDeclarationDict(name='tool_a', description='Test tool A')]), + ToolDict(function_declarations=[FunctionDeclarationDict(name='tool_b', description='Test tool B')]), + ] + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': ['tool_a']} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc is not None + assert fcc.get('mode') == FunctionCallingConfigMode.ANY + assert fcc.get('allowed_function_names') == ['tool_a'] + + +def test_tool_choice_invalid_tool_name(google_provider: GoogleProvider) -> None: + """Test that invalid tool names raise UserError.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + +def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvider) -> None: + """Test that tool_choice='none' with output tools warns and allows output tools.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + tools = [ + ToolDict(function_declarations=[FunctionDeclarationDict(name='func_tool', description='Function tool')]), + ToolDict(function_declarations=[FunctionDeclarationDict(name='output_tool', description='Output tool')]), + ] + + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc is not None + assert fcc.get('mode') == FunctionCallingConfigMode.ANY + assert fcc.get('allowed_function_names') == ['output_tool'] diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 9e44bd29e0..bedee6e88f 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -41,11 +41,11 @@ UserPromptPart, ) from pydantic_ai.builtin_tools import WebSearchTool +from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ( BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] ) -from pydantic_ai.exceptions import UserError from pydantic_ai.output import NativeOutput, PromptedOutput from pydantic_ai.usage import RequestUsage, RunUsage @@ -5694,3 +5694,63 @@ def my_tool(x: int) -> str: with pytest.raises(UserError, match='Invalid tool names in tool_choice'): await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) + + +async def test_tool_choice_multiple_tools_falls_back_to_required(allow_model_requests: None) -> None: + """Test that multiple tools in tool_choice falls back to 'required' with warning.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Groq only supports forcing a single tool'): + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == 'required' + + +async def test_tool_choice_none_with_output_tools(allow_model_requests: None) -> None: + """Test that tool_choice='none' with output tools warns and uses output tool.""" + + class MyOutput(BaseModel): + result: str + + # Tool call response that returns final_result tool + tool_call_response = completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + chat.ChatCompletionMessageToolCall( + id='call_1', + type='function', + function=chat.chat_completion_message_tool_call.Function( + name='final_result', arguments='{"result": "done"}' + ), + ) + ], + ) + ) + + mock_client = MockGroq.create_mock(tool_call_response) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent: Agent[None, MyOutput] = Agent(m, output_type=MyOutput) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + await agent.run('hello', model_settings={'tool_choice': 'none'}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + # When tool_choice='none' but output tools exist, it should force the output tool + assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 7a66bfee43..ccf46bb340 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -12,7 +12,10 @@ import pytest from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInputFunctionName, ChatCompletionInputMessage, + ChatCompletionInputTool, + ChatCompletionInputToolChoiceClass, ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionOutputFunctionDefinition, @@ -50,12 +53,13 @@ VideoUrl, ) from pydantic_ai.exceptions import ModelHTTPError, UserError -from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.models.huggingface import HuggingFaceModel, HuggingFaceModelSettings from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResult, AgentRunResultEvent from pydantic_ai.settings import ModelSettings -from pydantic_ai.tools import RunContext +from pydantic_ai.tools import RunContext, ToolDefinition from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import @@ -1026,3 +1030,154 @@ async def test_cache_point_filtering(): # CachePoint should be filtered out assert msg['role'] == 'user' assert len(msg['content']) == 1 # pyright: ignore[reportUnknownArgumentType] + + +# tool_choice tests + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: + """Test that tool_choice string values are correctly passed through.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'my_tool', 'description': 'Test tool'}) # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert result == expected + + +def test_tool_choice_specific_tool_single() -> None: + """Test tool_choice with a single specific tool name.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'tool_a', 'description': 'Test tool A'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'tool_b', 'description': 'Test tool B'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': ['tool_a']} + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert isinstance(result, ChatCompletionInputToolChoiceClass) + assert result.function == ChatCompletionInputFunctionName(name='tool_a') # type: ignore[call-arg] + + +def test_tool_choice_multiple_tools_falls_back_to_required() -> None: + """Test that multiple tools in tool_choice falls back to 'required' with warning.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'tool_a', 'description': 'Test tool A'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'tool_b', 'description': 'Test tool B'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + + with pytest.warns(UserWarning, match='HuggingFace only supports forcing a single tool'): + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' + + +def test_tool_choice_invalid_tool_name() -> None: + """Test that invalid tool names raise UserError.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'my_tool', 'description': 'Test tool'}) # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + +def test_tool_choice_none_with_output_tools_warns() -> None: + """Test that tool_choice='none' with output tools warns and allows output tools.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'func_tool', 'description': 'Function tool'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'output_tool', 'description': 'Output tool'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert isinstance(result, ChatCompletionInputToolChoiceClass) + assert result.function == ChatCompletionInputFunctionName(name='output_tool') # type: ignore[call-arg] diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4ae21ad221..1faa96f2a8 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -28,7 +28,9 @@ VideoUrl, ) from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UserError +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import @@ -55,7 +57,7 @@ ) from mistralai.types.basemodel import Unset as MistralUnset - from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse + from pydantic_ai.models.mistral import MistralModel, MistralModelSettings, MistralStreamedResponse from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings from pydantic_ai.providers.mistral import MistralProvider from pydantic_ai.providers.openai import OpenAIProvider @@ -2361,3 +2363,93 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ), ] ) + + +# tool_choice tests + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: + """Test that tool_choice string values are correctly passed through.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == expected + + +def test_tool_choice_specific_tool_falls_back_to_required() -> None: + """Test that specific tool falls back to 'required' with warning since Mistral doesn't support it.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[tool_a], allow_text_output=True, output_tools=[]) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': ['tool_a']} + + with pytest.warns(UserWarning, match="Mistral does not support forcing specific tools. Falling back to 'required'"): + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' + + +def test_tool_choice_invalid_tool_name() -> None: + """Test that invalid tool names raise UserError.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + +def test_tool_choice_none_with_output_tools_warns() -> None: + """Test that tool_choice='none' with output tools warns and returns 'required'.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' From 96681ac215719c6f15dcbaab64b3fbd78323d2a1 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:22:48 -0500 Subject: [PATCH 4/9] coverage? --- .../pydantic_ai/models/__init__.py | 4 +- .../pydantic_ai/models/anthropic.py | 5 +- .../pydantic_ai/models/bedrock.py | 5 +- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- .../pydantic_ai/models/huggingface.py | 2 +- .../pydantic_ai/models/mistral.py | 2 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 50 ++-- tests/models/test_anthropic.py | 15 - tests/models/test_bedrock.py | 18 +- tests/models/test_google.py | 17 -- tests/models/test_groq.py | 15 - tests/models/test_huggingface.py | 24 +- tests/models/test_mistral.py | 19 +- tests/models/test_openai.py | 261 +++++++++++++++++- tests/models/test_resolve_tool_choice.py | 154 +++++++++++ 16 files changed, 433 insertions(+), 162 deletions(-) create mode 100644 tests/models/test_resolve_tool_choice.py diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index ee53778cab..b42ab1f7bf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -418,6 +418,8 @@ def resolve_tool_choice( return ResolvedToolChoice(mode='required') if isinstance(user_tool_choice, list): + if not user_tool_choice: + raise UserError('tool_choice cannot be an empty list. Use None for default behavior.') function_tool_names = {t.name for t in model_request_parameters.function_tools} invalid_names = set(user_tool_choice) - function_tool_names if invalid_names: @@ -427,7 +429,7 @@ def resolve_tool_choice( ) return ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice)) - return None + return None # pragma: no cover class Model(ABC): diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 9ffcff8fcc..331a8917d2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -686,7 +686,8 @@ def _infer_tool_choice( else: tool_choice = {'type': 'any'} - elif resolved.mode == 'specific' and resolved.tool_names: + elif resolved.mode == 'specific': + assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice() if thinking_enabled: warnings.warn( "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", @@ -705,7 +706,7 @@ def _infer_tool_choice( ) tool_choice = {'type': 'any'} else: - tool_choice = {'type': 'auto'} + assert_never(resolved.mode) else: # Default behavior: infer from allow_text_output diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 9ae625c700..3a1f6d043f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -513,7 +513,8 @@ def _map_tool_config( elif resolved.mode == 'required': tool_choice = {'any': {}} - elif resolved.mode == 'specific' and resolved.tool_names: + elif resolved.mode == 'specific': + assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice() if len(resolved.tool_names) == 1: tool_choice = {'tool': {'name': resolved.tool_names[0]}} else: @@ -525,7 +526,7 @@ def _map_tool_config( ) tool_choice = {'any': {}} else: - tool_choice = {'auto': {}} + assert_never(resolved.mode) else: # Default behavior: infer from allow_text_output if not model_request_parameters.allow_text_output: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4903b8ed98..c04dec6a58 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -407,7 +407,7 @@ def _get_tool_config( ) ) - if resolved.mode == 'specific' and resolved.tool_names: + if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch return ToolConfigDict( function_calling_config=FunctionCallingConfigDict( mode=FunctionCallingConfigMode.ANY, diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 9fc3d69d9d..d12f4eadf9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -402,7 +402,7 @@ def _get_tool_choice( if resolved.mode == 'required': return 'required' - if resolved.mode == 'specific' and resolved.tool_names: + if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch if len(resolved.tool_names) == 1: return ChatCompletionNamedToolChoiceParam( type='function', diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 7656ff59dc..d3706267a1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -346,7 +346,7 @@ def _get_tool_choice( if resolved.mode == 'required': return 'required' - if resolved.mode == 'specific' and resolved.tool_names: + if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch if len(resolved.tool_names) == 1: return ChatCompletionInputToolChoiceClass( function=ChatCompletionInputFunctionName(name=resolved.tool_names[0]) # pyright: ignore[reportCallIssue] diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index dfdd0528c3..08d4b3fc5f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -343,7 +343,7 @@ def _get_tool_choice( if resolved.mode == 'required': return 'required' - if resolved.mode == 'specific': + if resolved.mode == 'specific': # pragma: no branch warnings.warn( "Mistral does not support forcing specific tools. Falling back to 'required'.", UserWarning, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index f1428c4f84..e1816e180e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -734,7 +734,7 @@ def _get_tool_choice( if resolved.mode == 'required': return 'required' - if resolved.mode == 'specific' and resolved.tool_names: + if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch if len(resolved.tool_names) == 1: return ChatCompletionNamedToolChoiceParam( type='function', @@ -1468,25 +1468,14 @@ def _get_responses_tool_choice( model_settings: OpenAIResponsesModelSettings, model_request_parameters: ModelRequestParameters, ) -> ResponsesToolChoice | None: - user_tool_choice = model_settings.get('tool_choice') - profile = OpenAIModelProfile.from_profile(self.profile) - if not tools: return None - # Handle explicit user-provided tool_choice - if user_tool_choice is not None: - if user_tool_choice == 'none': - # If output tools exist, we can't truly disable all tools - if model_request_parameters.output_tools: - warnings.warn( - "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' - "if you need tool_choice='none' with structured output.", - UserWarning, - stacklevel=6, - ) - # Allow only output tools + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is not None: + if resolved.mode == 'none': + if resolved.output_tools_fallback: output_tool_names = [t.name for t in model_request_parameters.output_tools] if len(output_tool_names) == 1: return ToolChoiceFunctionParam(type='function', name=output_tool_names[0]) @@ -1498,34 +1487,27 @@ def _get_responses_tool_choice( ) return 'none' - if user_tool_choice == 'auto': + if resolved.mode == 'auto': return 'auto' - if user_tool_choice == 'required': + if resolved.mode == 'required': return 'required' - # Handle list of specific tool names - if isinstance(user_tool_choice, list): - # Validate tool names exist in function_tools - function_tool_names = {t.name for t in model_request_parameters.function_tools} - invalid_names = set(user_tool_choice) - function_tool_names - if invalid_names: - raise UserError( - f'Invalid tool names in tool_choice: {invalid_names}. ' - f'Available function tools: {function_tool_names or "none"}' - ) - - if len(user_tool_choice) == 1: - return ToolChoiceFunctionParam(type='function', name=user_tool_choice[0]) + if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch + if len(resolved.tool_names) == 1: + return ToolChoiceFunctionParam(type='function', name=resolved.tool_names[0]) else: return ToolChoiceAllowedParam( type='allowed_tools', mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'name': n} for n in user_tool_choice], + tools=[{'type': 'function', 'name': n} for n in resolved.tool_names], ) # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required: + if ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + ): return 'required' return 'auto' diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index e92bb7694a..548ce5d0cd 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -7815,21 +7815,6 @@ def tool_b(x: int) -> str: assert kwargs['tool_choice']['type'] == 'any' -async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: - """Test that invalid tool names in tool_choice raise UserError.""" - c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) - mock_client = MockAnthropic.create_mock(c) - m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) - agent = Agent(m) - - @agent.tool_plain - def my_tool(x: int) -> str: - return str(x) # pragma: no cover - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) - - async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index ad91381551..af7083dd45 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -35,7 +35,7 @@ VideoUrl, ) from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded, UserError +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded from pydantic_ai.messages import AgentStreamEvent from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings @@ -1772,19 +1772,3 @@ async def test_tool_choice_multiple_tools_falls_back_to_any(bedrock_provider: Be 'toolChoice': {'any': {}}, } ) - - -async def test_tool_choice_invalid_tool_name(bedrock_provider: BedrockProvider) -> None: - """Test that invalid tool names raise UserError.""" - my_tool = ToolDefinition( - name='my_tool', - description='Test tool', - parameters_json_schema={'type': 'object', 'properties': {}}, - ) - mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) - - model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) - settings: BedrockModelSettings = {'tool_choice': ['nonexistent_tool']} - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 234be077b7..e0f9254fe7 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -4513,23 +4513,6 @@ def test_tool_choice_specific_tool_single(google_provider: GoogleProvider) -> No assert fcc.get('allowed_function_names') == ['tool_a'] -def test_tool_choice_invalid_tool_name(google_provider: GoogleProvider) -> None: - """Test that invalid tool names raise UserError.""" - my_tool = ToolDefinition( - name='my_tool', - description='Test tool', - parameters_json_schema={'type': 'object', 'properties': {}}, - ) - mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) - tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] - - model = GoogleModel('gemini-1.5-flash', provider=google_provider) - settings: GoogleModelSettings = {'tool_choice': ['nonexistent_tool']} - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] - - def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvider) -> None: """Test that tool_choice='none' with output tools warns and allows output tools.""" func_tool = ToolDefinition( diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index bedee6e88f..defd517886 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -41,7 +41,6 @@ UserPromptPart, ) from pydantic_ai.builtin_tools import WebSearchTool -from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ( BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] @@ -5682,20 +5681,6 @@ def tool_b(x: int) -> str: assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'tool_a'}} -async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: - """Test that invalid tool names in tool_choice raise UserError.""" - mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) - m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) - agent = Agent(m) - - @agent.tool_plain - def my_tool(x: int) -> str: - return str(x) # pragma: no cover - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) - - async def test_tool_choice_multiple_tools_falls_back_to_required(allow_model_requests: None) -> None: """Test that multiple tools in tool_choice falls back to 'required' with warning.""" mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index ccf46bb340..1cca3f0ed6 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -52,7 +52,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.exceptions import ModelHTTPError, UserError +from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.huggingface import HuggingFaceModel, HuggingFaceModelSettings from pydantic_ai.providers.huggingface import HuggingFaceProvider @@ -1128,28 +1128,6 @@ def test_tool_choice_multiple_tools_falls_back_to_required() -> None: assert result == 'required' -def test_tool_choice_invalid_tool_name() -> None: - """Test that invalid tool names raise UserError.""" - my_tool = ToolDefinition( - name='my_tool', - description='Test tool', - parameters_json_schema={'type': 'object', 'properties': {}}, - ) - mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) - tools: list[ChatCompletionInputTool] = [ - ChatCompletionInputTool(type='function', function={'name': 'my_tool', 'description': 'Test tool'}) # pyright: ignore[reportCallIssue] - ] - - mock_client = MockHuggingFace.create_mock( - completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore - ) - model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) - settings: HuggingFaceModelSettings = {'tool_choice': ['nonexistent_tool']} - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] - - def test_tool_choice_none_with_output_tools_warns() -> None: """Test that tool_choice='none' with output tools warns and allows output tools.""" func_tool = ToolDefinition( diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 1faa96f2a8..dfdbd54d2e 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -28,7 +28,7 @@ VideoUrl, ) from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UserError +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry from pydantic_ai.models import ModelRequestParameters from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage @@ -2412,23 +2412,6 @@ def test_tool_choice_specific_tool_falls_back_to_required() -> None: assert result == 'required' -def test_tool_choice_invalid_tool_name() -> None: - """Test that invalid tool names raise UserError.""" - my_tool = ToolDefinition( - name='my_tool', - description='Test tool', - parameters_json_schema={'type': 'object', 'properties': {}}, - ) - mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) - - mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) - model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) - settings: MistralModelSettings = {'tool_choice': ['nonexistent_tool']} - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] - - def test_tool_choice_none_with_output_tools_warns() -> None: """Test that tool_choice='none' with output tools warns and returns 'required'.""" func_tool = ToolDefinition( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index fd8e9bc097..39e959b6ef 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -71,6 +71,8 @@ from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob from openai.types.completion_usage import CompletionUsage, PromptTokensDetails + from openai.types.responses import ResponseFunctionToolCall + from openai.types.responses.response_output_message import ResponseOutputMessage, ResponseOutputText from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import ( @@ -3279,20 +3281,6 @@ def tool_c(x: int) -> str: assert tool_names == {'tool_a', 'tool_b'} -async def test_tool_choice_invalid_tool_name(allow_model_requests: None) -> None: - """Test that invalid tool names in tool_choice raise UserError.""" - mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) - model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) - agent = Agent(model) - - @agent.tool_plain - def my_tool(x: int) -> str: - return str(x) # pragma: no cover - - with pytest.raises(UserError, match='Invalid tool names in tool_choice'): - await agent.run('hello', model_settings={'tool_choice': ['nonexistent_tool']}) - - async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" @@ -3332,3 +3320,248 @@ def my_tool(x: int) -> str: kwargs = get_mock_chat_completion_kwargs(mock_client) # Output tool should be preserved (single output tool -> named tool choice) assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} + + +async def test_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: + """Test that tool_choice='none' with multiple output tools uses allowed_tools.""" + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + mock_client = MockOpenAI.create_mock( + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageFunctionToolCall( + id='1', + type='function', + function=Function( + name='final_result_LocationA', + arguments='{"city": "Paris"}', + ), + ), + ], + ) + ) + ) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(model, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == LocationA(city='Paris') + kwargs = get_mock_chat_completion_kwargs(mock_client) + # Multiple output tools -> allowed_tools + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'allowed_tools': { + 'mode': 'required', + 'tools': [ + {'type': 'function', 'function': {'name': 'final_result_LocationA'}}, + {'type': 'function', 'function': {'name': 'final_result_LocationB'}}, + ], + }, + } + + +# OpenAI Responses API tool_choice tests + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_responses_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Test that tool_choice string values are correctly passed to the Responses API.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == expected + + +async def test_responses_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Test Responses API tool_choice with a single specific tool name.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def other_tool(y: str) -> str: + return y # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['my_tool']}) + + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'name': 'my_tool'} + + +async def test_responses_tool_choice_specific_tool_multiple(allow_model_requests: None) -> None: + """Test Responses API tool_choice with multiple specific tool names.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(y: str) -> str: + return y # pragma: no cover + + @agent.tool_plain + def tool_c(z: float) -> str: + return str(z) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_responses_kwargs(mock_client) + # mode='auto' because allow_text_output=True (no output_type specified) + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'mode': 'auto', + 'tools': [{'type': 'function', 'name': 'tool_a'}, {'type': 'function', 'name': 'tool_b'}], + } + + +async def test_responses_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """Test that Responses API tool_choice='none' with output tools emits warning.""" + + class Location(BaseModel): + city: str + country: str + + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseFunctionToolCall( + id='call_123', + call_id='call_123', + name='final_result', + arguments='{"city": "Paris", "country": "France"}', + type='function_call', + status='completed', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == Location(city='Paris', country='France') + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'name': 'final_result'} + + +async def test_responses_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: + """Test that Responses API tool_choice='none' with multiple output tools uses allowed_tools.""" + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseFunctionToolCall( + id='call_123', + call_id='call_123', + name='final_result_LocationA', + arguments='{"city": "Paris"}', + type='function_call', + status='completed', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(model, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == LocationA(city='Paris') + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'mode': 'required', + 'tools': [ + {'type': 'function', 'name': 'final_result_LocationA'}, + {'type': 'function', 'name': 'final_result_LocationB'}, + ], + } diff --git a/tests/models/test_resolve_tool_choice.py b/tests/models/test_resolve_tool_choice.py new file mode 100644 index 0000000000..a41cc72df6 --- /dev/null +++ b/tests/models/test_resolve_tool_choice.py @@ -0,0 +1,154 @@ +"""Tests for the centralized resolve_tool_choice() function. + +These tests cover the common logic shared across all providers: +- String value resolution ('none', 'auto', 'required') +- List[str] validation and resolution +- Warning emission for tool_choice='none' with output tools +- Invalid tool name detection + +Provider-specific tests (API format mapping) remain in their respective test files. +""" + +from __future__ import annotations + +import warnings + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import ( + ModelRequestParameters, + ResolvedToolChoice, + resolve_tool_choice, +) +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ToolDefinition + + +def make_tool(name: str) -> ToolDefinition: + """Create a simple tool definition for testing.""" + return ToolDefinition( + name=name, + description=f'Tool {name}', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + + +class TestResolveToolChoiceNone: + """Tests for when tool_choice is not set.""" + + def test_none_model_settings_returns_none(self) -> None: + """When model_settings is None, resolve_tool_choice returns None.""" + params = ModelRequestParameters() + result = resolve_tool_choice(None, params) + assert result is None + + def test_empty_model_settings_returns_none(self) -> None: + """When model_settings is empty, resolve_tool_choice returns None.""" + params = ModelRequestParameters() + settings: ModelSettings = {} + result = resolve_tool_choice(settings, params) + assert result is None + + def test_tool_choice_not_set_returns_none(self) -> None: + """When tool_choice is not in model_settings, resolve_tool_choice returns None.""" + params = ModelRequestParameters() + settings: ModelSettings = {'temperature': 0.5} + result = resolve_tool_choice(settings, params) + assert result is None + + +class TestResolveToolChoiceStringValues: + """Tests for string tool_choice values.""" + + @pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', snapshot(ResolvedToolChoice(mode='none')), id='none'), + pytest.param('auto', snapshot(ResolvedToolChoice(mode='auto')), id='auto'), + pytest.param('required', snapshot(ResolvedToolChoice(mode='required')), id='required'), + ], + ) + def test_string_values(self, tool_choice: str, expected: ResolvedToolChoice) -> None: + """Test that string values are correctly resolved.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': tool_choice} # type: ignore + result = resolve_tool_choice(settings, params) + assert result == expected + + +class TestResolveToolChoiceSpecificTools: + """Tests for list[str] tool_choice values.""" + + def test_single_valid_tool(self) -> None: + """Test tool_choice with a single valid tool name.""" + params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b')]) + settings: ModelSettings = {'tool_choice': ['tool_a']} + result = resolve_tool_choice(settings, params) + assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a'])) + + def test_multiple_valid_tools(self) -> None: + """Test tool_choice with multiple valid tool names.""" + params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b'), make_tool('tool_c')]) + settings: ModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + result = resolve_tool_choice(settings, params) + assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a', 'tool_b'])) + + def test_invalid_tool_name_raises_user_error(self) -> None: + """Test that invalid tool names raise UserError.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + resolve_tool_choice(settings, params) + + def test_mixed_valid_and_invalid_tools(self) -> None: + """Test that mix of valid and invalid tool names raises error.""" + params = ModelRequestParameters(function_tools=[make_tool('valid_tool')]) + settings: ModelSettings = {'tool_choice': ['valid_tool', 'invalid_tool']} + + with pytest.raises(UserError, match='invalid_tool'): + resolve_tool_choice(settings, params) + + def test_no_function_tools_available(self) -> None: + """Test error when specifying tools but none are registered.""" + params = ModelRequestParameters() + settings: ModelSettings = {'tool_choice': ['some_tool']} + + with pytest.raises(UserError, match='Available function tools: none'): + resolve_tool_choice(settings, params) + + def test_empty_list_raises_user_error(self) -> None: + """Test tool_choice=[] raises UserError.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': []} + + with pytest.raises(UserError, match='tool_choice cannot be an empty list'): + resolve_tool_choice(settings, params) + + +class TestResolveToolChoiceOutputToolsWarning: + """Tests for tool_choice='none' with output tools.""" + + def test_none_with_output_tools_warns(self) -> None: + """Test that tool_choice='none' with output tools emits warning and sets fallback.""" + output_tool = make_tool('final_result') + params = ModelRequestParameters(output_tools=[output_tool]) + settings: ModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match='tool_choice=.none. is set but output tools are required'): + result = resolve_tool_choice(settings, params, stacklevel=2) + + assert result == snapshot(ResolvedToolChoice(mode='none', output_tools_fallback=True)) + + def test_none_without_output_tools_no_warning(self) -> None: + """Test that tool_choice='none' without output tools does not warn.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': 'none'} + + with warnings.catch_warnings(): + warnings.simplefilter('error') + result = resolve_tool_choice(settings, params, stacklevel=2) + + assert result == snapshot(ResolvedToolChoice(mode='none')) From e71dc863495738aaea7e5c504114b8dbfb920794 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Tue, 2 Dec 2025 09:24:05 -0500 Subject: [PATCH 5/9] coverage --- pydantic_ai_slim/pydantic_ai/models/google.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index c04dec6a58..b48336a4e2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -398,7 +398,7 @@ def _get_tool_config( names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: - if name := function_declaration.get('name'): + if name := function_declaration.get('name'): # pragma: no branch names.append(name) return ToolConfigDict( function_calling_config=FunctionCallingConfigDict( @@ -420,7 +420,7 @@ def _get_tool_config( names = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: - if name := function_declaration.get('name'): + if name := function_declaration.get('name'): # pragma: no branch names.append(name) return _tool_config(names) return None From 5c387fdf58e2664c5c59d66bd57e53b786204ddf Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:25:28 -0500 Subject: [PATCH 6/9] imrpove tests --- .../pydantic_ai/models/__init__.py | 1 - .../pydantic_ai/models/anthropic.py | 5 +- tests/models/test_anthropic.py | 36 +++++------ tests/models/test_bedrock.py | 11 ++-- tests/models/test_google.py | 58 +++++++++-------- tests/models/test_groq.py | 12 ++-- tests/models/test_huggingface.py | 11 ++-- tests/models/test_mistral.py | 9 +-- tests/models/test_openai.py | 64 +++++++++---------- tests/models/test_resolve_tool_choice.py | 36 +++++------ 10 files changed, 114 insertions(+), 129 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index b42ab1f7bf..9621c6ba54 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -367,7 +367,6 @@ class ResolvedToolChoice: """True if we need to fall back to output tools only (when 'none' was requested but output tools exist).""" -# Warning message used when tool_choice='none' conflicts with output tools _TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING = ( "tool_choice='none' is set but output tools are required for structured output. " 'The output tools will remain available. Consider using native or prompted output modes ' diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 331a8917d2..8d54a0a9b5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -924,9 +924,10 @@ async def _map_message( # noqa: C901 system_prompt_parts.insert(0, instructions) system_prompt = '\n\n'.join(system_prompt_parts) + ttl: Literal['5m', '1h'] # Add cache_control to the last message content if anthropic_cache_messages is enabled if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')): - ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages + ttl = '5m' if cache_messages is True else cache_messages m = anthropic_messages[-1] content = m['content'] if isinstance(content, str): @@ -946,7 +947,7 @@ async def _map_message( # noqa: C901 # If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')): # If True, use '5m'; otherwise use the specified ttl value - ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions + ttl = '5m' if cache_instructions is True else cache_instructions system_prompt_blocks = [ BetaTextBlockParam( type='text', diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 548ce5d0cd..958d0aa012 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -7744,9 +7744,6 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant assert usage2.output_tokens > 0 -# Tests for tool_choice ModelSettings - - @pytest.mark.parametrize( 'tool_choice,expected_type', [ @@ -7756,7 +7753,7 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant ], ) async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected_type: str) -> None: - """Test that tool_choice string values are correctly mapped to Anthropic's format.""" + """Ensure Anthropic string values map to the expected schema.""" c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) @@ -7773,7 +7770,7 @@ def my_tool(x: int) -> str: async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: - """Test tool_choice with a single specific tool name maps to Anthropic's 'tool' type.""" + """Single Anthropic tools should emit the 'tool' choice payload.""" c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) @@ -7794,7 +7791,7 @@ def tool_b(x: int) -> str: async def test_tool_choice_multiple_tools_falls_back_to_any(allow_model_requests: None) -> None: - """Test tool_choice with multiple tools falls back to 'any' with warning (Anthropic limitation).""" + """Multiple specific tools fall back to 'any' with a warning.""" c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) @@ -7812,11 +7809,11 @@ def tool_b(x: int) -> str: await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) kwargs = mock_client.chat_completion_kwargs[0] # type: ignore - assert kwargs['tool_choice']['type'] == 'any' + assert kwargs['tool_choice'] == snapshot({'type': 'any'}) async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: - """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" + """Structured output must remain available even with tool_choice='none'.""" class Location(BaseModel): city: str @@ -7837,14 +7834,13 @@ def my_tool(x: int) -> str: with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): result = await agent.run('hello', model_settings={'tool_choice': 'none'}) - assert result.output == Location(city='Paris', country='France') + assert result.output == snapshot(Location(city='Paris', country='France')) kwargs = mock_client.chat_completion_kwargs[0] # type: ignore - # Output tool should be preserved (single output tool -> 'tool' type) - assert kwargs['tool_choice'] == {'type': 'tool', 'name': 'final_result'} + assert kwargs['tool_choice'] == snapshot({'type': 'tool', 'name': 'final_result'}) async def test_tool_choice_required_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: - """Test that tool_choice='required' with thinking mode falls back to 'auto' with warning.""" + """Thinking mode overrides 'required' to 'auto'.""" c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) @@ -7868,7 +7864,7 @@ def my_tool(x: int) -> str: async def test_tool_choice_specific_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: - """Test that specific tool_choice with thinking mode falls back to 'auto' with warning.""" + """Specific tool forcing is incompatible with thinking mode.""" c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) @@ -7892,7 +7888,7 @@ def my_tool(x: int) -> str: async def test_tool_choice_none_with_multiple_output_tools_falls_back_to_auto(allow_model_requests: None) -> None: - """Test that tool_choice='none' with multiple output tools falls back to 'auto' with warning.""" + """Multiple output tools force a fallback to 'auto'.""" import warnings as warn_module class LocationA(BaseModel): @@ -7913,15 +7909,17 @@ class LocationB(BaseModel): def my_tool(x: int) -> str: return str(x) # pragma: no cover - # Expect two warnings: one from resolve_tool_choice about output tools, one from Anthropic about multiple tools with warn_module.catch_warnings(record=True) as w: warn_module.simplefilter('always') await agent.run('hello', model_settings={'tool_choice': 'none'}) - # Check that we got the Anthropic-specific warning about multiple tools - warning_messages = [str(warning.message) for warning in w] - assert any("tool_choice='none' is set but output tools are required" in msg for msg in warning_messages) - assert any('Anthropic only supports forcing a single tool' in msg for msg in warning_messages) + warning_messages = {str(warning.message) for warning in w} + assert { + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using native or prompted output modes ' + "if you need tool_choice='none' with structured output.", + "Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools.", + } <= warning_messages kwargs = mock_client.chat_completion_kwargs[0] # type: ignore assert kwargs['tool_choice']['type'] == 'auto' diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index af7083dd45..fc034758a5 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -1626,9 +1626,6 @@ async def test_cache_point_filtering(): assert messages[0]['role'] == 'user' -# tool_choice tests - - @pytest.mark.parametrize( 'tool_choice,expected_tool_choice', [ @@ -1639,7 +1636,7 @@ async def test_cache_point_filtering(): async def test_tool_choice_string_values( bedrock_provider: BedrockProvider, tool_choice: str, expected_tool_choice: dict[str, Any] ) -> None: - """Test that tool_choice string values are correctly mapped.""" + """Ensure simple string tool_choice values map to Bedrock's schema.""" my_tool = ToolDefinition( name='my_tool', description='Test tool', @@ -1656,7 +1653,7 @@ async def test_tool_choice_string_values( async def test_tool_choice_none_falls_back_to_auto(bedrock_provider: BedrockProvider) -> None: - """Test that tool_choice='none' falls back to 'auto' with warning since Bedrock doesn't support it.""" + """Bedrock lacks 'none' support, so we fall back to auto with a warning.""" my_tool = ToolDefinition( name='my_tool', description='Test tool', @@ -1687,7 +1684,7 @@ async def test_tool_choice_none_falls_back_to_auto(bedrock_provider: BedrockProv async def test_tool_choice_specific_tool_single(bedrock_provider: BedrockProvider) -> None: - """Test tool_choice with a single specific tool name.""" + """Single tool names should emit the {tool: {name}} payload.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', @@ -1730,7 +1727,7 @@ async def test_tool_choice_specific_tool_single(bedrock_provider: BedrockProvide async def test_tool_choice_multiple_tools_falls_back_to_any(bedrock_provider: BedrockProvider) -> None: - """Test that multiple tools in tool_choice falls back to 'any' with warning.""" + """Multiple tool names fall back to the 'any' configuration.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', diff --git a/tests/models/test_google.py b/tests/models/test_google.py index e0f9254fe7..a196229458 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -4431,18 +4431,27 @@ def test_google_missing_tool_call_thought_signature(): ) -# tool_choice tests +def test_tool_choice_string_value_none(google_provider: GoogleProvider) -> None: + """Test that tool_choice='none' maps to FunctionCallingConfigMode.NONE.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'none'} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] -@pytest.mark.parametrize( - 'tool_choice,expected_mode', - [ - pytest.param('none', FunctionCallingConfigMode.NONE, id='none'), - pytest.param('auto', FunctionCallingConfigMode.AUTO, id='auto'), - ], -) -def test_tool_choice_string_values(google_provider: GoogleProvider, tool_choice: str, expected_mode: str) -> None: - """Test that tool_choice string values are correctly mapped to Google's FunctionCallingConfigMode.""" + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.NONE}) + + +def test_tool_choice_string_value_auto(google_provider: GoogleProvider) -> None: + """Test that tool_choice='auto' maps to FunctionCallingConfigMode.AUTO.""" my_tool = ToolDefinition( name='my_tool', description='Test tool', @@ -4451,14 +4460,13 @@ def test_tool_choice_string_values(google_provider: GoogleProvider, tool_choice: mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] - model = GoogleModel('gemini-1.5-flash', provider=google_provider) - settings: GoogleModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'auto'} result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] assert result is not None fcc = result.get('function_calling_config') - assert fcc is not None - assert fcc.get('mode') == expected_mode + assert fcc == snapshot({'mode': FunctionCallingConfigMode.AUTO}) def test_tool_choice_required_maps_to_any(google_provider: GoogleProvider) -> None: @@ -4471,19 +4479,17 @@ def test_tool_choice_required_maps_to_any(google_provider: GoogleProvider) -> No mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] - model = GoogleModel('gemini-1.5-flash', provider=google_provider) + model = GoogleModel('gemini-2.5-flash', provider=google_provider) settings: GoogleModelSettings = {'tool_choice': 'required'} result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] assert result is not None fcc = result.get('function_calling_config') - assert fcc is not None - assert fcc.get('mode') == FunctionCallingConfigMode.ANY - assert fcc.get('allowed_function_names') == ['my_tool'] + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['my_tool']}) def test_tool_choice_specific_tool_single(google_provider: GoogleProvider) -> None: - """Test tool_choice with a single specific tool name.""" + """Specific tool names become allowed_function_names.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', @@ -4502,19 +4508,17 @@ def test_tool_choice_specific_tool_single(google_provider: GoogleProvider) -> No ToolDict(function_declarations=[FunctionDeclarationDict(name='tool_b', description='Test tool B')]), ] - model = GoogleModel('gemini-1.5-flash', provider=google_provider) + model = GoogleModel('gemini-2.5-flash', provider=google_provider) settings: GoogleModelSettings = {'tool_choice': ['tool_a']} result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] assert result is not None fcc = result.get('function_calling_config') - assert fcc is not None - assert fcc.get('mode') == FunctionCallingConfigMode.ANY - assert fcc.get('allowed_function_names') == ['tool_a'] + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['tool_a']}) def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvider) -> None: - """Test that tool_choice='none' with output tools warns and allows output tools.""" + """tool_choice='none' still allows the required output tool.""" func_tool = ToolDefinition( name='func_tool', description='Function tool', @@ -4533,7 +4537,7 @@ def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvide ToolDict(function_declarations=[FunctionDeclarationDict(name='output_tool', description='Output tool')]), ] - model = GoogleModel('gemini-1.5-flash', provider=google_provider) + model = GoogleModel('gemini-2.5-flash', provider=google_provider) settings: GoogleModelSettings = {'tool_choice': 'none'} with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): @@ -4541,6 +4545,4 @@ def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvide assert result is not None fcc = result.get('function_calling_config') - assert fcc is not None - assert fcc.get('mode') == FunctionCallingConfigMode.ANY - assert fcc.get('allowed_function_names') == ['output_tool'] + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['output_tool']}) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index defd517886..11f4669259 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -5634,9 +5634,6 @@ class CityLocation(BaseModel): ) -# Tests for tool_choice ModelSettings - - @pytest.mark.parametrize( 'tool_choice,expected', [ @@ -5646,7 +5643,7 @@ class CityLocation(BaseModel): ], ) async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: - """Test that tool_choice string values are correctly passed to the API.""" + """Ensure Groq string values are forwarded unchanged.""" mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) agent = Agent(m) @@ -5662,7 +5659,7 @@ def my_tool(x: int) -> str: async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: - """Test tool_choice with a single specific tool name.""" + """Single tool choices should use the named tool payload.""" mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) agent = Agent(m) @@ -5682,7 +5679,7 @@ def tool_b(x: int) -> str: async def test_tool_choice_multiple_tools_falls_back_to_required(allow_model_requests: None) -> None: - """Test that multiple tools in tool_choice falls back to 'required' with warning.""" + """Multiple specific tools fall back to 'required'.""" mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) agent = Agent(m) @@ -5703,7 +5700,7 @@ def tool_b(x: int) -> str: async def test_tool_choice_none_with_output_tools(allow_model_requests: None) -> None: - """Test that tool_choice='none' with output tools warns and uses output tool.""" + """tool_choice='none' still allows output tools to execute.""" class MyOutput(BaseModel): result: str @@ -5737,5 +5734,4 @@ def my_tool(x: int) -> str: await agent.run('hello', model_settings={'tool_choice': 'none'}) kwargs = get_mock_chat_completion_kwargs(mock_client)[0] - # When tool_choice='none' but output tools exist, it should force the output tool assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 1cca3f0ed6..4161f71605 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -1032,9 +1032,6 @@ async def test_cache_point_filtering(): assert len(msg['content']) == 1 # pyright: ignore[reportUnknownArgumentType] -# tool_choice tests - - @pytest.mark.parametrize( 'tool_choice,expected', [ @@ -1044,7 +1041,7 @@ async def test_cache_point_filtering(): ], ) def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: - """Test that tool_choice string values are correctly passed through.""" + """Ensure HuggingFace string values pass through unchanged.""" my_tool = ToolDefinition( name='my_tool', description='Test tool', @@ -1066,7 +1063,7 @@ def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: def test_tool_choice_specific_tool_single() -> None: - """Test tool_choice with a single specific tool name.""" + """Single tool entries should use ChatCompletionInputToolChoiceClass.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', @@ -1097,7 +1094,7 @@ def test_tool_choice_specific_tool_single() -> None: def test_tool_choice_multiple_tools_falls_back_to_required() -> None: - """Test that multiple tools in tool_choice falls back to 'required' with warning.""" + """Multiple specific tools fall back to 'required'.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', @@ -1129,7 +1126,7 @@ def test_tool_choice_multiple_tools_falls_back_to_required() -> None: def test_tool_choice_none_with_output_tools_warns() -> None: - """Test that tool_choice='none' with output tools warns and allows output tools.""" + """tool_choice='none' should not disable mandatory output tools.""" func_tool = ToolDefinition( name='func_tool', description='Function tool', diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index dfdbd54d2e..50d64597d5 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -2365,9 +2365,6 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ) -# tool_choice tests - - @pytest.mark.parametrize( 'tool_choice,expected', [ @@ -2377,7 +2374,7 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ], ) def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: - """Test that tool_choice string values are correctly passed through.""" + """Ensure Mistral string values pass through untouched.""" my_tool = ToolDefinition( name='my_tool', description='Test tool', @@ -2394,7 +2391,7 @@ def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: def test_tool_choice_specific_tool_falls_back_to_required() -> None: - """Test that specific tool falls back to 'required' with warning since Mistral doesn't support it.""" + """Specific tool forcing is unsupported and falls back to required.""" tool_a = ToolDefinition( name='tool_a', description='Test tool A', @@ -2413,7 +2410,7 @@ def test_tool_choice_specific_tool_falls_back_to_required() -> None: def test_tool_choice_none_with_output_tools_warns() -> None: - """Test that tool_choice='none' with output tools warns and returns 'required'.""" + """tool_choice='none' still forces required when output tools exist.""" func_tool = ToolDefinition( name='func_tool', description='Function tool', diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 39e959b6ef..d9cf297290 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3185,10 +3185,9 @@ async def test_cache_point_filtering(allow_model_requests: None): msg = await m._map_user_prompt(UserPromptPart(content=['text before', CachePoint(), 'text after'])) # pyright: ignore[reportPrivateUsage] # CachePoint should be filtered out, only text content should remain - assert msg['role'] == 'user' - assert len(msg['content']) == 2 # type: ignore[reportUnknownArgumentType] - assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] - assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] + assert msg == snapshot( + {'role': 'user', 'content': [{'text': 'text before', 'type': 'text'}, {'text': 'text after', 'type': 'text'}]} + ) async def test_cache_point_filtering_responses_model(): @@ -3199,13 +3198,12 @@ async def test_cache_point_filtering_responses_model(): ) # CachePoint should be filtered out, only text content should remain - assert msg['role'] == 'user' - assert len(msg['content']) == 2 - assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] - assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] - - -# Tests for tool_choice ModelSettings + assert msg == snapshot( + { + 'role': 'user', + 'content': [{'text': 'text before', 'type': 'input_text'}, {'text': 'text after', 'type': 'input_text'}], + } + ) @pytest.mark.parametrize( @@ -3217,7 +3215,7 @@ async def test_cache_point_filtering_responses_model(): ], ) async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: - """Test that tool_choice string values are correctly passed to the API.""" + """Ensure Chat tool_choice strings flow through unchanged.""" mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) agent = Agent(model) @@ -3233,7 +3231,7 @@ def my_tool(x: int) -> str: async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: - """Test tool_choice with a single specific tool name.""" + """Force the Chat API to call a specific tool.""" mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) agent = Agent(model) @@ -3253,7 +3251,7 @@ def tool_b(x: int) -> str: async def test_tool_choice_specific_tools_multiple(allow_model_requests: None) -> None: - """Test tool_choice with multiple specific tool names.""" + """Multiple Chat tools should produce an allowed_tools payload.""" mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) agent = Agent(model) @@ -3273,16 +3271,22 @@ def tool_c(x: int) -> str: await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) kwargs = get_mock_chat_completion_kwargs(mock_client) - tool_choice = kwargs[0]['tool_choice'] - assert tool_choice['type'] == 'allowed_tools' - assert tool_choice['allowed_tools']['mode'] == 'auto' - assert len(tool_choice['allowed_tools']['tools']) == 2 - tool_names = {t['function']['name'] for t in tool_choice['allowed_tools']['tools']} - assert tool_names == {'tool_a', 'tool_b'} + assert kwargs[0]['tool_choice'] == snapshot( + { + 'type': 'allowed_tools', + 'allowed_tools': { + 'mode': 'auto', + 'tools': [ + {'type': 'function', 'function': {'name': 'tool_a'}}, + {'type': 'function', 'function': {'name': 'tool_b'}}, + ], + }, + } + ) async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: - """Test that tool_choice='none' with output tools emits a warning and preserves output tools.""" + """Structured output tools persist even with tool_choice='none'.""" class Location(BaseModel): city: str @@ -3318,12 +3322,11 @@ def my_tool(x: int) -> str: assert result.output == Location(city='Paris', country='France') kwargs = get_mock_chat_completion_kwargs(mock_client) - # Output tool should be preserved (single output tool -> named tool choice) assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} async def test_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: - """Test that tool_choice='none' with multiple output tools uses allowed_tools.""" + """Multiple output tools fall back to allowed_tools when forcing 'none'.""" class LocationA(BaseModel): city: str @@ -3361,7 +3364,6 @@ def my_tool(x: int) -> str: assert result.output == LocationA(city='Paris') kwargs = get_mock_chat_completion_kwargs(mock_client) - # Multiple output tools -> allowed_tools assert kwargs[0]['tool_choice'] == { 'type': 'allowed_tools', 'allowed_tools': { @@ -3374,9 +3376,6 @@ def my_tool(x: int) -> str: } -# OpenAI Responses API tool_choice tests - - @pytest.mark.parametrize( 'tool_choice,expected', [ @@ -3386,7 +3385,7 @@ def my_tool(x: int) -> str: ], ) async def test_responses_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: - """Test that tool_choice string values are correctly passed to the Responses API.""" + """Ensure Responses tool_choice strings pass through untouched.""" mock_client = MockOpenAIResponses.create_mock( response_message( [ @@ -3414,7 +3413,7 @@ def my_tool(x: int) -> str: async def test_responses_tool_choice_specific_tool_single(allow_model_requests: None) -> None: - """Test Responses API tool_choice with a single specific tool name.""" + """Force a single tool when using the Responses API.""" mock_client = MockOpenAIResponses.create_mock( response_message( [ @@ -3446,7 +3445,7 @@ def other_tool(y: str) -> str: async def test_responses_tool_choice_specific_tool_multiple(allow_model_requests: None) -> None: - """Test Responses API tool_choice with multiple specific tool names.""" + """Multiple Responses tools rely on the allowed_tools payload.""" mock_client = MockOpenAIResponses.create_mock( response_message( [ @@ -3478,7 +3477,6 @@ def tool_c(z: float) -> str: await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) kwargs = get_mock_responses_kwargs(mock_client) - # mode='auto' because allow_text_output=True (no output_type specified) assert kwargs[0]['tool_choice'] == { 'type': 'allowed_tools', 'mode': 'auto', @@ -3487,7 +3485,7 @@ def tool_c(z: float) -> str: async def test_responses_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: - """Test that Responses API tool_choice='none' with output tools emits warning.""" + """tool_choice='none' cannot disable required Responses output tools.""" class Location(BaseModel): city: str @@ -3523,7 +3521,7 @@ def my_tool(x: int) -> str: async def test_responses_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: - """Test that Responses API tool_choice='none' with multiple output tools uses allowed_tools.""" + """Multiple Responses output tools still use allowed_tools when forced to 'none'.""" class LocationA(BaseModel): city: str diff --git a/tests/models/test_resolve_tool_choice.py b/tests/models/test_resolve_tool_choice.py index a41cc72df6..5a6f5faf6c 100644 --- a/tests/models/test_resolve_tool_choice.py +++ b/tests/models/test_resolve_tool_choice.py @@ -1,4 +1,4 @@ -"""Tests for the centralized resolve_tool_choice() function. +"""Tests for the centralized `resolve_tool_choice()` function. These tests cover the common logic shared across all providers: - String value resolution ('none', 'auto', 'required') @@ -27,7 +27,7 @@ def make_tool(name: str) -> ToolDefinition: - """Create a simple tool definition for testing.""" + """Return a minimal `ToolDefinition` used throughout the tests.""" return ToolDefinition( name=name, description=f'Tool {name}', @@ -36,23 +36,23 @@ def make_tool(name: str) -> ToolDefinition: class TestResolveToolChoiceNone: - """Tests for when tool_choice is not set.""" + """Cases where `tool_choice` is unset in the settings.""" def test_none_model_settings_returns_none(self) -> None: - """When model_settings is None, resolve_tool_choice returns None.""" + """`resolve_tool_choice` returns None when `model_settings` is None.""" params = ModelRequestParameters() result = resolve_tool_choice(None, params) assert result is None def test_empty_model_settings_returns_none(self) -> None: - """When model_settings is empty, resolve_tool_choice returns None.""" + """Empty `model_settings` dict should also yield None.""" params = ModelRequestParameters() settings: ModelSettings = {} result = resolve_tool_choice(settings, params) assert result is None def test_tool_choice_not_set_returns_none(self) -> None: - """When tool_choice is not in model_settings, resolve_tool_choice returns None.""" + """`tool_choice` missing from settings keeps provider defaults.""" params = ModelRequestParameters() settings: ModelSettings = {'temperature': 0.5} result = resolve_tool_choice(settings, params) @@ -60,7 +60,7 @@ def test_tool_choice_not_set_returns_none(self) -> None: class TestResolveToolChoiceStringValues: - """Tests for string tool_choice values.""" + """String-valued `tool_choice` entries.""" @pytest.mark.parametrize( 'tool_choice,expected', @@ -71,7 +71,7 @@ class TestResolveToolChoiceStringValues: ], ) def test_string_values(self, tool_choice: str, expected: ResolvedToolChoice) -> None: - """Test that string values are correctly resolved.""" + """Valid string entries map directly to their resolved form.""" params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) settings: ModelSettings = {'tool_choice': tool_choice} # type: ignore result = resolve_tool_choice(settings, params) @@ -79,24 +79,24 @@ def test_string_values(self, tool_choice: str, expected: ResolvedToolChoice) -> class TestResolveToolChoiceSpecificTools: - """Tests for list[str] tool_choice values.""" + """List-based tool_choice entries.""" def test_single_valid_tool(self) -> None: - """Test tool_choice with a single valid tool name.""" + """Single tool names remain in the returned result.""" params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b')]) settings: ModelSettings = {'tool_choice': ['tool_a']} result = resolve_tool_choice(settings, params) assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a'])) def test_multiple_valid_tools(self) -> None: - """Test tool_choice with multiple valid tool names.""" + """Multiple valid names stay in insertion order.""" params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b'), make_tool('tool_c')]) settings: ModelSettings = {'tool_choice': ['tool_a', 'tool_b']} result = resolve_tool_choice(settings, params) assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a', 'tool_b'])) def test_invalid_tool_name_raises_user_error(self) -> None: - """Test that invalid tool names raise UserError.""" + """Unknown names raise a UserError.""" params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) settings: ModelSettings = {'tool_choice': ['nonexistent_tool']} @@ -104,7 +104,7 @@ def test_invalid_tool_name_raises_user_error(self) -> None: resolve_tool_choice(settings, params) def test_mixed_valid_and_invalid_tools(self) -> None: - """Test that mix of valid and invalid tool names raises error.""" + """Mixed valid/invalid names still raise.""" params = ModelRequestParameters(function_tools=[make_tool('valid_tool')]) settings: ModelSettings = {'tool_choice': ['valid_tool', 'invalid_tool']} @@ -112,7 +112,7 @@ def test_mixed_valid_and_invalid_tools(self) -> None: resolve_tool_choice(settings, params) def test_no_function_tools_available(self) -> None: - """Test error when specifying tools but none are registered.""" + """Requesting specific tools without registered ones errors.""" params = ModelRequestParameters() settings: ModelSettings = {'tool_choice': ['some_tool']} @@ -120,7 +120,7 @@ def test_no_function_tools_available(self) -> None: resolve_tool_choice(settings, params) def test_empty_list_raises_user_error(self) -> None: - """Test tool_choice=[] raises UserError.""" + """Empty lists are not allowed.""" params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) settings: ModelSettings = {'tool_choice': []} @@ -129,10 +129,10 @@ def test_empty_list_raises_user_error(self) -> None: class TestResolveToolChoiceOutputToolsWarning: - """Tests for tool_choice='none' with output tools.""" + """Safety checks when `tool_choice='none'` conflicts with output tools.""" def test_none_with_output_tools_warns(self) -> None: - """Test that tool_choice='none' with output tools emits warning and sets fallback.""" + """`tool_choice='none'` issues a warning when output tools exist.""" output_tool = make_tool('final_result') params = ModelRequestParameters(output_tools=[output_tool]) settings: ModelSettings = {'tool_choice': 'none'} @@ -143,7 +143,7 @@ def test_none_with_output_tools_warns(self) -> None: assert result == snapshot(ResolvedToolChoice(mode='none', output_tools_fallback=True)) def test_none_without_output_tools_no_warning(self) -> None: - """Test that tool_choice='none' without output tools does not warn.""" + """No warning when `tool_choice='none'` and no output tools exist.""" params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) settings: ModelSettings = {'tool_choice': 'none'} From 363c718f8ee886600ae0c06eb7a6beed9d011f2c Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Fri, 5 Dec 2025 12:24:18 -0500 Subject: [PATCH 7/9] improvde code quality --- .../pydantic_ai/models/__init__.py | 14 +- .../pydantic_ai/models/anthropic.py | 98 ++++++------ .../pydantic_ai/models/bedrock.py | 59 +++---- pydantic_ai_slim/pydantic_ai/models/google.py | 76 ++++----- pydantic_ai_slim/pydantic_ai/models/groq.py | 58 ++++--- .../pydantic_ai/models/huggingface.py | 56 ++++--- .../pydantic_ai/models/mistral.py | 37 +++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 146 +++++++++--------- tests/models/test_anthropic.py | 2 +- 9 files changed, 271 insertions(+), 275 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 57fa384cf7..340081bbec 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -557,8 +557,8 @@ class ResolvedToolChoice: mode: Literal['none', 'auto', 'required', 'specific'] """The resolved tool choice mode.""" - tool_names: list[str] | None = None - """For 'specific' mode, the list of tool names to force.""" + tool_names: list[str] = field(default_factory=list) + """For 'specific' mode, the list of tool names to force. Empty for other modes.""" output_tools_fallback: bool = False """True if we need to fall back to output tools only (when 'none' was requested but output tools exist).""" @@ -566,11 +566,12 @@ class ResolvedToolChoice: _TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING = ( "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' + 'The output tools will remain available. Consider using `NativeOutput` or `PromptedOutput` ' "if you need tool_choice='none' with structured output." ) +# NOTE: for PR discussion: should this be a private method? a Model.method? Perhaps a ModelRequestParameters.method? def resolve_tool_choice( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, @@ -607,11 +608,8 @@ def resolve_tool_choice( return ResolvedToolChoice(mode='none', output_tools_fallback=True) return ResolvedToolChoice(mode='none') - if user_tool_choice == 'auto': - return ResolvedToolChoice(mode='auto') - - if user_tool_choice == 'required': - return ResolvedToolChoice(mode='required') + if user_tool_choice in ('auto', 'required'): + return ResolvedToolChoice(mode=user_tool_choice) if isinstance(user_tool_choice, list): if not user_tool_choice: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 132ea4c1a8..24cf6fb4f4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -658,65 +658,69 @@ def _infer_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - if len(output_tool_names) == 1: - tool_choice = {'type': 'tool', 'name': output_tool_names[0]} - else: - warnings.warn( - 'Anthropic only supports forcing a single tool. ' - "Falling back to 'auto' for multiple output tools.", - UserWarning, - stacklevel=6, - ) - tool_choice = {'type': 'auto'} - else: - tool_choice = {'type': 'none'} - - elif resolved.mode == 'auto': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + tool_choice = {'type': 'any'} + else: tool_choice = {'type': 'auto'} - elif resolved.mode == 'required': - if thinking_enabled: - warnings.warn( - "tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.", - UserWarning, - stacklevel=6, - ) - tool_choice = {'type': 'auto'} - else: - tool_choice = {'type': 'any'} + elif resolved.mode == 'auto': + tool_choice = {'type': 'auto'} - elif resolved.mode == 'specific': - assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice() - if thinking_enabled: - warnings.warn( - "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", - UserWarning, - stacklevel=6, - ) - tool_choice = {'type': 'auto'} - elif len(resolved.tool_names) == 1: - tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} + elif resolved.mode == 'required': + if thinking_enabled: + warnings.warn( + "tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + else: + tool_choice = {'type': 'any'} + + elif resolved.mode == 'none': + if not resolved.output_tools_fallback: + tool_choice = {'type': 'none'} + else: + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + tool_choice = {'type': 'tool', 'name': output_tool_names[0]} else: + # Anthropic's tool_choice only supports forcing a single tool via {"type": "tool", "name": "..."} + # See: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#forcing-tool-use warnings.warn( 'Anthropic only supports forcing a single tool. ' - "Falling back to 'any' (required) for multiple tools.", + "Falling back to 'auto' for multiple output tools.", UserWarning, stacklevel=6, ) - tool_choice = {'type': 'any'} + tool_choice = {'type': 'auto'} + + elif resolved.mode == 'specific': + if not resolved.tool_names: # pragma: no cover + # tool_names will always be filled out when mode=='specific' i.e. 'specific' will only be set when there are tool names + raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') + if thinking_enabled: + warnings.warn( + "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + elif len(resolved.tool_names) == 1: + tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} else: - assert_never(resolved.mode) + warnings.warn( + 'Anthropic only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple function tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'any'} else: - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: - tool_choice = {'type': 'any'} - else: - tool_choice = {'type': 'auto'} + assert_never(resolved.mode) if 'parallel_tool_calls' in model_settings and tool_choice.get('type') != 'none': # only `BetaToolChoiceNoneParam` doesn't have this field diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 3a1f6d043f..4837400e96 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -498,41 +498,44 @@ def _map_tool_config( resolved = resolve_tool_choice(model_settings, model_request_parameters) tool_choice: ToolChoiceTypeDef - if resolved is not None: - if resolved.mode == 'none': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + tool_choice = {'any': {}} + else: + tool_choice = {'auto': {}} + + elif resolved.mode == 'auto': + tool_choice = {'auto': {}} + + elif resolved.mode == 'required': + tool_choice = {'any': {}} + + elif resolved.mode == 'none': + warnings.warn( + "Bedrock does not support tool_choice='none'. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'auto': {}} + + elif resolved.mode == 'specific': + if not resolved.tool_names: # pragma: no cover + # tool_names will always be filled out when mode=='specific' i.e. 'specific' will only be set when there are tool names + raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') + if len(resolved.tool_names) == 1: + tool_choice = {'tool': {'name': resolved.tool_names[0]}} + else: warnings.warn( - "Bedrock does not support tool_choice='none'. Falling back to 'auto'.", + 'Bedrock only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple function tools.", UserWarning, stacklevel=6, ) - tool_choice = {'auto': {}} - - elif resolved.mode == 'auto': - tool_choice = {'auto': {}} - - elif resolved.mode == 'required': tool_choice = {'any': {}} - elif resolved.mode == 'specific': - assert resolved.tool_names # Guaranteed non-empty by resolve_tool_choice() - if len(resolved.tool_names) == 1: - tool_choice = {'tool': {'name': resolved.tool_names[0]}} - else: - warnings.warn( - 'Bedrock only supports forcing a single tool. ' - "Falling back to 'any' (required) for multiple tools.", - UserWarning, - stacklevel=6, - ) - tool_choice = {'any': {}} - else: - assert_never(resolved.mode) else: - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: - tool_choice = {'any': {}} - else: - tool_choice = {'auto': {}} + assert_never(resolved.mode) tool_config: ToolConfigurationTypeDef = {'tools': tools} if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index b48336a4e2..f601585af8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -375,55 +375,57 @@ def _get_tool_config( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - return ToolConfigDict( - function_calling_config=FunctionCallingConfigDict( - mode=FunctionCallingConfigMode.ANY, - allowed_function_names=output_tool_names, - ) - ) - return ToolConfigDict( - function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE) - ) - - if resolved.mode == 'auto': - return ToolConfigDict( - function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO) - ) - - if resolved.mode == 'required': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: if name := function_declaration.get('name'): # pragma: no branch names.append(name) - return ToolConfigDict( - function_calling_config=FunctionCallingConfigDict( - mode=FunctionCallingConfigMode.ANY, - allowed_function_names=names, - ) - ) + return _tool_config(names) + return None - if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch - return ToolConfigDict( - function_calling_config=FunctionCallingConfigDict( - mode=FunctionCallingConfigMode.ANY, - allowed_function_names=resolved.tool_names, - ) - ) + if resolved.mode == 'auto': + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO) + ) - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: + if resolved.mode == 'required': names = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: if name := function_declaration.get('name'): # pragma: no branch names.append(name) - return _tool_config(names) - return None + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=names, + ) + ) + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE) + ) + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=output_tool_names, + ) + ) + + if resolved.tool_names: + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=resolved.tool_names, + ) + ) + + return None # pragma: no cover @overload async def _generate_content( diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index d12f4eadf9..15de3fb066 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -386,40 +386,38 @@ def _get_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': output_tool_names[0]}, - ) - return 'none' - - if resolved.mode == 'auto': - return 'auto' - - if resolved.mode == 'required': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: return 'required' + return 'auto' - if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch - if len(resolved.tool_names) == 1: - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': resolved.tool_names[0]}, - ) - else: - warnings.warn( - "Groq only supports forcing a single tool. Falling back to 'required' for multiple tools.", - UserWarning, - stacklevel=6, - ) - return 'required' + if resolved.mode in ('auto', 'required'): + return resolved.mode - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': output_tool_names[0]}, + ) + + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': resolved.tool_names[0]}, + ) + warnings.warn( + "Groq only supports forcing a single tool. Falling back to 'required' for multiple function tools.", + UserWarning, + stacklevel=6, + ) return 'required' - return 'auto' + + return None # pragma: no cover def _get_builtin_tools( self, model_request_parameters: ModelRequestParameters diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index d3706267a1..83d947afa6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -331,39 +331,37 @@ def _get_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - return ChatCompletionInputToolChoiceClass( - function=ChatCompletionInputFunctionName(name=output_tool_names[0]) # pyright: ignore[reportCallIssue] - ) - return 'none' - - if resolved.mode == 'auto': - return 'auto' - - if resolved.mode == 'required': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: return 'required' + return 'auto' - if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch - if len(resolved.tool_names) == 1: - return ChatCompletionInputToolChoiceClass( - function=ChatCompletionInputFunctionName(name=resolved.tool_names[0]) # pyright: ignore[reportCallIssue] - ) - else: - warnings.warn( - 'HuggingFace only supports forcing a single tool. ' - "Falling back to 'required' for multiple tools.", - UserWarning, - stacklevel=6, - ) - return 'required' + if resolved.mode in ('auto', 'required'): + return resolved.mode - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=output_tool_names[0]) # pyright: ignore[reportCallIssue] + ) + + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=resolved.tool_names[0]) # pyright: ignore[reportCallIssue] + ) + warnings.warn( + 'HuggingFace only supports forcing a single tool. ' + "Falling back to 'required' for multiple function tools.", + UserWarning, + stacklevel=6, + ) return 'required' - return 'auto' + + return None # pragma: no cover async def _map_messages( self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 08d4b3fc5f..590d9ce364 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -331,30 +331,29 @@ def _get_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - return 'required' - return 'none' - - if resolved.mode == 'auto': - return 'auto' - - if resolved.mode == 'required': + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: return 'required' + return 'auto' - if resolved.mode == 'specific': # pragma: no branch - warnings.warn( - "Mistral does not support forcing specific tools. Falling back to 'required'.", - UserWarning, - stacklevel=6, - ) + if resolved.mode in ('auto', 'required'): + return resolved.mode + + if resolved.mode == 'none': + if resolved.output_tools_fallback: return 'required' + return 'none' - # Default behavior: infer from allow_text_output - if not model_request_parameters.allow_text_output: + if resolved.tool_names: + warnings.warn( + "Mistral does not support forcing specific tools. Falling back to 'required'.", + UserWarning, + stacklevel=6, + ) return 'required' - return 'auto' + + return None # pragma: no cover def _map_function_and_output_tools_definition( self, model_request_parameters: ModelRequestParameters diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 658c0fcded..735bdf3e86 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -717,53 +717,50 @@ def _get_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - if len(output_tool_names) == 1: - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': output_tool_names[0]}, - ) - else: - return ChatCompletionAllowedToolChoiceParam( - type='allowed_tools', - allowed_tools=ChatCompletionAllowedToolsParam( - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'function': {'name': n}} for n in output_tool_names], - ), - ) - return 'none' + if resolved is None: + # Default behavior: infer from allow_text_output + if ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + ): + return 'required' + return 'auto' - if resolved.mode == 'auto': - return 'auto' + if resolved.mode in ('auto', 'required'): + return resolved.mode - if resolved.mode == 'required': - return 'required' + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': output_tool_names[0]}, + ) + return ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'function': {'name': n}} for n in output_tool_names], + ), + ) - if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch - if len(resolved.tool_names) == 1: - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': resolved.tool_names[0]}, - ) - else: - return ChatCompletionAllowedToolChoiceParam( - type='allowed_tools', - allowed_tools=ChatCompletionAllowedToolsParam( - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'function': {'name': n}} for n in resolved.tool_names], - ), - ) + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': resolved.tool_names[0]}, + ) + return ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'function': {'name': n}} for n in resolved.tool_names], + ), + ) - # Default behavior: infer from allow_text_output - if ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - return 'required' - return 'auto' + return None # pragma: no cover def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None: for tool in model_request_parameters.builtin_tools: @@ -1489,43 +1486,40 @@ def _get_responses_tool_choice( resolved = resolve_tool_choice(model_settings, model_request_parameters) - if resolved is not None: - if resolved.mode == 'none': - if resolved.output_tools_fallback: - output_tool_names = [t.name for t in model_request_parameters.output_tools] - if len(output_tool_names) == 1: - return ToolChoiceFunctionParam(type='function', name=output_tool_names[0]) - else: - return ToolChoiceAllowedParam( - type='allowed_tools', - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'name': n} for n in output_tool_names], - ) - return 'none' + if resolved is None: + # Default behavior: infer from allow_text_output + if ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required + ): + return 'required' + return 'auto' - if resolved.mode == 'auto': - return 'auto' + if resolved.mode in ('auto', 'required'): + return resolved.mode - if resolved.mode == 'required': - return 'required' + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + return ToolChoiceFunctionParam(type='function', name=output_tool_names[0]) + return ToolChoiceAllowedParam( + type='allowed_tools', + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'name': n} for n in output_tool_names], + ) - if resolved.mode == 'specific' and resolved.tool_names: # pragma: no branch - if len(resolved.tool_names) == 1: - return ToolChoiceFunctionParam(type='function', name=resolved.tool_names[0]) - else: - return ToolChoiceAllowedParam( - type='allowed_tools', - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'name': n} for n in resolved.tool_names], - ) + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ToolChoiceFunctionParam(type='function', name=resolved.tool_names[0]) + return ToolChoiceAllowedParam( + type='allowed_tools', + mode='required' if not model_request_parameters.allow_text_output else 'auto', + tools=[{'type': 'function', 'name': n} for n in resolved.tool_names], + ) - # Default behavior: infer from allow_text_output - if ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - return 'required' - return 'auto' + return None # pragma: no cover def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]: tools: list[responses.ToolParam] = [] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index b57cc19504..25cc3483b1 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -8075,7 +8075,7 @@ def my_tool(x: int) -> str: warning_messages = {str(warning.message) for warning in w} assert { "tool_choice='none' is set but output tools are required for structured output. " - 'The output tools will remain available. Consider using native or prompted output modes ' + 'The output tools will remain available. Consider using `NativeOutput` or `PromptedOutput` ' "if you need tool_choice='none' with structured output.", "Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools.", } <= warning_messages From 31bb4e1cb4f5915e932729bc42372580ce36502e Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Fri, 5 Dec 2025 13:59:18 -0500 Subject: [PATCH 8/9] deduplicate openai logic --- pydantic_ai_slim/pydantic_ai/models/openai.py | 157 ++++++++++-------- tests/models/test_openai.py | 42 +++++ 2 files changed, 133 insertions(+), 66 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 735bdf3e86..37d9227286 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -172,6 +172,75 @@ } +@dataclass +class _OpenAIToolChoiceResult: + """Intermediate representation of resolved tool choice for OpenAI APIs.""" + + kind: Literal['literal', 'single_tool', 'allowed_tools'] + """The kind of tool choice result.""" + + value: str | None = None + """For 'literal': the literal value ('auto', 'required', 'none'). For 'single_tool': the tool name.""" + + tool_names: list[str] | None = None + """For 'allowed_tools': the list of allowed tool names.""" + + allowed_mode: Literal['auto', 'required'] = 'required' + """For 'allowed_tools': whether the model must use one of the tools or can choose not to.""" + + +def _resolve_openai_tool_choice( + model_settings: ModelSettings, + model_request_parameters: ModelRequestParameters, + profile: ModelProfile, +) -> _OpenAIToolChoiceResult | None: + """Resolve tool choice settings into an intermediate representation for OpenAI APIs. + + This centralizes the logic shared between Chat Completions and Responses APIs. + Returns None if there are no tools, otherwise returns an _OpenAIToolChoiceResult. + """ + resolved = resolve_tool_choice(model_settings, model_request_parameters) + openai_profile = OpenAIModelProfile.from_profile(profile) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output and openai_profile.openai_supports_tool_choice_required: + return _OpenAIToolChoiceResult(kind='literal', value='required') + return _OpenAIToolChoiceResult(kind='literal', value='auto') + + if resolved.mode == 'auto': + return _OpenAIToolChoiceResult(kind='literal', value='auto') + + if resolved.mode == 'required': + if not openai_profile.openai_supports_tool_choice_required: + warnings.warn( + "tool_choice='required' is not supported by this model. Falling back to 'auto'.", + UserWarning, + stacklevel=7, + ) + return _OpenAIToolChoiceResult(kind='literal', value='auto') + return _OpenAIToolChoiceResult(kind='literal', value='required') + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return _OpenAIToolChoiceResult(kind='literal', value='none') + output_tool_names = [t.name for t in model_request_parameters.output_tools] + allowed_mode: Literal['auto', 'required'] = ( + 'required' if not model_request_parameters.allow_text_output else 'auto' + ) + if len(output_tool_names) == 1: + return _OpenAIToolChoiceResult(kind='single_tool', value=output_tool_names[0]) + return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=output_tool_names, allowed_mode=allowed_mode) + + if resolved.tool_names: + allowed_mode = 'required' if not model_request_parameters.allow_text_output else 'auto' + if len(resolved.tool_names) == 1: + return _OpenAIToolChoiceResult(kind='single_tool', value=resolved.tool_names[0]) + return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=resolved.tool_names, allowed_mode=allowed_mode) + + return None # pragma: no cover + + class OpenAIChatModelSettings(ModelSettings, total=False): """Settings used for an OpenAI model request.""" @@ -715,47 +784,21 @@ def _get_tool_choice( if not tools: return None - resolved = resolve_tool_choice(model_settings, model_request_parameters) - - if resolved is None: - # Default behavior: infer from allow_text_output - if ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - return 'required' - return 'auto' - - if resolved.mode in ('auto', 'required'): - return resolved.mode - - if resolved.mode == 'none': - if not resolved.output_tools_fallback: - return 'none' - output_tool_names = [t.name for t in model_request_parameters.output_tools] - if len(output_tool_names) == 1: - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': output_tool_names[0]}, - ) - return ChatCompletionAllowedToolChoiceParam( - type='allowed_tools', - allowed_tools=ChatCompletionAllowedToolsParam( - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'function': {'name': n}} for n in output_tool_names], - ), - ) + resolved = _resolve_openai_tool_choice(model_settings, model_request_parameters, self.profile) + if resolved is None: # pragma: no cover + return None - if resolved.tool_names: - if len(resolved.tool_names) == 1: - return ChatCompletionNamedToolChoiceParam( - type='function', - function={'name': resolved.tool_names[0]}, - ) + if resolved.kind == 'literal': + return cast(ChatCompletionToolChoiceOptionParam, resolved.value) + + if resolved.kind == 'single_tool': + return ChatCompletionNamedToolChoiceParam(type='function', function={'name': resolved.value or ''}) + + if resolved.kind == 'allowed_tools' and resolved.tool_names: return ChatCompletionAllowedToolChoiceParam( type='allowed_tools', allowed_tools=ChatCompletionAllowedToolsParam( - mode='required' if not model_request_parameters.allow_text_output else 'auto', + mode=resolved.allowed_mode, tools=[{'type': 'function', 'function': {'name': n}} for n in resolved.tool_names], ), ) @@ -1484,38 +1527,20 @@ def _get_responses_tool_choice( if not tools: return None - resolved = resolve_tool_choice(model_settings, model_request_parameters) - - if resolved is None: - # Default behavior: infer from allow_text_output - if ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - return 'required' - return 'auto' - - if resolved.mode in ('auto', 'required'): - return resolved.mode - - if resolved.mode == 'none': - if not resolved.output_tools_fallback: - return 'none' - output_tool_names = [t.name for t in model_request_parameters.output_tools] - if len(output_tool_names) == 1: - return ToolChoiceFunctionParam(type='function', name=output_tool_names[0]) - return ToolChoiceAllowedParam( - type='allowed_tools', - mode='required' if not model_request_parameters.allow_text_output else 'auto', - tools=[{'type': 'function', 'name': n} for n in output_tool_names], - ) + resolved = _resolve_openai_tool_choice(model_settings, model_request_parameters, self.profile) + if resolved is None: # pragma: no cover + return None + + if resolved.kind == 'literal': + return cast(ResponsesToolChoice, resolved.value) + + if resolved.kind == 'single_tool': + return ToolChoiceFunctionParam(type='function', name=resolved.value or '') - if resolved.tool_names: - if len(resolved.tool_names) == 1: - return ToolChoiceFunctionParam(type='function', name=resolved.tool_names[0]) + if resolved.kind == 'allowed_tools' and resolved.tool_names: return ToolChoiceAllowedParam( type='allowed_tools', - mode='required' if not model_request_parameters.allow_text_output else 'auto', + mode=resolved.allowed_mode, tools=[{'type': 'function', 'name': n} for n in resolved.tool_names], ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index cfb8b8181f..fd4501c60f 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3131,6 +3131,48 @@ async def test_tool_choice_fallback_response_api(allow_model_requests: None) -> assert get_mock_responses_kwargs(mock_client)[0]['tool_choice'] == 'auto' +async def test_tool_choice_required_explicit_unsupported(allow_model_requests: None) -> None: + """Ensure explicit tool_choice='required' warns and falls back to 'auto' when unsupported.""" + profile = OpenAIModelProfile(openai_supports_tool_choice_required=False).update(openai_model_profile('stub')) + + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('stub', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + + params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=True) + settings: OpenAIChatModelSettings = {'tool_choice': 'required'} + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported by this model"): + await model._completions_create( # pyright: ignore[reportPrivateUsage] + messages=[], + stream=False, + model_settings=settings, + model_request_parameters=params, + ) + + assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice'] == 'auto' + + +async def test_tool_choice_required_explicit_unsupported_responses_api(allow_model_requests: None) -> None: + """Ensure explicit tool_choice='required' warns and falls back for Responses API when unsupported.""" + profile = OpenAIModelProfile(openai_supports_tool_choice_required=False).update(openai_model_profile('stub')) + + mock_client = MockOpenAIResponses.create_mock(response_message([])) + model = OpenAIResponsesModel('openai/gpt-oss', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + + params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=True) + settings: OpenAIResponsesModelSettings = {'tool_choice': 'required'} + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported by this model"): + await model._responses_create( # pyright: ignore[reportPrivateUsage] + messages=[], + stream=False, + model_settings=settings, + model_request_parameters=params, + ) + + assert get_mock_responses_kwargs(mock_client)[0]['tool_choice'] == 'auto' + + async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_requests: None, openai_api_key: str): m = OpenAIChatModel('gpt-5', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent(m) From 338a073a55e39bfb43c61d684205ccf938ecc1a8 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:38:27 -0500 Subject: [PATCH 9/9] remove cast --- pydantic_ai_slim/pydantic_ai/models/openai.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 37d9227286..991632da74 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -179,14 +179,17 @@ class _OpenAIToolChoiceResult: kind: Literal['literal', 'single_tool', 'allowed_tools'] """The kind of tool choice result.""" - value: str | None = None - """For 'literal': the literal value ('auto', 'required', 'none'). For 'single_tool': the tool name.""" + literal_value: Literal['auto', 'required', 'none'] | None = None + """For 'literal' kind: the literal value to pass to the API.""" + + tool_name: str | None = None + """For 'single_tool' kind: the specific tool name.""" tool_names: list[str] | None = None - """For 'allowed_tools': the list of allowed tool names.""" + """For 'allowed_tools' kind: the list of allowed tool names.""" allowed_mode: Literal['auto', 'required'] = 'required' - """For 'allowed_tools': whether the model must use one of the tools or can choose not to.""" + """For 'allowed_tools' kind: whether the model must use one of the tools or can choose not to.""" def _resolve_openai_tool_choice( @@ -205,11 +208,11 @@ def _resolve_openai_tool_choice( if resolved is None: # Default behavior: infer from allow_text_output if not model_request_parameters.allow_text_output and openai_profile.openai_supports_tool_choice_required: - return _OpenAIToolChoiceResult(kind='literal', value='required') - return _OpenAIToolChoiceResult(kind='literal', value='auto') + return _OpenAIToolChoiceResult(kind='literal', literal_value='required') + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') if resolved.mode == 'auto': - return _OpenAIToolChoiceResult(kind='literal', value='auto') + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') if resolved.mode == 'required': if not openai_profile.openai_supports_tool_choice_required: @@ -218,24 +221,24 @@ def _resolve_openai_tool_choice( UserWarning, stacklevel=7, ) - return _OpenAIToolChoiceResult(kind='literal', value='auto') - return _OpenAIToolChoiceResult(kind='literal', value='required') + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') + return _OpenAIToolChoiceResult(kind='literal', literal_value='required') if resolved.mode == 'none': if not resolved.output_tools_fallback: - return _OpenAIToolChoiceResult(kind='literal', value='none') + return _OpenAIToolChoiceResult(kind='literal', literal_value='none') output_tool_names = [t.name for t in model_request_parameters.output_tools] allowed_mode: Literal['auto', 'required'] = ( 'required' if not model_request_parameters.allow_text_output else 'auto' ) if len(output_tool_names) == 1: - return _OpenAIToolChoiceResult(kind='single_tool', value=output_tool_names[0]) + return _OpenAIToolChoiceResult(kind='single_tool', tool_name=output_tool_names[0]) return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=output_tool_names, allowed_mode=allowed_mode) if resolved.tool_names: allowed_mode = 'required' if not model_request_parameters.allow_text_output else 'auto' if len(resolved.tool_names) == 1: - return _OpenAIToolChoiceResult(kind='single_tool', value=resolved.tool_names[0]) + return _OpenAIToolChoiceResult(kind='single_tool', tool_name=resolved.tool_names[0]) return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=resolved.tool_names, allowed_mode=allowed_mode) return None # pragma: no cover @@ -789,10 +792,10 @@ def _get_tool_choice( return None if resolved.kind == 'literal': - return cast(ChatCompletionToolChoiceOptionParam, resolved.value) + return resolved.literal_value if resolved.kind == 'single_tool': - return ChatCompletionNamedToolChoiceParam(type='function', function={'name': resolved.value or ''}) + return ChatCompletionNamedToolChoiceParam(type='function', function={'name': resolved.tool_name or ''}) if resolved.kind == 'allowed_tools' and resolved.tool_names: return ChatCompletionAllowedToolChoiceParam( @@ -1532,10 +1535,10 @@ def _get_responses_tool_choice( return None if resolved.kind == 'literal': - return cast(ResponsesToolChoice, resolved.value) + return resolved.literal_value if resolved.kind == 'single_tool': - return ToolChoiceFunctionParam(type='function', name=resolved.value or '') + return ToolChoiceFunctionParam(type='function', name=resolved.tool_name or '') if resolved.kind == 'allowed_tools' and resolved.tool_names: return ToolChoiceAllowedParam(