From 406fa6ecac3f93c7636f3a1d3769613887f8e58d Mon Sep 17 00:00:00 2001 From: Hk669 Date: Sun, 23 Jun 2024 22:33:18 +0530 Subject: [PATCH 01/17] initial setup for cohere client --- autogen/logger/file_logger.py | 3 ++- autogen/logger/sqlite_logger.py | 5 +++- autogen/oai/cohere.py | 45 +++++++++++++++++++++++++++++++++ autogen/runtime_logging.py | 3 ++- setup.py | 1 + 5 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 autogen/oai/cohere.py diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index af5583587f66..d94d1d8a21d8 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -204,7 +205,7 @@ def log_new_wrapper( def log_new_client( self, - client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient, + client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | CohereClient, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 969a943017e3..6f132eed86ec 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -391,7 +392,9 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient], + client: Union[ + AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py new file mode 100644 index 000000000000..63895600e626 --- /dev/null +++ b/autogen/oai/cohere.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import copy +import inspect +import json +import os +import time +import warnings +from typing import Any, Dict, List, Tuple, Union + +from cohere import Client as Cohere +from cohere.types import ChatMessage, ToolCallDelta +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage +from typing_extensions import Annotated + +from autogen.oai.client_utils import validate_parameter + +COHERE_PRICING_1K = { + "command-r-plus": (0.003, 0.015), + "command-r": (0.0005, 0.0015), + "command-nightly": (0.00025, 0.00125), + "command": (0.015, 0.075), + "command-light": (0.008, 0.024), + "ccommand-light-nightly": (0.008, 0.024), +} + + +class CohereClient: + def __init__(self, **kwargs): + self.api_key = kwargs.get("api_key", None) + + if not self.api_key: + self.api_key = os.getenv("COHERE_API_KEY") + + if not self.api_key: + raise ValueError("API key is required") + + self.client = Cohere(self.api_key) + self.last_tool_use_status = {} + + @property + def api_key(self) -> str: + return self.api_key diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index adb55ba63b4f..1c6535f2b4cf 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -110,7 +111,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient], + client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/setup.py b/setup.py index 738e09d9061c..8c3a20ba8a63 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], "mistral": ["mistralai>=0.2.0"], + "cohere": ["cohere"], } setuptools.setup( From 288b8748950b3b352b2dcbd284ffc4c029a79742 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Sun, 23 Jun 2024 22:43:21 +0530 Subject: [PATCH 02/17] client update --- autogen/oai/client.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 2c14ca0d4a0c..edb846f2b3f7 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -70,6 +70,13 @@ except ImportError as e: together_import_exception = e +try: + from autogen.oai.cohere import CohereClient + + cohere_import_exception: Optional[ImportError] = None +except ImportError as e: + cohere_import_exception = e + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -484,6 +491,9 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s if together_import_exception: raise ImportError("Please install `together` to use the Together.AI API.") self._clients.append(TogetherClient(**config)) + elif api_type is not None and api_type.startswith("cohere"): + client = CohereClient(**openai_config) + self._clients.append(client) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client)) From 80d61552287f2d2eff50b4c0f1a4adfd97233aa3 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 10:42:13 +0530 Subject: [PATCH 03/17] changes: ClintType added to the utils --- autogen/logger/file_logger.py | 8 ++------ autogen/logger/sqlite_logger.py | 10 ++-------- autogen/oai/client_utils.py | 14 +++++++++++++- autogen/runtime_logging.py | 8 ++------ 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index d94d1d8a21d8..82aa2c48a1e1 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -17,11 +17,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.together import TogetherClient + from autogen.oai.client_utils import ClientType logger = logging.getLogger(__name__) @@ -205,7 +201,7 @@ def log_new_wrapper( def log_new_client( self, - client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | CohereClient, + client: ClientType, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 6f132eed86ec..09b813eb1761 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -18,11 +18,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.together import TogetherClient + from autogen.oai.client_utils import ClientType logger = logging.getLogger(__name__) lock = threading.Lock() @@ -392,9 +388,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[ - AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient - ], + client: ClientType, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py index 55730485b40c..16336ba6a6e6 100644 --- a/autogen/oai/client_utils.py +++ b/autogen/oai/client_utils.py @@ -1,7 +1,19 @@ """Utilities for client classes""" import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from openai import AzureOpenAI, OpenAI + +if TYPE_CHECKING: + from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient + from autogen.oai.gemini import GeminiClient + from autogen.oai.mistral import MistralAIClient + from autogen.oai.together import TogetherClient + + +ClientType = Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient] def validate_parameter( diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 1c6535f2b4cf..e08d99e299f6 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -13,11 +13,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.together import TogetherClient + from autogen.oai.client_utils import ClientType logger = logging.getLogger(__name__) @@ -111,7 +107,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient], + client: ClientType, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: From 7cf0e7d924d721c9195517cf54c32ca3750bff29 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 11:35:35 +0530 Subject: [PATCH 04/17] Revert "changes: ClintType added to the utils" This reverts commit 80d61552287f2d2eff50b4c0f1a4adfd97233aa3. --- autogen/logger/file_logger.py | 8 ++++++-- autogen/logger/sqlite_logger.py | 10 ++++++++-- autogen/oai/client_utils.py | 14 +------------- autogen/runtime_logging.py | 8 ++++++-- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 82aa2c48a1e1..d94d1d8a21d8 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -17,7 +17,11 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.client_utils import ClientType + from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient + from autogen.oai.gemini import GeminiClient + from autogen.oai.mistral import MistralAIClient + from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) @@ -201,7 +205,7 @@ def log_new_wrapper( def log_new_client( self, - client: ClientType, + client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | CohereClient, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 09b813eb1761..6f132eed86ec 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -18,7 +18,11 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.client_utils import ClientType + from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient + from autogen.oai.gemini import GeminiClient + from autogen.oai.mistral import MistralAIClient + from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) lock = threading.Lock() @@ -388,7 +392,9 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: ClientType, + client: Union[ + AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py index 16336ba6a6e6..55730485b40c 100644 --- a/autogen/oai/client_utils.py +++ b/autogen/oai/client_utils.py @@ -1,19 +1,7 @@ """Utilities for client classes""" import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union - -from openai import AzureOpenAI, OpenAI - -if TYPE_CHECKING: - from autogen.oai.anthropic import AnthropicClient - from autogen.oai.cohere import CohereClient - from autogen.oai.gemini import GeminiClient - from autogen.oai.mistral import MistralAIClient - from autogen.oai.together import TogetherClient - - -ClientType = Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient] +from typing import Any, Dict, List, Optional, Tuple def validate_parameter( diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index e08d99e299f6..1c6535f2b4cf 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -13,7 +13,11 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper - from autogen.oai.client_utils import ClientType + from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cohere import CohereClient + from autogen.oai.gemini import GeminiClient + from autogen.oai.mistral import MistralAIClient + from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) @@ -107,7 +111,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: ClientType, + client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, CohereClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: From 967e6fad9b1c96816c50bec0ce55f20f166295ff Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 24 Jun 2024 07:28:08 +0000 Subject: [PATCH 05/17] Message conversion to Cohere, Parameter handling, cost calculation, streaming, tool calling --- autogen/oai/cohere.py | 402 ++++++++++++++++++++++++++++++++++++++++-- setup.py | 2 +- 2 files changed, 393 insertions(+), 11 deletions(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 63895600e626..2d885ae26c25 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -1,15 +1,35 @@ +"""Create an OpenAI-compatible client using Cohere's API. + +Example: + llm_config={ + "config_list": [{ + "api_type": "cohere", + "model": "command-r-plus", + "api_key": os.environ.get("COHERE_API_KEY") + } + ]} + + agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) + +Install Groq's python library using: pip install --upgrade cohere + +Resources: +- https://docs.cohere.com/reference/chat +""" + from __future__ import annotations import copy import inspect import json import os +import random import time import warnings from typing import Any, Dict, List, Tuple, Union from cohere import Client as Cohere -from cohere.types import ChatMessage, ToolCallDelta +from cohere.types import ToolParameterDefinitionsValue, ToolResult from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage @@ -23,23 +43,385 @@ "command-nightly": (0.00025, 0.00125), "command": (0.015, 0.075), "command-light": (0.008, 0.024), - "ccommand-light-nightly": (0.008, 0.024), + "command-light-nightly": (0.008, 0.024), } class CohereClient: + """Client for Cohere's API.""" + def __init__(self, **kwargs): - self.api_key = kwargs.get("api_key", None) + """Requires api_key or environment variable to be set + Args: + api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set) + """ + # Ensure we have the api_key upon instantiation + self.api_key = kwargs.get("api_key", None) if not self.api_key: self.api_key = os.getenv("COHERE_API_KEY") - if not self.api_key: - raise ValueError("API key is required") + assert ( + self.api_key + ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable." + + def message_retrieval(self, response) -> List: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + return [choice.message for choice in response.choices] + + def cost(self, response) -> float: + return response.cost + + @staticmethod + def get_usage(response) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + # ... # pragma: no cover + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" + cohere_params = {} + + # Check that we have what we need to use Cohere's API + # We won't enforce the available models as they are likely to change + cohere_params["model"] = params.get("model", None) + assert cohere_params[ + "model" + ], "Please specify the 'model' in your config list entry to nominate the Cohere model to use." + + # Validate allowed Cohere parameters + # https://docs.cohere.com/reference/chat + cohere_params["temperature"] = validate_parameter( + params, "temperature", (int, float), False, 0.3, (0, None), None + ) + cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None) + cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None) + cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None) + cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None) + cohere_params["frequency_penalty"] = validate_parameter( + params, "frequency_penalty", (int, float), True, 0, (0, 1), None + ) + cohere_params["presence_penalty"] = validate_parameter( + params, "presence_penalty", (int, float), True, 0, (0, 1), None + ) + + # Cohere parameters we are ignoring: + # preamble - we will put the system prompt in here. + # parallel_tool_calls (defaults to True), stop + # conversation_id - allows resuming a previous conversation, we don't support this. + # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future. + # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring. + # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG. + # citation_quality - used for RAG flows and dependent on other parameters we're ignoring. + # max_input_tokens - limits input tokens, not needed. + # stop_sequences - used to stop generation, not needed. + + return cohere_params + + def create(self, params: Dict) -> ChatCompletion: + + messages = params.get("messages", []) + + # Parse parameters to the Cohere API's parameters + cohere_params = self.parse_params(params) + + # Convert AutoGen messages to Cohere messages + cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params) + + cohere_params["chat_history"] = cohere_messages + cohere_params["message"] = final_message + cohere_params["preamble"] = preamble + + # We use chat model by default + client = Cohere(api_key=self.api_key) + + # Token counts will be returned + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + # Streaming tool call recommendations + # streaming_tool_calls = [] + + # Stream if in parameters + streaming = True if "stream" in params and params["stream"] else False + cohere_finish = "" + + max_retries = 5 + for attempt in range(max_retries): + ans = None + try: + if streaming: + response = client.chat_stream(**cohere_params) + else: + response = client.chat(**cohere_params) + except Exception as e: + raise RuntimeError(f"Cohere exception occurred: {e}") + else: + + if streaming: + # Streaming... + ans = "" + for event in response: + if event.event_type == "text-generation": + ans = ans + event.text + elif event.event_type == "tool-calls-generation": + # When streaming, tool calls are compiled at the end into a single event_type + ans = event.text + cohere_finish = "tool_calls" + tool_calls = [] + for tool_call in event.tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=str(random.randint(0, 100000)), + function={ + "name": tool_call.name, + "arguments": ( + "" if tool_call.parameters is None else json.dumps(tool_call.parameters) + ), + }, + type="function", + ) + ) + + # Not using billed_units, but that may be better for cost purposes + prompt_tokens = event.response.meta.tokens.input_tokens + completion_tokens = event.response.meta.tokens.output_tokens + total_tokens = prompt_tokens + completion_tokens + + response_id = event.response.response_id + else: + # Non-streaming finished + ans: str = response.text + + # Not using billed_units, but that may be better for cost purposes + prompt_tokens = response.meta.tokens.input_tokens + completion_tokens = response.meta.tokens.output_tokens + total_tokens = prompt_tokens + completion_tokens + break + + if response is not None: + + if streaming: + # Streaming response + response_content = ans + + if cohere_finish == "": + cohere_finish = "stop" + tool_calls = None + else: + # Non-streaming response + # If we have tool calls as the response, populate completed tool calls for our return OAI response + if response.tool_calls is not None: + cohere_finish = "tool_calls" + tool_calls = [] + for tool_call in response.tool_calls: + + # if parameters are null, clear them out (Cohere can return a string "null" if no parameter values) + + tool_calls.append( + ChatCompletionMessageToolCall( + id=str(random.randint(0, 100000)), + function={ + "name": tool_call.name, + "arguments": ( + "" if tool_call.parameters is None else json.dumps(tool_call.parameters) + ), + }, + type="function", + ) + ) + else: + cohere_finish = "stop" + tool_calls = None + else: + raise RuntimeError(f"Failed to get response from Groq after retrying {attempt + 1} times.") + + # 3. convert output + message = ChatCompletionMessage( + role="assistant", + content=response_content, + function_call=None, + tool_calls=tool_calls, + ) + choices = [Choice(finish_reason=cohere_finish, index=0, message=message)] + + response_oai = ChatCompletion( + id=response_id, + model=cohere_params["model"], + created=int(time.time()), + object="chat.completion", + choices=choices, + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]), + ) + + return response_oai + + +def oai_messages_to_cohere_messages( + messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any] +) -> tuple[list[dict[str, Any]], str, str]: + """Convert messages from OAI format to Cohere's format. + We correct for any specific role orders and types. + + Parameters: + messages: list[Dict[str, Any]]: AutoGen messages + params: Dict[str, Any]: AutoGen parameters dictionary + cohere_params: Dict[str, Any]: Cohere parameters dictionary + + Returns: + List[Dict[str, Any]]: Chat History messages + str: Preamble (system message) + str: Message (the final user message) + """ + + cohere_messages = [] + preamble = "" + + # Tools + if "tools" in params: + cohere_tools = [] + for tool in params["tools"]: + + # build list of properties + parameters = {} + + for key, value in tool["function"]["parameters"]["properties"].items(): + type_str = value["type"] + required = True # Defaults to False, we could consider leaving it as default. + description = value["description"] + + # If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field) + if "enum" in value: + # Access the enum list + enum_values = value["enum"] + enum_strings = [str(value) for value in enum_values] + enum_string = ", ".join(enum_strings) + description = description + ". Possible values are " + enum_string + "." + + parameters[key] = ToolParameterDefinitionsValue( + description=description, type=type_str, required=required + ) + + cohere_tool = { + "name": tool["function"]["name"], + "description": tool["function"]["description"], + "parameter_definitions": parameters, + } + + cohere_tools.append(cohere_tool) + + if len(cohere_tools) > 0: + cohere_params["tools"] = cohere_tools + + tool_calls = [] + tool_results = [] + + # Rules for cohere messages: + # no 'name' field + # 'system' messages go into the preamble parameter + # user role = 'USER' + # assistant role = 'CHATBOT' + # 'content' field renamed to 'message' + # tools go into tools parameter + # tool_results go into tool_results parameter + for message in messages: + + if "role" in message and message["role"] == "system": + # System message + if preamble == "": + preamble = message["content"] + else: + preamble = preamble + "\n" + message["content"] + elif "tool_calls" in message: + # Suggested tool calls, build up the list before we put it into the tool_results + for tool_call in message["tool_calls"]: + tool_calls.append(tool_call) + + # We also add the suggested tool call as a message + new_message = { + "role": "CHATBOT", + "message": message["content"], + # Not including tools in this message, may need to. Testing required. + } + + cohere_messages.append(new_message) + elif "role" in message and message["role"] == "tool": + if "tool_call_id" in message: + # Convert the tool call to a result + + tool_call_id = message["tool_call_id"] + content_output = message["content"] + + # Find the original tool + for tool_call in tool_calls: + if tool_call["id"] == tool_call_id: + + call = { + "name": tool_call["function"]["name"], + "parameters": json.loads(tool_call["function"]["arguments"]), + } + output = [{"value": content_output}] + + tool_results.append(ToolResult(call=call, outputs=output)) + + break + elif "content" in message and isinstance(message["content"], str): + # Standard text message + new_message = { + "role": "USER" if message["role"] == "user" else "CHATBOT", + "message": message["content"], + } + + cohere_messages.append(new_message) + + # Append any Tool Results + if len(tool_results) != 0: + cohere_params["tool_results"] = tool_results + + # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use + cohere_params["force_single_step"] = False + + # We return a blank message when we have tool results + # TODO: Check what happens if tool_results aren't the latest message + return cohere_messages, preamble, "" + + else: + + # We need to get the last message to assign to the message field for Cohere, + # if the last message is a user message, use that, otherwise put in 'continue'. + if cohere_messages[-1]["role"] == "USER": + return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"] + else: + return cohere_messages, preamble, "Please continue." + + +def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float: + """Calculate the cost of the completion using the Cohere pricing.""" + total = 0.0 - self.client = Cohere(self.api_key) - self.last_tool_use_status = {} + if model in COHERE_PRICING_1K: + input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model] + input_cost = (input_tokens / 1000) * input_cost_per_k + output_cost = (output_tokens / 1000) * output_cost_per_k + total = input_cost + output_cost + else: + warnings.warn(f"Cost calculation not available for model {model}", UserWarning) - @property - def api_key(self) -> str: - return self.api_key + return total diff --git a/setup.py b/setup.py index 8c3a20ba8a63..b3ff6b94b4af 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], "mistral": ["mistralai>=0.2.0"], - "cohere": ["cohere"], + "cohere": ["cohere>=5.5.8"], } setuptools.setup( From 2ecbadfe436496a3d9b9663f224216d4ef6a525c Mon Sep 17 00:00:00 2001 From: Mark Sze <66362098+marklysze@users.noreply.github.com> Date: Mon, 24 Jun 2024 17:37:37 +1000 Subject: [PATCH 06/17] Changed Groq references. --- autogen/oai/cohere.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 2d885ae26c25..2d9dd3da2fe0 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -11,7 +11,7 @@ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) -Install Groq's python library using: pip install --upgrade cohere +Install Cohere's python library using: pip install --upgrade cohere Resources: - https://docs.cohere.com/reference/chat @@ -118,7 +118,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: # Cohere parameters we are ignoring: # preamble - we will put the system prompt in here. - # parallel_tool_calls (defaults to True), stop + # parallel_tool_calls (defaults to True), perfect as is. # conversation_id - allows resuming a previous conversation, we don't support this. # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future. # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring. @@ -246,7 +246,7 @@ def create(self, params: Dict) -> ChatCompletion: cohere_finish = "stop" tool_calls = None else: - raise RuntimeError(f"Failed to get response from Groq after retrying {attempt + 1} times.") + raise RuntimeError(f"Failed to get response from Cohere after retrying {attempt + 1} times.") # 3. convert output message = ChatCompletionMessage( From c0e7401a9d4de2f2c9ca6d3b6914b757618c7232 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 21:27:52 +0530 Subject: [PATCH 07/17] minor fix --- autogen/oai/cohere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 2d9dd3da2fe0..fc3b95720265 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -422,6 +422,6 @@ def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> output_cost = (output_tokens / 1000) * output_cost_per_k total = input_cost + output_cost else: - warnings.warn(f"Cost calculation not available for model {model}", UserWarning) + warnings.warn(f"Cost calculation not available for {model} model", UserWarning) return total From f4e5e515d9a6f11b54d48cc5f3aa17cbfc97cc9e Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 21:30:14 +0530 Subject: [PATCH 08/17] tests added --- test/oai/test_cohere.py | 69 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 test/oai/test_cohere.py diff --git a/test/oai/test_cohere.py b/test/oai/test_cohere.py new file mode 100644 index 000000000000..f6d26b07da44 --- /dev/null +++ b/test/oai/test_cohere.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 -m pytest + +import os + +import pytest + +try: + from autogen.oai.cohere import CohereClient, calculate_cohere_cost + + skip = False +except ImportError: + CohereClient = object + skip = True + + +reason = "Cohere dependency not installed!" + + +@pytest.fixture() +def cohere_client(): + return CohereClient(api_key="dummy_api_key") + + +@pytest.mark.skipif(skip, reason=reason) +def test_initialization_missing_api_key(): + os.environ.pop("ANTHROPIC_API_KEY", None) + with pytest.raises( + AssertionError, + match="Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable.", + ): + CohereClient() + + CohereClient(api_key="dummy_api_key") + + +@pytest.mark.skipif(skip, reason=reason) +def test_intialization(cohere_client): + assert cohere_client.api_key == "dummy_api_key", "`api_key` should be correctly set in the config" + + +@pytest.mark.skipif(skip, reason=reason) +def test_calculate_cohere_cost(): + assert ( + calculate_cohere_cost(0, 0, model="command-r") == 0.0 + ), "Cost should be 0 for 0 input_tokens and 0 output_tokens" + assert calculate_cohere_cost(100, 200, model="command-r-plus") == 0.0033 + + +@pytest.mark.skipif(skip, reason=reason) +def test_load_config(cohere_client): + params = { + "model": "command-r-plus", + "stream": False, + "temperature": 1, + "p": 0.8, + "max_tokens": 100, + } + expected_params = { + "model": "command-r-plus", + "temperature": 1, + "p": 0.8, + "seed": None, + "max_tokens": 100, + "frequency_penalty": 0, + "presence_penalty": 0, + "k": 0, + } + result = cohere_client.parse_params(params) + assert result == expected_params, "Config should be correctly loaded" From 2370f7f9717855be770c270d32d75f936d2a025b Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 21:31:34 +0530 Subject: [PATCH 09/17] ref fix --- test/oai/test_cohere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/oai/test_cohere.py b/test/oai/test_cohere.py index f6d26b07da44..83ef56b17087 100644 --- a/test/oai/test_cohere.py +++ b/test/oai/test_cohere.py @@ -23,7 +23,7 @@ def cohere_client(): @pytest.mark.skipif(skip, reason=reason) def test_initialization_missing_api_key(): - os.environ.pop("ANTHROPIC_API_KEY", None) + os.environ.pop("COHERE_API_KEY", None) with pytest.raises( AssertionError, match="Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable.", From 42ab555ce2459c3b27ee8d27da77fd233dcea824 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Mon, 24 Jun 2024 21:37:34 +0530 Subject: [PATCH 10/17] added in the workflows --- .github/workflows/contrib-tests.yml | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 7d8a932b0254..c1f61bf25d32 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -598,3 +598,39 @@ jobs: with: file: ./coverage.xml flags: unittests + + CohereTest: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Cohere + run: | + pip install -e .[cohere,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_cohere.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests From a598ce6f911879711c08e923abdc9faf577e5991 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 24 Jun 2024 20:53:02 +0000 Subject: [PATCH 11/17] Fixed bug on non-streaming text generation --- autogen/oai/cohere.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index fc3b95720265..b47b7ca81b54 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -19,21 +19,18 @@ from __future__ import annotations -import copy -import inspect import json import os import random import time import warnings -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List from cohere import Client as Cohere from cohere.types import ToolParameterDefinitionsValue, ToolResult from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from typing_extensions import Annotated from autogen.oai.client_utils import validate_parameter @@ -151,9 +148,6 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens = 0 total_tokens = 0 - # Streaming tool call recommendations - # streaming_tool_calls = [] - # Stream if in parameters streaming = True if "stream" in params and params["stream"] else False cohere_finish = "" @@ -209,14 +203,16 @@ def create(self, params: Dict) -> ChatCompletion: prompt_tokens = response.meta.tokens.input_tokens completion_tokens = response.meta.tokens.output_tokens total_tokens = prompt_tokens + completion_tokens + + response_id = response.response_id break if response is not None: + response_content = ans + if streaming: # Streaming response - response_content = ans - if cohere_finish == "": cohere_finish = "stop" tool_calls = None @@ -375,7 +371,11 @@ def oai_messages_to_cohere_messages( call = { "name": tool_call["function"]["name"], - "parameters": json.loads(tool_call["function"]["arguments"]), + "parameters": json.loads( + tool_call["function"]["arguments"] + if not tool_call["function"]["arguments"] == "" + else "{}" + ), } output = [{"value": content_output}] From a52323afed0a4c9fcbc73a62c1a5b93718b06189 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Sat, 29 Jun 2024 10:28:31 +0530 Subject: [PATCH 12/17] fix: formatting --- autogen/logger/file_logger.py | 11 ++++++++++- autogen/logger/sqlite_logger.py | 11 ++++++++++- autogen/runtime_logging.py | 4 +++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 9e73aa2eac7c..61a8a6335284 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -206,7 +206,16 @@ def log_new_wrapper( def log_new_client( self, - client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient | CohereClient, + client: ( + AzureOpenAI + | OpenAI + | GeminiClient + | AnthropicClient + | MistralAIClient + | TogetherClient + | GroqClient + | CohereClient + ), wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 6edcb9945d7e..2cf176ebb8f2 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -393,7 +393,16 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient], + client: Union[ + AzureOpenAI, + OpenAI, + GeminiClient, + AnthropicClient, + MistralAIClient, + TogetherClient, + GroqClient, + CohereClient, + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index ec287dfc8b4c..1ffc8b622f0a 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -112,7 +112,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient], + client: Union[ + AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, CohereClient + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: From 550a75946a6b7de378142f9c62b20a17220ea9e9 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 2 Jul 2024 08:52:48 +0000 Subject: [PATCH 13/17] Support Cohere rule for last message not USER when tool_results exist --- autogen/oai/cohere.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index b47b7ca81b54..681852d0fad6 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -398,6 +398,11 @@ def oai_messages_to_cohere_messages( # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use cohere_params["force_single_step"] = False + # If we're adding tool_results, like we are, the last message can't be a USER message + # So, we add a CHATBOT 'continue' message, if so. + if cohere_messages[-1]["role"] == "USER": + cohere_messages.append({"role": "CHATBOT", "content": "Please continue."}) + # We return a blank message when we have tool results # TODO: Check what happens if tool_results aren't the latest message return cohere_messages, preamble, "" From df33229a96920bca1dc4040b93326a91d2f37a8d Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 2 Jul 2024 09:51:12 +0000 Subject: [PATCH 14/17] Added Cohere to documentation --- .../non-openai-models/cloud-cohere.ipynb | 534 ++++++++++++++++++ 1 file changed, 534 insertions(+) create mode 100644 website/docs/topics/non-openai-models/cloud-cohere.ipynb diff --git a/website/docs/topics/non-openai-models/cloud-cohere.ipynb b/website/docs/topics/non-openai-models/cloud-cohere.ipynb new file mode 100644 index 000000000000..202a3c5e74eb --- /dev/null +++ b/website/docs/topics/non-openai-models/cloud-cohere.ipynb @@ -0,0 +1,534 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cohere\n", + "\n", + "[Cohere](https://cohere.com/) is a cloud based platform serving their own LLMs, in particular the Command family of models.\n", + "\n", + "Cohere's API differs from OpenAI's, which is the native API used by AutoGen, so to use Cohere's LLMs you need to use this library.\n", + "\n", + "You will need a Cohere account and create an API key. [See their website for further details](https://cohere.com/)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Features\n", + "\n", + "When using this client class, AutoGen's messages are automatically tailored to accommodate the specific requirements of Cohere's API.\n", + "\n", + "Additionally, this client class provides support for function/tool calling and will track token usage and cost correctly as per Cohere's API costs (as of July 2024).\n", + "\n", + "## Getting started\n", + "\n", + "First you need to install the `pyautogen` package to use AutoGen with the Cohere API library.\n", + "\n", + "``` bash\n", + "pip install pyautogen[cohere]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Cohere provides a number of models to use, included below. See the list of [models here](https://docs.cohere.com/docs/models).\n", + "\n", + "See the sample `OAI_CONFIG_LIST` below showing how the Cohere client class is used by specifying the `api_type` as `cohere`.\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"gpt-35-turbo\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"gpt-4-vision-preview\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"dalle\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"command-r-plus\",\n", + " \"api_key\": \"your Cohere API Key goes here\",\n", + " \"api_type\": \"cohere\"\n", + " },\n", + " {\n", + " \"model\": \"command-r\",\n", + " \"api_key\": \"your Cohere API Key goes here\",\n", + " \"api_type\": \"cohere\"\n", + " },\n", + " {\n", + " \"model\": \"command\",\n", + " \"api_key\": \"your Cohere API Key goes here\",\n", + " \"api_type\": \"cohere\"\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As an alternative to the `api_key` key and value in the config, you can set the environment variable `COHERE_API_KEY` to your Cohere key.\n", + "\n", + "Linux/Mac:\n", + "``` bash\n", + "export COHERE_API_KEY=\"your_cohere_api_key_here\"\n", + "```\n", + "\n", + "Windows:\n", + "``` bash\n", + "set COHERE_API_KEY=your_cohere_api_key_here\n", + "```\n", + "\n", + "## API parameters\n", + "\n", + "The following parameters can be added to your config for the Cohere API. See [this link](https://docs.cohere.com/reference/chat) for further information on them and their default values.\n", + "\n", + "- temperature (number > 0)\n", + "- p (number 0.01..0.99)\n", + "- k (number 0..500)\n", + "- max_tokens (null, integer >= 0)\n", + "- seed (null, integer)\n", + "- frequency_penalty (number 0..1)\n", + "- presence_penalty (number 0..1)\n", + "\n", + "Example:\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"command-r\",\n", + " \"api_key\": \"your Cohere API Key goes here\",\n", + " \"api_type\": \"cohere\",\n", + " \"temperature\": 0.5,\n", + " \"p\": 0.2,\n", + " \"k\": 100,\n", + " \"max_tokens\": 2048,\n", + " \"seed\": 42,\n", + " \"frequency_penalty\": 0.5,\n", + " \"presence_penalty\": 0.2\n", + " }\n", + "]\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-Agent Coding Example\n", + "\n", + "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n", + "\n", + "We'll use Cohere's Command R model which is suitable for coding." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "config_list = [\n", + " {\n", + " # Let's choose the Command-R model\n", + " \"model\": \"command-r\",\n", + " # Provide your Cohere's API key here or put it into the COHERE_API_KEY environment variable.\n", + " \"api_key\": os.environ.get(\"COHERE_API_KEY\"),\n", + " # We specify the API Type as 'cohere' so it uses the Cohere client class\n", + " \"api_type\": \"cohere\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.coding import LocalCommandLineCodeExecutor\n", + "\n", + "# Setting up the code executor\n", + "workdir = Path(\"coding\")\n", + "workdir.mkdir(exist_ok=True)\n", + "code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n", + "\n", + "# Setting up the agents\n", + "\n", + "# The UserProxyAgent will execute the code that the AssistantAgent provides\n", + "user_proxy_agent = UserProxyAgent(\n", + " name=\"User\",\n", + " code_execution_config={\"executor\": code_executor},\n", + " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n", + ")\n", + "\n", + "system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n", + "Solve tasks using your coding and language skills.\n", + "In the following cases, suggest python code (in a python coding block) for the user to execute.\n", + "Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n", + "When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n", + "Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n", + "If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n", + "When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n", + "IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n", + "\n", + "# The AssistantAgent, using Cohere's model, will take the coding request and return code\n", + "assistant_agent = AssistantAgent(\n", + " name=\"Cohere Assistant\",\n", + " system_message=system_message,\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser\u001b[0m (to Cohere Assistant):\n", + "\n", + "Provide code to count the number of prime numbers from 1 to 10000.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mCohere Assistant\u001b[0m (to User):\n", + "\n", + "Here's the code to count the number of prime numbers from 1 to 10,000:\n", + "```python\n", + "# Prime Number Counter\n", + "count = 0\n", + "for num in range(2, 10001):\n", + " if num > 1:\n", + " for div in range(2, num):\n", + " if (num % div) == 0:\n", + " break\n", + " else:\n", + " count += 1\n", + "print(count)\n", + "```\n", + "\n", + "My plan is to use two nested loops. The outer loop iterates through numbers from 2 to 10,000. The inner loop checks if there's any divisor for the current number in the range from 2 to the number itself. If there's no such divisor, the number is prime and the counter is incremented.\n", + "\n", + "Please execute the code and let me know the output.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n", + "\u001b[33mUser\u001b[0m (to Cohere Assistant):\n", + "\n", + "exitcode: 0 (execution succeeded)\n", + "Code output: 1229\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mCohere Assistant\u001b[0m (to User):\n", + "\n", + "That's correct! The code you executed successfully found 1229 prime numbers within the specified range.\n", + "\n", + "FINISH.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + ] + } + ], + "source": [ + "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n", + "chat_result = user_proxy_agent.initiate_chat(\n", + " assistant_agent,\n", + " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will show how Cohere's Command R+ model can perform parallel tool calling, where it recommends calling more than one tool at a time.\n", + "\n", + "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n", + "\n", + "We start by importing libraries and setting up our configuration to use Command R+ and the `cohere` client class." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from typing import Literal\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list = [\n", + " {\"api_type\": \"cohere\", \"model\": \"command-r-plus\", \"api_key\": os.getenv(\"COHERE_API_KEY\"), \"cache_seed\": None}\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create our two agents." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the agent for tool calling\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n", + " llm_config={\"config_list\": config_list},\n", + ")\n", + "\n", + "# Note that we have changed the termination string to be \"HAVE FUN!\"\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n", + "\n", + "We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We pass through our customers message and run the chat.\n", + "\n", + "Finally, we ask the LLM to summarise the chat and print that out." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "I will use the weather_forecast function to find out the weather in New York, and the currency_calculator function to convert 123.45 EUR to USD. I will then search for 'holiday tips' to find some extra information to include in my answer.\n", + "\u001b[32m***** Suggested tool call (45212): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\": \"New York\"}\n", + "\u001b[32m*********************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (16564): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n", + "\u001b[32m************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (45212) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m**********************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (16564) *****\u001b[0m\n", + "135.80 USD\n", + "\u001b[32m**********************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "The weather in New York is 11 degrees Fahrenheit. \n", + "\n", + "€123.45 is worth $135.80. \n", + "\n", + "Here are some holiday tips:\n", + "- Make sure to pack layers for the cold weather\n", + "- Try the local cuisine, New York is famous for its pizza\n", + "- Visit Central Park and take in the views from the top of the Rockefeller Centre\n", + "\n", + "HAVE FUN!\n", + "\n", + "--------------------------------------------------------------------------------\n", + "LLM SUMMARY: The weather in New York is 11 degrees Fahrenheit. 123.45 EUR is worth 135.80 USD. Holiday tips: make sure to pack warm clothes and have a great time!\n" + ] + } + ], + "source": [ + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "print(f\"LLM SUMMARY: {res.summary['content']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that Command R+ recommended we call both tools and passed through the right parameters. The user_proxy executed them and this was passed back to Command R+ to interpret them and respond. Finally, Command R+ was asked to summarise the whole conversation." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "autogen", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 1ffd0ba0742ec94fd868581c6c5feaad1263bed7 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 2 Jul 2024 20:50:53 +0000 Subject: [PATCH 15/17] fixed client.py merge, removed unnecessary comments in groq.py, updated Cohere documentation, added Groq documentation --- autogen/oai/client.py | 3 +- autogen/oai/groq.py | 7 - .../non-openai-models/cloud-cohere.ipynb | 2 +- .../topics/non-openai-models/cloud-groq.ipynb | 524 ++++++++++++++++++ 4 files changed, 527 insertions(+), 9 deletions(-) create mode 100644 website/docs/topics/non-openai-models/cloud-groq.ipynb diff --git a/autogen/oai/client.py b/autogen/oai/client.py index cd73d487642e..ef3a3fd2b1b3 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -504,8 +504,9 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s raise ImportError("Please install `groq` to use the Groq API.") client = GroqClient(**openai_config) self._clients.append(client) - self._clients.append(TogetherClient(**config)) elif api_type is not None and api_type.startswith("cohere"): + if cohere_import_exception: + raise ImportError("Please install `cohere` to use the Groq API.") client = CohereClient(**openai_config) self._clients.append(client) else: diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index a97240887c8e..d2abe5116a25 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -259,13 +259,6 @@ def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[s groq_messages = copy.deepcopy(messages) - # If we have a message with role='tool', which occurs when a function is executed, change it to 'user' - """ - for msg in together_messages: - if "role" in msg and msg["role"] == "tool": - msg["role"] = "user" - """ - # Remove the name field for message in groq_messages: if "name" in message: diff --git a/website/docs/topics/non-openai-models/cloud-cohere.ipynb b/website/docs/topics/non-openai-models/cloud-cohere.ipynb index 202a3c5e74eb..fed5911475f4 100644 --- a/website/docs/topics/non-openai-models/cloud-cohere.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cohere.ipynb @@ -506,7 +506,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can see that Command R+ recommended we call both tools and passed through the right parameters. The user_proxy executed them and this was passed back to Command R+ to interpret them and respond. Finally, Command R+ was asked to summarise the whole conversation." + "We can see that Command R+ recommended we call both tools and passed through the right parameters. The `user_proxy` executed them and this was passed back to Command R+ to interpret them and respond. Finally, Command R+ was asked to summarise the whole conversation." ] } ], diff --git a/website/docs/topics/non-openai-models/cloud-groq.ipynb b/website/docs/topics/non-openai-models/cloud-groq.ipynb new file mode 100644 index 000000000000..d2289cbdcd45 --- /dev/null +++ b/website/docs/topics/non-openai-models/cloud-groq.ipynb @@ -0,0 +1,524 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Groq\n", + "\n", + "[Groq](https://groq.com/) is a cloud based platform serving a number of popular open weight models at high inference speeds. Models include Meta's Llama 3, Mistral AI's Mixtral, and Google's Gemma.\n", + "\n", + "Although Groq's API is aligned well with OpenAI's, which is the native API used by AutoGen, this library provides the ability to set specific parameters as well as track API costs.\n", + "\n", + "You will need a Groq account and create an API key. [See their website for further details](https://groq.com/)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Groq provides a number of models to use, included below. See the list of [models here (requires login)](https://console.groq.com/docs/models).\n", + "\n", + "See the sample `OAI_CONFIG_LIST` below showing how the Groq client class is used by specifying the `api_type` as `groq`.\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"gpt-35-turbo\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"gpt-4-vision-preview\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"dalle\",\n", + " \"api_key\": \"your OpenAI Key goes here\",\n", + " },\n", + " {\n", + " \"model\": \"llama3-8b-8192\",\n", + " \"api_key\": \"your Groq API Key goes here\",\n", + " \"api_type\": \"groq\"\n", + " },\n", + " {\n", + " \"model\": \"llama3-70b-8192\",\n", + " \"api_key\": \"your Groq API Key goes here\",\n", + " \"api_type\": \"groq\"\n", + " },\n", + " {\n", + " \"model\": \"Mixtral 8x7b\",\n", + " \"api_key\": \"your Groq API Key goes here\",\n", + " \"api_type\": \"groq\"\n", + " },\n", + " {\n", + " \"model\": \"gemma-7b-it\",\n", + " \"api_key\": \"your Groq API Key goes here\",\n", + " \"api_type\": \"groq\"\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As an alternative to the `api_key` key and value in the config, you can set the environment variable `GROQ_API_KEY` to your Groq key.\n", + "\n", + "Linux/Mac:\n", + "``` bash\n", + "export GROQ_API_KEY=\"your_groq_api_key_here\"\n", + "```\n", + "\n", + "Windows:\n", + "``` bash\n", + "set GROQ_API_KEY=your_groq_api_key_here\n", + "```\n", + "\n", + "## API parameters\n", + "\n", + "The following parameters can be added to your config for the Groq API. See [this link](https://console.groq.com/docs/text-chat) for further information on them.\n", + "\n", + "- frequency_penalty (number 0..1)\n", + "- max_tokens (integer >= 0)\n", + "- presence_penalty (number -2..2)\n", + "- seed (integer)\n", + "- temperature (number 0..2)\n", + "- top_p (number)\n", + "\n", + "Example:\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3-8b-8192\",\n", + " \"api_key\": \"your Groq API Key goes here\",\n", + " \"api_type\": \"groq\",\n", + " \"frequency_penalty\": 0.5,\n", + " \"max_tokens\": 2048,\n", + " \"presence_penalty\": 0.2,\n", + " \"seed\": 42,\n", + " \"temperature\": 0.5,\n", + " \"top_p\": 0.2\n", + " }\n", + "]\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-Agent Coding Example\n", + "\n", + "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n", + "\n", + "We'll use Meta's Llama 3 model which is suitable for coding." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "config_list = [\n", + " {\n", + " # Let's choose the Llama 3 model\n", + " \"model\": \"llama3-8b-8192\",\n", + " # Put your Groq API key here or put it into the GROQ_API_KEY environment variable.\n", + " \"api_key\": os.environ.get(\"GROQ_API_KEY\"),\n", + " # We specify the API Type as 'groq' so it uses the Groq client class\n", + " \"api_type\": \"groq\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.coding import LocalCommandLineCodeExecutor\n", + "\n", + "# Setting up the code executor\n", + "workdir = Path(\"coding\")\n", + "workdir.mkdir(exist_ok=True)\n", + "code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n", + "\n", + "# Setting up the agents\n", + "\n", + "# The UserProxyAgent will execute the code that the AssistantAgent provides\n", + "user_proxy_agent = UserProxyAgent(\n", + " name=\"User\",\n", + " code_execution_config={\"executor\": code_executor},\n", + " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n", + ")\n", + "\n", + "system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n", + "Solve tasks using your coding and language skills.\n", + "In the following cases, suggest python code (in a python coding block) for the user to execute.\n", + "Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n", + "When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n", + "Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n", + "If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n", + "When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n", + "IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n", + "\n", + "# The AssistantAgent, using Groq's model, will take the coding request and return code\n", + "assistant_agent = AssistantAgent(\n", + " name=\"Groq Assistant\",\n", + " system_message=system_message,\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser\u001b[0m (to Groq Assistant):\n", + "\n", + "Provide code to count the number of prime numbers from 1 to 10000.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mGroq Assistant\u001b[0m (to User):\n", + "\n", + "Here's the plan to count the number of prime numbers from 1 to 10000:\n", + "\n", + "First, we need to write a helper function to check if a number is prime. A prime number is a number that is divisible only by 1 and itself.\n", + "\n", + "Then, we can use a loop to iterate through all numbers from 1 to 10000, check if each number is prime using our helper function, and count the number of prime numbers found.\n", + "\n", + "Here's the Python code to implement this plan:\n", + "```python\n", + "def is_prime(n):\n", + " if n <= 1:\n", + " return False\n", + " for i in range(2, int(n ** 0.5) + 1):\n", + " if n % i == 0:\n", + " return False\n", + " return True\n", + "\n", + "count = 0\n", + "for i in range(2, 10001):\n", + " if is_prime(i):\n", + " count += 1\n", + "\n", + "print(count)\n", + "```\n", + "Please execute this code, and I'll wait for the result.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n", + "\u001b[33mUser\u001b[0m (to Groq Assistant):\n", + "\n", + "exitcode: 0 (execution succeeded)\n", + "Code output: 1229\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mGroq Assistant\u001b[0m (to User):\n", + "\n", + "FINISH\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + ] + } + ], + "source": [ + "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n", + "chat_result = user_proxy_agent.initiate_chat(\n", + " assistant_agent,\n", + " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will show how we can use Meta's Llama 3 model to perform parallel tool calling, where it recommends calling more than one tool at a time, using Groq's cloud inference.\n", + "\n", + "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n", + "\n", + "We start by importing libraries and setting up our configuration to use Meta's Llama 3 model and the `groq` client class." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from typing import Literal\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list = [\n", + " {\"api_type\": \"groq\", \"model\": \"llama3-8b-8192\", \"api_key\": os.getenv(\"GROQ_API_KEY\"), \"cache_seed\": None}\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create our two agents." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the agent for tool calling\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n", + " llm_config={\"config_list\": config_list},\n", + ")\n", + "\n", + "# Note that we have changed the termination string to be \"HAVE FUN!\"\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n", + "\n", + "We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We pass through our customers message and run the chat.\n", + "\n", + "Finally, we ask the LLM to summarise the chat and print that out." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\u001b[32m***** Suggested tool call (call_hg7g): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\":\"New York\"}\n", + "\u001b[32m*************************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (call_hrsf): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\":123.45,\"base_currency\":\"EUR\",\"quote_currency\":\"USD\"}\n", + "\u001b[32m****************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (call_hg7g) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (call_hrsf) *****\u001b[0m\n", + "135.80 USD\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\u001b[32m***** Suggested tool call (call_ahwk): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\":\"New York\"}\n", + "\u001b[32m*************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "LLM SUMMARY: Based on the conversation, it's predicted that New York will be 11 degrees Fahrenheit. You also found out that 123.45 EUR is equal to 135.80 USD. Here are a few holiday tips:\n", + "\n", + "* Pack warm clothing for your visit to New York, as the temperature is expected to be quite chilly.\n", + "* Consider exchanging your money at a local currency exchange or an ATM since the exchange rate might not be as favorable in tourist areas.\n", + "* Make sure to check the estimated expenses for your holiday and adjust your budget accordingly.\n", + "\n", + "I hope you have a great trip!\n" + ] + } + ], + "source": [ + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "print(f\"LLM SUMMARY: {res.summary['content']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using its fast inference, Groq required less than 2 seconds for the whole chat!\n", + "\n", + "Additionally, Llama 3 was able to call both tools and pass through the right parameters. The `user_proxy` then executed them and this was passed back for Llama 3 to summarise the whole conversation." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "autogen", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From bdf27b475b3ae22bc4d4bec632f1584736b1c1de Mon Sep 17 00:00:00 2001 From: Hk669 Date: Wed, 3 Jul 2024 10:15:48 +0530 Subject: [PATCH 16/17] log: ignored params --- autogen/oai/cohere.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 681852d0fad6..0c9f9b800575 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -20,20 +20,31 @@ from __future__ import annotations import json +import logging import os import random +import sys import time import warnings from typing import Any, Dict, List from cohere import Client as Cohere from cohere.types import ToolParameterDefinitionsValue, ToolResult +from flaml.automl.logger import logger_formatter from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from autogen.oai.client_utils import validate_parameter +logger = logging.getLogger(__name__) +if not logger.handlers: + # Add the console handler. + _ch = logging.StreamHandler(stream=sys.stdout) + _ch.setFormatter(logger_formatter) + logger.addHandler(_ch) + + COHERE_PRICING_1K = { "command-r-plus": (0.003, 0.015), "command-r": (0.0005, 0.0015), @@ -117,12 +128,16 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: # preamble - we will put the system prompt in here. # parallel_tool_calls (defaults to True), perfect as is. # conversation_id - allows resuming a previous conversation, we don't support this. + logging.info("Conversation ID: %s", params.get("conversation_id", "None")) # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future. + logging.info("Connectors: %s", params.get("connectors", "None")) # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring. # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG. # citation_quality - used for RAG flows and dependent on other parameters we're ignoring. # max_input_tokens - limits input tokens, not needed. + logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None")) # stop_sequences - used to stop generation, not needed. + logging.info("Stop Sequences: %s", params.get("stop_sequences", "None")) return cohere_params From fae55b7179958f419c7ca6cbc2cf7b67d8a812cd Mon Sep 17 00:00:00 2001 From: Hk669 Date: Wed, 3 Jul 2024 19:55:41 +0530 Subject: [PATCH 17/17] update: custom exception added --- autogen/oai/cohere.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 0c9f9b800575..e04d07327203 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -175,7 +175,7 @@ def create(self, params: Dict) -> ChatCompletion: response = client.chat_stream(**cohere_params) else: response = client.chat(**cohere_params) - except Exception as e: + except CohereRateLimitError as e: raise RuntimeError(f"Cohere exception occurred: {e}") else: @@ -445,3 +445,15 @@ def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> warnings.warn(f"Cost calculation not available for {model} model", UserWarning) return total + + +class CohereError(Exception): + """Base class for other Cohere exceptions""" + + pass + + +class CohereRateLimitError(CohereError): + """Raised when rate limit is exceeded""" + + pass