diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py index e2448929e618..62078d42631d 100644 --- a/autogen/oai/anthropic.py +++ b/autogen/oai/anthropic.py @@ -16,6 +16,27 @@ ] assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list}) + +Example usage for Anthropic Bedrock: + +Install the `anthropic` package by running `pip install --upgrade anthropic`. +- https://docs.anthropic.com/en/docs/quickstart-guide + +import autogen + +config_list = [ + { + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "aws_access_key":, + "aws_secret_key":, + "aws_session_token":, + "aws_region":"us-east-1", + "api_type": "anthropic", + } +] + +assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list}) + """ from __future__ import annotations @@ -28,7 +49,7 @@ import warnings from typing import Any, Dict, List, Tuple, Union -from anthropic import Anthropic +from anthropic import Anthropic, AnthropicBedrock from anthropic import __version__ as anthropic_version from anthropic.types import Completion, Message, TextBlock, ToolUseBlock from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall @@ -64,14 +85,44 @@ def __init__(self, **kwargs: Any): api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable. """ self._api_key = kwargs.get("api_key", None) + self._aws_access_key = kwargs.get("aws_access_key", None) + self._aws_secret_key = kwargs.get("aws_secret_key", None) + self._aws_session_token = kwargs.get("aws_session_token", None) + self._aws_region = kwargs.get("aws_region", None) if not self._api_key: self._api_key = os.getenv("ANTHROPIC_API_KEY") - if self._api_key is None: - raise ValueError("API key is required to use the Anthropic API.") + if not self._aws_access_key: + self._aws_access_key = os.getenv("AWS_ACCESS_KEY") + + if not self._aws_secret_key: + self._aws_secret_key = os.getenv("AWS_SECRET_KEY") + + if not self._aws_session_token: + self._aws_session_token = os.getenv("AWS_SESSION_TOKEN") + + if not self._aws_region: + self._aws_region = os.getenv("AWS_REGION") + + if self._api_key is None and ( + self._aws_access_key is None + or self._aws_secret_key is None + or self._aws_session_token is None + or self._aws_region is None + ): + raise ValueError("API key or AWS credentials are required to use the Anthropic API.") + + if self._api_key is not None: + self._client = Anthropic(api_key=self._api_key) + else: + self._client = AnthropicBedrock( + aws_access_key=self._aws_access_key, + aws_secret_key=self._aws_secret_key, + aws_session_token=self._aws_session_token, + aws_region=self._aws_region, + ) - self._client = Anthropic(api_key=self._api_key) self._last_tooluse_status = {} def load_config(self, params: Dict[str, Any]): @@ -107,6 +158,22 @@ def cost(self, response) -> float: def api_key(self): return self._api_key + @property + def aws_access_key(self): + return self._aws_access_key + + @property + def aws_secret_key(self): + return self._aws_secret_key + + @property + def aws_session_token(self): + return self._aws_session_token + + @property + def aws_region(self): + return self._aws_region + def create(self, params: Dict[str, Any]) -> Completion: if "tools" in params: converted_functions = self.convert_tools_to_functions(params["tools"]) diff --git a/test/oai/test_anthropic.py b/test/oai/test_anthropic.py index 379ab47f6756..53926dbd18d6 100644 --- a/test/oai/test_anthropic.py +++ b/test/oai/test_anthropic.py @@ -50,17 +50,47 @@ def anthropic_client(): @pytest.mark.skipif(skip, reason=reason) def test_initialization_missing_api_key(): os.environ.pop("ANTHROPIC_API_KEY", None) - with pytest.raises(ValueError, match="API key is required to use the Anthropic API."): + os.environ.pop("AWS_ACCESS_KEY", None) + os.environ.pop("AWS_SECRET_KEY", None) + os.environ.pop("AWS_SESSION_TOKEN", None) + os.environ.pop("AWS_REGION", None) + with pytest.raises(ValueError, match="API key or AWS credentials are required to use the Anthropic API."): AnthropicClient() AnthropicClient(api_key="dummy_api_key") +@pytest.fixture() +def anthropic_client_with_aws_credentials(): + return AnthropicClient( + aws_access_key="dummy_access_key", + aws_secret_key="dummy_secret_key", + aws_session_token="dummy_session_token", + aws_region="us-west-2", + ) + + @pytest.mark.skipif(skip, reason=reason) def test_intialization(anthropic_client): assert anthropic_client.api_key == "dummy_api_key", "`api_key` should be correctly set in the config" +@pytest.mark.skipif(skip, reason=reason) +def test_intialization_with_aws_credentials(anthropic_client_with_aws_credentials): + assert ( + anthropic_client_with_aws_credentials.aws_access_key == "dummy_access_key" + ), "`aws_access_key` should be correctly set in the config" + assert ( + anthropic_client_with_aws_credentials.aws_secret_key == "dummy_secret_key" + ), "`aws_secret_key` should be correctly set in the config" + assert ( + anthropic_client_with_aws_credentials.aws_session_token == "dummy_session_token" + ), "`aws_session_token` should be correctly set in the config" + assert ( + anthropic_client_with_aws_credentials.aws_region == "us-west-2" + ), "`aws_region` should be correctly set in the config" + + # Test cost calculation @pytest.mark.skipif(skip, reason=reason) def test_cost_calculation(mock_completion):