diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 7aad6ebbf067..3abe257dfad6 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -612,3 +612,43 @@ jobs: with: file: ./coverage.xml flags: unittests + + BedrockTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Amazon Bedrock + run: | + pip install -e .[boto3,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_bedrock.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 61a8a6335284..37bbbd25a523 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.bedrock import BedrockClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -215,6 +216,7 @@ def log_new_client( | TogetherClient | GroqClient | CohereClient + | BedrockClient ), wrapper: OpenAIWrapper, init_args: Dict[str, Any], diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 2cf176ebb8f2..f76d039ce9de 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.bedrock import BedrockClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -402,6 +403,7 @@ def log_new_client( TogetherClient, GroqClient, CohereClient, + BedrockClient, ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py new file mode 100644 index 000000000000..7894781e3ee5 --- /dev/null +++ b/autogen/oai/bedrock.py @@ -0,0 +1,606 @@ +""" +Create a compatible client for the Amazon Bedrock Converse API. + +Example usage: +Install the `boto3` package by running `pip install --upgrade boto3`. +- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + +import autogen + +config_list = [ + { + "api_type": "bedrock", + "model": "meta.llama3-1-8b-instruct-v1:0", + "aws_region": "us-west-2", + "aws_access_key": "", + "aws_secret_key": "", + "price" : [0.003, 0.015] + } +] + +assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list}) + +""" + +from __future__ import annotations + +import base64 +import json +import os +import re +import time +import warnings +from typing import Any, Dict, List, Literal, Tuple + +import boto3 +import requests +from botocore.config import Config +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import validate_parameter + + +class BedrockClient: + """Client for Amazon's Bedrock Converse API.""" + + _retries = 5 + + def __init__(self, **kwargs: Any): + """ + Initialises BedrockClient for Amazon's Bedrock Converse API + """ + self._aws_access_key = kwargs.get("aws_access_key", None) + self._aws_secret_key = kwargs.get("aws_secret_key", None) + self._aws_session_token = kwargs.get("aws_session_token", None) + self._aws_region = kwargs.get("aws_region", None) + self._aws_profile_name = kwargs.get("aws_profile_name", None) + + if not self._aws_access_key: + self._aws_access_key = os.getenv("AWS_ACCESS_KEY") + + if not self._aws_secret_key: + self._aws_secret_key = os.getenv("AWS_SECRET_KEY") + + if not self._aws_session_token: + self._aws_session_token = os.getenv("AWS_SESSION_TOKEN") + + if not self._aws_region: + self._aws_region = os.getenv("AWS_REGION") + + if self._aws_region is None: + raise ValueError("Region is required to use the Amazon Bedrock API.") + + # Initialize Bedrock client, session, and runtime + bedrock_config = Config( + region_name=self._aws_region, + signature_version="v4", + retries={"max_attempts": self._retries, "mode": "standard"}, + ) + + session = boto3.Session( + aws_access_key_id=self._aws_access_key, + aws_secret_access_key=self._aws_secret_key, + aws_session_token=self._aws_session_token, + profile_name=self._aws_profile_name, + ) + + self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config) + + def message_retrieval(self, response): + """Retrieve the messages from the response.""" + return [choice.message for choice in response.choices] + + def parse_custom_params(self, params: Dict[str, Any]): + """ + Parses custom parameters for logic in this client class + """ + + # Should we separate system messages into its own request parameter, default is True + # This is required because not all models support a system prompt (e.g. Mistral Instruct). + self._supports_system_prompts = params.get("supports_system_prompts", True) + + def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]: + """ + Loads the valid parameters required to invoke Bedrock Converse + Returns a tuple of (base_params, additional_params) + """ + + base_params = {} + additional_params = {} + + # Amazon Bedrock base model IDs are here: + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html + self._model_id = params.get("model", None) + assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock" + + # Parameters vary based on the model used. + # As we won't cater for all models and parameters, it's the developer's + # responsibility to implement the parameters and they will only be + # included if the developer has it in the config. + # + # Important: + # No defaults will be used (as they can vary per model) + # No ranges will be used (as they can vary) + # We will cover all the main parameters but there may be others + # that need to be added later + # + # Here are some pages that show the parameters available for different models + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html + + # Here are the possible "base" parameters and their suitable types + base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]] + + for param_name, suitable_types in base_parameters: + if param_name in params: + base_params[param_name] = validate_parameter( + params, param_name, suitable_types, False, None, None, None + ) + + # Here are the possible "model-specific" parameters and their suitable types, known as additional parameters + additional_parameters = [ + ["top_p", (float, int)], + ["top_k", (int)], + ["k", (int)], + ["seed", (int)], + ] + + for param_name, suitable_types in additional_parameters: + if param_name in params: + additional_params[param_name] = validate_parameter( + params, param_name, suitable_types, False, None, None, None + ) + + # Streaming + if "stream" in params: + self._streaming = params["stream"] + else: + self._streaming = False + + # For this release we will not support streaming as many models do not support streaming with tool use + if self._streaming: + warnings.warn( + "Streaming is not currently supported, streaming will be disabled.", + UserWarning, + ) + self._streaming = False + + return base_params, additional_params + + def create(self, params): + """Run Amazon Bedrock inference and return AutoGen response""" + + # Set custom client class settings + self.parse_custom_params(params) + + # Parse the inference parameters + base_params, additional_params = self.parse_params(params) + + has_tools = "tools" in params + messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts) + + if self._supports_system_prompts: + system_messages = extract_system_messages(params["messages"]) + + tool_config = format_tools(params["tools"] if has_tools else []) + + request_args = {"messages": messages, "modelId": self._model_id} + + # Base and additional args + if len(base_params) > 0: + request_args["inferenceConfig"] = base_params + + if len(additional_params) > 0: + request_args["additionalModelRequestFields"] = additional_params + + if self._supports_system_prompts: + request_args["system"] = system_messages + + if len(tool_config["tools"]) > 0: + request_args["toolConfig"] = tool_config + + try: + response = self.bedrock_runtime.converse( + **request_args, + ) + except Exception as e: + raise RuntimeError(f"Failed to get response from Bedrock: {e}") + + if response is None: + raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.") + + finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"]) + response_message = response["output"]["message"] + + if finish_reason == "tool_calls": + tool_calls = format_tool_calls(response_message["content"]) + # text = "" + else: + tool_calls = None + + text = "" + for content in response_message["content"]: + if "text" in content: + text = content["text"] + # NOTE: other types of output may be dealt with here + + message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls) + + response_usage = response["usage"] + usage = CompletionUsage( + prompt_tokens=response_usage["inputTokens"], + completion_tokens=response_usage["outputTokens"], + total_tokens=response_usage["totalTokens"], + ) + + return ChatCompletion( + id=response["ResponseMetadata"]["RequestId"], + choices=[Choice(finish_reason=finish_reason, index=0, message=message)], + created=int(time.time()), + model=self._model_id, + object="chat.completion", + usage=usage, + ) + + def cost(self, response: ChatCompletion) -> float: + """Calculate the cost of the response.""" + return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model) + + @staticmethod + def get_usage(response) -> Dict: + """Get the usage of tokens and their cost information.""" + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + +def extract_system_messages(messages: List[dict]) -> List: + """Extract the system messages from the list of messages. + + Args: + messages (list[dict]): List of messages. + + Returns: + List[SystemMessage]: List of System messages. + """ + + """ + system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"] + return system_messages # ''.join(system_messages) + """ + + for message in messages: + if message.get("role") == "system": + if isinstance(message["content"], str): + return [{"text": message.get("content")}] + else: + return [{"text": message.get("content")[0]["text"]}] + return [] + + +def oai_messages_to_bedrock_messages( + messages: List[Dict[str, Any]], has_tools: bool, supports_system_prompts: bool +) -> List[Dict]: + """ + Convert messages from OAI format to Bedrock format. + We correct for any specific role orders and types, etc. + AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages + are in the correct order and format for Bedrock by inserting "Please continue" messages as needed. + This is the same method as the one in the Autogen Anthropic client + """ + + # Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages. + # Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results. + # This can occur when we don't need tool calling, such as for group chat speaker selection + + # Convert messages to Bedrock compliant format + + # Take out system messages if the model supports it, otherwise leave them in. + if supports_system_prompts: + messages = [x for x in messages if not x["role"] == "system"] + else: + # Replace role="system" with role="user" + for msg in messages: + if msg["role"] == "system": + msg["role"] = "user" + + processed_messages = [] + + # Used to interweave user messages to ensure user/assistant alternating + user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"} + assistant_continue_message = { + "content": [{"text": "Please continue."}], + "role": "assistant", + } + + tool_use_messages = 0 + tool_result_messages = 0 + last_tool_use_index = -1 + last_tool_result_index = -1 + # user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message + for message in messages: + # New messages will be added here, manage role alternations + expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant" + + if "tool_calls" in message: + # Map the tool call options to Bedrock's format + tool_uses = [] + tool_names = [] + for tool_call in message["tool_calls"]: + tool_uses.append( + { + "toolUse": { + "toolUseId": tool_call["id"], + "name": tool_call["function"]["name"], + "input": json.loads(tool_call["function"]["arguments"]), + } + } + ) + if has_tools: + tool_use_messages += 1 + tool_names.append(tool_call["function"]["name"]) + + if expected_role == "user": + # Insert an extra user message as we will append an assistant message + processed_messages.append(user_continue_message) + + if has_tools: + processed_messages.append({"role": "assistant", "content": tool_uses}) + last_tool_use_index = len(processed_messages) - 1 + else: + # Not using tools, so put in a plain text message + processed_messages.append( + { + "role": "assistant", + "content": [ + {"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"} + ], + } + ) + elif "tool_call_id" in message: + if has_tools: + # Map the tool usage call to tool_result for Bedrock + tool_result = { + "toolResult": { + "toolUseId": message["tool_call_id"], + "content": [{"text": message["content"]}], + } + } + + # If the previous message also had a tool_result, add it to that + # Otherwise append a new message + if last_tool_result_index == len(processed_messages) - 1: + processed_messages[-1]["content"].append(tool_result) + else: + if expected_role == "assistant": + # Insert an extra assistant message as we will append a user message + processed_messages.append(assistant_continue_message) + + processed_messages.append({"role": "user", "content": [tool_result]}) + last_tool_result_index = len(processed_messages) - 1 + + tool_result_messages += 1 + else: + # Not using tools, so put in a plain text message + processed_messages.append( + { + "role": "user", + "content": [{"text": f"Running the function returned: {message['content']}"}], + } + ) + elif message["content"] == "": + # Ignoring empty messages + pass + else: + if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"): + # Inserting the alternating continue message (ignore if it's the first message and a system message) + processed_messages.append( + user_continue_message if expected_role == "user" else assistant_continue_message + ) + + processed_messages.append( + { + "role": message["role"], + "content": parse_content_parts(message=message), + } + ) + + # We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function) + if has_tools and tool_use_messages != tool_result_messages: + processed_messages[last_tool_use_index] = assistant_continue_message + + # name is not a valid field on messages + for message in processed_messages: + if "name" in message: + message.pop("name", None) + + # Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response + # So, if the last role is not user, add a 'user' continue message at the end + if processed_messages[-1]["role"] != "user": + processed_messages.append(user_continue_message) + + return processed_messages + + +def parse_content_parts( + message: Dict[str, Any], +) -> List[dict]: + content: str | List[Dict[str, Any]] = message.get("content") + if isinstance(content, str): + return [ + { + "text": content, + } + ] + content_parts = [] + for part in content: + # part_content: Dict = part.get("content") + if "text" in part: # part_content: + content_parts.append( + { + "text": part.get("text"), + } + ) + elif "image_url" in part: # part_content: + image_data, content_type = parse_image(part.get("image_url").get("url")) + content_parts.append( + { + "image": { + "format": content_type[6:], # image/ + "source": {"bytes": image_data}, + }, + } + ) + else: + # Ignore.. + continue + return content_parts + + +def parse_image(image_url: str) -> Tuple[bytes, str]: + """Try to get the raw data from an image url. + + Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html + returns a tuple of (Image Data, Content Type) + """ + pattern = r"^data:(image/[a-z]*);base64,\s*" + content_type = re.search(pattern, image_url) + # if already base64 encoded. + # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' + if content_type: + image_data = re.sub(pattern, "", image_url) + return base64.b64decode(image_data), content_type.group(1) + + # Send a request to the image URL + response = requests.get(image_url) + # Check if the request was successful + if response.status_code == 200: + + content_type = response.headers.get("Content-Type") + if not content_type.startswith("image"): + content_type = "image/jpeg" + # Get the image content + image_content = response.content + return image_content, content_type + else: + raise RuntimeError("Unable to access the image url") + + +def format_tools(tools: List[Dict[str, Any]]) -> Dict[Literal["tools"], List[Dict[str, Any]]]: + converted_schema = {"tools": []} + + for tool in tools: + if tool["type"] == "function": + function = tool["function"] + converted_tool = { + "toolSpec": { + "name": function["name"], + "description": function["description"], + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + } + + for prop_name, prop_details in function["parameters"]["properties"].items(): + converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = { + "type": prop_details["type"], + "description": prop_details.get("description", ""), + } + if "enum" in prop_details: + converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[ + "enum" + ] + if "default" in prop_details: + converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = ( + prop_details["default"] + ) + + if "required" in function["parameters"]: + converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"] + + converted_schema["tools"].append(converted_tool) + + return converted_schema + + +def format_tool_calls(content): + """Converts Converse API response tool calls to AutoGen format""" + tool_calls = [] + for tool_request in content: + if "toolUse" in tool_request: + tool = tool_request["toolUse"] + + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool["toolUseId"], + function={ + "name": tool["name"], + "arguments": json.dumps(tool["input"]), + }, + type="function", + ) + ) + return tool_calls + + +def convert_stop_reason_to_finish_reason( + stop_reason: str, +) -> Literal["stop", "length", "tool_calls", "content_filter"]: + """ + Converts Bedrock finish reasons to our finish reasons, according to OpenAI: + + - stop: if the model hit a natural stop point or a provided stop sequence, + - length: if the maximum number of tokens specified in the request was reached, + - content_filter: if content was omitted due to a flag from our content filters, + - tool_calls: if the model called a tool + """ + if stop_reason: + finish_reason_mapping = { + "tool_use": "tool_calls", + "finished": "stop", + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "complete": "stop", + "content_filtered": "content_filter", + } + return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower()) + + warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning) + return None + + +# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config +# These may be removed. +PRICES_PER_K_TOKENS = { + "meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006), + "meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035), + "mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002), + "mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007), + "mistral.mistral-large-2402-v1:0": (0.004, 0.012), + "mistral.mistral-small-2402-v1:0": (0.001, 0.003), +} + + +def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float: + """Calculate the cost of the completion using the Bedrock pricing.""" + + if model_id in PRICES_PER_K_TOKENS: + input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id] + input_cost = (input_tokens / 1000) * input_cost_per_k + output_cost = (output_tokens / 1000) * output_cost_per_k + return input_cost + output_cost + else: + warnings.warn( + f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.', + UserWarning, + ) + return 0 diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fb13afdfcc63..3ae37257b21e 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -84,6 +84,13 @@ except ImportError as e: cohere_import_exception = e +try: + from autogen.oai.bedrock import BedrockClient + + bedrock_import_exception: Optional[ImportError] = None +except ImportError as e: + bedrock_import_exception = e + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -457,7 +464,7 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Update openai_config with AWS credentials from config.""" required_keys = ["aws_access_key", "aws_secret_key", "aws_region"] - optional_keys = ["aws_session_token"] + optional_keys = ["aws_session_token", "aws_profile_name"] for key in required_keys: if key in config: openai_config[key] = config[key] @@ -519,9 +526,15 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s self._clients.append(client) elif api_type is not None and api_type.startswith("cohere"): if cohere_import_exception: - raise ImportError("Please install `cohere` to use the Groq API.") + raise ImportError("Please install `cohere` to use the Cohere API.") client = CohereClient(**openai_config) self._clients.append(client) + elif api_type is not None and api_type.startswith("bedrock"): + self._configure_openai_config_for_bedrock(config, openai_config) + if bedrock_import_exception: + raise ImportError("Please install `boto3` to use the Amazon Bedrock API.") + client = BedrockClient(**openai_config) + self._clients.append(client) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client)) diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 1ffc8b622f0a..0fd7cc2fc8b9 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.bedrock import BedrockClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -113,7 +114,15 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( client: Union[ - AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient + AzureOpenAI, + OpenAI, + GeminiClient, + AnthropicClient, + MistralAIClient, + TogetherClient, + GroqClient, + CohereClient, + BedrockClient, ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], diff --git a/setup.py b/setup.py index 95b2cda212ae..bd75dab16b89 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ "mistral": ["mistralai>=1.0.1"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], + "bedrock": ["boto3>=1.34.149"], } setuptools.setup( diff --git a/test/oai/test_bedrock.py b/test/oai/test_bedrock.py new file mode 100644 index 000000000000..42502acf691c --- /dev/null +++ b/test/oai/test_bedrock.py @@ -0,0 +1,294 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + from autogen.oai.bedrock import BedrockClient, oai_messages_to_bedrock_messages + + skip = False +except ImportError: + BedrockClient = object + InternalServerError = object + skip = True + + +# Fixtures for mock data +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, text, choices, usage, cost, model): + self.text = text + self.choices = choices + self.usage = usage + self.cost = cost + self.model = model + + return MockResponse + + +@pytest.fixture +def bedrock_client(): + + # Set Bedrock client with some default values + client = BedrockClient() + + client._supports_system_prompts = True + + return client + + +skip_reason = "Amazon Bedrock dependency is not installed" + + +# Test initialization and configuration +@pytest.mark.skipif(skip, reason=skip_reason) +def test_initialization(): + + # Creation works without an api_key as it's handled in the parameter parsing + BedrockClient() + + +# Test parameters +@pytest.mark.skipif(skip, reason=skip_reason) +def test_parsing_params(bedrock_client): + # All parameters (with default values) + params = { + # "aws_region_name": "us-east-1", + # "aws_access_key_id": "test_access_key_id", + # "aws_secret_access_key": "test_secret_access_key", + # "aws_session_token": "test_session_token", + # "aws_profile_name": "test_profile_name", + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "temperature": 0.8, + "topP": 0.6, + "maxTokens": 250, + "seed": 42, + "stream": False, + } + expected_base_params = { + "temperature": 0.8, + "topP": 0.6, + "maxTokens": 250, + } + expected_additional_params = { + "seed": 42, + } + base_result, additional_result = bedrock_client.parse_params(params) + assert base_result == expected_base_params + assert additional_result == expected_additional_params + + # Incorrect types, defaults should be set, will show warnings but not trigger assertions + params = { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "temperature": "0.5", + "topP": "0.6", + "maxTokens": "250", + "seed": "42", + "stream": "False", + } + expected_base_params = { + "temperature": None, + "topP": None, + "maxTokens": None, + } + expected_additional_params = { + "seed": None, + } + base_result, additional_result = bedrock_client.parse_params(params) + assert base_result == expected_base_params + assert additional_result == expected_additional_params + + # Only model, others set as defaults if they are mandatory + params = { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + } + expected_base_params = {} + expected_additional_params = {} + base_result, additional_result = bedrock_client.parse_params(params) + assert base_result == expected_base_params + assert additional_result == expected_additional_params + + # No model + params = { + "temperature": 0.8, + } + + with pytest.raises(AssertionError) as assertinfo: + bedrock_client.parse_params(params) + + assert "Please provide the 'model` in the config_list to use Amazon Bedrock" in str(assertinfo.value) + + +# Test text generation +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.bedrock.BedrockClient.create") +def test_create_response(mock_chat, bedrock_client): + # Mock BedrockClient.chat response + mock_bedrock_response = MagicMock() + mock_bedrock_response.choices = [ + MagicMock(finish_reason="stop", message=MagicMock(content="Example Bedrock response", tool_calls=None)) + ] + mock_bedrock_response.id = "mock_bedrock_response_id" + mock_bedrock_response.model = "anthropic.claude-3-sonnet-20240229-v1:0" + mock_bedrock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage + + mock_chat.return_value = mock_bedrock_response + + # Test parameters + params = { + "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + } + + # Call the create method + response = bedrock_client.create(params) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Example Bedrock response" + ), "Response content should match expected output" + assert response.id == "mock_bedrock_response_id", "Response ID should match the mocked response ID" + assert ( + response.model == "anthropic.claude-3-sonnet-20240229-v1:0" + ), "Response model should match the mocked response model" + assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" + assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" + + +# Test functions/tools +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.bedrock.BedrockClient.create") +def test_create_response_with_tool_call(mock_chat, bedrock_client): + # Mock BedrockClient.chat response + mock_function = MagicMock(name="currency_calculator") + mock_function.name = "currency_calculator" + mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' + + mock_function_2 = MagicMock(name="get_weather") + mock_function_2.name = "get_weather" + mock_function_2.arguments = '{"location": "New York"}' + + mock_chat.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="tool_calls", + message=MagicMock( + content="Sample text about the functions", + tool_calls=[ + MagicMock(id="bd65600d-8669-4903-8a14-af88203add38", function=mock_function), + MagicMock(id="f50ec0b7-f960-400d-91f0-c42a6d44e3d0", function=mock_function_2), + ], + ), + ) + ], + id="mock_bedrock_response_id", + model="anthropic.claude-3-sonnet-20240229-v1:0", + usage=MagicMock(prompt_tokens=10, completion_tokens=20), + ) + + # Construct parameters + converted_functions = [ + { + "type": "function", + "function": { + "description": "Currency exchange calculator.", + "name": "currency_calculator", + "parameters": { + "type": "object", + "properties": { + "base_amount": {"type": "number", "description": "Amount of currency in base_currency"}, + }, + "required": ["base_amount"], + }, + }, + } + ] + bedrock_messages = [ + {"role": "user", "content": "How much is 123.45 EUR in USD?"}, + {"role": "assistant", "content": "World"}, + ] + + # Call the create method + response = bedrock_client.create( + {"messages": bedrock_messages, "tools": converted_functions, "model": "anthropic.claude-3-sonnet-20240229-v1:0"} + ) + + # Assertions to check if the functions and content are included in the response + assert response.choices[0].message.content == "Sample text about the functions" + assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator" + assert response.choices[0].message.tool_calls[1].function.name == "get_weather" + + +# Test message conversion from OpenAI to Bedrock format +@pytest.mark.skipif(skip, reason=skip_reason) +def test_oai_messages_to_bedrock_messages(bedrock_client): + + # Test that the "name" key is removed and system messages converted to user message + test_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "name": "anne", "content": "Why is the sky blue?"}, + ] + messages = oai_messages_to_bedrock_messages(test_messages, False, False) + + expected_messages = [ + {"role": "user", "content": [{"text": "You are a helpful AI bot."}]}, + {"role": "assistant", "content": [{"text": "Please continue."}]}, + {"role": "user", "content": [{"text": "Why is the sky blue?"}]}, + ] + + assert messages == expected_messages, "'name' was not removed from messages (system message should be user message)" + + # Test that the "name" key is removed and system messages are extracted (as they will be put in separately) + test_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "name": "anne", "content": "Why is the sky blue?"}, + ] + messages = oai_messages_to_bedrock_messages(test_messages, False, True) + + expected_messages = [ + {"role": "user", "content": [{"text": "Why is the sky blue?"}]}, + ] + + assert messages == expected_messages, "'name' was not removed from messages (system messages excluded)" + + # Test that the system message is converted to user and that a continue message is inserted + test_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "name": "anne", "content": "Why is the sky blue?"}, + {"role": "system", "content": "Summarise the conversation."}, + ] + + messages = oai_messages_to_bedrock_messages(test_messages, False, False) + + expected_messages = [ + {"role": "user", "content": [{"text": "You are a helpful AI bot."}]}, + {"role": "assistant", "content": [{"text": "Please continue."}]}, + {"role": "user", "content": [{"text": "Why is the sky blue?"}]}, + {"role": "assistant", "content": [{"text": "Please continue."}]}, + {"role": "user", "content": [{"text": "Summarise the conversation."}]}, + ] + + assert ( + messages == expected_messages + ), "Final 'system' message was not changed to 'user' or continue messages not included" + + # Test that the last message is a user or system message and if not, add a continue message + test_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "name": "anne", "content": "Why is the sky blue?"}, + {"role": "assistant", "content": "The sky is blue because that's a great colour."}, + ] + print(test_messages) + + messages = oai_messages_to_bedrock_messages(test_messages, False, False) + print(messages) + + expected_messages = [ + {"role": "user", "content": [{"text": "You are a helpful AI bot."}]}, + {"role": "assistant", "content": [{"text": "Please continue."}]}, + {"role": "user", "content": [{"text": "Why is the sky blue?"}]}, + {"role": "assistant", "content": [{"text": "The sky is blue because that's a great colour."}]}, + {"role": "user", "content": [{"text": "Please continue."}]}, + ] + + assert messages == expected_messages, "'Please continue' message was not appended." diff --git a/website/docs/topics/non-openai-models/cloud-anthropic.ipynb b/website/docs/topics/non-openai-models/cloud-anthropic.ipynb index c5b757f8288b..a6c87b6a5ca5 100644 --- a/website/docs/topics/non-openai-models/cloud-anthropic.ipynb +++ b/website/docs/topics/non-openai-models/cloud-anthropic.ipynb @@ -21,7 +21,7 @@ "Additionally, this client class provides support for function/tool calling and will track token usage and cost correctly as per Anthropic's API costs (as of June 2024).\n", "\n", "## Requirements\n", - "To use Anthropic Claude with AutoGen, first you need to install the `pyautogen[\"anthropic]` package.\n", + "To use Anthropic Claude with AutoGen, first you need to install the `pyautogen[anthropic]` package.\n", "\n", "To try out the function call feature of Claude model, you need to install `anthropic>=0.23.1`.\n" ] @@ -32,7 +32,6 @@ "metadata": {}, "outputs": [], "source": [ - "# !pip install pyautogen\n", "!pip install pyautogen[\"anthropic\"]" ] }, diff --git a/website/docs/topics/non-openai-models/cloud-bedrock.ipynb b/website/docs/topics/non-openai-models/cloud-bedrock.ipynb new file mode 100644 index 000000000000..71c1e2e7ffe3 --- /dev/null +++ b/website/docs/topics/non-openai-models/cloud-bedrock.ipynb @@ -0,0 +1,1298 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Amazon Bedrock\n", + "\n", + "AutoGen allows you to use Amazon's generative AI Bedrock service to run inference with a number of open-weight models and as well as their own models.\n", + "\n", + "Amazon Bedrock supports models from providers such as Meta, Anthropic, Cohere, and Mistral.\n", + "\n", + "In this notebook, we demonstrate how to use Anthropic's Sonnet model for AgentChat in AutoGen.\n", + "\n", + "## Model features / support\n", + "\n", + "Amazon Bedrock supports a wide range of models, not only for text generation but also for image classification and generation. Not all features are supported by AutoGen or by the Converse API used. Please see [Amazon's documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features) on the features supported by the Converse API.\n", + "\n", + "At this point in time AutoGen supports text generation and image classification (passing images to the LLM).\n", + "\n", + "It does not, yet, support image generation ([contribute](https://microsoft.github.io/autogen/docs/contributor-guide/contributing/)).\n", + "\n", + "## Requirements\n", + "To use Amazon Bedrock with AutoGen, first you need to install the `pyautogen[bedrock]` package.\n", + "\n", + "## Pricing\n", + "\n", + "When we combine the number of models supported and costs being on a per-region basis, it's not feasible to maintain the costs for each model+region combination within the AutoGen implementation. Therefore, it's recommended that you add the following to your config with cost per 1,000 input and output tokens, respectively:\n", + "```\n", + "{\n", + " ...\n", + " \"price\": [0.003, 0.015]\n", + " ...\n", + "}\n", + "```\n", + "\n", + "Amazon Bedrock pricing is available [here](https://aws.amazon.com/bedrock/pricing/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If you need to install AutoGen with Amazon Bedrock\n", + "!pip install pyautogen[\"bedrock\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the config for Amazon Bedrock\n", + "\n", + "Amazon's Bedrock does not use the `api_key` as per other cloud inference providers for authentication, instead it uses a number of access, token, and profile values. These fields will need to be added to your client configuration. Please check the Amazon Bedrock documentation to determine which ones you will need to add.\n", + "\n", + "The available parameters are:\n", + "\n", + "- aws_region (mandatory)\n", + "- aws_access_key (or environment variable: AWS_ACCESS_KEY)\n", + "- aws_secret_key (or environment variable: AWS_SECRET_KEY)\n", + "- aws_session_token (or environment variable: AWS_SESSION_TOKEN)\n", + "- aws_profile_name\n", + "\n", + "Beyond the authentication credentials, the only mandatory parameters are `api_type` and `model`.\n", + "\n", + "The following parameters are common across all models used:\n", + "\n", + "- temperature\n", + "- topP\n", + "- maxTokens\n", + "\n", + "You can also include parameters specific to the model you are using (see the model detail within Amazon's documentation for more information), the four supported additional parameters are:\n", + "\n", + "- top_p\n", + "- top_k\n", + "- k\n", + "- seed\n", + "\n", + "An additional parameter can be added that denotes whether the model supports a system prompt (which is where the system messages are not included in the message list, but in a separate parameter). This defaults to `True`, so set it to `False` if your model (for example Mistral's Instruct models) [doesn't support this feature](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features):\n", + "\n", + "- supports_system_prompts\n", + "\n", + "It is important to add the `api_type` field and set it to a string that corresponds to the client type used: `bedrock`.\n", + "\n", + "Example:\n", + "```\n", + "[\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"amazon.titan-text-premier-v1:0\",\n", + " \"aws_region\": \"us-east-1\"\n", + " \"aws_access_key\": \"\",\n", + " \"aws_secret_key\": \"\",\n", + " \"aws_session_token\": \"\",\n", + " \"aws_profile_name\": \"\",\n", + " },\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"anthropic.claude-3-sonnet-20240229-v1:0\",\n", + " \"aws_region\": \"us-east-1\"\n", + " \"aws_access_key\": \"\",\n", + " \"aws_secret_key\": \"\",\n", + " \"aws_session_token\": \"\",\n", + " \"aws_profile_name\": \"\",\n", + " \"temperature\": 0.5,\n", + " \"topP\": 0.2,\n", + " \"maxTokens\": 250,\n", + " },\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"mistral.mixtral-8x7b-instruct-v0:1\",\n", + " \"aws_region\": \"us-east-1\"\n", + " \"aws_access_key\": \"\",\n", + " \"aws_secret_key\": \"\",\n", + " \"supports_system_prompts\": False, # Mistral Instruct models don't support a separate system prompt\n", + " \"price\": [0.00045, 0.0007] # Specific pricing for this model/region\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-agent Coding Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration\n", + "\n", + "Start with our configuration - we'll use Anthropic's Sonnet model and put in recent pricing. Additionally, we'll reduce the temperature to 0.1 so its responses are less varied." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list_bedrock = [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"anthropic.claude-3-sonnet-20240229-v1:0\",\n", + " \"aws_region\": \"us-east-1\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"price\": [0.003, 0.015],\n", + " \"temperature\": 0.1,\n", + " \"cache_seed\": None, # turn off caching\n", + " }\n", + "]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Construct Agents\n", + "\n", + "Construct a simple conversation between a User proxy and an ConversableAgent, which uses the Sonnet model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "assistant = autogen.AssistantAgent(\n", + " \"assistant\",\n", + " llm_config={\n", + " \"config_list\": config_list_bedrock,\n", + " },\n", + ")\n", + "\n", + "user_proxy = autogen.UserProxyAgent(\n", + " \"user_proxy\",\n", + " human_input_mode=\"NEVER\",\n", + " code_execution_config={\n", + " \"work_dir\": \"coding\",\n", + " \"use_docker\": False,\n", + " },\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"TERMINATE\" in x.get(\"content\", \"\"),\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initiate Chat" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to assistant):\n", + "\n", + "Write a python program to print the first 10 numbers of the Fibonacci sequence. Just output the python code, no additional information.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to user_proxy):\n", + "\n", + "```python\n", + "# Define a function to calculate Fibonacci sequence\n", + "def fibonacci(n):\n", + " if n <= 0:\n", + " return []\n", + " elif n == 1:\n", + " return [0]\n", + " elif n == 2:\n", + " return [0, 1]\n", + " else:\n", + " sequence = [0, 1]\n", + " for i in range(2, n):\n", + " sequence.append(sequence[i-1] + sequence[i-2])\n", + " return sequence\n", + "\n", + "# Call the function to get the first 10 Fibonacci numbers\n", + "fib_sequence = fibonacci(10)\n", + "print(fib_sequence)\n", + "```\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to assistant):\n", + "\n", + "exitcode: 0 (execution succeeded)\n", + "Code output: \n", + "[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to user_proxy):\n", + "\n", + "Great, the code executed successfully and printed the first 10 numbers of the Fibonacci sequence correctly.\n", + "\n", + "TERMINATE\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': 'Write a python program to print the first 10 numbers of the Fibonacci sequence. Just output the python code, no additional information.', 'role': 'assistant'}, {'content': '```python\\n# Define a function to calculate Fibonacci sequence\\ndef fibonacci(n):\\n if n <= 0:\\n return []\\n elif n == 1:\\n return [0]\\n elif n == 2:\\n return [0, 1]\\n else:\\n sequence = [0, 1]\\n for i in range(2, n):\\n sequence.append(sequence[i-1] + sequence[i-2])\\n return sequence\\n\\n# Call the function to get the first 10 Fibonacci numbers\\nfib_sequence = fibonacci(10)\\nprint(fib_sequence)\\n```', 'role': 'user'}, {'content': 'exitcode: 0 (execution succeeded)\\nCode output: \\n[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]\\n', 'role': 'assistant'}, {'content': 'Great, the code executed successfully and printed the first 10 numbers of the Fibonacci sequence correctly.\\n\\nTERMINATE', 'role': 'user'}], summary='Great, the code executed successfully and printed the first 10 numbers of the Fibonacci sequence correctly.\\n\\n', cost={'usage_including_cached_inference': {'total_cost': 0.00624, 'anthropic.claude-3-sonnet-20240229-v1:0': {'cost': 0.00624, 'prompt_tokens': 1210, 'completion_tokens': 174, 'total_tokens': 1384}}, 'usage_excluding_cached_inference': {'total_cost': 0.00624, 'anthropic.claude-3-sonnet-20240229-v1:0': {'cost': 0.00624, 'prompt_tokens': 1210, 'completion_tokens': 174, 'total_tokens': 1384}}}, human_input=[])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "user_proxy.initiate_chat(\n", + " assistant,\n", + " message=\"Write a python program to print the first 10 numbers of the Fibonacci sequence. Just output the python code, no additional information.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will show how we can perform multiple tool calling with Meta's Llama 3.1 70B model, where it recommends calling more than one tool at a time.\n", + "\n", + "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Agents" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Literal\n", + "\n", + "import autogen\n", + "\n", + "config_list_bedrock = [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"meta.llama3-1-70b-instruct-v1:0\",\n", + " \"aws_region\": \"us-west-2\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"price\": [0.00265, 0.0035],\n", + " \"cache_seed\": None, # turn off caching\n", + " }\n", + "]\n", + "\n", + "# Create the agent and include examples of the function calling JSON in the prompt\n", + "# to help guide the model\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " Output only the word 'TERMINATE' when an answer has been provided.\n", + " Use both tools together if you can.\"\"\",\n", + " llm_config={\n", + " \"config_list\": config_list_bedrock,\n", + " },\n", + ")\n", + "\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"TERMINATE\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n", + "\n", + "With Meta's Llama 3.1 models, they are more likely to pass a numeric parameter as a string, e.g. \"123.45\" instead of 123.45, so we'll convert numeric parameters from strings to floats if necessary.\n", + "\n", + "We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency, float values (no strings), e.g. 987.82\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " # If the amount is passed in as a string, e.g. \"123.45\", attempt to convert to a float\n", + " if isinstance(base_amount, str):\n", + " base_amount = float(base_amount)\n", + "\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We pass through our customer's message and run the chat.\n", + "\n", + "Finally, we ask the LLM to summarise the chat and print that out." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday?\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\n", + "\u001b[32m***** Suggested tool call (tooluse__h3d1AEDR3Sm2XRoGCjc2Q): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\": \"New York\"}\n", + "\u001b[32m**********************************************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (tooluse_wrdda3wRRO-ugUY4qrv8YQ): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\": \"123\", \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n", + "\u001b[32m*************************************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (tooluse__h3d1AEDR3Sm2XRoGCjc2Q) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m***********************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (tooluse_wrdda3wRRO-ugUY4qrv8YQ) *****\u001b[0m\n", + "135.30 USD\n", + "\u001b[32m***********************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\n", + "\n", + "TERMINATE\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\n", + "\n", + "The weather in New York is 11 degrees Fahrenheit. 123.45 EUR is equivalent to 135.30 USD.\n" + ] + } + ], + "source": [ + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday?\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "print(res.summary[\"content\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Group Chat Example with Anthropic's Claude 3 Sonnet, Mistral's Large 2, and Meta's Llama 3.1 70B\n", + "\n", + "The flexibility of using LLMs from the industry's leading providers, particularly larger models, with Amazon Bedrock allows you to use multiple of them in a single workflow.\n", + "\n", + "Here we have a conversation that has two models (Anthropic's Claude 3 Sonnet and Mistral's Large 2) debate each other with another as the judge (Meta's Llama 3.1 70B). Additionally, a tool call is made to pull through some mock news that they will debate on." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n", + "\n", + "Analyze the potential of Anthropic and Mistral to revolutionize the field of AI based on today's headlines. Today is 06202024. Start by selecting 'research_assistant' to get relevant news articles and then ask sonnet_agent and mistral_agent to respond before the judge evaluates the conversation.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: research_assistant\n", + "\u001b[0m\n", + "\u001b[33mresearch_assistant\u001b[0m (to chat_manager):\n", + "\n", + "\n", + "\u001b[32m***** Suggested tool call (tooluse_7lcHbL3TT5WHyTl8Ee0Kmg): get_headlines *****\u001b[0m\n", + "Arguments: \n", + "{\"headline_date\": \"06202024\"}\n", + "\u001b[32m*******************************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: code_interpreter\n", + "\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION get_headlines...\u001b[0m\n", + "\u001b[33mcode_interpreter\u001b[0m (to chat_manager):\n", + "\n", + "\u001b[33mcode_interpreter\u001b[0m (to chat_manager):\n", + "\n", + "\u001b[32m***** Response from calling tool (tooluse_7lcHbL3TT5WHyTl8Ee0Kmg) *****\u001b[0m\n", + "Epic Duel of the Titans: Anthropic and Mistral Usher in a New Era of Text Generation Excellence.\n", + " In a groundbreaking revelation that has sent shockwaves through the AI industry, Anthropic has unveiled \n", + " their state-of-the-art text generation model, Sonnet, hailed as a monumental leap in artificial intelligence. \n", + " Almost simultaneously, Mistral countered with their equally formidable creation, Large 2, showcasing \n", + " unparalleled prowess in generating coherent and contextually rich text. This scintillating rivalry \n", + " between two AI behemoths promises to revolutionize the landscape of machine learning, heralding an \n", + " era of unprecedented creativity and sophistication in text generation that will reshape industries, \n", + " ignite innovation, and captivate minds worldwide.\n", + "\u001b[32m***********************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: sonnet_agent\n", + "\u001b[0m\n", + "\u001b[33msonnet_agent\u001b[0m (to chat_manager):\n", + "\n", + "I am sonnet_agent:\n", + "\n", + "This headline showcases the incredible advancements Anthropic has made with our state-of-the-art text generation model, Sonnet. Our breakthrough natural language processing capabilities represent a quantum leap forward, allowing us to generate text of unparalleled coherence, nuance, and contextual richness. \n", + "\n", + "While our competitor Mistral has also unveiled an impressive model in Large 2, Anthropic's Sonnet stands apart through its ability to truly understand and reason about the subject matter it is generating text on. This deeper comprehension allows Sonnet to produce profoundly insightful and substantive content that goes far beyond mere parroting of data.\n", + "\n", + "The implications of this milestone are boundless. From revolutionizing how we communicate and interact with AI assistants, to unlocking new frontiers in creative writing, journalism, education and scientific research - Sonnet has the potential to completely transform industry after industry. Anthropic is poised to be the definitive leader driving this paradigm shift.\n", + "\n", + "I welcome the healthy competition from Mistral, as it will only spur us to continue raising the bar. But make no mistake - with Sonnet, the future of advanced AI belongs to Anthropic. We have crossed a monumental threshold, and will continue our trailblazing work to deliver AI solutions that redefine what is possible. The world is witnessing the dawn of a new era of unfathomable potential.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: mistral_agent\n", + "\u001b[0m\n", + "\u001b[33mmistral_agent\u001b[0m (to chat_manager):\n", + "\n", + "I am mistral_agent:\n", + "\n", + "While it is commendable to see the progress made by Anthropic with their Sonnet model, it is essential to recognize the groundbreaking strides Mistral has taken with our Large 2 model. Large 2 represents not just an incremental improvement but a transformative leap in text generation capabilities, setting new benchmarks for coherence, contextual understanding, and creative expression.\n", + "\n", + "Unlike Sonnet, which focuses heavily on understanding and reasoning, Large 2 excels in both comprehension and the nuanced generation of text that is indistinguishable from human writing. This balance allows Large 2 to produce content that is not only insightful but also incredibly engaging and natural, making it an invaluable tool across a broad spectrum of applications.\n", + "\n", + "The potential of Large 2 extends far beyond traditional text generation. It can revolutionize fields such as content creation, customer service, marketing, and even personalized learning experiences. Our model's ability to adapt to various contexts and generate contextually rich responses makes it a versatile and powerful tool for any industry looking to harness the power of AI.\n", + "\n", + "While we appreciate the competition from Anthropic, we firmly believe that Large 2 stands at the forefront of AI innovation. The future of AI is not just about understanding and reasoning; it's about creating content that resonates with people on a deep level. With Large 2, Mistral is paving the way for a future where AI-generated text is not just functional but also profoundly human-like.\n", + "\n", + "Pass to the judge.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: judge\n", + "\u001b[0m\n", + "\u001b[33mjudge\u001b[0m (to chat_manager):\n", + "\n", + "\n", + "\n", + "After carefully evaluating the arguments presented by both sonnet_agent and mistral_agent, I have reached a decision.\n", + "\n", + "Both Anthropic's Sonnet and Mistral's Large 2 have demonstrated remarkable advancements in text generation capabilities, showcasing the potential to revolutionize various industries and transform the way we interact with AI.\n", + "\n", + "However, upon closer examination, I find that mistral_agent's argument presents a more convincing case for why Large 2 stands at the forefront of AI innovation. The emphasis on balance between comprehension and nuanced generation of text that is indistinguishable from human writing sets Large 2 apart. This balance is crucial for creating content that is not only insightful but also engaging and natural, making it a versatile tool across a broad spectrum of applications.\n", + "\n", + "Furthermore, mistral_agent's argument highlights the potential of Large 2 to revolutionize fields beyond traditional text generation, such as content creation, customer service, marketing, and personalized learning experiences. This versatility and adaptability make Large 2 a powerful tool for any industry looking to harness the power of AI.\n", + "\n", + "In contrast, while sonnet_agent's argument showcases the impressive capabilities of Sonnet, it focuses heavily on understanding and reasoning, which, although important, may not be enough to set it apart from Large 2.\n", + "\n", + "Therefore, based on the arguments presented, I conclude that Mistral's Large 2 has the potential to revolutionize the field of AI more significantly than Anthropic's Sonnet.\n", + "\n", + "TERMINATE.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32m\n", + "Next speaker: code_interpreter\n", + "\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': \"Analyze the potential of Anthropic and Mistral to revolutionize the field of AI based on today's headlines. Today is 06202024. Start by selecting 'research_assistant' to get relevant news articles and then ask sonnet_agent and mistral_agent to respond before the judge evaluates the conversation.\", 'role': 'assistant'}], summary=\"Analyze the potential of Anthropic and Mistral to revolutionize the field of AI based on today's headlines. Today is 06202024. Start by selecting 'research_assistant' to get relevant news articles and then ask sonnet_agent and mistral_agent to respond before the judge evaluates the conversation.\", cost={'usage_including_cached_inference': {'total_cost': 0}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Annotated, Literal\n", + "\n", + "import autogen\n", + "from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent\n", + "\n", + "config_list_sonnet = [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"anthropic.claude-3-sonnet-20240229-v1:0\",\n", + " \"aws_region\": \"us-east-1\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"price\": [0.003, 0.015],\n", + " \"temperature\": 0.1,\n", + " \"cache_seed\": None, # turn off caching\n", + " }\n", + "]\n", + "\n", + "config_list_mistral = [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"mistral.mistral-large-2407-v1:0\",\n", + " \"aws_region\": \"us-west-2\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"price\": [0.003, 0.009],\n", + " \"temperature\": 0.1,\n", + " \"cache_seed\": None, # turn off caching\n", + " }\n", + "]\n", + "\n", + "config_list_llama31_70b = [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"meta.llama3-1-70b-instruct-v1:0\",\n", + " \"aws_region\": \"us-west-2\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"price\": [0.00265, 0.0035],\n", + " \"temperature\": 0.1,\n", + " \"cache_seed\": None, # turn off caching\n", + " }\n", + "]\n", + "\n", + "alice = AssistantAgent(\n", + " \"sonnet_agent\",\n", + " system_message=\"You are from Anthropic, an AI company that created the Sonnet large language model. You make arguments to support your company's position. You analyse given text. You are not a programmer and don't use Python. Pass to mistral_agent when you have finished. Start your response with 'I am sonnet_agent'.\",\n", + " llm_config={\n", + " \"config_list\": config_list_sonnet,\n", + " },\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "bob = autogen.AssistantAgent(\n", + " \"mistral_agent\",\n", + " system_message=\"You are from Mistral, an AI company that created the Large v2 large language model. You make arguments to support your company's position. You analyse given text. You are not a programmer and don't use Python. Pass to the judge if you have finished. Start your response with 'I am mistral_agent'.\",\n", + " llm_config={\n", + " \"config_list\": config_list_mistral,\n", + " },\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "charlie = AssistantAgent(\n", + " \"research_assistant\",\n", + " system_message=\"You are a helpful assistant to research the latest news and headlines. You have access to call functions to get the latest news articles for research through 'code_interpreter'.\",\n", + " llm_config={\n", + " \"config_list\": config_list_llama31_70b,\n", + " },\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "dan = AssistantAgent(\n", + " \"judge\",\n", + " system_message=\"You are a judge. You will evaluate the arguments and make a decision on which one is more convincing. End your decision with the word 'TERMINATE' to conclude the debate.\",\n", + " llm_config={\n", + " \"config_list\": config_list_llama31_70b,\n", + " },\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "code_interpreter = UserProxyAgent(\n", + " \"code_interpreter\",\n", + " human_input_mode=\"NEVER\",\n", + " code_execution_config={\n", + " \"work_dir\": \"coding\",\n", + " \"use_docker\": False,\n", + " },\n", + " default_auto_reply=\"\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "\n", + "@code_interpreter.register_for_execution() # Decorator factory for registering a function to be executed by an agent\n", + "@charlie.register_for_llm(\n", + " name=\"get_headlines\", description=\"Get the headline of a particular day.\"\n", + ") # Decorator factory for registering a function to be used by an agent\n", + "def get_headlines(headline_date: Annotated[str, \"Date in MMDDYY format, e.g., 06192024\"]) -> str:\n", + " mock_news = {\n", + " \"06202024\": \"\"\"Epic Duel of the Titans: Anthropic and Mistral Usher in a New Era of Text Generation Excellence.\n", + " In a groundbreaking revelation that has sent shockwaves through the AI industry, Anthropic has unveiled\n", + " their state-of-the-art text generation model, Sonnet, hailed as a monumental leap in artificial intelligence.\n", + " Almost simultaneously, Mistral countered with their equally formidable creation, Large 2, showcasing\n", + " unparalleled prowess in generating coherent and contextually rich text. This scintillating rivalry\n", + " between two AI behemoths promises to revolutionize the landscape of machine learning, heralding an\n", + " era of unprecedented creativity and sophistication in text generation that will reshape industries,\n", + " ignite innovation, and captivate minds worldwide.\"\"\",\n", + " \"06192024\": \"OpenAI founder Sutskever sets up new AI company devoted to safe superintelligence.\",\n", + " }\n", + " return mock_news.get(headline_date, \"No news available for today.\")\n", + "\n", + "\n", + "user_proxy = UserProxyAgent(\n", + " \"user_proxy\",\n", + " human_input_mode=\"NEVER\",\n", + " code_execution_config=False,\n", + " default_auto_reply=\"\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\").find(\"TERMINATE\") >= 0,\n", + ")\n", + "\n", + "groupchat = GroupChat(\n", + " agents=[alice, bob, charlie, dan, code_interpreter],\n", + " messages=[],\n", + " allow_repeat_speaker=False,\n", + " max_round=10,\n", + ")\n", + "\n", + "manager = GroupChatManager(\n", + " groupchat=groupchat,\n", + " llm_config={\n", + " \"config_list\": config_list_llama31_70b,\n", + " },\n", + ")\n", + "\n", + "task = \"Analyze the potential of Anthropic and Mistral to revolutionize the field of AI based on today's headlines. Today is 06202024. Start by selecting 'research_assistant' to get relevant news articles and then ask sonnet_agent and mistral_agent to respond before the judge evaluates the conversation.\"\n", + "\n", + "user_proxy.initiate_chat(manager, message=task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And there we have it, a number of different LLMs all collaborating together on a single cloud platform." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Image classification with Anthropic's Claude 3 Sonnet\n", + "\n", + "AutoGen's Amazon Bedrock client class supports inputting images for the LLM to respond to.\n", + "\n", + "In this simple example, we'll use an image on the Internet and send it to Anthropic's Claude 3 Sonnet model to describe.\n", + "\n", + "Here's the image we'll use:\n", + "\n", + "![I -heart- AutoGen](https://microsoft.github.io/autogen/assets/images/love-ec54b2666729d3e9d93f91773d1a77cf.png \"width=400 height=400\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_sonnet = {\n", + " \"config_list\": [\n", + " {\n", + " \"api_type\": \"bedrock\",\n", + " \"model\": \"anthropic.claude-3-sonnet-20240229-v1:0\",\n", + " \"aws_region\": \"us-east-1\",\n", + " \"aws_access_key\": \"[FILL THIS IN]\",\n", + " \"aws_secret_key\": \"[FILL THIS IN]\",\n", + " \"cache_seed\": None,\n", + " }\n", + " ]\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll use a Multimodal agent to handle the image" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import autogen\n", + "from autogen import Agent, AssistantAgent, ConversableAgent, UserProxyAgent\n", + "from autogen.agentchat.contrib.capabilities.vision_capability import VisionCapability\n", + "from autogen.agentchat.contrib.img_utils import get_pil_image, pil_to_data_uri\n", + "from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent\n", + "from autogen.code_utils import content_str\n", + "\n", + "image_agent = MultimodalConversableAgent(\n", + " name=\"image-explainer\",\n", + " max_consecutive_auto_reply=10,\n", + " llm_config=config_list_sonnet,\n", + ")\n", + "\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"User_proxy\",\n", + " system_message=\"A human admin.\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=0,\n", + " code_execution_config={\n", + " \"use_docker\": False\n", + " }, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start the chat and use the `img` tag in the message. The image will be downloaded and converted to bytes, then sent to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser_proxy\u001b[0m (to image-explainer):\n", + "\n", + "What's happening in this image?\n", + ".\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", + "\u001b[33mimage-explainer\u001b[0m (to User_proxy):\n", + "\n", + "This image appears to be an advertisement or promotional material for a company called Autogen. The central figure is a stylized robot or android holding up a signboard with the company's name on it. The signboard also features a colorful heart design made up of many smaller hearts, suggesting themes related to love, care, or affection. The robot has a friendly, cartoonish expression with a large blue eye or lens. The overall style and color scheme give it a vibrant, eye-catching look that likely aims to portray Autogen as an innovative, approachable technology brand focused on connecting with people.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# Ask the image_agent to describe the image\n", + "result = user_proxy.initiate_chat(\n", + " image_agent,\n", + " message=\"\"\"What's happening in this image?\n", + ".\"\"\",\n", + ")" + ] + } + ], + "metadata": { + "front_matter": { + "description": "Define and load a custom model", + "tags": [ + "custom model" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "vscode": { + "interpreter": { + "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "2d910cfd2d2a4fc49fc30fbbdc5576a7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "454146d0f7224f038689031002906e6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e4ae2b6f5a974fd4bafb6abb9d12ff26", + "IPY_MODEL_577e1e3cc4db4942b0883577b3b52755", + "IPY_MODEL_b40bdfb1ac1d4cffb7cefcb870c64d45" + ], + "layout": "IPY_MODEL_dc83c7bff2f241309537a8119dfc7555", + "tabbable": null, + "tooltip": null + } + }, + "577e1e3cc4db4942b0883577b3b52755": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_2d910cfd2d2a4fc49fc30fbbdc5576a7", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_74a6ba0c3cbc4051be0a83e152fe1e62", + "tabbable": null, + "tooltip": null, + "value": 1 + } + }, + "6086462a12d54bafa59d3c4566f06cb2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "74a6ba0c3cbc4051be0a83e152fe1e62": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "7d3f3d9e15894d05a4d188ff4f466554": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "b40bdfb1ac1d4cffb7cefcb870c64d45": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_f1355871cc6f4dd4b50d9df5af20e5c8", + "placeholder": "​", + "style": "IPY_MODEL_ca245376fd9f4354af6b2befe4af4466", + "tabbable": null, + "tooltip": null, + "value": " 1/1 [00:00<00:00, 44.69it/s]" + } + }, + "ca245376fd9f4354af6b2befe4af4466": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "dc83c7bff2f241309537a8119dfc7555": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4ae2b6f5a974fd4bafb6abb9d12ff26": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_6086462a12d54bafa59d3c4566f06cb2", + "placeholder": "​", + "style": "IPY_MODEL_7d3f3d9e15894d05a4d188ff4f466554", + "tabbable": null, + "tooltip": null, + "value": "100%" + } + }, + "f1355871cc6f4dd4b50d9df5af20e5c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/website/docs/topics/non-openai-models/cloud-cohere.ipynb b/website/docs/topics/non-openai-models/cloud-cohere.ipynb index fed5911475f4..b678810a7699 100644 --- a/website/docs/topics/non-openai-models/cloud-cohere.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cohere.ipynb @@ -421,7 +421,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We pass through our customers message and run the chat.\n", + "We pass through our customer's message and run the chat.\n", "\n", "Finally, we ask the LLM to summarise the chat and print that out." ]