-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* intial setup for the anthropic client with cost config * update: other methods added * fix: formatting * fix: config unused * update: changes made in the client * update: test added to the workflow * update: added tests to the anthropic client * fix: errors in workflows and client * fix * fix: anthropic tools type * update: notebook anthropic * Nonetype fixed * fix-tests config * update: tests and client issues * logger support * remove sys path * updated the functioning of the client * update: type hints and stream * skip tests- importerror * fix: anthropic client and tests * none fix * Alternating roles, parameter keywords, cost on response, * update: anthropic notebook * update: notebook with more details * devcontainer * update: added validate_params from the client_utils * fix: formatting * fix: minor comment --------- Co-authored-by: Mark Sze <[email protected]>
- Loading branch information
1 parent
8d2cfaf
commit 5a496b3
Showing
10 changed files
with
541 additions
and
346 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
""" | ||
Create an OpenAI-compatible client for the Anthropic API. | ||
Example usage: | ||
Install the `anthropic` package by running `pip install --upgrade anthropic`. | ||
- https://docs.anthropic.com/en/docs/quickstart-guide | ||
import autogen | ||
config_list = [ | ||
{ | ||
"model": "claude-3-sonnet-20240229", | ||
"api_key": os.getenv("ANTHROPIC_API_KEY"), | ||
"api_type": "anthropic", | ||
} | ||
] | ||
assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list}) | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import inspect | ||
import json | ||
import os | ||
import warnings | ||
from typing import Any, Dict, List, Tuple, Union | ||
|
||
from anthropic import Anthropic | ||
from anthropic import __version__ as anthropic_version | ||
from anthropic.types import Completion, Message | ||
from client_utils import validate_parameter | ||
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall | ||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice | ||
from typing_extensions import Annotated | ||
|
||
TOOL_ENABLED = anthropic_version >= "0.23.1" | ||
if TOOL_ENABLED: | ||
from anthropic.types.tool_use_block_param import ( | ||
ToolUseBlockParam, | ||
) | ||
|
||
|
||
ANTHROPIC_PRICING_1k = { | ||
"claude-3-sonnet-20240229": (0.003, 0.015), | ||
"claude-3-opus-20240229": (0.015, 0.075), | ||
"claude-2.0": (0.008, 0.024), | ||
"claude-2.1": (0.008, 0.024), | ||
"claude-3.0-opus": (0.015, 0.075), | ||
"claude-3.0-haiku": (0.00025, 0.00125), | ||
} | ||
|
||
|
||
class AnthropicClient: | ||
def __init__(self, **kwargs: Any): | ||
""" | ||
Initialize the Anthropic API client. | ||
Args: | ||
api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable. | ||
""" | ||
self._api_key = kwargs.get("api_key", None) | ||
|
||
if not self._api_key: | ||
self._api_key = os.getenv("ANTHROPIC_API_KEY") | ||
|
||
if self._api_key is None: | ||
raise ValueError("API key is required to use the Anthropic API.") | ||
|
||
self._client = Anthropic(api_key=self._api_key) | ||
self._last_tooluse_status = {} | ||
|
||
def load_config(self, params: Dict[str, Any]): | ||
"""Load the configuration for the Anthropic API client.""" | ||
anthropic_params = {} | ||
|
||
anthropic_params["model"] = params.get("model", None) | ||
assert anthropic_params["model"], "Please provide a `model` in the config_list to use the Anthropic API." | ||
|
||
anthropic_params["temperature"] = validate_parameter( | ||
params, "temperature", (float, int), False, 1.0, (0.0, 1.0), None | ||
) | ||
anthropic_params["max_tokens"] = validate_parameter(params, "max_tokens", int, False, 4096, (1, None), None) | ||
anthropic_params["top_k"] = validate_parameter(params, "top_k", int, True, None, (1, None), None) | ||
anthropic_params["top_p"] = validate_parameter(params, "top_p", (float, int), True, None, (0.0, 1.0), None) | ||
anthropic_params["stop_sequences"] = validate_parameter(params, "stop_sequences", list, True, None, None, None) | ||
anthropic_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None) | ||
|
||
if anthropic_params["stream"]: | ||
warnings.warn( | ||
"Streaming is not currently supported, streaming will be disabled.", | ||
UserWarning, | ||
) | ||
anthropic_params["stream"] = False | ||
|
||
return anthropic_params | ||
|
||
def cost(self, response) -> float: | ||
"""Calculate the cost of the completion using the Anthropic pricing.""" | ||
return response.cost | ||
|
||
@property | ||
def api_key(self): | ||
return self._api_key | ||
|
||
def create(self, params: Dict[str, Any]) -> Completion: | ||
"""Create a completion for a given config. | ||
Args: | ||
params: The params for the completion. | ||
Returns: | ||
The completion. | ||
""" | ||
if "tools" in params: | ||
converted_functions = self.convert_tools_to_functions(params["tools"]) | ||
params["functions"] = params.get("functions", []) + converted_functions | ||
|
||
raw_contents = params["messages"] | ||
anthropic_params = self.load_config(params) | ||
|
||
processed_messages = [] | ||
for message in raw_contents: | ||
|
||
if message["role"] == "system": | ||
params["system"] = message["content"] | ||
elif message["role"] == "function": | ||
processed_messages.append(self.return_function_call_result(message["content"])) | ||
elif "function_call" in message: | ||
processed_messages.append(self.restore_last_tooluse_status()) | ||
elif message["content"] == "": | ||
message["content"] = "I'm done. Please send TERMINATE" # Not sure about this one. | ||
processed_messages.append(message) | ||
else: | ||
processed_messages.append(message) | ||
|
||
# Check for interleaving roles and correct, for Anthropic must be: user, assistant, user, etc. | ||
for i, message in enumerate(processed_messages): | ||
if message["role"] is not ("user" if i % 2 == 0 else "assistant"): | ||
message["role"] = "user" if i % 2 == 0 else "assistant" | ||
|
||
# Note: When using reflection_with_llm we may end up with an "assistant" message as the last message | ||
if processed_messages[-1]["role"] != "user": | ||
# If the last role is not user, add a continue message at the end | ||
continue_message = {"content": "continue", "role": "user"} | ||
processed_messages.append(continue_message) | ||
|
||
params["messages"] = processed_messages | ||
|
||
# TODO: support stream | ||
params = params.copy() | ||
if "functions" in params: | ||
tools_configs = params.pop("functions") | ||
tools_configs = [self.openai_func_to_anthropic(tool) for tool in tools_configs] | ||
params["tools"] = tools_configs | ||
|
||
# Anthropic doesn't accept None values, so we need to use keyword argument unpacking instead of setting parameters. | ||
# Copy params we need into anthropic_params | ||
# Remove any that don't have values | ||
anthropic_params["messages"] = params["messages"] | ||
if "system" in params: | ||
anthropic_params["system"] = params["system"] | ||
if "tools" in params: | ||
anthropic_params["tools"] = params["tools"] | ||
if anthropic_params["top_k"] is None: | ||
del anthropic_params["top_k"] | ||
if anthropic_params["top_p"] is None: | ||
del anthropic_params["top_p"] | ||
if anthropic_params["stop_sequences"] is None: | ||
del anthropic_params["stop_sequences"] | ||
|
||
response = self._client.messages.create(**anthropic_params) | ||
|
||
# Calculate and save the cost onto the response | ||
prompt_tokens = response.usage.input_tokens | ||
completion_tokens = response.usage.output_tokens | ||
response.cost = _calculate_cost(prompt_tokens, completion_tokens, anthropic_params["model"]) | ||
|
||
return response | ||
|
||
def message_retrieval(self, response: Union[Message]) -> Union[List[str], List[ChatCompletionMessage]]: | ||
"""Retrieve the messages from the response.""" | ||
messages = response.content | ||
if len(messages) == 0: | ||
return [None] | ||
res = [] | ||
if TOOL_ENABLED: | ||
for choice in messages: | ||
if choice.type == "tool_use": | ||
res.insert(0, self.response_to_openai_message(choice)) | ||
self._last_tooluse_status["tool_use"] = choice.model_dump() | ||
else: | ||
res.append(choice.text) | ||
self._last_tooluse_status["think"] = choice.text | ||
|
||
return res | ||
|
||
else: | ||
return [ # type: ignore [return-value] | ||
choice.text if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] | ||
for choice in messages | ||
] | ||
|
||
def response_to_openai_message(self, response) -> ChatCompletionMessage: | ||
"""Convert the client response to OpenAI ChatCompletion Message""" | ||
dict_response = response.model_dump() | ||
return ChatCompletionMessage( | ||
content=None, | ||
role="assistant", | ||
function_call={"name": dict_response["name"], "arguments": json.dumps(dict_response["input"])}, | ||
) | ||
|
||
def restore_last_tooluse_status(self) -> Dict: | ||
cached_content = [] | ||
if "think" in self._last_tooluse_status: | ||
cached_content.append({"type": "text", "text": self._last_tooluse_status["think"]}) | ||
cached_content.append(self._last_tooluse_status["tool_use"]) | ||
res = {"role": "assistant", "content": cached_content} | ||
return res | ||
|
||
def return_function_call_result(self, result: str) -> Dict: | ||
return { | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "tool_result", | ||
"tool_use_id": self._last_tooluse_status["tool_use"]["id"], | ||
"content": result, | ||
} | ||
], | ||
} | ||
|
||
@staticmethod | ||
def openai_func_to_anthropic(openai_func: dict) -> dict: | ||
res = openai_func.copy() | ||
res["input_schema"] = res.pop("parameters") | ||
return res | ||
|
||
@staticmethod | ||
def get_usage(response: Message) -> Dict: | ||
"""Get the usage of tokens and their cost information.""" | ||
return { | ||
"prompt_tokens": response.usage.input_tokens if response.usage is not None else 0, | ||
"completion_tokens": response.usage.output_tokens if response.usage is not None else 0, | ||
"total_tokens": ( | ||
response.usage.input_tokens + response.usage.output_tokens if response.usage is not None else 0 | ||
), | ||
"cost": response.cost if hasattr(response, "cost") else 0.0, | ||
"model": response.model, | ||
} | ||
|
||
@staticmethod | ||
def convert_tools_to_functions(tools: List) -> List: | ||
functions = [] | ||
for tool in tools: | ||
if tool.get("type") == "function" and "function" in tool: | ||
functions.append(tool["function"]) | ||
|
||
return functions | ||
|
||
|
||
def _calculate_cost(input_tokens: int, output_tokens: int, model: str) -> float: | ||
"""Calculate the cost of the completion using the Anthropic pricing.""" | ||
total = 0.0 | ||
|
||
if model in ANTHROPIC_PRICING_1k: | ||
input_cost_per_1k, output_cost_per_1k = ANTHROPIC_PRICING_1k[model] | ||
input_cost = (input_tokens / 1000) * input_cost_per_1k | ||
output_cost = (output_tokens / 1000) * output_cost_per_1k | ||
total = input_cost + output_cost | ||
else: | ||
warnings.warn(f"Cost calculation not available for model {model}", UserWarning) | ||
|
||
return total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.