Skip to content

Commit ce9d15d

Browse files
committed
feat(models): expose endpoint_url on BedrockModel constructor
1 parent 57cb9e5 commit ce9d15d

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/strands/models/bedrock.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ class BedrockConfig(TypedDict, total=False):
7575
streaming: Flag to enable/disable streaming. Defaults to True.
7676
temperature: Controls randomness in generation (higher = more random)
7777
top_p: Controls diversity via nucleus sampling (alternative to temperature)
78-
endpoint_url: Custom endpoint URL for VPC endpoints (PrivateLink)
7978
"""
8079

8180
additional_args: Optional[dict[str, Any]]
@@ -97,14 +96,14 @@ class BedrockConfig(TypedDict, total=False):
9796
streaming: Optional[bool]
9897
temperature: Optional[float]
9998
top_p: Optional[float]
100-
endpoint_url: Optional[str] #Adding Endpoint URL
10199

102100
def __init__(
103101
self,
104102
*,
105103
boto_session: Optional[boto3.Session] = None,
106104
boto_client_config: Optional[BotocoreConfig] = None,
107105
region_name: Optional[str] = None,
106+
endpoint_url: Optional[str] = None,
108107
**model_config: Unpack[BedrockConfig],
109108
):
110109
"""Initialize provider instance.
@@ -114,6 +113,7 @@ def __init__(
114113
boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client.
115114
region_name: AWS region to use for the Bedrock service.
116115
Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set.
116+
endpoint_url: Custom endpoint URL for VPC endpoints (PrivateLink)
117117
**model_config: Configuration options for the Bedrock model.
118118
Use endpoint_url for VPC endpoint connectivity.
119119
"""
@@ -146,7 +146,7 @@ def __init__(
146146
self.client = session.client(
147147
service_name="bedrock-runtime",
148148
config=client_config,
149-
endpoint_url=self.config.get("endpoint_url"),
149+
endpoint_url=endpoint_url,
150150
region_name=resolved_region,
151151
)
152152

tests/strands/models/test_bedrock.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test__init__with_default_region(session_cls, mock_client_method):
129129
with unittest.mock.patch.object(os, "environ", {}):
130130
BedrockModel()
131131
session_cls.return_value.client.assert_called_with(
132-
region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY
132+
region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None
133133
)
134134

135135

@@ -139,22 +139,22 @@ def test__init__with_session_region(session_cls, mock_client_method):
139139

140140
BedrockModel()
141141

142-
mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY)
142+
mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY, endpoint_url=None)
143143

144144

145145
def test__init__with_custom_region(mock_client_method):
146146
"""Test that BedrockModel uses the provided region."""
147147
custom_region = "us-east-1"
148148
BedrockModel(region_name=custom_region)
149-
mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY)
149+
mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY, endpoint_url=None)
150150

151151

152152
def test__init__with_default_environment_variable_region(mock_client_method):
153153
"""Test that BedrockModel uses the AWS_REGION since we code that in."""
154154
with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}):
155155
BedrockModel()
156156

157-
mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY)
157+
mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY, endpoint_url=None)
158158

159159

160160
def test__init__region_precedence(mock_client_method, session_cls):
@@ -164,21 +164,28 @@ def test__init__region_precedence(mock_client_method, session_cls):
164164

165165
# specifying a region always wins out
166166
BedrockModel(region_name="us-specified-1")
167-
mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY)
167+
mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY, endpoint_url=None)
168168

169169
# other-wise uses the session's
170170
BedrockModel()
171-
mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY)
171+
mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY, endpoint_url=None)
172172

173173
# environment variable next
174174
session_cls.return_value.region_name = None
175175
BedrockModel()
176-
mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY)
176+
mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY, endpoint_url=None)
177177

178178
mock_os_environ.pop("AWS_REGION")
179179
session_cls.return_value.region_name = None # No session region
180180
BedrockModel()
181-
mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY)
181+
mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None)
182+
183+
184+
def test__init__with_endpoint_url(mock_client_method):
185+
"""Test that BedrockModel uses the provided endpoint_url for VPC endpoints."""
186+
custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com"
187+
BedrockModel(endpoint_url=custom_endpoint)
188+
mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint)
182189

183190

184191
def test__init__with_region_and_session_raises_value_error():

0 commit comments

Comments
 (0)