Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pip/uv-add "pydantic-ai-slim[openai]"
* `mistral` — installs [Mistral Model](models/mistral.md) dependency `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"}
* `cohere` - installs [Cohere Model](models/cohere.md) dependency `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"}
* `bedrock` - installs [Bedrock Model](models/bedrock.md) dependency `boto3` [PyPI ↗](https://pypi.org/project/boto3){:target="_blank"}
* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub[inference]` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"}
* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"}
* `outlines-transformers` - installs [Outlines Model](models/outlines.md) dependency `outlines[transformers]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}
* `outlines-llamacpp` - installs [Outlines Model](models/outlines.md) dependency `outlines[llamacpp]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}
* `outlines-mlxlm` - installs [Outlines Model](models/outlines.md) dependency `outlines[mlxlm]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}
Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,15 @@
'heroku:gpt-oss-120b',
'heroku:nova-lite',
'heroku:nova-pro',
'huggingface:Qwen/QwQ-32B',
'huggingface:MiniMaxAI/MiniMax-M2',
'huggingface:Qwen/Qwen2.5-72B-Instruct',
'huggingface:Qwen/Qwen3-235B-A22B',
'huggingface:Qwen/Qwen3-32B',
'huggingface:Qwen/Qwen3-Coder-30B-A3B-Instruct',
'huggingface:deepseek-ai/DeepSeek-R1',
'huggingface:meta-llama/Llama-3.3-70B-Instruct',
'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct',
'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct',
'huggingface:meta-llama/Llama-3.1-8B-Instruct',
'huggingface:openai/gpt-oss-120b',
'huggingface:openai/gpt-oss-20b',
'mistral:codestral-latest',
'mistral:mistral-large-latest',
'mistral:mistral-moderation-latest',
Expand Down
50 changes: 37 additions & 13 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import assert_never

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
Expand Down Expand Up @@ -41,19 +42,17 @@
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
ModelRequestParameters,
StreamedResponse,
check_allow_model_requests,
)
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests

try:
import aiohttp
from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputGrammarType,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputResponseFormatJSONObject,
ChatCompletionInputResponseFormatJSONSchema,
ChatCompletionInputTool,
ChatCompletionInputToolCall,
ChatCompletionInputURL,
Expand All @@ -80,10 +79,11 @@

LatestHuggingFaceModelNames = Literal[
'deepseek-ai/DeepSeek-R1',
'meta-llama/Llama-3.3-70B-Instruct',
'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
'meta-llama/Llama-4-Scout-17B-16E-Instruct',
'Qwen/QwQ-32B',
'meta-llama/Llama-3.1-8B-Instruct',
'MiniMaxAI/MiniMax-M2',
'openai/gpt-oss-20b',
'openai/gpt-oss-120b',
'Qwen/Qwen3-Coder-30B-A3B-Instruct',
'Qwen/Qwen2.5-72B-Instruct',
'Qwen/Qwen3-235B-A22B',
'Qwen/Qwen3-32B',
Expand Down Expand Up @@ -142,13 +142,14 @@ def __init__(
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name
if isinstance(provider, str):
provider = infer_provider(provider)
self._provider = provider
self.client = provider.client
provider_profile = provider.model_profile(model_name)
self._model_name = model_name.rsplit(':', 1)[0]

super().__init__(settings=settings, profile=profile or provider.model_profile)
super().__init__(settings=settings, profile=profile or provider_profile)
self.client = provider.client

@property
def model_name(self) -> HuggingFaceModelName:
Expand Down Expand Up @@ -233,6 +234,15 @@ async def _completions_create(
raise UserError('HuggingFace does not support built-in tools')

hf_messages = await self._map_messages(messages, model_request_parameters)
response_format: ChatCompletionInputGrammarType | None = None
if model_request_parameters.output_mode == 'native':
output_object = model_request_parameters.output_object
assert output_object is not None
response_format = self._map_json_schema(output_object)
elif (
model_request_parameters.output_mode == 'prompted' and self.profile.supports_json_object_output
): # pragma: no branch
response_format = ChatCompletionInputResponseFormatJSONObject.parse_obj_as_instance({'type': 'json_object'}) # type: ignore

try:
return await self.client.chat.completions.create( # type: ignore
Expand All @@ -245,6 +255,7 @@ async def _completions_create(
temperature=model_settings.get('temperature', None),
top_p=model_settings.get('top_p', None),
seed=model_settings.get('seed', None),
response_format=response_format or None,
presence_penalty=model_settings.get('presence_penalty', None),
frequency_penalty=model_settings.get('frequency_penalty', None),
logit_bias=model_settings.get('logit_bias', None), # type: ignore
Expand Down Expand Up @@ -377,6 +388,19 @@ def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
}
)

def _map_json_schema(self, o: OutputObjectDefinition) -> ChatCompletionInputGrammarType:
response_format_param: ChatCompletionInputResponseFormatJSONSchema = { # type: ignore
'type': 'json_schema',
'json_schema': {
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
'schema': o.json_schema,
'strict': o.strict,
},
}
if o.description: # pragma: no branch
response_format_param['json_schema']['description'] = o.description
return response_format_param

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
Expand Down
140 changes: 133 additions & 7 deletions pydantic_ai_slim/pydantic_ai/providers/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations as _annotations

import logging
import os
from dataclasses import replace
from functools import lru_cache
from typing import overload

import httpx
from httpx import AsyncClient
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypedDict

from pydantic_ai import ModelProfile
from pydantic_ai.exceptions import UserError
Expand All @@ -14,6 +20,8 @@
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
from pydantic_ai.profiles.qwen import qwen_model_profile

from . import Provider

try:
from huggingface_hub import AsyncInferenceClient
except ImportError as _import_error: # pragma: no cover
Expand All @@ -22,7 +30,75 @@
"you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`"
) from _import_error

from . import Provider
_logger = logging.getLogger(__name__)

HF_ROUTER_MODELS_URL = 'https://router.huggingface.co/v1/models'


class HfRouterModel(TypedDict):
"""Hugging Face router model definition."""

id: str


class HfRouterResponse(TypedDict):
"""Hugging Face router response."""

data: list[HfRouterModel]


class HfRouterProvider(TypedDict):
"""Hugging Face router provider definition."""

provider: str
status: str
supports_tools: bool
supports_structured_output: bool


class HfRouterModelInfo(TypedDict):
"""Hugging Face router model info."""

id: str
providers: list[HfRouterProvider]


class HfRouterResponseData(TypedDict):
"""Hugging Face router response data."""

data: HfRouterModelInfo


@lru_cache(maxsize=128)
def _get_router_info(model_id: str) -> HfRouterModelInfo | None:
try:
resp = httpx.get(f'{HF_ROUTER_MODELS_URL}/{model_id}', timeout=5.0, follow_redirects=True)
if resp.status_code != 200:
return None
payload = TypeAdapter(HfRouterResponseData).validate_json(resp.content)
return payload['data']
except (httpx.HTTPError, ValidationError, Exception):
return None


def select_provider(providers: list[HfRouterProvider]) -> HfRouterProvider | None:
"""Select the best provider based on capabilities."""
live_providers = [p for p in providers if p['status'] == 'live']
if not live_providers:
live_providers = providers

# 1 - supports_tools=True AND supports_structured_output=True
both = [p for p in live_providers if p['supports_tools'] and p['supports_structured_output']]
if both:
return both[0]

# 2 - supports_tools=True OR supports_structured_output=True
either = [p for p in live_providers if p['supports_tools'] or p['supports_structured_output']]
if either:
return either[0]

# 3 - Any
return live_providers[0] if live_providers else None


class HuggingFaceProvider(Provider[AsyncInferenceClient]):
Expand Down Expand Up @@ -54,11 +130,55 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
return None

model_name = model_name.lower()
provider, model_name = model_name.split('/', 1)
if provider in provider_to_profile:
return provider_to_profile[provider](model_name)
model_prefix, model_suffix = model_name.split('/', 1)

base_profile: ModelProfile | None = None
if model_prefix in provider_to_profile:
base_profile = provider_to_profile[model_prefix](model_suffix)

# fetch model capabilities
router_info = _get_router_info(model_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I think it breaks expectations too much for the sync Provider.model_profile method to be doing a HTTP request 😬 We could make it a bit less problematic by moving the request to a thread using _utils.run_in_executor but I'm still uncomfortable with this to the point where I can't see myself merging this...

Overriding self._client below is also problematic because these providers are often reused across multiple model instances.

Is there any way we could achieve your goals by waiting to do this capabilities checking / provider selection until Model.request?


selected_provider_info: HfRouterProvider | None = None
if router_info:
providers = router_info['providers']
if self._provider:
for p in providers:
if p['provider'] == self._provider:
selected_provider_info = p
break
if selected_provider_info is None:
selected_provider_info = select_provider(providers)
else:
# Auto select using router preference
selected_provider_info = select_provider(providers)

if selected_provider_info:
if base_profile is None:
base_profile = ModelProfile()

# Update the client to use the selected provider
self._client = AsyncInferenceClient(
token=self.api_key,
provider=selected_provider_info['provider'], # type: ignore
)

return None
provider_name = selected_provider_info['provider']
if not selected_provider_info['supports_structured_output']:
_logger.warning(
f'Provider {provider_name} does not support structured output (NativeOutput).',
)
if not selected_provider_info['supports_tools']:
_logger.warning(f"Provider '{provider_name}' does not support tools.")

return replace(
base_profile,
supports_tools=selected_provider_info['supports_tools'],
supports_json_schema_output=selected_provider_info['supports_structured_output'],
supports_json_object_output=selected_provider_info['supports_structured_output'],
)

return base_profile

@overload
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
Expand Down Expand Up @@ -96,8 +216,12 @@ def __init__(
If `base_url` is passed, then `provider_name` is not used.
"""
api_key = api_key or os.getenv('HF_TOKEN')
if api_key is None and hf_client is not None:
api_key = getattr(hf_client, 'token', None)

if api_key is None:
self.api_key = api_key

if self.api_key is None:
raise UserError(
'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`'
'to use the HuggingFace provider.'
Expand All @@ -110,6 +234,8 @@ def __init__(
raise ValueError('Cannot provide both `base_url` and `provider_name`.')

if hf_client is None:
self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
self._client = AsyncInferenceClient(api_key=self.api_key, provider=provider_name, base_url=base_url) # type: ignore
else:
self._client = hf_client

self._provider: str | None = provider_name
Loading