Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ contains all the exceptions encountered during the `run` execution.

By default, the `FallbackModel` only moves on to the next model if the current model raises a
[`ModelAPIError`][pydantic_ai.exceptions.ModelAPIError], which includes
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError] and [`ContentFilterError`][pydantic_ai.exceptions.ContentFilterError]. You can customize this behavior by
passing a custom `fallback_on` argument to the `FallbackModel` constructor.

!!! note
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:

try:
yield async_iter_groups()
finally: # pragma: no cover
finally:
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
if task:
task.cancel('Cancelling due to error in iterator')
Expand Down
27 changes: 27 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
'UsageLimitExceeded',
'ModelAPIError',
'ModelHTTPError',
'ContentFilterError',
'PromptContentFilterError',
'ResponseContentFilterError',
'IncompleteToolCall',
'FallbackExceptionGroup',
)
Expand Down Expand Up @@ -179,6 +182,30 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None
super().__init__(model_name=model_name, message=message)


class ContentFilterError(ModelHTTPError):
"""Raised when content filtering is triggered by the model provider."""

def __init__(self, message: str, status_code: int, model_name: str, body: object | None = None):
super().__init__(status_code, model_name, body)
self.message = message


class PromptContentFilterError(ContentFilterError):
"""Raised when the prompt triggers a content filter."""

def __init__(self, status_code: int, model_name: str, body: object | None = None):
message = f"Model '{model_name}' content filter was triggered by the user's prompt"
super().__init__(message, status_code, model_name, body)


class ResponseContentFilterError(ContentFilterError):
"""Raised when the generated response triggers a content filter."""

def __init__(self, model_name: str, body: object | None = None, status_code: int = 200):
message = f"Model '{model_name}' triggered its content filter while generating a response"
super().__init__(message, status_code, model_name, body)


class FallbackExceptionGroup(ExceptionGroup[Any]):
"""A group of exceptions that can be raised when all fallback models fail."""

Expand Down
11 changes: 10 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._run_context import RunContext
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool
from ..exceptions import ModelAPIError, UserError
from ..exceptions import ModelAPIError, ResponseContentFilterError, UserError
from ..messages import (
BinaryContent,
BuiltinToolCallPart,
Expand Down Expand Up @@ -526,6 +526,11 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
if raw_finish_reason := response.stop_reason: # pragma: no branch
provider_details = {'finish_reason': raw_finish_reason}
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)

return ModelResponse(
parts=items,
Expand Down Expand Up @@ -1243,6 +1248,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if self.finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)

elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch
if isinstance(current_block, BetaMCPToolUseBlock):
Expand Down
12 changes: 5 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._output import OutputObjectDefinition
from .._run_context import RunContext
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebFetchTool, WebSearchTool
from ..exceptions import ModelAPIError, ModelHTTPError, UserError
from ..exceptions import ModelAPIError, ModelHTTPError, ResponseContentFilterError, UserError
from ..messages import (
BinaryContent,
BuiltinToolCallPart,
Expand Down Expand Up @@ -495,8 +495,8 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:

if candidate.content is None or candidate.content.parts is None:
if finish_reason == 'content_filter' and raw_finish_reason:
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
raise ResponseContentFilterError(
model_name=response.model_version or self._model_name, body=response.model_dump_json()
)
parts = [] # pragma: no cover
else:
Expand Down Expand Up @@ -697,10 +697,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=web_fetch_return)

if candidate.content is None or candidate.content.parts is None:
if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
)
if self.finish_reason == 'content_filter' and raw_finish_reason:
raise ResponseContentFilterError(model_name=self.model_name, body=chunk.model_dump_json())
else: # pragma: no cover
continue

Expand Down
46 changes: 45 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
from ..exceptions import UserError
from ..exceptions import PromptContentFilterError, ResponseContentFilterError, UserError
from ..messages import (
AudioUrl,
BinaryContent,
Expand Down Expand Up @@ -160,6 +160,24 @@
}


def _check_azure_content_filter(e: APIStatusError, model_name: str) -> None:
"""Check if the error is an Azure content filter error and raise PromptContentFilterError if so."""
if e.status_code == 400:
body_any: Any = e.body

if isinstance(body_any, dict):
body_dict = cast(dict[str, Any], body_any)

if (error := body_dict.get('error')) and isinstance(error, dict):
error_dict = cast(dict[str, Any], error)
if error_dict.get('code') == 'content_filter':
raise PromptContentFilterError(
status_code=e.status_code,
model_name=model_name,
body=body_dict,
) from e


class OpenAIChatModelSettings(ModelSettings, total=False):
"""Settings used for an OpenAI model request."""

Expand Down Expand Up @@ -555,6 +573,8 @@ async def _completions_create(
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
_check_azure_content_filter(e, self.model_name)

raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise # pragma: lax no cover
except APIConnectionError as e:
Expand Down Expand Up @@ -601,6 +621,13 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e

choice = response.choices[0]

if choice.finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)

items: list[ModelResponsePart] = []

if thinking_parts := self._process_thinking(choice.message):
Expand Down Expand Up @@ -1242,6 +1269,11 @@ def _process_response( # noqa: C901
finish_reason: FinishReason | None = None
provider_details: dict[str, Any] | None = None
raw_finish_reason = details.reason if (details := response.incomplete_details) else response.status
if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)
if raw_finish_reason:
provider_details = {'finish_reason': raw_finish_reason}
finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down Expand Up @@ -1398,6 +1430,8 @@ async def _responses_create( # noqa: C901
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
_check_azure_content_filter(e, self.model_name)

raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise # pragma: lax no cover
except APIConnectionError as e:
Expand Down Expand Up @@ -1903,6 +1937,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
continue

if raw_finish_reason := choice.finish_reason:
if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)
self.finish_reason = self._map_finish_reason(raw_finish_reason)

if provider_details := self._map_provider_details(chunk):
Expand Down Expand Up @@ -2047,6 +2085,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
raw_finish_reason = (
details.reason if (details := chunk.response.incomplete_details) else chunk.response.status
)

if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)

if raw_finish_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down
56 changes: 55 additions & 1 deletion tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
UserPromptPart,
)
from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.exceptions import ResponseContentFilterError, UserError
from pydantic_ai.messages import (
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
Expand Down Expand Up @@ -7901,3 +7901,57 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant
assert usage2.cache_read_tokens > 0
assert usage2.cache_write_tokens > 0
assert usage2.output_tokens > 0


async def test_anthropic_response_filter_error_sync(allow_model_requests: None):
c = completion_message(
[BetaTextBlock(text='partial', type='text')],
usage=BetaUsage(input_tokens=5, output_tokens=10),
)
# 'refusal' maps to 'content_filter' in _FINISH_REASON_MAP
c.stop_reason = 'refusal'

mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-123', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

with pytest.raises(
ResponseContentFilterError,
match=r"Model 'claude-3-5-haiku-123' triggered its content filter while generating a response",
):
await agent.run('hello')


async def test_anthropic_response_filter_error_stream(allow_model_requests: None):
stream = [
BetaRawMessageStartEvent(
type='message_start',
message=BetaMessage(
id='msg_123',
model='claude-3-5-haiku-123',
role='assistant',
type='message',
content=[],
stop_reason=None,
usage=BetaUsage(input_tokens=20, output_tokens=0),
),
),
BetaRawMessageDeltaEvent(
type='message_delta',
delta=Delta(stop_reason='refusal'), # maps to content_filter
usage=BetaMessageDeltaUsage(input_tokens=20, output_tokens=5),
),
BetaRawMessageStopEvent(type='message_stop'),
]

mock_client = MockAnthropic.create_stream_mock([stream])
m = AnthropicModel('claude-3-5-haiku-123', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

with pytest.raises(
ResponseContentFilterError,
match=r"Model 'claude-3-5-haiku-123' triggered its content filter while generating a response",
):
async with agent.run_stream('hello') as result:
async for _ in result.stream_text():
pass
76 changes: 74 additions & 2 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@
WebFetchTool,
WebSearchTool,
)
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.exceptions import (
ModelAPIError,
ModelHTTPError,
ModelRetry,
ResponseContentFilterError,
UserError,
)
from pydantic_ai.messages import (
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
Expand Down Expand Up @@ -982,7 +988,11 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p
)
agent = Agent(m, instructions='You hate the world!', model_settings=settings)

with pytest.raises(UnexpectedModelBehavior, match="Content filter 'SAFETY' triggered"):
# Changed expected exception from UnexpectedModelBehavior to ResponseContentFilterError
with pytest.raises(
ResponseContentFilterError,
match="Model 'gemini-1.5-flash' triggered its content filter while generating a response",
):
await agent.run('Tell me a joke about a Brazilians.')


Expand Down Expand Up @@ -4425,3 +4435,65 @@ def test_google_missing_tool_call_thought_signature():
],
}
)


async def test_google_response_filter_error_sync(
allow_model_requests: None, google_provider: GoogleProvider, mocker: MockerFixture
):
model_name = 'gemini-2.5-flash'
model = GoogleModel(model_name, provider=google_provider)

# Create a Candidate mock with the specific failure condition
candidate_mock = mocker.Mock(
finish_reason=GoogleFinishReason.SAFETY, content=None, grounding_metadata=None, url_context_metadata=None
)

# Create the Response mock containing the candidate
response_mock = mocker.Mock(candidates=[candidate_mock], model_version=model_name, usage_metadata=None)

response_mock.model_dump_json.return_value = '{"mock": "json"}'

# Patch the client
mocker.patch.object(model.client.aio.models, 'generate_content', return_value=response_mock)

agent = Agent(model=model)

# Verify the exception is raised
with pytest.raises(
ResponseContentFilterError,
match="Model 'gemini-2.5-flash' triggered its content filter while generating a response",
):
await agent.run('bad content')


async def test_google_response_filter_error_stream(
allow_model_requests: None, google_provider: GoogleProvider, mocker: MockerFixture
):
model_name = 'gemini-2.5-flash'
model = GoogleModel(model_name, provider=google_provider)

# Create Candidate mock
candidate_mock = mocker.Mock(
finish_reason=GoogleFinishReason.SAFETY, content=None, grounding_metadata=None, url_context_metadata=None
)

# Create Chunk mock
chunk_mock = mocker.Mock(
candidates=[candidate_mock], model_version=model_name, usage_metadata=None, create_time=datetime.datetime.now()
)
chunk_mock.model_dump_json.return_value = '{"mock": "json"}'

async def stream_iterator():
yield chunk_mock

mocker.patch.object(model.client.aio.models, 'generate_content_stream', return_value=stream_iterator())

agent = Agent(model=model)

with pytest.raises(
ResponseContentFilterError,
match="Model 'gemini-2.5-flash' triggered its content filter while generating a response",
):
async with agent.run_stream('bad content') as result:
async for _ in result.stream_text():
pass
Loading