diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 255ce07ede..340081bbec 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -546,6 +546,86 @@ def prompted_output_instructions(self) -> str | None: __repr__ = _utils.dataclasses_no_defaults_repr +@dataclass +class ResolvedToolChoice: + """Provider-agnostic resolved tool choice. + + This is the result of validating and resolving the user's `tool_choice` setting. + Providers should map this to their API-specific format. + """ + + mode: Literal['none', 'auto', 'required', 'specific'] + """The resolved tool choice mode.""" + + tool_names: list[str] = field(default_factory=list) + """For 'specific' mode, the list of tool names to force. Empty for other modes.""" + + output_tools_fallback: bool = False + """True if we need to fall back to output tools only (when 'none' was requested but output tools exist).""" + + +_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING = ( + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using `NativeOutput` or `PromptedOutput` ' + "if you need tool_choice='none' with structured output." +) + + +# NOTE: for PR discussion: should this be a private method? a Model.method? Perhaps a ModelRequestParameters.method? +def resolve_tool_choice( + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + *, + stacklevel: int = 6, +) -> ResolvedToolChoice | None: + """Resolve and validate tool_choice from model settings. + + This centralizes the common logic for handling tool_choice across all providers: + - Validates tool names in list[str] against available function_tools + - Issues warnings for conflicting settings (tool_choice='none' with output tools) + - Returns a provider-agnostic ResolvedToolChoice for the provider to map to their API format + + Args: + model_settings: The model settings containing tool_choice. + model_request_parameters: The request parameters containing tool definitions. + stacklevel: The stack level for warnings (default 6 works for most provider call stacks). + + Returns: + ResolvedToolChoice if an explicit tool_choice was provided and validated, + None if tool_choice was not set (provider should use default behavior based on allow_text_output). + + Raises: + UserError: If tool names in list[str] are invalid. + """ + user_tool_choice = (model_settings or {}).get('tool_choice') + + if user_tool_choice is None: + return None + + if user_tool_choice == 'none': + if model_request_parameters.output_tools: + warnings.warn(_TOOL_CHOICE_NONE_WITH_OUTPUT_TOOLS_WARNING, UserWarning, stacklevel=stacklevel) + return ResolvedToolChoice(mode='none', output_tools_fallback=True) + return ResolvedToolChoice(mode='none') + + if user_tool_choice in ('auto', 'required'): + return ResolvedToolChoice(mode=user_tool_choice) + + if isinstance(user_tool_choice, list): + if not user_tool_choice: + raise UserError('tool_choice cannot be an empty list. Use None for default behavior.') + function_tool_names = {t.name for t in model_request_parameters.function_tools} + invalid_names = set(user_tool_choice) - function_tool_names + if invalid_names: + raise UserError( + f'Invalid tool names in tool_choice: {invalid_names}. ' + f'Available function tools: {function_tool_names or "none"}' + ) + return ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice)) + + return None # pragma: no cover + + class Model(ABC): """Abstract class for a model.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 586f9762f1..24cf6fb4f4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import io +import warnings from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace @@ -42,7 +43,15 @@ from ..providers.anthropic import AsyncAnthropicClient from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition -from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent +from . import ( + Model, + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, + download_item, + get_user_agent, + resolve_tool_choice, +) _FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = { 'end_turn': 'stop', @@ -643,18 +652,81 @@ def _infer_tool_choice( ) -> BetaToolChoiceParam | None: if not tools: return None - else: - tool_choice: BetaToolChoiceParam + thinking_enabled = model_settings.get('anthropic_thinking') is not None + tool_choice: BetaToolChoiceParam + + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is None: + # Default behavior: infer from allow_text_output if not model_request_parameters.allow_text_output: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} - if 'parallel_tool_calls' in model_settings: - tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] + elif resolved.mode == 'auto': + tool_choice = {'type': 'auto'} + + elif resolved.mode == 'required': + if thinking_enabled: + warnings.warn( + "tool_choice='required' is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + else: + tool_choice = {'type': 'any'} + + elif resolved.mode == 'none': + if not resolved.output_tools_fallback: + tool_choice = {'type': 'none'} + else: + output_tool_names = [t.name for t in model_request_parameters.output_tools] + if len(output_tool_names) == 1: + tool_choice = {'type': 'tool', 'name': output_tool_names[0]} + else: + # Anthropic's tool_choice only supports forcing a single tool via {"type": "tool", "name": "..."} + # See: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#forcing-tool-use + warnings.warn( + 'Anthropic only supports forcing a single tool. ' + "Falling back to 'auto' for multiple output tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + + elif resolved.mode == 'specific': + if not resolved.tool_names: # pragma: no cover + # tool_names will always be filled out when mode=='specific' i.e. 'specific' will only be set when there are tool names + raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') + if thinking_enabled: + warnings.warn( + "Forcing specific tools is not supported with Anthropic thinking mode. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'auto'} + elif len(resolved.tool_names) == 1: + tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} + else: + warnings.warn( + 'Anthropic only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple function tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'type': 'any'} + + else: + assert_never(resolved.mode) + + if 'parallel_tool_calls' in model_settings and tool_choice.get('type') != 'none': + # only `BetaToolChoiceNoneParam` doesn't have this field + tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] # pyright: ignore[reportGeneralTypeIssues] - return tool_choice + return tool_choice async def _map_message( # noqa: C901 self, @@ -859,9 +931,10 @@ async def _map_message( # noqa: C901 system_prompt_parts.insert(0, instructions) system_prompt = '\n\n'.join(system_prompt_parts) + ttl: Literal['5m', '1h'] # Add cache_control to the last message content if anthropic_cache_messages is enabled if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')): - ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages + ttl = '5m' if cache_messages is True else cache_messages m = anthropic_messages[-1] content = m['content'] if isinstance(content, str): @@ -881,7 +954,7 @@ async def _map_message( # noqa: C901 # If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')): # If True, use '5m'; otherwise use the specified ttl value - ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions + ttl = '5m' if cache_instructions is True else cache_instructions system_prompt_blocks = [ BetaTextBlockParam( type='text', diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ff03460904..4837400e96 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -2,6 +2,7 @@ import functools import typing +import warnings from collections.abc import AsyncIterator, Iterable, Iterator, Mapping from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -41,7 +42,7 @@ ) from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError -from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item, resolve_tool_choice from pydantic_ai.providers import Provider, infer_provider from pydantic_ai.providers.bedrock import BedrockModelProfile from pydantic_ai.settings import ModelSettings @@ -429,7 +430,7 @@ async def _messages_create( 'inferenceConfig': inference_config, } - tool_config = self._map_tool_config(model_request_parameters) + tool_config = self._map_tool_config(model_request_parameters, model_settings) if tool_config: params['toolConfig'] = tool_config @@ -485,17 +486,57 @@ def _map_inference_config( return inference_config - def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: + def _map_tool_config( + self, + model_request_parameters: ModelRequestParameters, + model_settings: BedrockModelSettings | None, + ) -> ToolConfigurationTypeDef | None: tools = self._get_tools(model_request_parameters) if not tools: return None + resolved = resolve_tool_choice(model_settings, model_request_parameters) tool_choice: ToolChoiceTypeDef - if not model_request_parameters.allow_text_output: + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + tool_choice = {'any': {}} + else: + tool_choice = {'auto': {}} + + elif resolved.mode == 'auto': + tool_choice = {'auto': {}} + + elif resolved.mode == 'required': tool_choice = {'any': {}} - else: + + elif resolved.mode == 'none': + warnings.warn( + "Bedrock does not support tool_choice='none'. Falling back to 'auto'.", + UserWarning, + stacklevel=6, + ) tool_choice = {'auto': {}} + elif resolved.mode == 'specific': + if not resolved.tool_names: # pragma: no cover + # tool_names will always be filled out when mode=='specific' i.e. 'specific' will only be set when there are tool names + raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') + if len(resolved.tool_names) == 1: + tool_choice = {'tool': {'name': resolved.tool_names[0]}} + else: + warnings.warn( + 'Bedrock only supports forcing a single tool. ' + "Falling back to 'any' (required) for multiple function tools.", + UserWarning, + stacklevel=6, + ) + tool_choice = {'any': {}} + + else: + assert_never(resolved.mode) + tool_config: ToolConfigurationTypeDef = {'tools': tools} if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: tool_config['toolChoice'] = tool_choice diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 89290ea3ce..f601585af8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -49,6 +49,7 @@ check_allow_model_requests, download_item, get_user_agent, + resolve_tool_choice, ) try: @@ -364,17 +365,67 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T return tools or None def _get_tool_config( - self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None + self, + model_request_parameters: ModelRequestParameters, + tools: list[ToolDict] | None, + model_settings: GoogleModelSettings, ) -> ToolConfigDict | None: - if not model_request_parameters.allow_text_output and tools: - names: list[str] = [] + if not tools: + return None + + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + names: list[str] = [] + for tool in tools: + for function_declaration in tool.get('function_declarations') or []: + if name := function_declaration.get('name'): # pragma: no branch + names.append(name) + return _tool_config(names) + return None + + if resolved.mode == 'auto': + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.AUTO) + ) + + if resolved.mode == 'required': + names = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: if name := function_declaration.get('name'): # pragma: no branch names.append(name) - return _tool_config(names) - else: - return None + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=names, + ) + ) + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict(mode=FunctionCallingConfigMode.NONE) + ) + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=output_tool_names, + ) + ) + + if resolved.tool_names: + return ToolConfigDict( + function_calling_config=FunctionCallingConfigDict( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=resolved.tool_names, + ) + ) + + return None # pragma: no cover @overload async def _generate_content( @@ -440,7 +491,7 @@ async def _build_content_and_config( raise UserError('JSON output is not supported by this model.') response_mime_type = 'application/json' - tool_config = self._get_tool_config(model_request_parameters, tools) + tool_config = self._get_tool_config(model_request_parameters, tools, model_settings) system_instruction, contents = await self._map_messages(messages, model_request_parameters) modalities = [Modality.TEXT.value] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 64f7ddcf85..15de3fb066 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -49,6 +50,7 @@ StreamedResponse, check_allow_model_requests, get_user_agent, + resolve_tool_choice, ) try: @@ -56,6 +58,8 @@ from groq.types import chat from groq.types.chat.chat_completion_content_part_image_param import ImageURL from groq.types.chat.chat_completion_message import ExecutedTool + from groq.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam + from groq.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam except ImportError as _import_error: raise ImportError( 'Please install `groq` to use the Groq model, ' @@ -265,12 +269,7 @@ async def _completions_create( ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) tools += self._get_builtin_tools(model_request_parameters) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) groq_messages = self._map_messages(messages, model_request_parameters) @@ -376,6 +375,50 @@ async def _process_streamed_response( def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[chat.ChatCompletionToolParam], + model_settings: GroqModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionToolChoiceOptionParam | None: + if not tools: + return None + + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + return 'required' + return 'auto' + + if resolved.mode in ('auto', 'required'): + return resolved.mode + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': output_tool_names[0]}, + ) + + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ChatCompletionNamedToolChoiceParam( + type='function', + function={'name': resolved.tool_names[0]}, + ) + warnings.warn( + "Groq only supports forcing a single tool. Falling back to 'required' for multiple function tools.", + UserWarning, + stacklevel=6, + ) + return 'required' + + return None # pragma: no cover + def _get_builtin_tools( self, model_request_parameters: ModelRequestParameters ) -> list[chat.ChatCompletionToolParam]: diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 790b30bec3..83d947afa6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -46,16 +47,19 @@ ModelRequestParameters, StreamedResponse, check_allow_model_requests, + resolve_tool_choice, ) try: import aiohttp from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInputFunctionName, ChatCompletionInputMessage, ChatCompletionInputMessageChunk, ChatCompletionInputTool, ChatCompletionInputToolCall, + ChatCompletionInputToolChoiceClass, ChatCompletionInputURL, ChatCompletionOutput, ChatCompletionOutputMessage, @@ -221,13 +225,7 @@ async def _completions_create( model_request_parameters: ModelRequestParameters, ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: tools = self._get_tools(model_request_parameters) - - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) if model_request_parameters.builtin_tools: raise UserError('HuggingFace does not support built-in tools') @@ -322,6 +320,49 @@ async def _process_streamed_response( def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[ChatCompletionInputTool], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> Literal['none', 'required', 'auto'] | ChatCompletionInputToolChoiceClass | None: + if not tools: + return None + + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + return 'required' + return 'auto' + + if resolved.mode in ('auto', 'required'): + return resolved.mode + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return 'none' + output_tool_names = [t.name for t in model_request_parameters.output_tools] + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=output_tool_names[0]) # pyright: ignore[reportCallIssue] + ) + + if resolved.tool_names: + if len(resolved.tool_names) == 1: + return ChatCompletionInputToolChoiceClass( + function=ChatCompletionInputFunctionName(name=resolved.tool_names[0]) # pyright: ignore[reportCallIssue] + ) + warnings.warn( + 'HuggingFace only supports forcing a single tool. ' + "Falling back to 'required' for multiple function tools.", + UserWarning, + stacklevel=6, + ) + return 'required' + + return None # pragma: no cover + async def _map_messages( self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index cefa28e9dc..590d9ce364 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -47,6 +48,7 @@ StreamedResponse, check_allow_model_requests, get_user_agent, + resolve_tool_choice, ) try: @@ -233,7 +235,7 @@ async def _completions_create( messages=self._map_messages(messages, model_request_parameters), n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), + tool_choice=self._get_tool_choice(model_request_parameters, model_settings), stream=False, max_tokens=model_settings.get('max_tokens', UNSET), temperature=model_settings.get('temperature', UNSET), @@ -273,7 +275,7 @@ async def _stream_completions_create( messages=mistral_messages, n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), + tool_choice=self._get_tool_choice(model_request_parameters, model_settings), temperature=model_settings.get('temperature', UNSET), top_p=model_settings.get('top_p', 1), max_tokens=model_settings.get('max_tokens', UNSET), @@ -312,7 +314,11 @@ async def _stream_completions_create( assert response, 'A unexpected empty response from Mistral.' return response - def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> MistralToolChoiceEnum | None: + def _get_tool_choice( + self, + model_request_parameters: ModelRequestParameters, + model_settings: MistralModelSettings, + ) -> MistralToolChoiceEnum | None: """Get tool choice for the model. - "auto": Default mode. Model decides if it uses the tool or not. @@ -322,11 +328,33 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif not model_request_parameters.allow_text_output: - return 'required' - else: + + resolved = resolve_tool_choice(model_settings, model_request_parameters) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output: + return 'required' return 'auto' + if resolved.mode in ('auto', 'required'): + return resolved.mode + + if resolved.mode == 'none': + if resolved.output_tools_fallback: + return 'required' + return 'none' + + if resolved.tool_names: + warnings.warn( + "Mistral does not support forcing specific tools. Falling back to 'required'.", + UserWarning, + stacklevel=6, + ) + return 'required' + + return None # pragma: no cover + def _map_function_and_output_tools_definition( self, model_request_parameters: ModelRequestParameters ) -> list[MistralTool] | None: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 3c5c184a76..991632da74 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -52,7 +52,15 @@ from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition -from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent +from . import ( + Model, + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, + download_item, + get_user_agent, + resolve_tool_choice, +) try: from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream @@ -67,6 +75,8 @@ chat_completion_chunk, chat_completion_token_logprob, ) + from openai.types.chat.chat_completion_allowed_tool_choice_param import ChatCompletionAllowedToolChoiceParam + from openai.types.chat.chat_completion_allowed_tools_param import ChatCompletionAllowedToolsParam from openai.types.chat.chat_completion_content_part_image_param import ImageURL from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.chat.chat_completion_content_part_param import File, FileFile @@ -75,16 +85,21 @@ from openai.types.chat.chat_completion_message_function_tool_call_param import ( ChatCompletionMessageFunctionToolCallParam, ) + from openai.types.chat.chat_completion_named_tool_choice_param import ChatCompletionNamedToolChoiceParam from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam + from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam from openai.types.chat.completion_create_params import ( WebSearchOptions, WebSearchOptionsUserLocation, WebSearchOptionsUserLocationApproximate, ) from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam + from openai.types.responses.response_create_params import ToolChoice as ResponsesToolChoice from openai.types.responses.response_input_param import FunctionCallOutput, Message from openai.types.responses.response_reasoning_item_param import Summary from openai.types.responses.response_status import ResponseStatus + from openai.types.responses.tool_choice_allowed_param import ToolChoiceAllowedParam + from openai.types.responses.tool_choice_function_param import ToolChoiceFunctionParam from openai.types.shared import ReasoningEffort from openai.types.shared_params import Reasoning except ImportError as _import_error: @@ -157,6 +172,78 @@ } +@dataclass +class _OpenAIToolChoiceResult: + """Intermediate representation of resolved tool choice for OpenAI APIs.""" + + kind: Literal['literal', 'single_tool', 'allowed_tools'] + """The kind of tool choice result.""" + + literal_value: Literal['auto', 'required', 'none'] | None = None + """For 'literal' kind: the literal value to pass to the API.""" + + tool_name: str | None = None + """For 'single_tool' kind: the specific tool name.""" + + tool_names: list[str] | None = None + """For 'allowed_tools' kind: the list of allowed tool names.""" + + allowed_mode: Literal['auto', 'required'] = 'required' + """For 'allowed_tools' kind: whether the model must use one of the tools or can choose not to.""" + + +def _resolve_openai_tool_choice( + model_settings: ModelSettings, + model_request_parameters: ModelRequestParameters, + profile: ModelProfile, +) -> _OpenAIToolChoiceResult | None: + """Resolve tool choice settings into an intermediate representation for OpenAI APIs. + + This centralizes the logic shared between Chat Completions and Responses APIs. + Returns None if there are no tools, otherwise returns an _OpenAIToolChoiceResult. + """ + resolved = resolve_tool_choice(model_settings, model_request_parameters) + openai_profile = OpenAIModelProfile.from_profile(profile) + + if resolved is None: + # Default behavior: infer from allow_text_output + if not model_request_parameters.allow_text_output and openai_profile.openai_supports_tool_choice_required: + return _OpenAIToolChoiceResult(kind='literal', literal_value='required') + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') + + if resolved.mode == 'auto': + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') + + if resolved.mode == 'required': + if not openai_profile.openai_supports_tool_choice_required: + warnings.warn( + "tool_choice='required' is not supported by this model. Falling back to 'auto'.", + UserWarning, + stacklevel=7, + ) + return _OpenAIToolChoiceResult(kind='literal', literal_value='auto') + return _OpenAIToolChoiceResult(kind='literal', literal_value='required') + + if resolved.mode == 'none': + if not resolved.output_tools_fallback: + return _OpenAIToolChoiceResult(kind='literal', literal_value='none') + output_tool_names = [t.name for t in model_request_parameters.output_tools] + allowed_mode: Literal['auto', 'required'] = ( + 'required' if not model_request_parameters.allow_text_output else 'auto' + ) + if len(output_tool_names) == 1: + return _OpenAIToolChoiceResult(kind='single_tool', tool_name=output_tool_names[0]) + return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=output_tool_names, allowed_mode=allowed_mode) + + if resolved.tool_names: + allowed_mode = 'required' if not model_request_parameters.allow_text_output else 'auto' + if len(resolved.tool_names) == 1: + return _OpenAIToolChoiceResult(kind='single_tool', tool_name=resolved.tool_names[0]) + return _OpenAIToolChoiceResult(kind='allowed_tools', tool_names=resolved.tool_names, allowed_mode=allowed_mode) + + return None # pragma: no cover + + class OpenAIChatModelSettings(ModelSettings, total=False): """Settings used for an OpenAI model request.""" @@ -493,15 +580,7 @@ async def _completions_create( tools = self._get_tools(model_request_parameters) web_search_options = self._get_web_search_options(model_request_parameters) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_tool_choice(tools, model_settings, model_request_parameters) openai_messages = await self._map_messages(messages, model_request_parameters) @@ -699,6 +778,36 @@ def _map_usage(self, response: chat.ChatCompletion) -> usage.RequestUsage: def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_tool_choice( + self, + tools: list[chat.ChatCompletionToolParam], + model_settings: OpenAIChatModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionToolChoiceOptionParam | None: + if not tools: + return None + + resolved = _resolve_openai_tool_choice(model_settings, model_request_parameters, self.profile) + if resolved is None: # pragma: no cover + return None + + if resolved.kind == 'literal': + return resolved.literal_value + + if resolved.kind == 'single_tool': + return ChatCompletionNamedToolChoiceParam(type='function', function={'name': resolved.tool_name or ''}) + + if resolved.kind == 'allowed_tools' and resolved.tool_names: + return ChatCompletionAllowedToolChoiceParam( + type='allowed_tools', + allowed_tools=ChatCompletionAllowedToolsParam( + mode=resolved.allowed_mode, + tools=[{'type': 'function', 'function': {'name': n}} for n in resolved.tool_names], + ), + ) + + return None # pragma: no cover + def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None: for tool in model_request_parameters.builtin_tools: if isinstance(tool, WebSearchTool): # pragma: no branch @@ -1288,7 +1397,7 @@ async def _responses_create( model_request_parameters: ModelRequestParameters, ) -> AsyncStream[responses.ResponseStreamEvent]: ... - async def _responses_create( # noqa: C901 + async def _responses_create( self, messages: list[ModelRequest | ModelResponse], stream: bool, @@ -1301,12 +1410,7 @@ async def _responses_create( # noqa: C901 + self._get_tools(model_request_parameters) ) profile = OpenAIModelProfile.from_profile(self.profile) - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required: - tool_choice = 'required' - else: - tool_choice = 'auto' + tool_choice = self._get_responses_tool_choice(tools, model_settings, model_request_parameters) previous_response_id = model_settings.get('openai_previous_response_id') if previous_response_id == 'auto': @@ -1417,6 +1521,34 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] + def _get_responses_tool_choice( + self, + tools: list[responses.ToolParam], + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ResponsesToolChoice | None: + if not tools: + return None + + resolved = _resolve_openai_tool_choice(model_settings, model_request_parameters, self.profile) + if resolved is None: # pragma: no cover + return None + + if resolved.kind == 'literal': + return resolved.literal_value + + if resolved.kind == 'single_tool': + return ToolChoiceFunctionParam(type='function', name=resolved.tool_name or '') + + if resolved.kind == 'allowed_tools' and resolved.tool_names: + return ToolChoiceAllowedParam( + type='allowed_tools', + mode=resolved.allowed_mode, + tools=[{'type': 'function', 'name': n} for n in resolved.tool_names], + ) + + return None # pragma: no cover + def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]: tools: list[responses.ToolParam] = [] has_image_generating_tool = False diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index 6941eb1ab3..5f0fde9312 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + from httpx import Timeout from typing_extensions import TypedDict @@ -88,6 +90,33 @@ class ModelSettings(TypedDict, total=False): * Anthropic """ + tool_choice: Literal['none', 'required', 'auto'] | list[str] | None + """Control which function tools the model can use. + + This setting only affects function tools registered on the agent, not output tools + used for structured output. + + * `None` (default): Automatically determined based on output configuration + * `'auto'`: Model decides whether to use function tools + * `'required'`: Model must use one of the available function tools + * `'none'`: Model cannot use function tools (output tools remain available if needed) + * `list[str]`: Model must use one of the specified function tools (validated against registered tools) + + If the agent has a structured output type that requires an output tool and `tool_choice='none'` + is set, the output tool will still be available and a warning will be logged. Consider using + native or prompted output modes if you need `tool_choice='none'` with structured output. + + Supported by: + + * OpenAI + * Anthropic (note: `'required'` and specific tools not supported with thinking/extended thinking) + * Gemini + * Groq + * Mistral + * HuggingFace + * Bedrock (note: `'none'` not supported, will fall back to `'auto'` with a warning) + """ + seed: int """The random seed to use for the model, theoretically allowing for deterministic results. diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 2df96b0278..25cc3483b1 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -7901,3 +7901,184 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant assert usage2.cache_read_tokens > 0 assert usage2.cache_write_tokens > 0 assert usage2.output_tokens > 0 + + +@pytest.mark.parametrize( + 'tool_choice,expected_type', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'any', id='required-maps-to-any'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected_type: str) -> None: + """Ensure Anthropic string values map to the expected schema.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == expected_type + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Single Anthropic tools should emit the 'tool' choice payload.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice'] == {'type': 'tool', 'name': 'tool_a'} + + +async def test_tool_choice_multiple_tools_falls_back_to_any(allow_model_requests: None) -> None: + """Multiple specific tools fall back to 'any' with a warning.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Anthropic only supports forcing a single tool'): + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice'] == snapshot({'type': 'any'}) + + +async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """Structured output must remain available even with tool_choice='none'.""" + + class Location(BaseModel): + city: str + country: str + + c = completion_message( + [BetaToolUseBlock(id='1', type='tool_use', name='final_result', input={'city': 'Paris', 'country': 'France'})], + BetaUsage(input_tokens=5, output_tokens=10), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == snapshot(Location(city='Paris', country='France')) + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice'] == snapshot({'type': 'tool', 'name': 'final_result'}) + + +async def test_tool_choice_required_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: + """Thinking mode overrides 'required' to 'auto'.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported with Anthropic thinking mode"): + await agent.run( + 'hello', + model_settings={ + 'tool_choice': 'required', + 'anthropic_thinking': {'type': 'enabled', 'budget_tokens': 1000}, + }, # type: ignore + ) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' + + +async def test_tool_choice_specific_with_thinking_falls_back_to_auto(allow_model_requests: None) -> None: + """Specific tool forcing is incompatible with thinking mode.""" + c = completion_message([BetaTextBlock(text='ok', type='text')], BetaUsage(input_tokens=5, output_tokens=10)) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Forcing specific tools is not supported with Anthropic thinking mode'): + await agent.run( + 'hello', + model_settings={ + 'tool_choice': ['my_tool'], + 'anthropic_thinking': {'type': 'enabled', 'budget_tokens': 1000}, + }, # type: ignore + ) + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' + + +async def test_tool_choice_none_with_multiple_output_tools_falls_back_to_auto(allow_model_requests: None) -> None: + """Multiple output tools force a fallback to 'auto'.""" + import warnings as warn_module + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + c = completion_message( + [BetaToolUseBlock(id='1', type='tool_use', name='final_result_LocationA', input={'city': 'Paris'})], + BetaUsage(input_tokens=5, output_tokens=10), + ) + mock_client = MockAnthropic.create_mock(c) + m = AnthropicModel('claude-haiku-4-5', provider=AnthropicProvider(anthropic_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(m, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with warn_module.catch_warnings(record=True) as w: + warn_module.simplefilter('always') + await agent.run('hello', model_settings={'tool_choice': 'none'}) + + warning_messages = {str(warning.message) for warning in w} + assert { + "tool_choice='none' is set but output tools are required for structured output. " + 'The output tools will remain available. Consider using `NativeOutput` or `PromptedOutput` ' + "if you need tool_choice='none' with structured output.", + "Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools.", + } <= warning_messages + + kwargs = mock_client.chat_completion_kwargs[0] # type: ignore + assert kwargs['tool_choice']['type'] == 'auto' diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index f13aaff4fb..fc034758a5 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -1324,7 +1324,7 @@ async def test_bedrock_group_consecutive_tool_return_parts(bedrock_provider: Bed ] # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage] assert bedrock_messages == snapshot( [ @@ -1445,7 +1445,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Models other than Mistral support toolResult.content with text, not json model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage,reportArgumentType] assert bedrock_messages == snapshot( [ @@ -1461,7 +1461,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Mistral requires toolResult.content to hold json, not text model = BedrockConverseModel('mistral.mistral-7b-instruct-v0:2', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # pyright: ignore[reportPrivateUsage,reportArgumentType] assert bedrock_messages == snapshot( [ @@ -1485,7 +1485,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Amazon Nova supports tool_choice model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1506,7 +1506,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Anthropic supports tool_choice model = BedrockConverseModel('us.anthropic.claude-3-7-sonnet-20250219-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1527,7 +1527,7 @@ async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider): # Other models don't support tool_choice model = BedrockConverseModel('us.meta.llama4-maverick-17b-instruct-v1:0', provider=bedrock_provider) - tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage] + tool_config = model._map_tool_config(mrp, None) # pyright: ignore[reportPrivateUsage] assert tool_config == snapshot( { @@ -1624,3 +1624,148 @@ async def test_cache_point_filtering(): # CachePoint should be filtered out, message should still be valid assert len(messages) == 1 assert messages[0]['role'] == 'user' + + +@pytest.mark.parametrize( + 'tool_choice,expected_tool_choice', + [ + pytest.param('auto', {'auto': {}}, id='auto'), + pytest.param('required', {'any': {}}, id='required-maps-to-any'), + ], +) +async def test_tool_choice_string_values( + bedrock_provider: BedrockProvider, tool_choice: str, expected_tool_choice: dict[str, Any] +) -> None: + """Ensure simple string tool_choice values map to Bedrock's schema.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config is not None + assert tool_config.get('toolChoice') == expected_tool_choice + + +async def test_tool_choice_none_falls_back_to_auto(bedrock_provider: BedrockProvider) -> None: + """Bedrock lacks 'none' support, so we fall back to auto with a warning.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + + settings: BedrockModelSettings = {'tool_choice': 'none'} + with pytest.warns(UserWarning, match="Bedrock does not support tool_choice='none'. Falling back to 'auto'"): + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'my_tool', + 'description': 'Test tool', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + } + ], + 'toolChoice': {'auto': {}}, + } + ) + + +async def test_tool_choice_specific_tool_single(bedrock_provider: BedrockProvider) -> None: + """Single tool names should emit the {tool: {name}} payload.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': ['tool_a']} + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'tool_a', + 'description': 'Test tool A', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + { + 'toolSpec': { + 'name': 'tool_b', + 'description': 'Test tool B', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + ], + 'toolChoice': {'tool': {'name': 'tool_a'}}, + } + ) + + +async def test_tool_choice_multiple_tools_falls_back_to_any(bedrock_provider: BedrockProvider) -> None: + """Multiple tool names fall back to the 'any' configuration.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + + model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + settings: BedrockModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + + with pytest.warns(UserWarning, match='Bedrock only supports forcing a single tool'): + tool_config = model._map_tool_config(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert tool_config == snapshot( + { + 'tools': [ + { + 'toolSpec': { + 'name': 'tool_a', + 'description': 'Test tool A', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + { + 'toolSpec': { + 'name': 'tool_b', + 'description': 'Test tool B', + 'inputSchema': {'json': {'type': 'object', 'properties': {}}}, + } + }, + ], + 'toolChoice': {'any': {}}, + } + ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 3ef8cd5dda..a196229458 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -59,6 +59,7 @@ from pydantic_ai.models import ModelRequestParameters from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits from ..conftest import IsBytes, IsDatetime, IsInstance, IsStr, try_import @@ -68,12 +69,15 @@ from google.genai import errors from google.genai.types import ( FinishReason as GoogleFinishReason, + FunctionCallingConfigMode, + FunctionDeclarationDict, GenerateContentResponse, GenerateContentResponseUsageMetadata, HarmBlockThreshold, HarmCategory, MediaModality, ModalityTokenCount, + ToolDict, ) from pydantic_ai.models.google import ( @@ -4425,3 +4429,120 @@ def test_google_missing_tool_call_thought_signature(): ], } ) + + +def test_tool_choice_string_value_none(google_provider: GoogleProvider) -> None: + """Test that tool_choice='none' maps to FunctionCallingConfigMode.NONE.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'none'} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.NONE}) + + +def test_tool_choice_string_value_auto(google_provider: GoogleProvider) -> None: + """Test that tool_choice='auto' maps to FunctionCallingConfigMode.AUTO.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'auto'} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.AUTO}) + + +def test_tool_choice_required_maps_to_any(google_provider: GoogleProvider) -> None: + """Test that 'required' maps to ANY mode with all tool names.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools = [ToolDict(function_declarations=[FunctionDeclarationDict(name='my_tool', description='Test tool')])] + + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'required'} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['my_tool']}) + + +def test_tool_choice_specific_tool_single(google_provider: GoogleProvider) -> None: + """Specific tool names become allowed_function_names.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools = [ + ToolDict(function_declarations=[FunctionDeclarationDict(name='tool_a', description='Test tool A')]), + ToolDict(function_declarations=[FunctionDeclarationDict(name='tool_b', description='Test tool B')]), + ] + + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': ['tool_a']} + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['tool_a']}) + + +def test_tool_choice_none_with_output_tools_warns(google_provider: GoogleProvider) -> None: + """tool_choice='none' still allows the required output tool.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + tools = [ + ToolDict(function_declarations=[FunctionDeclarationDict(name='func_tool', description='Function tool')]), + ToolDict(function_declarations=[FunctionDeclarationDict(name='output_tool', description='Output tool')]), + ] + + model = GoogleModel('gemini-2.5-flash', provider=google_provider) + settings: GoogleModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_config(mrp, tools, settings) # pyright: ignore[reportPrivateUsage] + + assert result is not None + fcc = result.get('function_calling_config') + assert fcc == snapshot({'mode': FunctionCallingConfigMode.ANY, 'allowed_function_names': ['output_tool']}) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index dd3395750e..11f4669259 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -3,7 +3,7 @@ import json import os from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timezone from functools import cached_property from typing import Any, Literal, cast @@ -97,6 +97,7 @@ class MockGroq: completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None stream: Sequence[MockChatCompletionChunk] | Sequence[Sequence[MockChatCompletionChunk]] | None = None index: int = 0 + chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list) @cached_property def chat(self) -> Any: @@ -115,8 +116,9 @@ def create_mock_stream( return cast(AsyncGroq, cls(stream=stream)) async def chat_completions_create( - self, *_args: Any, stream: bool = False, **_kwargs: Any + self, *_args: Any, stream: bool = False, **kwargs: Any ) -> chat.ChatCompletion | MockAsyncStream[MockChatCompletionChunk]: + self.chat_completion_kwargs.append(kwargs) if stream: assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' if isinstance(self.stream[0], Sequence): @@ -137,6 +139,13 @@ async def chat_completions_create( return response +def get_mock_chat_completion_kwargs(groq_client: AsyncGroq) -> list[dict[str, Any]]: + if isinstance(groq_client, MockGroq): + return groq_client.chat_completion_kwargs + else: # pragma: no cover + raise RuntimeError('Not a MockGroq instance') + + def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion: return chat.ChatCompletion( id='123', @@ -5623,3 +5632,106 @@ class CityLocation(BaseModel): ), ] ) + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Ensure Groq string values are forwarded unchanged.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == expected + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Single tool choices should use the named tool payload.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'tool_a'}} + + +async def test_tool_choice_multiple_tools_falls_back_to_required(allow_model_requests: None) -> None: + """Multiple specific tools fall back to 'required'.""" + mock_client = MockGroq.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent = Agent(m) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match='Groq only supports forcing a single tool'): + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == 'required' + + +async def test_tool_choice_none_with_output_tools(allow_model_requests: None) -> None: + """tool_choice='none' still allows output tools to execute.""" + + class MyOutput(BaseModel): + result: str + + # Tool call response that returns final_result tool + tool_call_response = completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + chat.ChatCompletionMessageToolCall( + id='call_1', + type='function', + function=chat.chat_completion_message_tool_call.Function( + name='final_result', arguments='{"result": "done"}' + ), + ) + ], + ) + ) + + mock_client = MockGroq.create_mock(tool_call_response) + m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client)) + agent: Agent[None, MyOutput] = Agent(m, output_type=MyOutput) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + await agent.run('hello', model_settings={'tool_choice': 'none'}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert kwargs['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 56d74ed619..4161f71605 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -12,7 +12,10 @@ import pytest from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInputFunctionName, ChatCompletionInputMessage, + ChatCompletionInputTool, + ChatCompletionInputToolChoiceClass, ChatCompletionOutput, ChatCompletionOutputComplete, ChatCompletionOutputFunctionDefinition, @@ -50,12 +53,13 @@ VideoUrl, ) from pydantic_ai.exceptions import ModelHTTPError -from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.models.huggingface import HuggingFaceModel, HuggingFaceModelSettings from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResult, AgentRunResultEvent from pydantic_ai.settings import ModelSettings -from pydantic_ai.tools import RunContext +from pydantic_ai.tools import RunContext, ToolDefinition from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import @@ -1026,3 +1030,129 @@ async def test_cache_point_filtering(): # CachePoint should be filtered out assert msg['role'] == 'user' assert len(msg['content']) == 1 # pyright: ignore[reportUnknownArgumentType] + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: + """Ensure HuggingFace string values pass through unchanged.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'my_tool', 'description': 'Test tool'}) # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert result == expected + + +def test_tool_choice_specific_tool_single() -> None: + """Single tool entries should use ChatCompletionInputToolChoiceClass.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'tool_a', 'description': 'Test tool A'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'tool_b', 'description': 'Test tool B'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': ['tool_a']} + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert isinstance(result, ChatCompletionInputToolChoiceClass) + assert result.function == ChatCompletionInputFunctionName(name='tool_a') # type: ignore[call-arg] + + +def test_tool_choice_multiple_tools_falls_back_to_required() -> None: + """Multiple specific tools fall back to 'required'.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + tool_b = ToolDefinition( + name='tool_b', + description='Test tool B', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[tool_a, tool_b], allow_text_output=True, output_tools=[] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'tool_a', 'description': 'Test tool A'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'tool_b', 'description': 'Test tool B'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + + with pytest.warns(UserWarning, match='HuggingFace only supports forcing a single tool'): + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' + + +def test_tool_choice_none_with_output_tools_warns() -> None: + """tool_choice='none' should not disable mandatory output tools.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + tools: list[ChatCompletionInputTool] = [ + ChatCompletionInputTool(type='function', function={'name': 'func_tool', 'description': 'Function tool'}), # pyright: ignore[reportCallIssue] + ChatCompletionInputTool(type='function', function={'name': 'output_tool', 'description': 'Output tool'}), # pyright: ignore[reportCallIssue] + ] + + mock_client = MockHuggingFace.create_mock( + completion_message(ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'ok', 'role': 'assistant'})) # type: ignore + ) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + settings: HuggingFaceModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_choice(tools, settings, mrp) # pyright: ignore[reportPrivateUsage] + + assert isinstance(result, ChatCompletionInputToolChoiceClass) + assert result.function == ChatCompletionInputFunctionName(name='output_tool') # type: ignore[call-arg] diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4ae21ad221..50d64597d5 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -29,6 +29,8 @@ ) from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import @@ -55,7 +57,7 @@ ) from mistralai.types.basemodel import Unset as MistralUnset - from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse + from pydantic_ai.models.mistral import MistralModel, MistralModelSettings, MistralStreamedResponse from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings from pydantic_ai.providers.mistral import MistralProvider from pydantic_ai.providers.openai import OpenAIProvider @@ -2361,3 +2363,73 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ), ] ) + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +def test_tool_choice_string_values(tool_choice: str, expected: str) -> None: + """Ensure Mistral string values pass through untouched.""" + my_tool = ToolDefinition( + name='my_tool', + description='Test tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=True, output_tools=[]) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': tool_choice} # type: ignore[assignment] + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == expected + + +def test_tool_choice_specific_tool_falls_back_to_required() -> None: + """Specific tool forcing is unsupported and falls back to required.""" + tool_a = ToolDefinition( + name='tool_a', + description='Test tool A', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[tool_a], allow_text_output=True, output_tools=[]) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': ['tool_a']} + + with pytest.warns(UserWarning, match="Mistral does not support forcing specific tools. Falling back to 'required'"): + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' + + +def test_tool_choice_none_with_output_tools_warns() -> None: + """tool_choice='none' still forces required when output tools exist.""" + func_tool = ToolDefinition( + name='func_tool', + description='Function tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + output_tool = ToolDefinition( + name='output_tool', + description='Output tool', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + mrp = ModelRequestParameters( + output_mode='tool', function_tools=[func_tool], allow_text_output=False, output_tools=[output_tool] + ) + + mock_client = MockMistralAI.create_mock(completion_message(MistralAssistantMessage(content='ok', role='assistant'))) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + settings: MistralModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = model._get_tool_choice(mrp, settings) # pyright: ignore[reportPrivateUsage] + + assert result == 'required' diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index ed68edd94f..fd4501c60f 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -71,6 +71,8 @@ from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob from openai.types.completion_usage import CompletionUsage, PromptTokensDetails + from openai.types.responses import ResponseFunctionToolCall + from openai.types.responses.response_output_message import ResponseOutputMessage, ResponseOutputText from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import ( @@ -3129,6 +3131,48 @@ async def test_tool_choice_fallback_response_api(allow_model_requests: None) -> assert get_mock_responses_kwargs(mock_client)[0]['tool_choice'] == 'auto' +async def test_tool_choice_required_explicit_unsupported(allow_model_requests: None) -> None: + """Ensure explicit tool_choice='required' warns and falls back to 'auto' when unsupported.""" + profile = OpenAIModelProfile(openai_supports_tool_choice_required=False).update(openai_model_profile('stub')) + + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('stub', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + + params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=True) + settings: OpenAIChatModelSettings = {'tool_choice': 'required'} + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported by this model"): + await model._completions_create( # pyright: ignore[reportPrivateUsage] + messages=[], + stream=False, + model_settings=settings, + model_request_parameters=params, + ) + + assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice'] == 'auto' + + +async def test_tool_choice_required_explicit_unsupported_responses_api(allow_model_requests: None) -> None: + """Ensure explicit tool_choice='required' warns and falls back for Responses API when unsupported.""" + profile = OpenAIModelProfile(openai_supports_tool_choice_required=False).update(openai_model_profile('stub')) + + mock_client = MockOpenAIResponses.create_mock(response_message([])) + model = OpenAIResponsesModel('openai/gpt-oss', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + + params = ModelRequestParameters(function_tools=[ToolDefinition(name='x')], allow_text_output=True) + settings: OpenAIResponsesModelSettings = {'tool_choice': 'required'} + + with pytest.warns(UserWarning, match="tool_choice='required' is not supported by this model"): + await model._responses_create( # pyright: ignore[reportPrivateUsage] + messages=[], + stream=False, + model_settings=settings, + model_request_parameters=params, + ) + + assert get_mock_responses_kwargs(mock_client)[0]['tool_choice'] == 'auto' + + async def test_openai_model_settings_temperature_ignored_on_gpt_5(allow_model_requests: None, openai_api_key: str): m = OpenAIChatModel('gpt-5', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent(m) @@ -3183,10 +3227,9 @@ async def test_cache_point_filtering(allow_model_requests: None): msg = await m._map_user_prompt(UserPromptPart(content=['text before', CachePoint(), 'text after'])) # pyright: ignore[reportPrivateUsage] # CachePoint should be filtered out, only text content should remain - assert msg['role'] == 'user' - assert len(msg['content']) == 2 # type: ignore[reportUnknownArgumentType] - assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] - assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] + assert msg == snapshot( + {'role': 'user', 'content': [{'text': 'text before', 'type': 'text'}, {'text': 'text after', 'type': 'text'}]} + ) async def test_cache_point_filtering_responses_model(): @@ -3197,10 +3240,371 @@ async def test_cache_point_filtering_responses_model(): ) # CachePoint should be filtered out, only text content should remain - assert msg['role'] == 'user' - assert len(msg['content']) == 2 - assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType] - assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType] + assert msg == snapshot( + { + 'role': 'user', + 'content': [{'text': 'text before', 'type': 'input_text'}, {'text': 'text after', 'type': 'input_text'}], + } + ) + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Ensure Chat tool_choice strings flow through unchanged.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == expected + + +async def test_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Force the Chat API to call a specific tool.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'tool_a'}} + + +async def test_tool_choice_specific_tools_multiple(allow_model_requests: None) -> None: + """Multiple Chat tools should produce an allowed_tools payload.""" + mock_client = MockOpenAI.create_mock(completion_message(ChatCompletionMessage(content='ok', role='assistant'))) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_c(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == snapshot( + { + 'type': 'allowed_tools', + 'allowed_tools': { + 'mode': 'auto', + 'tools': [ + {'type': 'function', 'function': {'name': 'tool_a'}}, + {'type': 'function', 'function': {'name': 'tool_b'}}, + ], + }, + } + ) + + +async def test_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """Structured output tools persist even with tool_choice='none'.""" + + class Location(BaseModel): + city: str + country: str + + mock_client = MockOpenAI.create_mock( + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageFunctionToolCall( + id='1', + type='function', + function=Function( + name='final_result', + arguments='{"city": "Paris", "country": "France"}', + ), + ), + ], + ) + ) + ) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == Location(city='Paris', country='France') + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'function': {'name': 'final_result'}} + + +async def test_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: + """Multiple output tools fall back to allowed_tools when forcing 'none'.""" + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + mock_client = MockOpenAI.create_mock( + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageFunctionToolCall( + id='1', + type='function', + function=Function( + name='final_result_LocationA', + arguments='{"city": "Paris"}', + ), + ), + ], + ) + ) + ) + model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(model, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == LocationA(city='Paris') + kwargs = get_mock_chat_completion_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'allowed_tools': { + 'mode': 'required', + 'tools': [ + {'type': 'function', 'function': {'name': 'final_result_LocationA'}}, + {'type': 'function', 'function': {'name': 'final_result_LocationB'}}, + ], + }, + } + + +@pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', 'none', id='none'), + pytest.param('auto', 'auto', id='auto'), + pytest.param('required', 'required', id='required'), + ], +) +async def test_responses_tool_choice_string_values(allow_model_requests: None, tool_choice: str, expected: str) -> None: + """Ensure Responses tool_choice strings pass through untouched.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': tool_choice}) # type: ignore + + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == expected + + +async def test_responses_tool_choice_specific_tool_single(allow_model_requests: None) -> None: + """Force a single tool when using the Responses API.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def other_tool(y: str) -> str: + return y # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['my_tool']}) + + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'name': 'my_tool'} + + +async def test_responses_tool_choice_specific_tool_multiple(allow_model_requests: None) -> None: + """Multiple Responses tools rely on the allowed_tools payload.""" + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseOutputMessage( + id='msg_123', + content=[ResponseOutputText(text='ok', type='output_text', annotations=[])], + role='assistant', + status='completed', + type='message', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + @agent.tool_plain + def tool_a(x: int) -> str: + return str(x) # pragma: no cover + + @agent.tool_plain + def tool_b(y: str) -> str: + return y # pragma: no cover + + @agent.tool_plain + def tool_c(z: float) -> str: + return str(z) # pragma: no cover + + await agent.run('hello', model_settings={'tool_choice': ['tool_a', 'tool_b']}) + + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'mode': 'auto', + 'tools': [{'type': 'function', 'name': 'tool_a'}, {'type': 'function', 'name': 'tool_b'}], + } + + +async def test_responses_tool_choice_none_with_output_tools_warns(allow_model_requests: None) -> None: + """tool_choice='none' cannot disable required Responses output tools.""" + + class Location(BaseModel): + city: str + country: str + + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseFunctionToolCall( + id='call_123', + call_id='call_123', + name='final_result', + arguments='{"city": "Paris", "country": "France"}', + type='function_call', + status='completed', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model, output_type=Location) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == Location(city='Paris', country='France') + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == {'type': 'function', 'name': 'final_result'} + + +async def test_responses_tool_choice_none_with_multiple_output_tools(allow_model_requests: None) -> None: + """Multiple Responses output tools still use allowed_tools when forced to 'none'.""" + + class LocationA(BaseModel): + city: str + + class LocationB(BaseModel): + country: str + + mock_client = MockOpenAIResponses.create_mock( + response_message( + [ + ResponseFunctionToolCall( + id='call_123', + call_id='call_123', + name='final_result_LocationA', + arguments='{"city": "Paris"}', + type='function_call', + status='completed', + ) + ] + ) + ) + model = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent: Agent[None, LocationA | LocationB] = Agent(model, output_type=[LocationA, LocationB]) + + @agent.tool_plain + def my_tool(x: int) -> str: + return str(x) # pragma: no cover + + with pytest.warns(UserWarning, match="tool_choice='none' is set but output tools are required"): + result = await agent.run('hello', model_settings={'tool_choice': 'none'}) + + assert result.output == LocationA(city='Paris') + kwargs = get_mock_responses_kwargs(mock_client) + assert kwargs[0]['tool_choice'] == { + 'type': 'allowed_tools', + 'mode': 'required', + 'tools': [ + {'type': 'function', 'name': 'final_result_LocationA'}, + {'type': 'function', 'name': 'final_result_LocationB'}, + ], + } async def test_openai_custom_reasoning_field_sending_back_in_thinking_tags(allow_model_requests: None): diff --git a/tests/models/test_resolve_tool_choice.py b/tests/models/test_resolve_tool_choice.py new file mode 100644 index 0000000000..5a6f5faf6c --- /dev/null +++ b/tests/models/test_resolve_tool_choice.py @@ -0,0 +1,154 @@ +"""Tests for the centralized `resolve_tool_choice()` function. + +These tests cover the common logic shared across all providers: +- String value resolution ('none', 'auto', 'required') +- List[str] validation and resolution +- Warning emission for tool_choice='none' with output tools +- Invalid tool name detection + +Provider-specific tests (API format mapping) remain in their respective test files. +""" + +from __future__ import annotations + +import warnings + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import ( + ModelRequestParameters, + ResolvedToolChoice, + resolve_tool_choice, +) +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ToolDefinition + + +def make_tool(name: str) -> ToolDefinition: + """Return a minimal `ToolDefinition` used throughout the tests.""" + return ToolDefinition( + name=name, + description=f'Tool {name}', + parameters_json_schema={'type': 'object', 'properties': {}}, + ) + + +class TestResolveToolChoiceNone: + """Cases where `tool_choice` is unset in the settings.""" + + def test_none_model_settings_returns_none(self) -> None: + """`resolve_tool_choice` returns None when `model_settings` is None.""" + params = ModelRequestParameters() + result = resolve_tool_choice(None, params) + assert result is None + + def test_empty_model_settings_returns_none(self) -> None: + """Empty `model_settings` dict should also yield None.""" + params = ModelRequestParameters() + settings: ModelSettings = {} + result = resolve_tool_choice(settings, params) + assert result is None + + def test_tool_choice_not_set_returns_none(self) -> None: + """`tool_choice` missing from settings keeps provider defaults.""" + params = ModelRequestParameters() + settings: ModelSettings = {'temperature': 0.5} + result = resolve_tool_choice(settings, params) + assert result is None + + +class TestResolveToolChoiceStringValues: + """String-valued `tool_choice` entries.""" + + @pytest.mark.parametrize( + 'tool_choice,expected', + [ + pytest.param('none', snapshot(ResolvedToolChoice(mode='none')), id='none'), + pytest.param('auto', snapshot(ResolvedToolChoice(mode='auto')), id='auto'), + pytest.param('required', snapshot(ResolvedToolChoice(mode='required')), id='required'), + ], + ) + def test_string_values(self, tool_choice: str, expected: ResolvedToolChoice) -> None: + """Valid string entries map directly to their resolved form.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': tool_choice} # type: ignore + result = resolve_tool_choice(settings, params) + assert result == expected + + +class TestResolveToolChoiceSpecificTools: + """List-based tool_choice entries.""" + + def test_single_valid_tool(self) -> None: + """Single tool names remain in the returned result.""" + params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b')]) + settings: ModelSettings = {'tool_choice': ['tool_a']} + result = resolve_tool_choice(settings, params) + assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a'])) + + def test_multiple_valid_tools(self) -> None: + """Multiple valid names stay in insertion order.""" + params = ModelRequestParameters(function_tools=[make_tool('tool_a'), make_tool('tool_b'), make_tool('tool_c')]) + settings: ModelSettings = {'tool_choice': ['tool_a', 'tool_b']} + result = resolve_tool_choice(settings, params) + assert result == snapshot(ResolvedToolChoice(mode='specific', tool_names=['tool_a', 'tool_b'])) + + def test_invalid_tool_name_raises_user_error(self) -> None: + """Unknown names raise a UserError.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': ['nonexistent_tool']} + + with pytest.raises(UserError, match='Invalid tool names in tool_choice'): + resolve_tool_choice(settings, params) + + def test_mixed_valid_and_invalid_tools(self) -> None: + """Mixed valid/invalid names still raise.""" + params = ModelRequestParameters(function_tools=[make_tool('valid_tool')]) + settings: ModelSettings = {'tool_choice': ['valid_tool', 'invalid_tool']} + + with pytest.raises(UserError, match='invalid_tool'): + resolve_tool_choice(settings, params) + + def test_no_function_tools_available(self) -> None: + """Requesting specific tools without registered ones errors.""" + params = ModelRequestParameters() + settings: ModelSettings = {'tool_choice': ['some_tool']} + + with pytest.raises(UserError, match='Available function tools: none'): + resolve_tool_choice(settings, params) + + def test_empty_list_raises_user_error(self) -> None: + """Empty lists are not allowed.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': []} + + with pytest.raises(UserError, match='tool_choice cannot be an empty list'): + resolve_tool_choice(settings, params) + + +class TestResolveToolChoiceOutputToolsWarning: + """Safety checks when `tool_choice='none'` conflicts with output tools.""" + + def test_none_with_output_tools_warns(self) -> None: + """`tool_choice='none'` issues a warning when output tools exist.""" + output_tool = make_tool('final_result') + params = ModelRequestParameters(output_tools=[output_tool]) + settings: ModelSettings = {'tool_choice': 'none'} + + with pytest.warns(UserWarning, match='tool_choice=.none. is set but output tools are required'): + result = resolve_tool_choice(settings, params, stacklevel=2) + + assert result == snapshot(ResolvedToolChoice(mode='none', output_tools_fallback=True)) + + def test_none_without_output_tools_no_warning(self) -> None: + """No warning when `tool_choice='none'` and no output tools exist.""" + params = ModelRequestParameters(function_tools=[make_tool('my_tool')]) + settings: ModelSettings = {'tool_choice': 'none'} + + with warnings.catch_warnings(): + warnings.simplefilter('error') + result = resolve_tool_choice(settings, params, stacklevel=2) + + assert result == snapshot(ResolvedToolChoice(mode='none'))