From 53a3930077bc9d32e906c653f7655b9f048011b5 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 27 Nov 2025 18:10:12 +0100 Subject: [PATCH 1/5] add structured output support for hugging face --- docs/install.md | 2 +- .../pydantic_ai/models/huggingface.py | 50 +++++-- .../pydantic_ai/providers/huggingface.py | 140 +++++++++++++++++- tests/models/test_huggingface.py | 45 ++++++ uv.lock | 80 +++------- 5 files changed, 235 insertions(+), 82 deletions(-) diff --git a/docs/install.md b/docs/install.md index 77ff4e56c1..e91ab8723c 100644 --- a/docs/install.md +++ b/docs/install.md @@ -50,7 +50,7 @@ pip/uv-add "pydantic-ai-slim[openai]" * `mistral` — installs [Mistral Model](models/mistral.md) dependency `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"} * `cohere` - installs [Cohere Model](models/cohere.md) dependency `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} * `bedrock` - installs [Bedrock Model](models/bedrock.md) dependency `boto3` [PyPI ↗](https://pypi.org/project/boto3){:target="_blank"} -* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub[inference]` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"} +* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"} * `outlines-transformers` - installs [Outlines Model](models/outlines.md) dependency `outlines[transformers]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"} * `outlines-llamacpp` - installs [Outlines Model](models/outlines.md) dependency `outlines[llamacpp]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"} * `outlines-mlxlm` - installs [Outlines Model](models/outlines.md) dependency `outlines[mlxlm]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"} diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 790b30bec3..4762f3cac4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -9,6 +9,7 @@ from typing_extensions import assert_never from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition from .._run_context import RunContext from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc @@ -41,19 +42,17 @@ from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition -from . import ( - Model, - ModelRequestParameters, - StreamedResponse, - check_allow_model_requests, -) +from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests try: import aiohttp from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInputGrammarType, ChatCompletionInputMessage, ChatCompletionInputMessageChunk, + ChatCompletionInputResponseFormatJSONObject, + ChatCompletionInputResponseFormatJSONSchema, ChatCompletionInputTool, ChatCompletionInputToolCall, ChatCompletionInputURL, @@ -80,10 +79,11 @@ LatestHuggingFaceModelNames = Literal[ 'deepseek-ai/DeepSeek-R1', - 'meta-llama/Llama-3.3-70B-Instruct', - 'meta-llama/Llama-4-Maverick-17B-128E-Instruct', - 'meta-llama/Llama-4-Scout-17B-16E-Instruct', - 'Qwen/QwQ-32B', + 'meta-llama/Llama-3.1-8B-Instruct', + 'MiniMaxAI/MiniMax-M2', + 'openai/gpt-oss-20b', + 'openai/gpt-oss-120b', + 'Qwen/Qwen3-Coder-30B-A3B-Instruct', 'Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-32B', @@ -142,13 +142,14 @@ def __init__( profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. settings: Model-specific settings that will be used as defaults for this model. """ - self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self._provider = provider - self.client = provider.client + provider_profile = provider.model_profile(model_name) + self._model_name = model_name.rsplit(':', 1)[0] - super().__init__(settings=settings, profile=profile or provider.model_profile) + super().__init__(settings=settings, profile=profile or provider_profile) + self.client = provider.client @property def model_name(self) -> HuggingFaceModelName: @@ -233,6 +234,15 @@ async def _completions_create( raise UserError('HuggingFace does not support built-in tools') hf_messages = await self._map_messages(messages, model_request_parameters) + response_format: ChatCompletionInputGrammarType | None = None + if model_request_parameters.output_mode == 'native': + output_object = model_request_parameters.output_object + assert output_object is not None + response_format = self._map_json_schema(output_object) + elif ( + model_request_parameters.output_mode == 'prompted' and self.profile.supports_json_object_output + ): # pragma: no branch + response_format = ChatCompletionInputResponseFormatJSONObject.parse_obj_as_instance({'type': 'json_object'}) # type: ignore try: return await self.client.chat.completions.create( # type: ignore @@ -245,6 +255,7 @@ async def _completions_create( temperature=model_settings.get('temperature', None), top_p=model_settings.get('top_p', None), seed=model_settings.get('seed', None), + response_format=response_format or None, presence_penalty=model_settings.get('presence_penalty', None), frequency_penalty=model_settings.get('frequency_penalty', None), logit_bias=model_settings.get('logit_bias', None), # type: ignore @@ -377,6 +388,19 @@ def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall: } ) + def _map_json_schema(self, o: OutputObjectDefinition) -> ChatCompletionInputGrammarType: + response_format_param: ChatCompletionInputResponseFormatJSONSchema = { # type: ignore + 'type': 'json_schema', + 'json_schema': { + 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, + 'schema': o.json_schema, + 'strict': o.strict, + }, + } + if o.description: # pragma: no branch + response_format_param['json_schema']['description'] = o.description + return response_format_param + @staticmethod def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index 45e1f9ef78..efc38b1773 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -1,9 +1,14 @@ from __future__ import annotations as _annotations import os +from dataclasses import replace +from functools import lru_cache from typing import overload +import httpx from httpx import AsyncClient +from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypedDict from pydantic_ai import ModelProfile from pydantic_ai.exceptions import UserError @@ -14,6 +19,8 @@ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile from pydantic_ai.profiles.qwen import qwen_model_profile +from . import Provider + try: from huggingface_hub import AsyncInferenceClient except ImportError as _import_error: # pragma: no cover @@ -22,7 +29,74 @@ "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" ) from _import_error -from . import Provider + +HF_ROUTER_MODELS_URL = 'https://router.huggingface.co/v1/models' + + +class HfRouterModel(TypedDict): + """Hugging Face router model definition.""" + + id: str + + +class HfRouterResponse(TypedDict): + """Hugging Face router response.""" + + data: list[HfRouterModel] + + +class HfRouterProvider(TypedDict): + """Hugging Face router provider definition.""" + + provider: str + status: str + supports_tools: bool + supports_structured_output: bool + + +class HfRouterModelInfo(TypedDict): + """Hugging Face router model info.""" + + id: str + providers: list[HfRouterProvider] + + +class HfRouterResponseData(TypedDict): + """Hugging Face router response data.""" + + data: HfRouterModelInfo + + +@lru_cache(maxsize=128) +def _get_router_info(model_id: str) -> HfRouterModelInfo | None: + try: + resp = httpx.get(f'{HF_ROUTER_MODELS_URL}/{model_id}', timeout=5.0, follow_redirects=True) + if resp.status_code != 200: + return None + payload = TypeAdapter(HfRouterResponseData).validate_json(resp.content) + return payload['data'] + except (httpx.HTTPError, ValidationError, Exception): + return None + + +def select_provider(providers: list[HfRouterProvider]) -> HfRouterProvider | None: + """Select the best provider based on capabilities.""" + live_providers = [p for p in providers if p['status'] == 'live'] + if not live_providers: + live_providers = providers + + # 1 - supports_tools=True AND supports_structured_output=True + both = [p for p in live_providers if p['supports_tools'] and p['supports_structured_output']] + if both: + return both[0] + + # 2 - supports_tools=True OR supports_structured_output=True + either = [p for p in live_providers if p['supports_tools'] or p['supports_structured_output']] + if either: + return either[0] + + # 3 - Any + return live_providers[0] if live_providers else None class HuggingFaceProvider(Provider[AsyncInferenceClient]): @@ -53,12 +127,58 @@ def model_profile(self, model_name: str) -> ModelProfile | None: if '/' not in model_name: return None + provider_override = None + if ':' in model_name: + model_name, provider_override = model_name.rsplit(':', 1) + model_name = model_name.lower() - provider, model_name = model_name.split('/', 1) - if provider in provider_to_profile: - return provider_to_profile[provider](model_name) + model_prefix, _ = model_name.split('/', 1) + + base_profile: ModelProfile | None = None + if model_prefix in provider_to_profile: + base_profile = provider_to_profile[model_prefix](model_name) + + # fetch model capabilities + router_info = _get_router_info(model_name) + + selected_provider_info: HfRouterProvider | None = None + if router_info: + providers = router_info['providers'] + if provider_override: + # Find specific provider from model suffix + for p in providers: + if p['provider'] == provider_override: + selected_provider_info = p + break + elif self._provider: + for p in providers: + if p['provider'] == self._provider: + selected_provider_info = p + break + if selected_provider_info is None: + selected_provider_info = select_provider(providers) + else: + # Auto select using router preference + selected_provider_info = select_provider(providers) + + if selected_provider_info: + if base_profile is None: + base_profile = ModelProfile() + + # Update the client to use the selected provider + self._client = AsyncInferenceClient( + token=self.api_key, + provider=selected_provider_info['provider'], # type: ignore + ) - return None + return replace( + base_profile, + supports_tools=selected_provider_info['supports_tools'], + supports_json_schema_output=selected_provider_info['supports_structured_output'], + supports_json_object_output=selected_provider_info['supports_structured_output'], + ) + + return base_profile @overload def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ... @@ -96,8 +216,12 @@ def __init__( If `base_url` is passed, then `provider_name` is not used. """ api_key = api_key or os.getenv('HF_TOKEN') + if api_key is None and hf_client is not None: + api_key = getattr(hf_client, 'token', None) - if api_key is None: + self.api_key = api_key + + if self.api_key is None: raise UserError( 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' 'to use the HuggingFace provider.' @@ -110,6 +234,8 @@ def __init__( raise ValueError('Cannot provide both `base_url` and `provider_name`.') if hf_client is None: - self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore + self._client = AsyncInferenceClient(api_key=self.api_key, provider=provider_name, base_url=base_url) # type: ignore else: self._client = hf_client + + self._provider: str | None = provider_name diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 56d74ed619..e62cc5c40e 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -51,6 +51,8 @@ ) from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.output import NativeOutput +from pydantic_ai.profiles import ModelProfile from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResult, AgentRunResultEvent @@ -1026,3 +1028,46 @@ async def test_cache_point_filtering(): # CachePoint should be filtered out assert msg['role'] == 'user' assert len(msg['content']) == 1 # pyright: ignore[reportUnknownArgumentType] + + +async def test_native_output_structured_response(allow_model_requests: None): + completion = completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'content': '{"first": "One", "second": "Two"}', + 'role': 'assistant', + } + ) + ) + mock_client = MockHuggingFace.create_mock(completion) + model = HuggingFaceModel( + 'hf-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + profile=ModelProfile( + supports_json_schema_output=True, + supports_json_object_output=True, + ), + ) + agent = Agent( + model, + output_type=NativeOutput( + MyTypedDict, + name='final_result', + description='Return the first and second values.', + ), + ) + + result = await agent.run('Hello') + assert result.output == snapshot({'first': 'One', 'second': 'Two'}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + response_format = kwargs['response_format'] + assert response_format is not None + assert response_format['type'] == 'json_schema' + json_schema = response_format['json_schema'] + assert json_schema['name'] == 'final_result' + assert json_schema['description'] == 'Return the first and second values.' + schema = json_schema['schema'] + assert schema['type'] == 'object' + assert schema['properties']['first']['type'] == 'string' + assert schema['properties']['second']['type'] == 'string' diff --git a/uv.lock b/uv.lock index f91cf46244..e571a220e2 100644 --- a/uv.lock +++ b/uv.lock @@ -27,8 +27,7 @@ name = "accelerate" version = "1.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub" }, { name = "numpy" }, { name = "packaging" }, { name = "psutil" }, @@ -1338,8 +1337,7 @@ dependencies = [ { name = "dill" }, { name = "filelock" }, { name = "fsspec", extra = ["http"] }, - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub" }, { name = "multiprocess" }, { name = "numpy" }, { name = "packaging" }, @@ -2037,8 +2035,7 @@ dependencies = [ { name = "ffmpy" }, { name = "gradio-client" }, { name = "httpx" }, - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub" }, { name = "jinja2" }, { name = "markupsafe" }, { name = "numpy" }, @@ -2071,8 +2068,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fsspec" }, { name = "httpx" }, - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub" }, { name = "packaging" }, { name = "typing-extensions" }, { name = "websockets" }, @@ -2334,55 +2330,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819, upload-time = "2023-12-22T08:01:19.89Z" }, ] -[[package]] -name = "huggingface-hub" -version = "0.33.5" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", - "python_full_version >= '3.13' and platform_python_implementation != 'PyPy'", - "python_full_version == '3.12.*' and platform_python_implementation == 'PyPy'", - "python_full_version == '3.12.*' and platform_python_implementation != 'PyPy'", -] -dependencies = [ - { name = "filelock", marker = "python_full_version >= '3.12'" }, - { name = "fsspec", marker = "python_full_version >= '3.12'" }, - { name = "hf-xet", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64') or (python_full_version >= '3.12' and platform_machine == 'amd64') or (python_full_version >= '3.12' and platform_machine == 'arm64') or (python_full_version >= '3.12' and platform_machine == 'x86_64')" }, - { name = "packaging", marker = "python_full_version >= '3.12'" }, - { name = "pyyaml", marker = "python_full_version >= '3.12'" }, - { name = "requests", marker = "python_full_version >= '3.12'" }, - { name = "tqdm", marker = "python_full_version >= '3.12'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/02/16/5716d03e2b48bcc8e32d9b18ed7e55d2ae52e3d5df146cced9fe0581b5ff/huggingface_hub-0.33.5.tar.gz", hash = "sha256:814097e475646d170c44be4c38f7d381ccc4539156a5ac62a54f53aaf1602ed8", size = 427075, upload-time = "2025-07-24T12:30:31.449Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/d5/d9e9b75d8dc9cf125fff16fb0cd51d864a29e8b46b6880d8808940989405/huggingface_hub-0.33.5-py3-none-any.whl", hash = "sha256:29b4e64982c2064006021af297e1b17d44c85a8aaf90a0d7efeff7e7d2426296", size = 515705, upload-time = "2025-07-24T12:30:29.55Z" }, -] - -[package.optional-dependencies] -inference = [ - { name = "aiohttp", marker = "python_full_version >= '3.12'" }, -] - [[package]] name = "huggingface-hub" version = "0.36.0" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.11.*' and platform_python_implementation == 'PyPy'", - "python_full_version == '3.11.*' and platform_python_implementation != 'PyPy'", - "python_full_version < '3.11' and platform_python_implementation != 'PyPy'", - "python_full_version < '3.11' and platform_python_implementation == 'PyPy'", -] dependencies = [ - { name = "filelock", marker = "python_full_version < '3.12'" }, - { name = "fsspec", marker = "python_full_version < '3.12'" }, - { name = "hf-xet", marker = "(python_full_version < '3.12' and platform_machine == 'aarch64') or (python_full_version < '3.12' and platform_machine == 'amd64') or (python_full_version < '3.12' and platform_machine == 'arm64') or (python_full_version < '3.12' and platform_machine == 'x86_64')" }, - { name = "packaging", marker = "python_full_version < '3.12'" }, - { name = "pyyaml", marker = "python_full_version < '3.12'" }, - { name = "requests", marker = "python_full_version < '3.12'" }, - { name = "tqdm", marker = "python_full_version < '3.12'" }, - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } wheels = [ @@ -2391,7 +2351,7 @@ wheels = [ [package.optional-dependencies] inference = [ - { name = "aiohttp", marker = "python_full_version < '3.12'" }, + { name = "aiohttp" }, ] [[package]] @@ -4452,8 +4412,7 @@ wheels = [ [package.optional-dependencies] llamacpp = [ - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub" }, { name = "llama-cpp-python" }, { name = "numba", version = "0.61.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, { name = "numba", version = "0.62.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, @@ -5594,8 +5553,7 @@ groq = [ { name = "groq" }, ] huggingface = [ - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, extra = ["inference"], marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, extra = ["inference"], marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub", extra = ["inference"] }, ] logfire = [ { name = "logfire", extra = ["httpx"] }, @@ -7459,7 +7417,7 @@ resolution-markers = [ "python_full_version == '3.12.*' and platform_python_implementation != 'PyPy'", ] dependencies = [ - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "huggingface-hub", marker = "python_full_version >= '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/20/41/c2be10975ca37f6ec40d7abd7e98a5213bb04f284b869c1a24e6504fd94d/tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4", size = 343021, upload-time = "2024-11-27T13:11:23.89Z" } wheels = [ @@ -7490,7 +7448,7 @@ resolution-markers = [ "python_full_version < '3.11' and platform_python_implementation == 'PyPy'", ] dependencies = [ - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub", marker = "python_full_version < '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } wheels = [ @@ -7776,7 +7734,7 @@ resolution-markers = [ ] dependencies = [ { name = "filelock", marker = "python_full_version >= '3.12'" }, - { name = "huggingface-hub", version = "0.33.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "huggingface-hub", marker = "python_full_version >= '3.12'" }, { name = "numpy", marker = "python_full_version >= '3.12'" }, { name = "packaging", marker = "python_full_version >= '3.12'" }, { name = "pyyaml", marker = "python_full_version >= '3.12'" }, @@ -7803,7 +7761,7 @@ resolution-markers = [ ] dependencies = [ { name = "filelock", marker = "python_full_version < '3.12'" }, - { name = "huggingface-hub", version = "0.36.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "huggingface-hub", marker = "python_full_version < '3.12'" }, { name = "numpy", marker = "python_full_version < '3.12'" }, { name = "packaging", marker = "python_full_version < '3.12'" }, { name = "pyyaml", marker = "python_full_version < '3.12'" }, From 5578e04adbc0d1d80ce5d5f7bcdd69942ea8b74a Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 27 Nov 2025 18:21:52 +0100 Subject: [PATCH 2/5] better --- .../pydantic_ai/providers/huggingface.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index efc38b1773..a19f814dab 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -127,10 +127,6 @@ def model_profile(self, model_name: str) -> ModelProfile | None: if '/' not in model_name: return None - provider_override = None - if ':' in model_name: - model_name, provider_override = model_name.rsplit(':', 1) - model_name = model_name.lower() model_prefix, _ = model_name.split('/', 1) @@ -144,13 +140,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None: selected_provider_info: HfRouterProvider | None = None if router_info: providers = router_info['providers'] - if provider_override: - # Find specific provider from model suffix - for p in providers: - if p['provider'] == provider_override: - selected_provider_info = p - break - elif self._provider: + if self._provider: for p in providers: if p['provider'] == self._provider: selected_provider_info = p From d50e3cc17adb6a3e6b8d1f0fe8d216eb28a34de2 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 27 Nov 2025 18:28:46 +0100 Subject: [PATCH 3/5] add logging --- pydantic_ai_slim/pydantic_ai/providers/huggingface.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index a19f814dab..a2a37b852a 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import logging import os from dataclasses import replace from functools import lru_cache @@ -29,6 +30,7 @@ "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" ) from _import_error +_logger = logging.getLogger(__name__) HF_ROUTER_MODELS_URL = 'https://router.huggingface.co/v1/models' @@ -161,6 +163,14 @@ def model_profile(self, model_name: str) -> ModelProfile | None: provider=selected_provider_info['provider'], # type: ignore ) + provider_name = selected_provider_info['provider'] + if not selected_provider_info['supports_structured_output']: + _logger.warning( + f'Provider {provider_name} does not support structured output (NativeOutput).', + ) + if not selected_provider_info['supports_tools']: + _logger.warning(f"Provider '{provider_name}' does not support tools.") + return replace( base_profile, supports_tools=selected_provider_info['supports_tools'], From bfb835e7ddfa458efa40b1a41dd553f8d70a552d Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 27 Nov 2025 19:20:35 +0100 Subject: [PATCH 4/5] more tests --- .../pydantic_ai/models/__init__.py | 9 +- .../pydantic_ai/providers/huggingface.py | 4 +- tests/models/test_huggingface.py | 33 ++- tests/providers/test_huggingface.py | 244 +++++++++++++++++- 4 files changed, 280 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index d7e9402b9e..6b52db4865 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -210,14 +210,15 @@ 'heroku:gpt-oss-120b', 'heroku:nova-lite', 'heroku:nova-pro', - 'huggingface:Qwen/QwQ-32B', + 'huggingface:MiniMaxAI/MiniMax-M2', 'huggingface:Qwen/Qwen2.5-72B-Instruct', 'huggingface:Qwen/Qwen3-235B-A22B', 'huggingface:Qwen/Qwen3-32B', + 'huggingface:Qwen/Qwen3-Coder-30B-A3B-Instruct', 'huggingface:deepseek-ai/DeepSeek-R1', - 'huggingface:meta-llama/Llama-3.3-70B-Instruct', - 'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct', - 'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct', + 'huggingface:meta-llama/Llama-3.1-8B-Instruct', + 'huggingface:openai/gpt-oss-120b', + 'huggingface:openai/gpt-oss-20b', 'mistral:codestral-latest', 'mistral:mistral-large-latest', 'mistral:mistral-moderation-latest', diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index a2a37b852a..d402c97d35 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -130,11 +130,11 @@ def model_profile(self, model_name: str) -> ModelProfile | None: return None model_name = model_name.lower() - model_prefix, _ = model_name.split('/', 1) + model_prefix, model_suffix = model_name.split('/', 1) base_profile: ModelProfile | None = None if model_prefix in provider_to_profile: - base_profile = provider_to_profile[model_prefix](model_name) + base_profile = provider_to_profile[model_prefix](model_suffix) # fetch model capabilities router_info = _get_router_info(model_name) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index e62cc5c40e..a30288f8ba 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -51,7 +51,7 @@ ) from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.models.huggingface import HuggingFaceModel -from pydantic_ai.output import NativeOutput +from pydantic_ai.output import NativeOutput, PromptedOutput from pydantic_ai.profiles import ModelProfile from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.result import RunUsage @@ -1071,3 +1071,34 @@ async def test_native_output_structured_response(allow_model_requests: None): assert schema['type'] == 'object' assert schema['properties']['first']['type'] == 'string' assert schema['properties']['second']['type'] == 'string' + + +async def test_prompted_output_json_object_response(allow_model_requests: None): + """Test that prompted output uses json_object response format when supported.""" + completion = completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'content': '{"first": "One", "second": "Two"}', + 'role': 'assistant', + } + ) + ) + mock_client = MockHuggingFace.create_mock(completion) + model = HuggingFaceModel( + 'hf-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + profile=ModelProfile( + supports_json_schema_output=False, + supports_json_object_output=True, + ), + ) + # Using PromptedOutput triggers 'prompted' mode + agent = Agent(model, output_type=PromptedOutput(MyTypedDict)) + + result = await agent.run('Hello') + assert result.output == snapshot({'first': 'One', 'second': 'Two'}) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + response_format = kwargs['response_format'] + assert response_format is not None + assert response_format['type'] == 'json_object' diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 50c1c1e9e6..a58c92cd11 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import logging import re from unittest.mock import MagicMock, Mock, patch @@ -14,14 +15,18 @@ from pydantic_ai.profiles.meta import meta_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile from pydantic_ai.profiles.qwen import qwen_model_profile +from pydantic_ai.providers.huggingface import ( + HfRouterProvider, + HuggingFaceProvider, + _get_router_info, # type: ignore + select_provider, +) from ..conftest import TestEnv, try_import with try_import() as imports_successful: from huggingface_hub import AsyncInferenceClient - from pydantic_ai.providers.huggingface import HuggingFaceProvider - pytestmark = pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed') @@ -150,10 +155,18 @@ def test_huggingface_provider_base_url(): def test_huggingface_provider_model_profile(mocker: MockerFixture): + from pydantic_ai.providers.huggingface import _get_router_info # type: ignore + + # Clear lru_cache before mocking to ensure no cached results interfere + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + # Mock _get_router_info to return None (no network calls) + mocker.patch(f'{ns}._get_router_info', return_value=None) + mock_client = Mock(spec=AsyncInferenceClient) provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') - ns = 'pydantic_ai.providers.huggingface' deepseek_model_profile_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile) meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile) qwen_model_profile_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile) @@ -187,3 +200,228 @@ def test_huggingface_provider_model_profile(mocker: MockerFixture): unknown_profile = provider.model_profile('unknown/model') assert unknown_profile is None + + +def test_select_provider_both_capabilities(): + """Test select_provider prefers providers with both tools and structured output.""" + providers: list[HfRouterProvider] = [ + {'provider': 'p1', 'status': 'live', 'supports_tools': False, 'supports_structured_output': False}, + {'provider': 'p2', 'status': 'live', 'supports_tools': True, 'supports_structured_output': True}, + {'provider': 'p3', 'status': 'live', 'supports_tools': True, 'supports_structured_output': False}, + ] + result = select_provider(providers) + assert result is not None + assert result['provider'] == 'p2' + + +def test_select_provider_either_capability(): + """Test select_provider falls back to providers with either capability.""" + providers: list[HfRouterProvider] = [ + {'provider': 'p1', 'status': 'live', 'supports_tools': False, 'supports_structured_output': False}, + {'provider': 'p2', 'status': 'live', 'supports_tools': True, 'supports_structured_output': False}, + ] + result = select_provider(providers) + assert result is not None + assert result['provider'] == 'p2' + + +def test_select_provider_any(): + """Test select_provider falls back to any provider.""" + providers: list[HfRouterProvider] = [ + {'provider': 'p1', 'status': 'live', 'supports_tools': False, 'supports_structured_output': False}, + ] + result = select_provider(providers) + assert result is not None + assert result['provider'] == 'p1' + + +def test_select_provider_empty(): + """Test select_provider returns None for empty list.""" + result = select_provider([]) + assert result is None + + +def test_select_provider_no_live_fallback(): + """Test select_provider falls back to non-live providers if no live ones.""" + providers: list[HfRouterProvider] = [ + {'provider': 'p1', 'status': 'pending', 'supports_tools': True, 'supports_structured_output': True}, + ] + result = select_provider(providers) + assert result is not None + assert result['provider'] == 'p1' + + +def test_get_router_info_success(mocker: MockerFixture): + """Test _get_router_info successfully parses response.""" + _get_router_info.cache_clear() + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b'{"data": {"id": "test/model", "providers": []}}' + + mocker.patch('pydantic_ai.providers.huggingface.httpx.get', return_value=mock_response) + + result = _get_router_info('test/model') + assert result is not None + assert result['id'] == 'test/model' + + +def test_get_router_info_http_error(mocker: MockerFixture): + """Test _get_router_info handles HTTP errors.""" + _get_router_info.cache_clear() + + mocker.patch('pydantic_ai.providers.huggingface.httpx.get', side_effect=httpx.HTTPError('error')) + + result = _get_router_info('test/model') + assert result is None + + +def test_get_router_info_non_200(mocker: MockerFixture): + """Test _get_router_info handles non-200 status.""" + _get_router_info.cache_clear() + + mock_response = Mock() + mock_response.status_code = 404 + + mocker.patch('pydantic_ai.providers.huggingface.httpx.get', return_value=mock_response) + + result = _get_router_info('test/model') + assert result is None + + +def test_model_profile_with_router_info(mocker: MockerFixture): + """Test model_profile uses router info to select provider and set capabilities. + + This also tests that when base_profile is None (unknown prefix), a new ModelProfile is created. + """ + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'unknown/model', + 'providers': [ + {'provider': 'test-provider', 'status': 'live', 'supports_tools': True, 'supports_structured_output': True}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mock_client_class = mocker.patch(f'{ns}.AsyncInferenceClient') + + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + + # 'unknown' prefix doesn't match any known provider, so base_profile starts as None + # Router info is found, so a fresh ModelProfile is created + profile = provider.model_profile('unknown/model') + + assert profile is not None + assert profile.supports_tools is True + assert profile.supports_json_schema_output is True + assert profile.supports_json_object_output is True + # Verify the client was updated with the selected provider + mock_client_class.assert_called_with(token='test-api-key', provider='test-provider') + # Verify the provider's client was updated + assert provider.client is mock_client_class.return_value + + +def test_model_profile_with_provider_name_override(mocker: MockerFixture): + """Test model_profile respects provider_name override.""" + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'unknown/model', + 'providers': [ + {'provider': 'default', 'status': 'live', 'supports_tools': True, 'supports_structured_output': True}, + {'provider': 'override', 'status': 'live', 'supports_tools': False, 'supports_structured_output': False}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mock_client_class = mocker.patch(f'{ns}.AsyncInferenceClient') + + provider = HuggingFaceProvider(api_key='test-api-key', provider_name='override') + + profile = provider.model_profile('unknown/model') + + assert profile is not None + assert profile.supports_tools is False + mock_client_class.assert_called_with(token='test-api-key', provider='override') + + +def test_model_profile_provider_name_not_found_fallback(mocker: MockerFixture): + """Test model_profile falls back to select_provider when provider_name not found.""" + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'unknown/model', + 'providers': [ + {'provider': 'available', 'status': 'live', 'supports_tools': True, 'supports_structured_output': True}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mock_client_class = mocker.patch(f'{ns}.AsyncInferenceClient') + + provider = HuggingFaceProvider(api_key='test-api-key', provider_name='nonexistent') + + profile = provider.model_profile('unknown/model') + + assert profile is not None + # Falls back to 'available' since 'nonexistent' not found + mock_client_class.assert_called_with(token='test-api-key', provider='available') + + +def test_model_profile_logs_warning_no_structured_output(mocker: MockerFixture, caplog: pytest.LogCaptureFixture): + """Test model_profile logs warning when provider doesn't support structured output.""" + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'unknown/model', + 'providers': [ + {'provider': 'limited', 'status': 'live', 'supports_tools': True, 'supports_structured_output': False}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mocker.patch(f'{ns}.AsyncInferenceClient') + + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + + with caplog.at_level(logging.WARNING): + provider.model_profile('unknown/model') + + assert 'Provider limited does not support structured output' in caplog.text + + +def test_model_profile_logs_warning_no_tools(mocker: MockerFixture, caplog: pytest.LogCaptureFixture): + """Test model_profile logs warning when provider doesn't support tools.""" + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'unknown/model', + 'providers': [ + {'provider': 'limited', 'status': 'live', 'supports_tools': False, 'supports_structured_output': True}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mocker.patch(f'{ns}.AsyncInferenceClient') + + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + + with caplog.at_level(logging.WARNING): + provider.model_profile('unknown/model') + + assert "Provider 'limited' does not support tools" in caplog.text + + +def test_huggingface_provider_api_key_from_hf_client(monkeypatch: pytest.MonkeyPatch): + """Test api_key is extracted from hf_client.token when not provided.""" + monkeypatch.delenv('HF_TOKEN', raising=False) + + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.token = 'client-token' + + provider = HuggingFaceProvider(hf_client=mock_client) + assert provider.api_key == 'client-token' From 3c8f0ffa894fbc894074fd746cbadd79f0af8c2d Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Thu, 27 Nov 2025 19:41:19 +0100 Subject: [PATCH 5/5] one additional test --- tests/providers/test_huggingface.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index a58c92cd11..88096a4800 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -425,3 +425,32 @@ def test_huggingface_provider_api_key_from_hf_client(monkeypatch: pytest.MonkeyP provider = HuggingFaceProvider(hf_client=mock_client) assert provider.api_key == 'client-token' + + +def test_model_profile_with_router_info_and_known_prefix(mocker: MockerFixture): + """Test model_profile when base_profile exists (known prefix) AND router info is found. + + This covers the branch where base_profile is NOT None but we still update capabilities from router. + """ + _get_router_info.cache_clear() + + ns = 'pydantic_ai.providers.huggingface' + router_info = { + 'id': 'qwen/qwen-model', + 'providers': [ + {'provider': 'test-provider', 'status': 'live', 'supports_tools': True, 'supports_structured_output': True}, + ], + } + mocker.patch(f'{ns}._get_router_info', return_value=router_info) + mock_client_class = mocker.patch(f'{ns}.AsyncInferenceClient') + + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + + profile = provider.model_profile('Qwen/qwen-model') + + assert profile is not None + assert profile.supports_tools is True + assert profile.supports_json_schema_output is True + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + mock_client_class.assert_called_with(token='test-api-key', provider='test-provider')