Skip to content

Commit

Permalink
Added anthropic bedrock (#3103)
Browse files Browse the repository at this point in the history
* Added anthropic bedrock

* Code format and fixed import

* Added tests for anthropic bedrock

* tests update

---------

Co-authored-by: Chi Wang <[email protected]>
Co-authored-by: HRUSHIKESH DOKALA <[email protected]>
  • Loading branch information
3 people authored and victordibia committed Jul 30, 2024
1 parent 7d26c69 commit 4a4eae6
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 5 deletions.
75 changes: 71 additions & 4 deletions autogen/oai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
Example usage for Anthropic Bedrock:
Install the `anthropic` package by running `pip install --upgrade anthropic`.
- https://docs.anthropic.com/en/docs/quickstart-guide
import autogen
config_list = [
{
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"aws_access_key":<accessKey>,
"aws_secret_key":<secretKey>,
"aws_session_token":<sessionTok>,
"aws_region":"us-east-1",
"api_type": "anthropic",
}
]
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
"""

from __future__ import annotations
Expand All @@ -28,7 +49,7 @@
import warnings
from typing import Any, Dict, List, Tuple, Union

from anthropic import Anthropic
from anthropic import Anthropic, AnthropicBedrock
from anthropic import __version__ as anthropic_version
from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
Expand Down Expand Up @@ -64,14 +85,44 @@ def __init__(self, **kwargs: Any):
api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable.
"""
self._api_key = kwargs.get("api_key", None)
self._aws_access_key = kwargs.get("aws_access_key", None)
self._aws_secret_key = kwargs.get("aws_secret_key", None)
self._aws_session_token = kwargs.get("aws_session_token", None)
self._aws_region = kwargs.get("aws_region", None)

if not self._api_key:
self._api_key = os.getenv("ANTHROPIC_API_KEY")

if self._api_key is None:
raise ValueError("API key is required to use the Anthropic API.")
if not self._aws_access_key:
self._aws_access_key = os.getenv("AWS_ACCESS_KEY")

if not self._aws_secret_key:
self._aws_secret_key = os.getenv("AWS_SECRET_KEY")

if not self._aws_session_token:
self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")

if not self._aws_region:
self._aws_region = os.getenv("AWS_REGION")

if self._api_key is None and (
self._aws_access_key is None
or self._aws_secret_key is None
or self._aws_session_token is None
or self._aws_region is None
):
raise ValueError("API key or AWS credentials are required to use the Anthropic API.")

if self._api_key is not None:
self._client = Anthropic(api_key=self._api_key)
else:
self._client = AnthropicBedrock(
aws_access_key=self._aws_access_key,
aws_secret_key=self._aws_secret_key,
aws_session_token=self._aws_session_token,
aws_region=self._aws_region,
)

self._client = Anthropic(api_key=self._api_key)
self._last_tooluse_status = {}

def load_config(self, params: Dict[str, Any]):
Expand Down Expand Up @@ -107,6 +158,22 @@ def cost(self, response) -> float:
def api_key(self):
return self._api_key

@property
def aws_access_key(self):
return self._aws_access_key

@property
def aws_secret_key(self):
return self._aws_secret_key

@property
def aws_session_token(self):
return self._aws_session_token

@property
def aws_region(self):
return self._aws_region

def create(self, params: Dict[str, Any]) -> Completion:
if "tools" in params:
converted_functions = self.convert_tools_to_functions(params["tools"])
Expand Down
32 changes: 31 additions & 1 deletion test/oai/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,47 @@ def anthropic_client():
@pytest.mark.skipif(skip, reason=reason)
def test_initialization_missing_api_key():
os.environ.pop("ANTHROPIC_API_KEY", None)
with pytest.raises(ValueError, match="API key is required to use the Anthropic API."):
os.environ.pop("AWS_ACCESS_KEY", None)
os.environ.pop("AWS_SECRET_KEY", None)
os.environ.pop("AWS_SESSION_TOKEN", None)
os.environ.pop("AWS_REGION", None)
with pytest.raises(ValueError, match="API key or AWS credentials are required to use the Anthropic API."):
AnthropicClient()

AnthropicClient(api_key="dummy_api_key")


@pytest.fixture()
def anthropic_client_with_aws_credentials():
return AnthropicClient(
aws_access_key="dummy_access_key",
aws_secret_key="dummy_secret_key",
aws_session_token="dummy_session_token",
aws_region="us-west-2",
)


@pytest.mark.skipif(skip, reason=reason)
def test_intialization(anthropic_client):
assert anthropic_client.api_key == "dummy_api_key", "`api_key` should be correctly set in the config"


@pytest.mark.skipif(skip, reason=reason)
def test_intialization_with_aws_credentials(anthropic_client_with_aws_credentials):
assert (
anthropic_client_with_aws_credentials.aws_access_key == "dummy_access_key"
), "`aws_access_key` should be correctly set in the config"
assert (
anthropic_client_with_aws_credentials.aws_secret_key == "dummy_secret_key"
), "`aws_secret_key` should be correctly set in the config"
assert (
anthropic_client_with_aws_credentials.aws_session_token == "dummy_session_token"
), "`aws_session_token` should be correctly set in the config"
assert (
anthropic_client_with_aws_credentials.aws_region == "us-west-2"
), "`aws_region` should be correctly set in the config"


# Test cost calculation
@pytest.mark.skipif(skip, reason=reason)
def test_cost_calculation(mock_completion):
Expand Down

0 comments on commit 4a4eae6

Please sign in to comment.