diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 0fe332a47..c44717041 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -103,6 +103,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + endpoint_url: Optional[str] = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -112,6 +113,7 @@ def __init__( boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client. region_name: AWS region to use for the Bedrock service. Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set. + endpoint_url: Custom endpoint URL for VPC endpoints (PrivateLink) **model_config: Configuration options for the Bedrock model. """ if region_name and boto_session: @@ -143,6 +145,7 @@ def __init__( self.client = session.client( service_name="bedrock-runtime", config=client_config, + endpoint_url=endpoint_url, region_name=resolved_region, ) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 09e508845..f1a2250e4 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -129,7 +129,7 @@ def test__init__with_default_region(session_cls, mock_client_method): with unittest.mock.patch.object(os, "environ", {}): BedrockModel() session_cls.return_value.client.assert_called_with( - region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None ) @@ -139,14 +139,14 @@ def test__init__with_session_region(session_cls, mock_client_method): BedrockModel() - mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY, endpoint_url=None) def test__init__with_custom_region(mock_client_method): """Test that BedrockModel uses the provided region.""" custom_region = "us-east-1" BedrockModel(region_name=custom_region) - mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY, endpoint_url=None) def test__init__with_default_environment_variable_region(mock_client_method): @@ -154,7 +154,7 @@ def test__init__with_default_environment_variable_region(mock_client_method): with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}): BedrockModel() - mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY, endpoint_url=None) def test__init__region_precedence(mock_client_method, session_cls): @@ -164,21 +164,38 @@ def test__init__region_precedence(mock_client_method, session_cls): # specifying a region always wins out BedrockModel(region_name="us-specified-1") - mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-specified-1", config=ANY, service_name=ANY, endpoint_url=None + ) # other-wise uses the session's BedrockModel() - mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-session-1", config=ANY, service_name=ANY, endpoint_url=None + ) # environment variable next session_cls.return_value.region_name = None BedrockModel() - mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-environment-1", config=ANY, service_name=ANY, endpoint_url=None + ) mock_os_environ.pop("AWS_REGION") session_cls.return_value.region_name = None # No session region BedrockModel() - mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None + ) + + +def test__init__with_endpoint_url(mock_client_method): + """Test that BedrockModel uses the provided endpoint_url for VPC endpoints.""" + custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com" + BedrockModel(endpoint_url=custom_endpoint) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint + ) def test__init__with_region_and_session_raises_value_error():