Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import warnings
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast

import boto3
Expand All @@ -29,7 +30,9 @@

logger = logging.getLogger(__name__)

# See: `BedrockModel._get_default_model_with_warning` for why we need both
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
DEFAULT_BEDROCK_REGION = "us-west-2"

BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
Expand All @@ -47,6 +50,7 @@

DEFAULT_READ_TIMEOUT = 120


class BedrockModel(Model):
"""AWS Bedrock model provider implementation.

Expand Down Expand Up @@ -129,13 +133,16 @@ def __init__(
if region_name and boto_session:
raise ValueError("Cannot specify both `region_name` and `boto_session`.")

self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
session = boto_session or boto3.Session()
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
self.config = BedrockModel.BedrockConfig(
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
include_tool_result_status="auto",
)
self.update_config(**model_config)

logger.debug("config=<%s> | initializing", self.config)

session = boto_session or boto3.Session()

# Add strands-agents to the request user agent
if boto_client_config:
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
Expand All @@ -150,8 +157,6 @@ def __init__(
else:
client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT)

resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION

self.client = session.client(
service_name="bedrock-runtime",
config=client_config,
Expand Down Expand Up @@ -770,3 +775,46 @@ async def structured_output(
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

yield {"output": output_model(**output_response)}

@staticmethod
def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str:
"""Get the default Bedrock modelId based on region.

If the region is not **known** to support inference then we show a helpful warning
that compliments the exception that Bedrock will throw.
If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID`
then we should not process further.

Args:
region_name (str): region for bedrock model
model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
"""
if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"):
return DEFAULT_BEDROCK_MODEL_ID

model_config = model_config or {}
if model_config.get("model_id"):
return model_config["model_id"]

prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix

prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1`
if prefix not in {"us", "eu", "ap", "us-gov"}:
warnings.warn(
f"""
================== WARNING ==================

This region {region_name} does not support
our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}.
Update the agent to pass in a 'model_id' like so:
```
Agent(..., model='valid_model_id', ...)
````
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html

==================================================
""",
stacklevel=2,
)

return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix))
7 changes: 5 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from tests.fixtures.mock_session_repository import MockedSessionRepository
from tests.fixtures.mocked_model_provider import MockedModelProvider

# For unit testing we will use the the us inference
FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def mock_randint():
Expand Down Expand Up @@ -211,7 +214,7 @@ def test_agent__init__with_default_model():
agent = Agent()

assert isinstance(agent.model, BedrockModel)
assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID


def test_agent__init__with_explicit_model(mock_model):
Expand Down Expand Up @@ -891,7 +894,7 @@ def test_agent__del__(agent):
def test_agent_init_with_no_model_or_model_id():
agent = Agent()
assert agent.model is not None
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID


def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator):
Expand Down
96 changes: 94 additions & 2 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@

import strands
from strands.models import BedrockModel
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT
from strands.models.bedrock import (
_DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_MODEL_ID,
DEFAULT_BEDROCK_REGION,
DEFAULT_READ_TIMEOUT,
)
from strands.types.exceptions import ModelThrottledException
from strands.types.tools import ToolSpec

FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def session_cls():
Expand Down Expand Up @@ -119,7 +126,7 @@ def test__init__default_model_id(bedrock_client):
model = BedrockModel()

tru_model_id = model.get_config().get("model_id")
exp_model_id = DEFAULT_BEDROCK_MODEL_ID
exp_model_id = FORMATTED_DEFAULT_MODEL_ID

assert tru_model_id == exp_model_id

Expand Down Expand Up @@ -1543,3 +1550,88 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
model.format_request(messages, tool_choice=None)

assert len(captured_warnings) == 0


def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings):
"""Test get_model_prefix_with_warning doesn't warn for supported region prefixes."""
BedrockModel._get_default_model_with_warning("us-west-2")
BedrockModel._get_default_model_with_warning("eu-west-2")
assert len(captured_warnings) == 0


def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("eu-west-1")
assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0"
assert len(captured_warnings) == 0


def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
assert len(captured_warnings) == 0


def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1")
assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0"
assert len(captured_warnings) == 0


def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings):
"""Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes."""
model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1")
assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0"


def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings):
"""Test _get_default_model_with_warning warns for unsupported regions."""
BedrockModel._get_default_model_with_warning("ca-central-1")
assert len(captured_warnings) == 1
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)
assert "our default inference endpoint" in str(captured_warnings[0].message)


def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings):
"""Test _get_default_model_with_warning doesn't warn when custom model_id provided."""
model_config = {"model_id": "custom-model"}
model_id = BedrockModel._get_default_model_with_warning("ca-central-1", model_config)

assert model_id == "custom-model"
assert len(captured_warnings) == 0


def test_init_with_unsupported_region_warns(session_cls, captured_warnings):
"""Test BedrockModel initialization warns for unsupported regions."""
BedrockModel(region_name="ca-central-1")

assert len(captured_warnings) == 1
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)


def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings):
"""Test BedrockModel initialization doesn't warn when custom model_id provided."""
BedrockModel(region_name="ca-central-1", model_id="custom-model")
assert len(captured_warnings) == 0


def test_override_default_model_id_uses_the_overriden_value(captured_warnings):
with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", "custom-overridden-model"):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "custom-overridden-model"


def test_no_override_uses_formatted_default_model_id(captured_warnings):
model_id = BedrockModel._get_default_model_with_warning("us-east-1")
assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0"
assert model_id != _DEFAULT_BEDROCK_MODEL_ID
assert len(captured_warnings) == 0


def test_custom_model_id_not_overridden_by_region_formatting(session_cls):
"""Test that custom model_id is not overridden by region formatting."""
custom_model_id = "custom.model.id"

model = BedrockModel(model_id=custom_model_id)
model_id = model.get_config().get("model_id")

assert model_id == custom_model_id
Loading