-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add gateway/...:... to Known Model Names
#3593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
293b3f8
00d05bb
90680fc
e3b090a
4af4513
9d213ce
6e05859
dc3fd1c
0359386
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from __future__ import annotations as _annotations | ||
|
|
||
| import os | ||
| from collections.abc import Awaitable, Callable | ||
| from collections.abc import Awaitable, Callable, Mapping | ||
| from typing import TYPE_CHECKING, Any, Literal, overload | ||
|
|
||
| import httpx | ||
|
|
@@ -93,13 +93,38 @@ def gateway_provider( | |
| ) -> Provider[Any]: ... | ||
|
|
||
|
|
||
| UpstreamProvider = Literal[ | ||
| ModelProviders = Literal[ | ||
| 'openai', | ||
| 'groq', | ||
| 'anthropic', | ||
| 'bedrock', | ||
| 'google-vertex', | ||
| # Those are only API formats, but we still support them for convenience. | ||
| ] | ||
|
|
||
|
|
||
| def gateway_provider_to_model_names() -> Mapping[ModelProviders, object]: | ||
|
||
| """Get the gateway model name for a given provider. | ||
|
|
||
| Args: | ||
| provider: The provider to get the model name for. | ||
| """ | ||
| from pydantic_ai.models.anthropic import AnthropicModelName | ||
| from pydantic_ai.models.bedrock import BedrockModelName | ||
| from pydantic_ai.models.google import GoogleModelName | ||
| from pydantic_ai.models.groq import GroqModelName | ||
| from pydantic_ai.models.openai import OpenAIModelName | ||
|
|
||
| return { | ||
|
||
| 'openai': OpenAIModelName, | ||
| 'groq': GroqModelName, | ||
| 'anthropic': AnthropicModelName, | ||
| 'bedrock': BedrockModelName, | ||
| 'google-vertex': GoogleModelName, | ||
| } | ||
|
|
||
|
|
||
| # These are only API formats, but we still support them for convenience. | ||
| ApiFormatProviders = Literal[ | ||
|
||
| 'openai-chat', | ||
| 'openai-responses', | ||
| 'chat', | ||
|
|
@@ -108,6 +133,8 @@ def gateway_provider( | |
| 'gemini', | ||
| ] | ||
|
|
||
| UpstreamProvider = ModelProviders | ApiFormatProviders | ||
|
|
||
|
|
||
| def gateway_provider( | ||
| upstream_provider: UpstreamProvider | str, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from pydantic_ai.models.huggingface import HuggingFaceModelName | ||
| from pydantic_ai.models.mistral import MistralModelName | ||
| from pydantic_ai.models.openai import OpenAIModelName | ||
| from pydantic_ai.providers.gateway import gateway_provider_to_model_names | ||
| from pydantic_ai.providers.grok import GrokModelName | ||
| from pydantic_ai.providers.moonshotai import MoonshotAIModelName | ||
|
|
||
|
|
@@ -70,6 +71,11 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: | |
| openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] | ||
| bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] | ||
| deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] | ||
| gateway_names = [ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize you couldn't see this comment in a private Slack channel, but I responded to Samuel (and he agreed):
So we should NOT hard-code this, but dynamically build this based on the known model names of the providers that are known to work with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now done |
||
| f'gateway/{provider}:{model_name}' | ||
| for provider, model_names in gateway_provider_to_model_names().items() | ||
| for model_name in get_model_names(model_names) | ||
| ] | ||
| huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] | ||
| heroku_names = get_heroku_model_names() | ||
| cerebras_names = get_cerebras_model_names() | ||
|
|
@@ -86,6 +92,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: | |
| + openai_names | ||
| + bedrock_names | ||
| + deepseek_names | ||
| + gateway_names | ||
| + huggingface_names | ||
| + heroku_names | ||
| + cerebras_names | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a type, so singular :)