Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ async def _map_message( # noqa: C901
else:
assert_never(m)
if instructions := self._get_instructions(messages, model_request_parameters):
system_prompt_parts.insert(0, instructions)
system_prompt_parts.append(instructions)
system_prompt = '\n\n'.join(system_prompt_parts)

# Add cache_control to the last message content if anthropic_cache_messages is enabled
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def _map_messages( # noqa: C901
last_message = cast(dict[str, Any], current_message)

if instructions := self._get_instructions(messages, model_request_parameters):
system_prompt.insert(0, {'text': instructions})
system_prompt.append({'text': instructions})

return system_prompt, processed_messages

Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def _map_messages(
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions))
system_prompt_count = sum(1 for m in cohere_messages if isinstance(m, SystemChatMessageV2))
cohere_messages.insert(system_prompt_count, SystemChatMessageV2(role='system', content=instructions))
return cohere_messages

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def _message_to_gemini_content(
else:
assert_never(m)
if instructions := self._get_instructions(messages, model_request_parameters):
sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions))
sys_prompt_parts.append(_GeminiTextPart(text=instructions))
return sys_prompt_parts, contents

async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ async def _map_messages(
contents = [{'role': 'user', 'parts': [{'text': ''}]}]

if instructions := self._get_instructions(messages, model_request_parameters):
system_parts.insert(0, {'text': instructions})
system_parts.append({'text': instructions})
system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None

return system_instruction, contents
Expand Down
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ def _map_messages(
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions))
system_prompt_count = sum(1 for m in groq_messages if m.get('role') == 'system')
groq_messages.insert(
system_prompt_count, chat.ChatCompletionSystemMessageParam(role='system', content=instructions)
)
return groq_messages

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ async def _map_messages(
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
system_prompt_count = sum(1 for m in hf_messages if getattr(m, 'role', None) == 'system')
hf_messages.insert(system_prompt_count, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
return hf_messages

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ def _map_messages(
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
system_prompt_count = sum(1 for m in mistral_messages if isinstance(m, MistralSystemMessage))
mistral_messages.insert(system_prompt_count, MistralSystemMessage(content=instructions))

# Post-process messages to insert fake assistant message after tool message if followed by user message
# to work around `Unexpected role 'user' after role 'tool'` error.
Expand Down
10 changes: 8 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,10 @@ async def _map_messages(
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
system_prompt_count = sum(1 for m in openai_messages if m.get('role') == 'system')
openai_messages.insert(
system_prompt_count, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')
)
return openai_messages

@staticmethod
Expand Down Expand Up @@ -1313,7 +1316,10 @@ async def _responses_create( # noqa: C901
# > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'.
# Apparently they're only checking input messages for "JSON", not instructions.
assert isinstance(instructions, str)
openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions))
system_prompt_count = sum(1 for m in openai_messages if m.get('role') == 'system')
openai_messages.insert(
system_prompt_count, responses.EasyInputMessageParam(role='system', content=instructions)
)
instructions = OMIT

if verbosity := model_settings.get('openai_text_verbosity'):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6313,3 +6313,43 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
]
)
assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')


def test_instructions_inserted_after_system_prompt():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM Even though I did write this test I am not confident that this test actually tests the intended functionality.

The PR changes how providers internally combine the system prompt parts and instructions but the pydantic_ai message structure remains unchanged, i.e., instructions remains a separate field on Agent whereas system_prompt is in a SystemPromptPart in the messages

I tried to find a similar test in this file for inspiration, but could not find one

Should I instead try to write tests in provider specific test files as there might be a more robust way of testing it there due to the existing mocking patterns? Or do you have another suggestion?

"""Tests that instructions are inserted after system prompts."""

agent = Agent('test')

@agent.system_prompt
def system_prompt_1() -> str:
return 'System prompt 1'

@agent.system_prompt
def system_prompt_2() -> str:
return 'System prompt 2'

@agent.instructions
def instructions() -> str:
return 'Instructions'

result = agent.run_sync('Hello')
assert result.all_messages()[0] == snapshot(
ModelRequest(
parts=[
SystemPromptPart(
content='System prompt 1',
timestamp=IsNow(tz=timezone.utc),
),
SystemPromptPart(
content='System prompt 2',
timestamp=IsNow(tz=timezone.utc),
),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc),
),
],
instructions='Instructions',
run_id=IsStr(),
)
)