Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ contains all the exceptions encountered during the `run` execution.

By default, the `FallbackModel` only moves on to the next model if the current model raises a
[`ModelAPIError`][pydantic_ai.exceptions.ModelAPIError], which includes
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError] and [`ContentFilterError`][pydantic_ai.exceptions.ContentFilterError]. You can customize this behavior by
passing a custom `fallback_on` argument to the `FallbackModel` constructor.

!!! note
Expand Down
27 changes: 27 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
'UsageLimitExceeded',
'ModelAPIError',
'ModelHTTPError',
'ContentFilterError',
'PromptContentFilterError',
'ResponseContentFilterError',
'IncompleteToolCall',
'FallbackExceptionGroup',
)
Expand Down Expand Up @@ -179,6 +182,30 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None
super().__init__(model_name=model_name, message=message)


class ContentFilterError(ModelHTTPError):
"""Raised when content filtering is triggered by the model provider."""

def __init__(self, message: str, status_code: int, model_name: str, body: object | None = None):
super().__init__(status_code, model_name, body)
self.message = message


class PromptContentFilterError(ContentFilterError):
"""Raised when the prompt triggers a content filter."""

def __init__(self, status_code: int, model_name: str, body: object | None = None):
message = f"Prompt content filtered by model '{model_name}'"
super().__init__(message, status_code, model_name, body)


class ResponseContentFilterError(ContentFilterError):
"""Raised when the generated response triggers a content filter."""

def __init__(self, model_name: str, body: object | None = None, status_code: int = 200):
message = f"Response content filtered by model '{model_name}'"
super().__init__(message, status_code, model_name, body)


class FallbackExceptionGroup(ExceptionGroup[Any]):
"""A group of exceptions that can be raised when all fallback models fail."""

Expand Down
11 changes: 10 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._run_context import RunContext
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool
from ..exceptions import ModelAPIError, UserError
from ..exceptions import ModelAPIError, ResponseContentFilterError, UserError
from ..messages import (
BinaryContent,
BuiltinToolCallPart,
Expand Down Expand Up @@ -526,6 +526,11 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
if raw_finish_reason := response.stop_reason: # pragma: no branch
provider_details = {'finish_reason': raw_finish_reason}
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)

return ModelResponse(
parts=items,
Expand Down Expand Up @@ -1243,6 +1248,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if self.finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)

elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch
if isinstance(current_block, BetaMCPToolUseBlock):
Expand Down
12 changes: 5 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._output import OutputObjectDefinition
from .._run_context import RunContext
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebFetchTool, WebSearchTool
from ..exceptions import ModelAPIError, ModelHTTPError, UserError
from ..exceptions import ModelAPIError, ModelHTTPError, ResponseContentFilterError, UserError
from ..messages import (
BinaryContent,
BuiltinToolCallPart,
Expand Down Expand Up @@ -495,8 +495,8 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:

if candidate.content is None or candidate.content.parts is None:
if finish_reason == 'content_filter' and raw_finish_reason:
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
raise ResponseContentFilterError(
model_name=response.model_version or self._model_name, body=response.model_dump_json()
)
parts = [] # pragma: no cover
else:
Expand Down Expand Up @@ -697,10 +697,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=web_fetch_return)

if candidate.content is None or candidate.content.parts is None:
if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
)
if self.finish_reason == 'content_filter' and raw_finish_reason:
raise ResponseContentFilterError(model_name=self.model_name, body=chunk.model_dump_json())
else: # pragma: no cover
continue

Expand Down
47 changes: 46 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
from ..exceptions import UserError
from ..exceptions import PromptContentFilterError, ResponseContentFilterError, UserError
from ..messages import (
AudioUrl,
BinaryContent,
Expand Down Expand Up @@ -160,6 +160,24 @@
}


def _check_azure_content_filter(e: APIStatusError, model_name: str) -> None:
"""Check if the error is an Azure content filter error and raise PromptContentFilterError if so."""
if e.status_code == 400:
body_any: Any = e.body

if isinstance(body_any, dict):
body_dict = cast(dict[str, Any], body_any)

if (error := body_dict.get('error')) and isinstance(error, dict):
error_dict = cast(dict[str, Any], error)
if error_dict.get('code') == 'content_filter':
raise PromptContentFilterError(
status_code=e.status_code,
model_name=model_name,
body=body_dict,
) from e


class OpenAIChatModelSettings(ModelSettings, total=False):
"""Settings used for an OpenAI model request."""

Expand Down Expand Up @@ -555,6 +573,8 @@ async def _completions_create(
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
_check_azure_content_filter(e, self.model_name)

raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise # pragma: lax no cover
except APIConnectionError as e:
Expand Down Expand Up @@ -601,6 +621,13 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e

choice = response.choices[0]

if choice.finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)

items: list[ModelResponsePart] = []

if thinking_parts := self._process_thinking(choice.message):
Expand Down Expand Up @@ -1242,6 +1269,11 @@ def _process_response( # noqa: C901
finish_reason: FinishReason | None = None
provider_details: dict[str, Any] | None = None
raw_finish_reason = details.reason if (details := response.incomplete_details) else response.status
if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=response.model,
body=response.model_dump(),
)
if raw_finish_reason:
provider_details = {'finish_reason': raw_finish_reason}
finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down Expand Up @@ -1398,6 +1430,9 @@ async def _responses_create( # noqa: C901
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
_check_azure_content_filter(e, self.model_name)

# Reverted cast
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise # pragma: lax no cover
except APIConnectionError as e:
Expand Down Expand Up @@ -1903,6 +1938,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
continue

if raw_finish_reason := choice.finish_reason:
if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)
self.finish_reason = self._map_finish_reason(raw_finish_reason)

if provider_details := self._map_provider_details(chunk):
Expand Down Expand Up @@ -2047,6 +2086,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
raw_finish_reason = (
details.reason if (details := chunk.response.incomplete_details) else chunk.response.status
)

if raw_finish_reason == 'content_filter':
raise ResponseContentFilterError(
model_name=self.model_name,
)

if raw_finish_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down
50 changes: 49 additions & 1 deletion tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
UserPromptPart,
)
from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.exceptions import ResponseContentFilterError, UserError
from pydantic_ai.messages import (
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
Expand Down Expand Up @@ -7901,3 +7901,51 @@ async def test_anthropic_cache_messages_real_api(allow_model_requests: None, ant
assert usage2.cache_read_tokens > 0
assert usage2.cache_write_tokens > 0
assert usage2.output_tokens > 0


async def test_anthropic_response_filter_error_sync(allow_model_requests: None):
c = completion_message(
[BetaTextBlock(text='partial', type='text')],
usage=BetaUsage(input_tokens=5, output_tokens=10),
)
# 'refusal' maps to 'content_filter' in _FINISH_REASON_MAP
c.stop_reason = 'refusal'

mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-123', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

with pytest.raises(ResponseContentFilterError, match=r"Response content filtered by model 'claude-3-5-haiku-123'"):
await agent.run('hello')


async def test_anthropic_response_filter_error_stream(allow_model_requests: None):
stream = [
BetaRawMessageStartEvent(
type='message_start',
message=BetaMessage(
id='msg_123',
model='claude-3-5-haiku-123',
role='assistant',
type='message',
content=[],
stop_reason=None,
usage=BetaUsage(input_tokens=20, output_tokens=0),
),
),
BetaRawMessageDeltaEvent(
type='message_delta',
delta=Delta(stop_reason='refusal'), # maps to content_filter
usage=BetaMessageDeltaUsage(input_tokens=20, output_tokens=5),
),
BetaRawMessageStopEvent(type='message_stop'),
]

mock_client = MockAnthropic.create_stream_mock([stream])
m = AnthropicModel('claude-3-5-haiku-123', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

with pytest.raises(ResponseContentFilterError, match=r"Response content filtered by model 'claude-3-5-haiku-123'"):
async with agent.run_stream('hello') as result:
async for _ in result.stream_text():
pass
67 changes: 65 additions & 2 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@
WebFetchTool,
WebSearchTool,
)
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.exceptions import (
ModelAPIError,
ModelHTTPError,
ModelRetry,
ResponseContentFilterError,
UserError,
)
from pydantic_ai.messages import (
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
Expand Down Expand Up @@ -982,7 +988,8 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p
)
agent = Agent(m, instructions='You hate the world!', model_settings=settings)

with pytest.raises(UnexpectedModelBehavior, match="Content filter 'SAFETY' triggered"):
# Changed expected exception from UnexpectedModelBehavior to ResponseContentFilterError
with pytest.raises(ResponseContentFilterError, match="Response content filtered by model 'gemini-1.5-flash'"):
await agent.run('Tell me a joke about a Brazilians.')


Expand Down Expand Up @@ -4425,3 +4432,59 @@ def test_google_missing_tool_call_thought_signature():
],
}
)


async def test_google_response_filter_error_sync(
allow_model_requests: None, google_provider: GoogleProvider, mocker: MockerFixture
):
model_name = 'gemini-2.5-flash'
model = GoogleModel(model_name, provider=google_provider)

# Create a Candidate mock with the specific failure condition
candidate_mock = mocker.Mock(
finish_reason=GoogleFinishReason.SAFETY, content=None, grounding_metadata=None, url_context_metadata=None
)

# Create the Response mock containing the candidate
response_mock = mocker.Mock(candidates=[candidate_mock], model_version=model_name, usage_metadata=None)

response_mock.model_dump_json.return_value = '{"mock": "json"}'

# Patch the client
mocker.patch.object(model.client.aio.models, 'generate_content', return_value=response_mock)

agent = Agent(model=model)

# Verify the exception is raised
with pytest.raises(ResponseContentFilterError, match=f"Response content filtered by model '{model_name}'"):
await agent.run('bad content')


async def test_google_response_filter_error_stream(
allow_model_requests: None, google_provider: GoogleProvider, mocker: MockerFixture
):
model_name = 'gemini-2.5-flash'
model = GoogleModel(model_name, provider=google_provider)

# Create Candidate mock
candidate_mock = mocker.Mock(
finish_reason=GoogleFinishReason.SAFETY, content=None, grounding_metadata=None, url_context_metadata=None
)

# Create Chunk mock
chunk_mock = mocker.Mock(
candidates=[candidate_mock], model_version=model_name, usage_metadata=None, create_time=datetime.datetime.now()
)
chunk_mock.model_dump_json.return_value = '{"mock": "json"}'

async def stream_iterator():
yield chunk_mock

mocker.patch.object(model.client.aio.models, 'generate_content_stream', return_value=stream_iterator())

agent = Agent(model=model)

with pytest.raises(ResponseContentFilterError, match=f"Response content filtered by model '{model_name}'"):
async with agent.run_stream('bad content') as result:
async for _ in result.stream_text():
pass
Loading
Loading