diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ff03460904..a7f35696f3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -155,11 +155,20 @@ 'tool_use': 'tool_call', } -_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.', 'global.') -"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.'). +_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ( + 'us.', + 'eu.', + 'apac.', + 'jp.', + 'au.', + 'ca.', + 'global.', + 'us-gov.', +) +"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.', 'us-gov.'). Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens, -which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.'). +which does not accept profile IDs. """ diff --git a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py index f6fac74fae..2c48cfa1fe 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py @@ -58,6 +58,23 @@ def bedrock_deepseek_model_profile(model_name: str) -> ModelProfile | None: return profile # pragma: no cover +# Known geo prefixes for cross-region inference profile IDs +_BEDROCK_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.', 'global.', 'us-gov.') + + +def _strip_geo_prefix(model_name: str) -> str: + """Strip geographic/regional prefix from model name if present. + + AWS Bedrock cross-region inference uses prefixes like 'us.', 'eu.', 'us-gov.', 'global.' + to route requests to specific regions. This function strips those prefixes so we can + identify the underlying provider and model. + """ + for prefix in _BEDROCK_GEO_PREFIXES: + if model_name.startswith(prefix): + return model_name.removeprefix(prefix) + return model_name + + class BedrockProvider(Provider[BaseClient]): """Provider for AWS Bedrock.""" @@ -87,13 +104,11 @@ def model_profile(self, model_name: str) -> ModelProfile | None: 'deepseek': bedrock_deepseek_model_profile, } - # Split the model name into parts - parts = model_name.split('.', 2) - - # Handle regional prefixes (e.g. "us.") - if len(parts) > 2 and len(parts[0]) == 2: - parts = parts[1:] + # Strip regional/geo prefix if present (e.g. "us.", "eu.", "us-gov.", "global.") + model_name = _strip_geo_prefix(model_name) + # Split the model name into provider and model parts + parts = model_name.split('.', 1) if len(parts) < 2: return None diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index e791fe7a43..c249d5e134 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -100,3 +100,50 @@ def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture): unknown_model = provider.model_profile('unknown.unknown-model') assert unknown_model is None + + +@pytest.mark.parametrize('prefix', ['us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.', 'global.', 'us-gov.']) +def test_bedrock_provider_model_profile_all_geo_prefixes(env: TestEnv, prefix: str): + """Test that all cross-region inference geo prefixes are correctly handled. + + This is critical for AWS GovCloud (us-gov.) and other regional deployments + where models use prefixes longer than 2 characters. + """ + env.set('AWS_DEFAULT_REGION', 'us-east-1') + provider = BedrockProvider() + + # Test Anthropic model with geo prefix + model_name = f'{prefix}anthropic.claude-sonnet-4-5-20250929-v1:0' + profile = provider.model_profile(model_name) + + assert profile is not None, f'model_profile returned None for {model_name}' + assert isinstance(profile, BedrockModelProfile) + assert profile.bedrock_supports_tool_choice is True + assert profile.bedrock_send_back_thinking_parts is True + + +def test_bedrock_provider_model_profile_us_gov_anthropic(env: TestEnv, mocker: MockerFixture): + """Test that us-gov. prefixed Anthropic models get the correct profile. + + This specifically tests the us-gov. prefix which was previously broken + because the provider only handled 2-character prefixes. + """ + env.set('AWS_DEFAULT_REGION', 'us-east-1') + provider = BedrockProvider() + + ns = 'pydantic_ai.providers.bedrock' + anthropic_model_profile_mock = mocker.patch(f'{ns}.anthropic_model_profile', wraps=anthropic_model_profile) + + # Test us-gov. prefix (AWS GovCloud cross-region inference) + profile = provider.model_profile('us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0') + anthropic_model_profile_mock.assert_called_with('claude-sonnet-4-5-20250929') + assert isinstance(profile, BedrockModelProfile) + assert profile.bedrock_supports_tool_choice is True + assert profile.bedrock_send_back_thinking_parts is True + + # Test global. prefix + profile = provider.model_profile('global.anthropic.claude-opus-4-5-20251101-v1:0') + anthropic_model_profile_mock.assert_called_with('claude-opus-4-5-20251101') + assert isinstance(profile, BedrockModelProfile) + assert profile.bedrock_supports_tool_choice is True + assert profile.bedrock_send_back_thinking_parts is True