Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions examples/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
client = AnthropicBedrock()

model = "anthropic.claude-sonnet-4-5-20250929-v1:0"

print("------ standard response ------")
message = client.messages.create(
max_tokens=1024,
Expand All @@ -19,7 +21,7 @@
"content": "Hello!",
}
],
model="anthropic.claude-sonnet-4-5-20250929-v1:0",
model=model,
)
print(message.model_dump_json(indent=2))

Expand All @@ -33,7 +35,7 @@
"content": "Say hello there!",
}
],
model="anthropic.claude-sonnet-4-5-20250929-v1:0",
model=model,
) as stream:
for text in stream.text_stream:
print(text, end="", flush=True)
Expand All @@ -44,3 +46,15 @@
# inside of the context manager
accumulated = stream.get_final_message()
print("accumulated message: ", accumulated.model_dump_json(indent=2))

print("------ count tokens ------")
count = client.messages.count_tokens(
model=model,
messages=[
{
"role": "user",
"content": "Hello, world!",
}
],
)
print(count.model_dump_json(indent=2))
20 changes: 19 additions & 1 deletion src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
import json
import base64
import logging
import urllib.parse
from typing import Any, Union, Mapping, TypeVar
Expand Down Expand Up @@ -62,7 +64,23 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
raise AnthropicError("The Batch API is not supported in Bedrock yet")

if options.url == "/v1/messages/count_tokens":
raise AnthropicError("Token counting is not supported in Bedrock yet")
if not is_dict(options.json_data):
raise RuntimeError("Expected dictionary json_data for /v1/messages/count_tokens endpoint")

model = options.json_data.pop("model", None)
model = urllib.parse.quote(str(model), safe=":")

# max_tokens is required for the request to be valid.
# Use 500 which is enough to get a response.
options.json_data["max_tokens"] = 500

# body element of the request is base64 encoded.
# See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelTokensRequest.html
input_to_count = json.dumps(options.json_data)
encoded_bytes = base64.b64encode(input_to_count.encode("utf-8")).decode("utf-8")
options.json_data = {"input": {"invokeModel": {"body": encoded_bytes}}}

options.url = f"/model/{model}/count-tokens"

return options

Expand Down
34 changes: 34 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import json
import base64
import typing as t
import tempfile
from typing import TypedDict, cast
Expand Down Expand Up @@ -96,6 +98,38 @@ def test_messages_retries(respx_mock: MockRouter) -> None:
)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
def test_messages_count_tokens(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock(
side_effect=[httpx.Response(200, json={"foo": "bar"})],
)

sync_client.messages.count_tokens(
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
messages=[{"role": "user", "content": "Hello, world!"}],
)

calls = cast("list[MockRequestCall]", respx_mock.calls)
assert len(calls) == 1
assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens"
)

# Check that the request content is correct.
requested_content = json.loads(calls[0].request.content)
assert "input" in requested_content
assert "invokeModel" in requested_content["input"]
assert "body" in requested_content["input"]["invokeModel"]
decoded_body = base64.b64decode(requested_content["input"]["invokeModel"]["body"]).decode("utf-8")
assert json.loads(decoded_body) == {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 500,
"messages": [{"role": "user", "content": "Hello, world!"}],
}


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.respx()
@pytest.mark.asyncio()
Expand Down