-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add configurable audio encoding for OpenAI models (Data URI support) #3596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Pavanmanikanta98
wants to merge
4
commits into
pydantic:main
Choose a base branch
from
Pavanmanikanta98:fix/qwen-omni-audio-encoding
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import os | ||
| from typing import overload | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
| from pydantic_ai import ModelProfile | ||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.models import cached_async_http_client | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile | ||
| from pydantic_ai.profiles.qwen import qwen_model_profile | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| try: | ||
| from openai import AsyncOpenAI | ||
| except ImportError as _import_error: # pragma: no cover | ||
| raise ImportError( | ||
| 'Please install the `openai` package to use the Qwen provider, ' | ||
| 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' | ||
| ) from _import_error | ||
|
|
||
|
|
||
| class QwenProvider(Provider[AsyncOpenAI]): | ||
| """Provider for Qwen / DashScope OpenAI-compatible API.""" | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return 'qwen' | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| # Using the international endpoint by default as it's more standard for global users | ||
| # Users in China region can override this via passing `openai_client` or implementing logic to check region | ||
| return 'https://dashscope-intl.aliyuncs.com/compatible-mode/v1' | ||
|
|
||
| @property | ||
| def client(self) -> AsyncOpenAI: | ||
| return self._client | ||
|
|
||
| def model_profile(self, model_name: str) -> ModelProfile | None: | ||
| base_profile = qwen_model_profile(model_name) | ||
|
|
||
| # Wrap/merge into OpenAIModelProfile | ||
| openai_profile = OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(base_profile) | ||
|
|
||
| # For Qwen Omni models, force URI audio input encoding | ||
| if 'omni' in model_name.lower(): | ||
| openai_profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri').update(openai_profile) | ||
|
|
||
| return openai_profile | ||
|
|
||
| @overload | ||
| def __init__(self) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, api_key: str) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| api_key: str | None = None, | ||
| openai_client: AsyncOpenAI | None = None, | ||
| http_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| api_key = api_key or os.getenv('QWEN_API_KEY') or os.getenv('DASHSCOPE_API_KEY') | ||
| if not api_key and openai_client is None: | ||
| raise UserError( | ||
| 'Set the `QWEN_API_KEY` (or `DASHSCOPE_API_KEY`) environment variable or pass it via ' | ||
| '`QwenProvider(api_key=...)` to use the Qwen provider.' | ||
| ) | ||
|
|
||
| if openai_client is not None: | ||
| self._client = openai_client | ||
| elif http_client is not None: | ||
| self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
| else: | ||
| http_client = cached_async_http_client(provider='qwen') | ||
| self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import base64 | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
|
|
||
| from pydantic_ai import Agent, AudioUrl, BinaryContent | ||
| from pydantic_ai.models.openai import OpenAIChatModel | ||
| from pydantic_ai.profiles.openai import OpenAIModelProfile | ||
| from pydantic_ai.providers.openai import OpenAIProvider | ||
|
|
||
| from ..conftest import try_import | ||
| from .mock_openai import MockOpenAI, completion_message, get_mock_chat_completion_kwargs | ||
|
|
||
| with try_import() as imports_successful: | ||
| from openai.types.chat.chat_completion_message import ChatCompletionMessage | ||
|
|
||
| pytestmark = [ | ||
| pytest.mark.skipif(not imports_successful(), reason='openai not installed'), | ||
| pytest.mark.anyio, | ||
| ] | ||
|
|
||
|
|
||
| def test_openai_chat_audio_default_base64(allow_model_requests: None): | ||
| c = completion_message(ChatCompletionMessage(content='success', role='assistant')) | ||
| mock_client = MockOpenAI.create_mock(c) | ||
| model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client)) | ||
| agent = Agent(model) | ||
|
|
||
| # BinaryContent | ||
| audio_data = b'fake_audio_data' | ||
| binary_audio = BinaryContent(audio_data, media_type='audio/wav') | ||
|
|
||
| agent.run_sync(['Process this audio', binary_audio]) | ||
|
|
||
| request_kwargs = get_mock_chat_completion_kwargs(mock_client) | ||
| messages = request_kwargs[0]['messages'] | ||
| user_message = messages[0] | ||
|
|
||
| # Find the input_audio part | ||
| audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') | ||
|
|
||
| # Expect raw base64 | ||
| expected_data = base64.b64encode(audio_data).decode('utf-8') | ||
| assert audio_part['input_audio']['data'] == expected_data | ||
| assert audio_part['input_audio']['format'] == 'wav' | ||
|
|
||
|
|
||
| def test_openai_chat_audio_uri_encoding(allow_model_requests: None): | ||
| c = completion_message(ChatCompletionMessage(content='success', role='assistant')) | ||
| mock_client = MockOpenAI.create_mock(c) | ||
|
|
||
| # Set profile to use URI encoding | ||
| profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') | ||
| model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) | ||
| agent = Agent(model) | ||
|
|
||
| # BinaryContent | ||
| audio_data = b'fake_audio_data' | ||
| binary_audio = BinaryContent(audio_data, media_type='audio/wav') | ||
|
|
||
| agent.run_sync(['Process this audio', binary_audio]) | ||
|
|
||
| request_kwargs = get_mock_chat_completion_kwargs(mock_client) | ||
| messages = request_kwargs[0]['messages'] | ||
| user_message = messages[0] | ||
|
|
||
| # Find the input_audio part | ||
| audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') | ||
|
|
||
| # Expect Data URI | ||
| expected_data = f'data:audio/wav;base64,{base64.b64encode(audio_data).decode("utf-8")}' | ||
| assert audio_part['input_audio']['data'] == expected_data | ||
| assert audio_part['input_audio']['format'] == 'wav' | ||
|
|
||
|
|
||
| async def test_openai_chat_audio_url_default_base64(allow_model_requests: None): | ||
| c = completion_message(ChatCompletionMessage(content='success', role='assistant')) | ||
| mock_client = MockOpenAI.create_mock(c) | ||
| model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client)) | ||
| agent = Agent(model) | ||
|
|
||
| audio_url = AudioUrl('https://example.com/audio.mp3') | ||
|
|
||
| # Mock download_item to return base64 data | ||
| fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') | ||
|
|
||
| with patch('pydantic_ai.models.openai.download_item') as mock_download: | ||
| mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} | ||
|
|
||
| await agent.run(['Process this audio url', audio_url]) | ||
|
|
||
| request_kwargs = get_mock_chat_completion_kwargs(mock_client) | ||
| messages = request_kwargs[0]['messages'] | ||
| user_message = messages[0] | ||
|
|
||
| # Find the input_audio part | ||
| audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') | ||
|
|
||
| # Expect raw base64 (which is what download_item returns in this mock) | ||
| assert audio_part['input_audio']['data'] == fake_base64_data | ||
| assert audio_part['input_audio']['format'] == 'mp3' | ||
|
|
||
|
|
||
| async def test_openai_chat_audio_url_uri_encoding(allow_model_requests: None): | ||
| c = completion_message(ChatCompletionMessage(content='success', role='assistant')) | ||
| mock_client = MockOpenAI.create_mock(c) | ||
|
|
||
| # Set profile to use URI encoding | ||
| profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') | ||
| model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) | ||
| agent = Agent(model) | ||
|
|
||
| audio_url = AudioUrl('https://example.com/audio.mp3') | ||
|
|
||
| # Mock download_item to return base64 data | ||
| fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') | ||
|
|
||
| with patch('pydantic_ai.models.openai.download_item') as mock_download: | ||
| mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} | ||
|
|
||
| await agent.run(['Process this audio url', audio_url]) | ||
|
|
||
| request_kwargs = get_mock_chat_completion_kwargs(mock_client) | ||
| messages = request_kwargs[0]['messages'] | ||
| user_message = messages[0] | ||
|
|
||
| # Find the input_audio part | ||
| audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') | ||
|
|
||
| # Expect Data URI with correct MIME type for mp3 | ||
| expected_data = f'data:audio/mpeg;base64,{fake_base64_data}' | ||
| assert audio_part['input_audio']['data'] == expected_data | ||
| assert audio_part['input_audio']['format'] == 'mp3' | ||
|
|
||
|
|
||
| async def test_openai_chat_audio_url_custom_media_type(allow_model_requests: None): | ||
| c = completion_message(ChatCompletionMessage(content='success', role='assistant')) | ||
| mock_client = MockOpenAI.create_mock(c) | ||
|
|
||
| # Set profile to use URI encoding | ||
| profile = OpenAIModelProfile(openai_chat_audio_input_encoding='uri') | ||
| model = OpenAIChatModel('gpt-4o-audio-preview', provider=OpenAIProvider(openai_client=mock_client), profile=profile) | ||
| agent = Agent(model) | ||
|
|
||
| # AudioUrl with explicit media_type that differs from default extension mapping | ||
| # e.g., .mp3 extension but we want to force a specific weird mime type | ||
| audio_url = AudioUrl('https://example.com/audio.mp3', media_type='audio/custom-weird-format') | ||
|
|
||
| fake_base64_data = base64.b64encode(b'fake_downloaded_audio').decode('utf-8') | ||
|
|
||
| with patch('pydantic_ai.models.openai.download_item') as mock_download: | ||
| mock_download.return_value = {'data': fake_base64_data, 'data_type': 'mp3'} | ||
|
|
||
| await agent.run(['Process this audio url', audio_url]) | ||
|
|
||
| request_kwargs = get_mock_chat_completion_kwargs(mock_client) | ||
| messages = request_kwargs[0]['messages'] | ||
| user_message = messages[0] | ||
|
|
||
| audio_part = next(part for part in user_message['content'] if part['type'] == 'input_audio') | ||
|
|
||
| # Expect Data URI with the CUSTOM MIME type | ||
| expected_data = f'data:audio/custom-weird-format;base64,{fake_base64_data}' | ||
| assert audio_part['input_audio']['data'] == expected_data | ||
| assert audio_part['input_audio']['format'] == 'mp3' |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should still make it so that this is used automatically for Qwen Omni. If that's only a requirement of Qwen's own ChatCompletions-compatible API, we may want a new provider class that can define its own
model_profilemethod and be used withOpenAIChatModel. We shouldn't set this in the existingqwen_model_profilemethod as Qwen can also be used with providers that probably do not have this quirk.