22
33import json
44from collections .abc import Sequence
5- from dataclasses import dataclass
5+ from dataclasses import dataclass , field
66from datetime import datetime , timezone
77from functools import cached_property
88from typing import Any , Literal , cast
2828from .mock_async_stream import MockAsyncStream
2929
3030with try_import () as imports_successful :
31- from openai import AsyncOpenAI
31+ from openai import NOT_GIVEN , AsyncOpenAI
3232 from openai .types import chat
3333 from openai .types .chat .chat_completion import Choice
3434 from openai .types .chat .chat_completion_chunk import (
4141 from openai .types .chat .chat_completion_message_tool_call import Function
4242 from openai .types .completion_usage import CompletionUsage , PromptTokensDetails
4343
44- from pydantic_ai .models .openai import OpenAIModel
44+ from pydantic_ai .models .openai import OpenAIModel , OpenAISystemPromptRole
4545
4646pytestmark = [
4747 pytest .mark .skipif (not imports_successful (), reason = 'openai not installed' ),
5050
5151
5252def test_init ():
53- m = OpenAIModel ('gpt-4 ' , api_key = 'foobar' )
53+ m = OpenAIModel ('gpt-4o ' , api_key = 'foobar' )
5454 assert str (m .client .base_url ) == 'https://api.openai.com/v1/'
5555 assert m .client .api_key == 'foobar'
56- assert m .name () == 'openai:gpt-4 '
56+ assert m .name () == 'openai:gpt-4o '
5757
5858
5959def test_init_with_base_url ():
60- m = OpenAIModel ('gpt-4 ' , base_url = 'https://example.com/v1' , api_key = 'foobar' )
60+ m = OpenAIModel ('gpt-4o ' , base_url = 'https://example.com/v1' , api_key = 'foobar' )
6161 assert str (m .client .base_url ) == 'https://example.com/v1/'
6262 assert m .client .api_key == 'foobar'
63- assert m .name () == 'openai:gpt-4 '
63+ assert m .name () == 'openai:gpt-4o '
6464 m .name ()
6565
6666
6767@dataclass
6868class MockOpenAI :
6969 completions : chat .ChatCompletion | list [chat .ChatCompletion ] | None = None
7070 stream : list [chat .ChatCompletionChunk ] | list [list [chat .ChatCompletionChunk ]] | None = None
71- index = 0
71+ index : int = 0
72+ chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
7273
7374 @cached_property
7475 def chat (self ) -> Any :
@@ -86,8 +87,10 @@ def create_mock_stream(
8687 return cast (AsyncOpenAI , cls (stream = list (stream ))) # pyright: ignore[reportArgumentType]
8788
8889 async def chat_completions_create ( # pragma: no cover
89- self , * _args : Any , stream : bool = False , ** _kwargs : Any
90+ self , * _args : Any , stream : bool = False , ** kwargs : Any
9091 ) -> chat .ChatCompletion | MockAsyncStream [chat .ChatCompletionChunk ]:
92+ self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
93+
9194 if stream :
9295 assert self .stream is not None , 'you can only used `stream=True` if `stream` is provided'
9396 # noinspection PyUnresolvedReferences
@@ -106,12 +109,19 @@ async def chat_completions_create( # pragma: no cover
106109 return response
107110
108111
112+ def get_mock_chat_completion_kwargs (async_open_ai : AsyncOpenAI ) -> list [dict [str , Any ]]:
113+ if isinstance (async_open_ai , MockOpenAI ):
114+ return async_open_ai .chat_completion_kwargs
115+ else : # pragma: no cover
116+ raise RuntimeError ('Not a MockOpenAI instance' )
117+
118+
109119def completion_message (message : ChatCompletionMessage , * , usage : CompletionUsage | None = None ) -> chat .ChatCompletion :
110120 return chat .ChatCompletion (
111121 id = '123' ,
112122 choices = [Choice (finish_reason = 'stop' , index = 0 , message = message )],
113123 created = 1704067200 , # 2024-01-01
114- model = 'gpt-4 ' ,
124+ model = 'gpt-4o ' ,
115125 object = 'chat.completion' ,
116126 usage = usage ,
117127 )
@@ -120,7 +130,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
120130async def test_request_simple_success (allow_model_requests : None ):
121131 c = completion_message (ChatCompletionMessage (content = 'world' , role = 'assistant' ))
122132 mock_client = MockOpenAI .create_mock (c )
123- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
133+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
124134 agent = Agent (m )
125135
126136 result = await agent .run ('hello' )
@@ -138,17 +148,29 @@ async def test_request_simple_success(allow_model_requests: None):
138148 ModelRequest (parts = [UserPromptPart (content = 'hello' , timestamp = IsNow (tz = timezone .utc ))]),
139149 ModelResponse (
140150 parts = [TextPart (content = 'world' )],
141- model_name = 'gpt-4 ' ,
151+ model_name = 'gpt-4o ' ,
142152 timestamp = datetime (2024 , 1 , 1 , 0 , 0 , tzinfo = timezone .utc ),
143153 ),
144154 ModelRequest (parts = [UserPromptPart (content = 'hello' , timestamp = IsNow (tz = timezone .utc ))]),
145155 ModelResponse (
146156 parts = [TextPart (content = 'world' )],
147- model_name = 'gpt-4 ' ,
157+ model_name = 'gpt-4o ' ,
148158 timestamp = datetime (2024 , 1 , 1 , 0 , 0 , tzinfo = timezone .utc ),
149159 ),
150160 ]
151161 )
162+ assert get_mock_chat_completion_kwargs (mock_client ) == [
163+ {'messages' : [{'content' : 'hello' , 'role' : 'user' }], 'model' : 'gpt-4o' , 'n' : 1 },
164+ {
165+ 'messages' : [
166+ {'content' : 'hello' , 'role' : 'user' },
167+ {'content' : 'world' , 'role' : 'assistant' },
168+ {'content' : 'hello' , 'role' : 'user' },
169+ ],
170+ 'model' : 'gpt-4o' ,
171+ 'n' : 1 ,
172+ },
173+ ]
152174
153175
154176async def test_request_simple_usage (allow_model_requests : None ):
@@ -157,7 +179,7 @@ async def test_request_simple_usage(allow_model_requests: None):
157179 usage = CompletionUsage (completion_tokens = 1 , prompt_tokens = 2 , total_tokens = 3 ),
158180 )
159181 mock_client = MockOpenAI .create_mock (c )
160- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
182+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
161183 agent = Agent (m )
162184
163185 result = await agent .run ('Hello' )
@@ -180,7 +202,7 @@ async def test_request_structured_response(allow_model_requests: None):
180202 )
181203 )
182204 mock_client = MockOpenAI .create_mock (c )
183- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
205+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
184206 agent = Agent (m , result_type = list [int ])
185207
186208 result = await agent .run ('Hello' )
@@ -196,7 +218,7 @@ async def test_request_structured_response(allow_model_requests: None):
196218 tool_call_id = '123' ,
197219 )
198220 ],
199- model_name = 'gpt-4 ' ,
221+ model_name = 'gpt-4o ' ,
200222 timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
201223 ),
202224 ModelRequest (
@@ -256,7 +278,7 @@ async def test_request_tool_call(allow_model_requests: None):
256278 completion_message (ChatCompletionMessage (content = 'final response' , role = 'assistant' )),
257279 ]
258280 mock_client = MockOpenAI .create_mock (responses )
259- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
281+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
260282 agent = Agent (m , system_prompt = 'this is the system prompt' )
261283
262284 @agent .tool_plain
@@ -284,7 +306,7 @@ async def get_location(loc_name: str) -> str:
284306 tool_call_id = '1' ,
285307 )
286308 ],
287- model_name = 'gpt-4 ' ,
309+ model_name = 'gpt-4o ' ,
288310 timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
289311 ),
290312 ModelRequest (
@@ -305,7 +327,7 @@ async def get_location(loc_name: str) -> str:
305327 tool_call_id = '2' ,
306328 )
307329 ],
308- model_name = 'gpt-4 ' ,
330+ model_name = 'gpt-4o ' ,
309331 timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
310332 ),
311333 ModelRequest (
@@ -320,7 +342,7 @@ async def get_location(loc_name: str) -> str:
320342 ),
321343 ModelResponse (
322344 parts = [TextPart (content = 'final response' )],
323- model_name = 'gpt-4 ' ,
345+ model_name = 'gpt-4o ' ,
324346 timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
325347 ),
326348 ]
@@ -346,7 +368,7 @@ def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -
346368 ChunkChoice (index = index , delta = delta , finish_reason = finish_reason ) for index , delta in enumerate (delta )
347369 ],
348370 created = 1704067200 , # 2024-01-01
349- model = 'gpt-4 ' ,
371+ model = 'gpt-4o ' ,
350372 object = 'chat.completion.chunk' ,
351373 usage = CompletionUsage (completion_tokens = 1 , prompt_tokens = 2 , total_tokens = 3 ),
352374 )
@@ -359,7 +381,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha
359381async def test_stream_text (allow_model_requests : None ):
360382 stream = text_chunk ('hello ' ), text_chunk ('world' ), chunk ([])
361383 mock_client = MockOpenAI .create_mock_stream (stream )
362- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
384+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
363385 agent = Agent (m )
364386
365387 async with agent .run_stream ('' ) as result :
@@ -372,7 +394,7 @@ async def test_stream_text(allow_model_requests: None):
372394async def test_stream_text_finish_reason (allow_model_requests : None ):
373395 stream = text_chunk ('hello ' ), text_chunk ('world' ), text_chunk ('.' , finish_reason = 'stop' )
374396 mock_client = MockOpenAI .create_mock_stream (stream )
375- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
397+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
376398 agent = Agent (m )
377399
378400 async with agent .run_stream ('' ) as result :
@@ -419,7 +441,7 @@ async def test_stream_structured(allow_model_requests: None):
419441 chunk ([]),
420442 )
421443 mock_client = MockOpenAI .create_mock_stream (stream )
422- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
444+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
423445 agent = Agent (m , result_type = MyTypedDict )
424446
425447 async with agent .run_stream ('' ) as result :
@@ -447,7 +469,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
447469 struc_chunk (None , None , finish_reason = 'stop' ),
448470 )
449471 mock_client = MockOpenAI .create_mock_stream (stream )
450- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
472+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
451473 agent = Agent (m , result_type = MyTypedDict )
452474
453475 async with agent .run_stream ('' ) as result :
@@ -467,7 +489,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
467489async def test_no_content (allow_model_requests : None ):
468490 stream = chunk ([ChoiceDelta ()]), chunk ([ChoiceDelta ()])
469491 mock_client = MockOpenAI .create_mock_stream (stream )
470- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
492+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
471493 agent = Agent (m , result_type = MyTypedDict )
472494
473495 with pytest .raises (UnexpectedModelBehavior , match = 'Received empty model response' ):
@@ -482,11 +504,38 @@ async def test_no_delta(allow_model_requests: None):
482504 text_chunk ('world' ),
483505 )
484506 mock_client = MockOpenAI .create_mock_stream (stream )
485- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
507+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
486508 agent = Agent (m )
487509
488510 async with agent .run_stream ('' ) as result :
489511 assert not result .is_complete
490512 assert [c async for c in result .stream_text (debounce_by = None )] == snapshot (['hello ' , 'hello world' ])
491513 assert result .is_complete
492514 assert result .usage () == snapshot (Usage (requests = 1 , request_tokens = 6 , response_tokens = 3 , total_tokens = 9 ))
515+
516+
517+ @pytest .mark .parametrize ('system_prompt_role' , ['system' , 'developer' , None ])
518+ async def test_system_prompt_role (
519+ allow_model_requests : None , system_prompt_role : OpenAISystemPromptRole | None
520+ ) -> None :
521+ """Testing the system prompt role for OpenAI models is properly set / inferred."""
522+
523+ c = completion_message (ChatCompletionMessage (content = 'world' , role = 'assistant' ))
524+ mock_client = MockOpenAI .create_mock (c )
525+ m = OpenAIModel ('gpt-4o' , system_prompt_role = system_prompt_role , openai_client = mock_client )
526+ assert m .system_prompt_role == system_prompt_role
527+
528+ agent = Agent (m , system_prompt = 'some instructions' )
529+ result = await agent .run ('hello' )
530+ assert result .data == 'world'
531+
532+ assert get_mock_chat_completion_kwargs (mock_client ) == [
533+ {
534+ 'messages' : [
535+ {'content' : 'some instructions' , 'role' : system_prompt_role or 'system' },
536+ {'content' : 'hello' , 'role' : 'user' },
537+ ],
538+ 'model' : 'gpt-4o' ,
539+ 'n' : 1 ,
540+ }
541+ ]
0 commit comments