diff --git a/docs/models/openai.md b/docs/models/openai.md index 2eae27bd03..403cdf6696 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -730,3 +730,33 @@ result = agent.run_sync('What is the capital of France?') print(result.output) #> The capital of France is Paris. ``` + +### Qwen + +To use Qwen models via the OpenAI-compatible API from [Alibaba Cloud DashScope](https://www.alibabacloud.com/help/doc-detail/2712576.html), you can set the `QWEN_API_KEY` (or `DASHSCOPE_API_KEY`) environment variable and use [`QwenProvider`][pydantic_ai.providers.qwen.QwenProvider] by name: + +```python +from pydantic_ai import Agent + +agent = Agent('qwen:qwen-max') +... +``` + +Or initialise the model and provider directly: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.qwen import QwenProvider + +model = OpenAIChatModel( + 'qwen-max', + provider=QwenProvider(api_key='your-qwen-api-key'), +) +agent = Agent(model) +... +``` + +The `QwenProvider` uses the international DashScope compatible endpoint `https://dashscope-intl.aliyuncs.com/compatible-mode/v1` by default. + +When using **Qwen Omni** models (e.g. `qwen-omni-turbo`), this provider automatically handles audio input using the Data URI format required by the DashScope API. diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 3bc676559a..078dd38464 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -817,6 +817,7 @@ def infer_model( # noqa: C901 'litellm', 'nebius', 'ovhcloud', + 'qwen', ): model_kind = 'openai-chat' elif model_kind in ('google-gla', 'google-vertex'): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 10af284ee8..5280e037d6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -939,7 +939,11 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) elif item.is_audio: assert item.format in ('wav', 'mp3') - audio = InputAudio(data=base64.b64encode(item.data).decode('utf-8'), format=item.format) + profile = OpenAIModelProfile.from_profile(self.profile) + if profile.openai_chat_audio_input_encoding == 'uri': + audio = InputAudio(data=item.data_uri, format=item.format) + else: + audio = InputAudio(data=base64.b64encode(item.data).decode('utf-8'), format=item.format) content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) elif item.is_document: content.append( @@ -959,7 +963,13 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse 'wav', 'mp3', ), f'Unsupported audio format: {downloaded_item["data_type"]}' - audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type']) + profile = OpenAIModelProfile.from_profile(self.profile) + if profile.openai_chat_audio_input_encoding == 'uri': + mime_type = item.media_type or f'audio/{downloaded_item["data_type"]}' + data_uri = f'data:{mime_type};base64,{downloaded_item["data"]}' + audio = InputAudio(data=data_uri, format=downloaded_item['data_type']) + else: + audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type']) content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) elif isinstance(item, DocumentUrl): if self._is_text_like_media_type(item.media_type): diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 37c0316c34..1c1caf6f89 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -41,6 +41,13 @@ class OpenAIModelProfile(ModelProfile): openai_chat_supports_web_search: bool = False """Whether the model supports web search in Chat Completions API.""" + openai_chat_audio_input_encoding: Literal['base64', 'uri'] = 'base64' + """The encoding to use for audio input in Chat Completions requests. + + - `'base64'`: Raw base64 encoded string. (Default, used by OpenAI) + - `'uri'`: Data URI (e.g. `data:audio/wav;base64,...`). (Used by Qwen Omni) + """ + openai_supports_encrypted_reasoning_content: bool = False """Whether the model supports including encrypted reasoning content in the response.""" diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 9557e8e87b..b1fed5c3a5 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -145,6 +145,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .ovhcloud import OVHcloudProvider return OVHcloudProvider + elif provider == 'qwen': + from .qwen import QwenProvider + + return QwenProvider elif provider == 'outlines': from .outlines import OutlinesProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/qwen.py b/pydantic_ai_slim/pydantic_ai/providers/qwen.py new file mode 100644 index 0000000000..29361851b9 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/qwen.py @@ -0,0 +1,86 @@ +from __future__ import annotations as _annotations + +import os +from typing import overload + +import httpx +from openai import AsyncOpenAI + +from pydantic_ai import ModelProfile +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile +from pydantic_ai.profiles.qwen import qwen_model_profile +from pydantic_ai.providers import Provider + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the Qwen provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + + +class QwenProvider(Provider[AsyncOpenAI]): + """Provider for Qwen / DashScope OpenAI-compatible API.""" + + @property + def name(self) -> str: + return 'qwen' + + @property + def base_url(self) -> str: + # Using the international endpoint by default as it's more standard for global users + # Users in China region can override this via passing `openai_client` or implementing logic to check region + return 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' + + @property + def client(self) -> AsyncOpenAI: + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + base_profile = qwen_model_profile(model_name) + + # Wrap/merge into OpenAIModelProfile + openai_profile = OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(base_profile) + + # For Qwen Omni models, force URI audio input encoding + if 'omni' in model_name.lower(): + openai_profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri').update(openai_profile) + + return openai_profile + + @overload + def __init__(self) -> None: ... + + @overload + def __init__(self, *, api_key: str) -> None: ... + + @overload + def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... + + @overload + def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... + + def __init__( + self, + *, + api_key: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: + api_key = api_key or os.getenv('QWEN_API_KEY') or os.getenv('DASHSCOPE_API_KEY') + if not api_key and openai_client is None: + raise UserError( + 'Set the `QWEN_API_KEY` (or `DASHSCOPE_API_KEY`) environment variable or pass it via ' + '`QwenProvider(api_key=...)` to use the Qwen provider.' + ) + + if openai_client is not None: + self._client = openai_client + elif http_client is not None: + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) + else: + http_client = cached_async_http_client(provider='qwen') + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) diff --git a/tests/models/test_openai_audio.py b/tests/models/test_openai_audio.py new file mode 100644 index 0000000000..4dd310af50 --- /dev/null +++ b/tests/models/test_openai_audio.py @@ -0,0 +1,167 @@ +from __future__ import annotations as _annotations + +import base64 +from unittest.mock import patch + +import pytest + +from pydantic_ai import Agent, AudioUrl, BinaryContent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.profiles.openai import OpenAIModelProfile +from pydantic_ai.providers.openai import OpenAIProvider + +from ..conftest import try_import +from .mock_openai import MockOpenAI, completion_message, get_mock_chat_completion_kwargs + +with try_import() as imports_successful: + from openai.types.chat.chat_completion_message import ChatCompletionMessage + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='openai not installed'), + pytest.mark.anyio, +] + + +def test_openai_chat_audio_default_base64(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='success', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + # BinaryContent + audio_data = b'fake_audio_data' + binary_audio = BinaryContent(audio_data, media_type='audio/wav') + + agent.run_sync(['Process this audio', binary_audio]) + + request_kwargs = get_mock_chat_completion_kwargs(mock_client) + messages = request_kwargs[0]['messages'] + user_message = messages[0] + + # Find the input_audio part + audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') + + # Expect raw base64 + expected_data = base64.b64encode(audio_data).decode('utf-8') + assert audio_part['input_audio']['data'] == expected_data + assert audio_part['input_audio']['format'] == 'wav' + + +def test_openai_chat_audio_uri_encoding(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='success', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + + # Set profile to use URI encoding + profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') + model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + agent = Agent(model) + + # BinaryContent + audio_data = b'fake_audio_data' + binary_audio = BinaryContent(audio_data, media_type='audio/wav') + + agent.run_sync(['Process this audio', binary_audio]) + + request_kwargs = get_mock_chat_completion_kwargs(mock_client) + messages = request_kwargs[0]['messages'] + user_message = messages[0] + + # Find the input_audio part + audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') + + # Expect Data URI + expected_data = f'data:audio/wav;base64,{base64.b64encode(audio_data).decode("utf-8")}' + assert audio_part['input_audio']['data'] == expected_data + assert audio_part['input_audio']['format'] == 'wav' + + +async def test_openai_chat_audio_url_default_base64(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='success', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(model) + + audio_url = AudioUrl('https://example.com/audio.mp3') + + # Mock download_item to return base64 data + fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') + + with patch('pydantic_ai.models.openai.download_item') as mock_download: + mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} + + await agent.run(['Process this audio url', audio_url]) + + request_kwargs = get_mock_chat_completion_kwargs(mock_client) + messages = request_kwargs[0]['messages'] + user_message = messages[0] + + # Find the input_audio part + audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') + + # Expect raw base64 (which is what download_item returns in this mock) + assert audio_part['input_audio']['data'] == fake_base64_data + assert audio_part['input_audio']['format'] == 'mp3' + + +async def test_openai_chat_audio_url_uri_encoding(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='success', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + + # Set profile to use URI encoding + profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') + model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + agent = Agent(model) + + audio_url = AudioUrl('https://example.com/audio.mp3') + + # Mock download_item to return base64 data + fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') + + with patch('pydantic_ai.models.openai.download_item') as mock_download: + mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} + + await agent.run(['Process this audio url', audio_url]) + + request_kwargs = get_mock_chat_completion_kwargs(mock_client) + messages = request_kwargs[0]['messages'] + user_message = messages[0] + + # Find the input_audio part + audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') + + # Expect Data URI with correct MIME type for mp3 + expected_data = f'data:audio/mpeg;base64,{fake_base64_data}' + assert audio_part['input_audio']['data'] == expected_data + assert audio_part['input_audio']['format'] == 'mp3' + + +async def test_openai_chat_audio_url_custom_media_type(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='success', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + + # Set profile to use URI encoding + profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') + model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) + agent = Agent(model) + + # AudioUrl with explicit media_type that differs from default extension mapping + # e.g., .mp3 extension but we want to force a specific weird mime type + audio_url = AudioUrl('https://example.com/audio.mp3', media_type='audio/custom-weird-format') + + fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') + + with patch('pydantic_ai.models.openai.download_item') as mock_download: + mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} + + await agent.run(['Process this audio url', audio_url]) + + request_kwargs = get_mock_chat_completion_kwargs(mock_client) + messages = request_kwargs[0]['messages'] + user_message = messages[0] + + audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') + + # Expect Data URI with the CUSTOM MIME type + expected_data = f'data:audio/custom-weird-format;base64,{fake_base64_data}' + assert audio_part['input_audio']['data'] == expected_data + assert audio_part['input_audio']['format'] == 'mp3' diff --git a/tests/providers/test_qwen_provider.py b/tests/providers/test_qwen_provider.py new file mode 100644 index 0000000000..b114a13725 --- /dev/null +++ b/tests/providers/test_qwen_provider.py @@ -0,0 +1,80 @@ +import httpx +import pytest + +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles.openai import OpenAIModelProfile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + import openai + + from pydantic_ai.providers import infer_provider + from pydantic_ai.providers.qwen import QwenProvider + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') + + +def test_qwen_provider_init(): + provider = QwenProvider(api_key='test-key') + assert provider.name == 'qwen' + assert provider.base_url == 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' + assert isinstance(provider.client, openai.AsyncOpenAI) + assert provider.client.api_key == 'test-key' + + +def test_qwen_provider_env_key(env: TestEnv): + env.set('QWEN_API_KEY', 'env-key') + provider = QwenProvider() + assert provider.client.api_key == 'env-key' + + +def test_qwen_provider_dashscope_env_key_fallback(env: TestEnv): + env.remove('QWEN_API_KEY') + env.set('DASHSCOPE_API_KEY', 'dash-key') + provider = QwenProvider() + assert provider.client.api_key == 'dash-key' + + +def test_qwen_provider_missing_key(env: TestEnv): + env.remove('QWEN_API_KEY') + env.remove('DASHSCOPE_API_KEY') + with pytest.raises(UserError, match='Set the `QWEN_API_KEY`'): + QwenProvider() + + +def test_infer_provider(env: TestEnv): + # infer_provider instantiates the class, so we need an env var or it raises UserError + env.set('QWEN_API_KEY', 'key') + provider = infer_provider('qwen') + assert isinstance(provider, QwenProvider) + + +def test_qwen_omni_profile_audio_uri(): + provider = QwenProvider(api_key='key') + # Omni model -> expect 'uri' encoding + profile = provider.model_profile('qwen-omni-turbo') + assert isinstance(profile, OpenAIModelProfile) + assert profile.openai_chat_audio_input_encoding == 'uri' + + +def test_qwen_non_omni_profile_default(): + provider = QwenProvider(api_key='key') + # Non-omni model -> expect default (base64) + profile = provider.model_profile('qwen-max') + assert isinstance(profile, OpenAIModelProfile) + assert profile.openai_chat_audio_input_encoding == 'base64' + + +def test_qwen_provider_with_openai_client(): + client = openai.AsyncOpenAI(api_key='foo') + provider = QwenProvider(openai_client=client) + assert provider.client is client + + +def test_qwen_provider_with_http_client(): + http_client = httpx.AsyncClient() + provider = QwenProvider(api_key='foo', http_client=http_client) + assert provider.client.api_key == 'foo' + # The line `self._client = AsyncOpenAI(..., http_client=http_client)` is executed, + # which is enough for coverage. diff --git a/tests/test_examples.py b/tests/test_examples.py index 3490f0dd3e..7ca5fee68e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -179,6 +179,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('MOONSHOTAI_API_KEY', 'testing') env.set('DEEPSEEK_API_KEY', 'testing') env.set('OVHCLOUD_API_KEY', 'testing') + env.set('QWEN_API_KEY', 'testing') env.set('PYDANTIC_AI_GATEWAY_API_KEY', 'testing') prefix_settings = example.prefix_settings()