From 017e906b35ce1702a5ced57cdb413bfc0f92e03d Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 23 Jun 2024 06:32:08 +0000 Subject: [PATCH 1/5] Groq Client Class - main class and setup, except tests --- .github/workflows/contrib-tests.yml | 40 +++++ autogen/logger/file_logger.py | 3 +- autogen/logger/sqlite_logger.py | 3 +- autogen/oai/client.py | 11 ++ autogen/oai/groq.py | 269 ++++++++++++++++++++++++++++ autogen/runtime_logging.py | 3 +- setup.py | 1 + 7 files changed, 327 insertions(+), 3 deletions(-) create mode 100644 autogen/oai/groq.py diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 7d8a932b0254..895e810022de 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -598,3 +598,43 @@ jobs: with: file: ./coverage.xml flags: unittests + + GroqTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Groq + run: | + pip install -e .[groq,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_groq.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index af5583587f66..cdebbdc0eb79 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -19,6 +19,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.gemini import GeminiClient + from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -204,7 +205,7 @@ def log_new_wrapper( def log_new_client( self, - client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient, + client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient, wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 969a943017e3..ccde6bd1d81b 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -20,6 +20,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.gemini import GeminiClient + from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -391,7 +392,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient], + client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 2c14ca0d4a0c..e8bb0535b481 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -70,6 +70,13 @@ except ImportError as e: together_import_exception = e +try: + from autogen.oai.groq import GroqClient + + groq_import_exception: Optional[ImportError] = None +except ImportError as e: + groq_import_exception = e + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -484,6 +491,10 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s if together_import_exception: raise ImportError("Please install `together` to use the Together.AI API.") self._clients.append(TogetherClient(**config)) + elif api_type is not None and api_type.startswith("groq"): + if groq_import_exception: + raise ImportError("Please install `groq` to use the Groq API.") + self._clients.append(GroqClient(**config)) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client)) diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py new file mode 100644 index 000000000000..a05315f24dcc --- /dev/null +++ b/autogen/oai/groq.py @@ -0,0 +1,269 @@ +"""Create an OpenAI-compatible client using Groq's API. + +Example: + llm_config={ + "config_list": [{ + "api_type": "groq", + "model": "mixtral-8x7b-32768", + "api_key": os.environ.get("GROQ_API_KEY") + } + ]} + + agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) + +Install Groq's python library using: pip install --upgrade groq + +Resources: +- https://console.groq.com/docs/quickstart +""" + +from __future__ import annotations + +import copy +import os +import time +import warnings +from typing import Any, Dict, List + +from groq import Groq +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter + + +class GroqClient: + """Client for Groq's API.""" + + def __init__(self, **kwargs): + """Requires api_key or environment variable to be set + + Args: + api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set) + """ + # Ensure we have the api_key upon instantiation + self.api_key = kwargs.get("api_key", None) + if not self.api_key: + self.api_key = os.getenv("GROQ_API_KEY") + + assert ( + self.api_key + ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." + + def message_retrieval(self, response) -> List: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + return [choice.message for choice in response.choices] + + def cost(self, response) -> float: + return response.cost + + @staticmethod + def get_usage(response) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + # ... # pragma: no cover + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" + groq_params = {} + + # Check that we have what we need to use Groq's API + # We won't enforce the available models as they are likely to change + groq_params["model"] = params.get("model", None) + assert groq_params[ + "model" + ], "Please specify the 'model' in your config list entry to nominate the Groq model to use." + + # Validate allowed Groq parameters + # https://console.groq.com/docs/api-reference#chat + groq_params["frequency_penalty"] = validate_parameter( + params, "frequency_penalty", (int, float), True, None, (-2, 2), None + ) + groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None) + groq_params["presence_penalty"] = validate_parameter( + params, "presence_penalty", (int, float), True, None, (-2, 2), None + ) + groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None) + groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) + groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None) + groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None) + + # Groq parameters not supported by their models yet, ignoring + # logit_bias, logprobs, top_logprobs + + # Groq parameters we are ignoring: + # n (must be 1), response_format (to enforce JSON but needs prompting as well), user, + # parallel_tool_calls (defaults to True), stop + # function_call (deprecated), functions (deprecated) + # tool_choice (none if no tools, auto if there are tools) + + # Check if they want to stream and use tools, which isn't currently supported (TODO) + if groq_params["stream"] and "tools" in params: + warnings.warn( + "Streaming is not supported when using tools, streaming will be disabled.", + UserWarning, + ) + + groq_params["stream"] = False + + return groq_params + + def create(self, params: Dict) -> ChatCompletion: + + messages = params.get("messages", []) + + # Convert AutoGen messages to Groq messages + groq_messages = oai_messages_to_groq_messages(messages) + + # Parse parameters to the Groq API's parameters + groq_params = self.parse_params(params) + + # Add tools to the call if we have them and aren't hiding them + if "tools" in params: + hide_tools = validate_parameter( + params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"] + ) + if not should_hide_tools(groq_messages, params["tools"], hide_tools): + groq_params["tools"] = params["tools"] + + groq_params["messages"] = groq_messages + + # We use chat model by default + client = Groq(api_key=self.api_key) + + # Token counts will be returned + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + max_retries = 5 + for attempt in range(max_retries): + ans = None + try: + response = client.chat.completions.create(**groq_params) + except Exception as e: + raise RuntimeError(f"Groq exception occurred: {e}") + else: + + if groq_params["stream"]: + # Read in the chunks as they stream + ans = "" + for chunk in response: + ans = ans + (chunk.choices[0].delta.content or "") + + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + total_tokens = chunk.usage.total_tokens + else: + # Non-streaming + ans: str = response.choices[0].message.content + + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + total_tokens = response.usage.total_tokens + break + + if response is not None: + # If we have tool calls as the response, populate completed tool calls for our return OAI response + if response.choices[0].finish_reason == "tool_calls": + groq_finish = "tool_calls" + tool_calls = [] + for tool_call in response.choices[0].message.tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={"name": tool_call.function.name, "arguments": tool_call.function.arguments}, + type="function", + ) + ) + else: + groq_finish = "stop" + tool_calls = None + + else: + raise RuntimeError(f"Failed to get response from Groq after retrying {attempt + 1} times.") + + # 3. convert output + message = ChatCompletionMessage( + role="assistant", + content=response.choices[0].message.content, + function_call=None, + tool_calls=tool_calls, + ) + choices = [Choice(finish_reason=groq_finish, index=0, message=message)] + + response_oai = ChatCompletion( + id=response.id, + model=groq_params["model"], + created=int(time.time()), + object="chat.completion", + choices=choices, + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]), + ) + + return response_oai + + +def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + """Convert messages from OAI format to Groq's format. + We correct for any specific role orders and types. + """ + + groq_messages = copy.deepcopy(messages) + + # If we have a message with role='tool', which occurs when a function is executed, change it to 'user' + """ + for msg in together_messages: + if "role" in msg and msg["role"] == "tool": + msg["role"] = "user" + """ + + # Remove the name field + for message in groq_messages: + if "name" in message: + message.pop("name", None) + + return groq_messages + + +# PRICING + +# Cost per million tokens - Input / Output +GROQ_PRICING_1M = { + "llama3-70b-8192": (0.59, 0.79), + "mixtral-8x7b-32768": (0.24, 0.24), + "llama3-8b-8192": (0.05, 0.08), + "gemma-7b-it": (0.07, 0.07), +} + + +def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float: + """Calculate the cost of the completion using the Groq pricing.""" + total = 0.0 + + if model in GROQ_PRICING_1M: + input_cost_per_mil, output_cost_per_mil = GROQ_PRICING_1M[model] + input_cost = (input_tokens / 1000000) * input_cost_per_mil + output_cost = (output_tokens / 1000000) * output_cost_per_mil + total = input_cost + output_cost + else: + warnings.warn(f"Cost calculation not available for model {model}", UserWarning) + + return total diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index adb55ba63b4f..4ad76cf5b7da 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -15,6 +15,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.gemini import GeminiClient + from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient from autogen.oai.together import TogetherClient @@ -110,7 +111,7 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient], + client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/setup.py b/setup.py index 738e09d9061c..9a67c70f49de 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], "mistral": ["mistralai>=0.2.0"], + "groq": ["groq>=0.9.0"], } setuptools.setup( From 76b8ef1ff41887acd318428d4898f365750dd08f Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 23 Jun 2024 20:12:53 +0000 Subject: [PATCH 2/5] Change pricing per K, added tests --- autogen/oai/groq.py | 27 ++--- test/oai/test_groq.py | 249 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 15 deletions(-) create mode 100644 test/oai/test_groq.py diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index a05315f24dcc..20dd558de3c9 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -32,6 +32,14 @@ from autogen.oai.client_utils import should_hide_tools, validate_parameter +# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K) +GROQ_PRICING_1K = { + "llama3-70b-8192": (0.00059, 0.00079), + "mixtral-8x7b-32768": (0.00024, 0.00024), + "llama3-8b-8192": (0.00005, 0.00008), + "gemma-7b-it": (0.00007, 0.00007), +} + class GroqClient: """Client for Groq's API.""" @@ -243,25 +251,14 @@ def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[s return groq_messages -# PRICING - -# Cost per million tokens - Input / Output -GROQ_PRICING_1M = { - "llama3-70b-8192": (0.59, 0.79), - "mixtral-8x7b-32768": (0.24, 0.24), - "llama3-8b-8192": (0.05, 0.08), - "gemma-7b-it": (0.07, 0.07), -} - - def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float: """Calculate the cost of the completion using the Groq pricing.""" total = 0.0 - if model in GROQ_PRICING_1M: - input_cost_per_mil, output_cost_per_mil = GROQ_PRICING_1M[model] - input_cost = (input_tokens / 1000000) * input_cost_per_mil - output_cost = (output_tokens / 1000000) * output_cost_per_mil + if model in GROQ_PRICING_1K: + input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model] + input_cost = (input_tokens / 1000) * input_cost_per_k + output_cost = (output_tokens / 1000) * output_cost_per_k total = input_cost + output_cost else: warnings.warn(f"Cost calculation not available for model {model}", UserWarning) diff --git a/test/oai/test_groq.py b/test/oai/test_groq.py new file mode 100644 index 000000000000..f55edbd8c7a6 --- /dev/null +++ b/test/oai/test_groq.py @@ -0,0 +1,249 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + from autogen.oai.groq import GroqClient, calculate_groq_cost + + skip = False +except ImportError: + GroqClient = object + InternalServerError = object + skip = True + + +# Fixtures for mock data +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, text, choices, usage, cost, model): + self.text = text + self.choices = choices + self.usage = usage + self.cost = cost + self.model = model + + return MockResponse + + +@pytest.fixture +def groq_client(): + return GroqClient(api_key="fake_api_key") + + +skip_reason = "Groq dependency is not installed" + + +# Test initialization and configuration +@pytest.mark.skipif(skip, reason=skip_reason) +def test_initialization(): + + # Missing any api_key + with pytest.raises(AssertionError) as assertinfo: + GroqClient() # Should raise an AssertionError due to missing api_key + + assert "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable." in str( + assertinfo.value + ) + + # Creation works + GroqClient(api_key="fake_api_key") # Should create okay now. + + +# Test standard initialization +@pytest.mark.skipif(skip, reason=skip_reason) +def test_valid_initialization(groq_client): + assert groq_client.api_key == "fake_api_key", "Config api_key should be correctly set" + + +# Test parameters +@pytest.mark.skipif(skip, reason=skip_reason) +def test_parsing_params(groq_client): + # All parameters + params = { + "model": "llama3-8b-8192", + "frequency_penalty": 1.5, + "presence_penalty": 1.5, + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + expected_params = { + "model": "llama3-8b-8192", + "frequency_penalty": 1.5, + "presence_penalty": 1.5, + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + result = groq_client.parse_params(params) + assert result == expected_params + + # Only model, others set as defaults + params = { + "model": "llama3-8b-8192", + } + expected_params = { + "model": "llama3-8b-8192", + "frequency_penalty": None, + "presence_penalty": None, + "max_tokens": None, + "seed": None, + "stream": False, + "temperature": 1, + "top_p": None, + } + result = groq_client.parse_params(params) + assert result == expected_params + + # Incorrect types, defaults should be set, will show warnings but not trigger assertions + params = { + "model": "llama3-8b-8192", + "frequency_penalty": "1.5", + "presence_penalty": "1.5", + "max_tokens": "1000", + "seed": "42", + "stream": "False", + "temperature": "1", + "top_p": "0.8", + } + result = groq_client.parse_params(params) + assert result == expected_params + + # Values outside bounds, should warn and set to defaults + params = { + "model": "llama3-8b-8192", + "frequency_penalty": 5000, + "presence_penalty": -500, + "temperature": 3, + } + result = groq_client.parse_params(params) + assert result == expected_params + + # No model + params = { + "frequency_penalty": 1, + } + + with pytest.raises(AssertionError) as assertinfo: + result = groq_client.parse_params(params) + + assert "Please specify the 'model' in your config list entry to nominate the Groq model to use." in str( + assertinfo.value + ) + + +# Test cost calculation +@pytest.mark.skipif(skip, reason=skip_reason) +def test_cost_calculation(mock_response): + response = mock_response( + text="Example response", + choices=[{"message": "Test message 1"}], + usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800}, + cost=None, + model="llama3-70b-8192", + ) + assert ( + calculate_groq_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model) + == 0.000532 + ), "Cost for this should be $0.000532" + + +# Test text generation +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.groq.GroqClient.create") +def test_create_response(mock_chat, groq_client): + # Mock GroqClient.chat response + mock_groq_response = MagicMock() + mock_groq_response.choices = [ + MagicMock(finish_reason="stop", message=MagicMock(content="Example Groq response", tool_calls=None)) + ] + mock_groq_response.id = "mock_groq_response_id" + mock_groq_response.model = "llama3-70b-8192" + mock_groq_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage + + mock_chat.return_value = mock_groq_response + + # Test parameters + params = { + "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], + "model": "llama3-70b-8192", + } + + # Call the create method + response = groq_client.create(params) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Example Groq response" + ), "Response content should match expected output" + assert response.id == "mock_groq_response_id", "Response ID should match the mocked response ID" + assert response.model == "llama3-70b-8192", "Response model should match the mocked response model" + assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" + assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" + + +# Test functions/tools +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.groq.GroqClient.create") +def test_create_response_with_tool_call(mock_chat, groq_client): + # Mock `groq_response = client.chat(**groq_params)` + mock_function = MagicMock(name="currency_calculator") + mock_function.name = "currency_calculator" + mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' + + mock_function_2 = MagicMock(name="get_weather") + mock_function_2.name = "get_weather" + mock_function_2.arguments = '{"location": "Chicago"}' + + mock_chat.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="tool_calls", + message=MagicMock( + content="Sample text about the functions", + tool_calls=[ + MagicMock(id="gdRdrvnHh", function=mock_function), + MagicMock(id="abRdrvnHh", function=mock_function_2), + ], + ), + ) + ], + id="mock_groq_response_id", + model="llama3-70b-8192", + usage=MagicMock(prompt_tokens=10, completion_tokens=20), + ) + + # Construct parameters + converted_functions = [ + { + "type": "function", + "function": { + "description": "Currency exchange calculator.", + "name": "currency_calculator", + "parameters": { + "type": "object", + "properties": { + "base_amount": {"type": "number", "description": "Amount of currency in base_currency"}, + }, + "required": ["base_amount"], + }, + }, + } + ] + groq_messages = [ + {"role": "user", "content": "How much is 123.45 EUR in USD?"}, + {"role": "assistant", "content": "World"}, + ] + + # Call the create method + response = groq_client.create({"messages": groq_messages, "tools": converted_functions, "model": "llama3-70b-8192"}) + + # Assertions to check if the functions and content are included in the response + assert response.choices[0].message.content == "Sample text about the functions" + assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator" + assert response.choices[0].message.tool_calls[1].function.name == "get_weather" From e89231e2aa2646f5951fd56b20e16940c98c33e2 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 23 Jun 2024 20:49:16 +0000 Subject: [PATCH 3/5] Streaming support, including with tool calling --- autogen/oai/groq.py | 88 +++++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index 20dd558de3c9..412370ebbe0c 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -25,7 +25,7 @@ import warnings from typing import Any, Dict, List -from groq import Groq +from groq import Groq, Stream from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage @@ -117,15 +117,6 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: # function_call (deprecated), functions (deprecated) # tool_choice (none if no tools, auto if there are tools) - # Check if they want to stream and use tools, which isn't currently supported (TODO) - if groq_params["stream"] and "tools" in params: - warnings.warn( - "Streaming is not supported when using tools, streaming will be disabled.", - UserWarning, - ) - - groq_params["stream"] = False - return groq_params def create(self, params: Dict) -> ChatCompletion: @@ -156,6 +147,9 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens = 0 total_tokens = 0 + # Streaming tool call recommendations + streaming_tool_calls = [] + max_retries = 5 for attempt in range(max_retries): ans = None @@ -166,16 +160,32 @@ def create(self, params: Dict) -> ChatCompletion: else: if groq_params["stream"]: - # Read in the chunks as they stream + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested ans = "" for chunk in response: ans = ans + (chunk.choices[0].delta.content or "") - prompt_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens - total_tokens = chunk.usage.total_tokens + if chunk.choices[0].delta.tool_calls: + # We have a tool call recommendation + for tool_call in chunk.choices[0].delta.tool_calls: + streaming_tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + type="function", + ) + ) + + if chunk.choices[0].finish_reason: + prompt_tokens = chunk.x_groq.usage.prompt_tokens + completion_tokens = chunk.x_groq.usage.completion_tokens + total_tokens = chunk.x_groq.usage.total_tokens else: - # Non-streaming + # Non-streaming finished ans: str = response.choices[0].message.content prompt_tokens = response.usage.prompt_tokens @@ -184,36 +194,52 @@ def create(self, params: Dict) -> ChatCompletion: break if response is not None: - # If we have tool calls as the response, populate completed tool calls for our return OAI response - if response.choices[0].finish_reason == "tool_calls": - groq_finish = "tool_calls" - tool_calls = [] - for tool_call in response.choices[0].message.tool_calls: - tool_calls.append( - ChatCompletionMessageToolCall( - id=tool_call.id, - function={"name": tool_call.function.name, "arguments": tool_call.function.arguments}, - type="function", - ) - ) + + if isinstance(response, Stream): + # Streaming response + if chunk.choices[0].finish_reason == "tool_calls": + groq_finish = "tool_calls" + tool_calls = streaming_tool_calls + else: + groq_finish = "stop" + tool_calls = None + + response_content = ans + response_id = chunk.id else: - groq_finish = "stop" - tool_calls = None + # Non-streaming response + # If we have tool calls as the response, populate completed tool calls for our return OAI response + if response.choices[0].finish_reason == "tool_calls": + groq_finish = "tool_calls" + tool_calls = [] + for tool_call in response.choices[0].message.tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={"name": tool_call.function.name, "arguments": tool_call.function.arguments}, + type="function", + ) + ) + else: + groq_finish = "stop" + tool_calls = None + response_content = response.choices[0].message.content + response_id = response.id else: raise RuntimeError(f"Failed to get response from Groq after retrying {attempt + 1} times.") # 3. convert output message = ChatCompletionMessage( role="assistant", - content=response.choices[0].message.content, + content=response_content, function_call=None, tool_calls=tool_calls, ) choices = [Choice(finish_reason=groq_finish, index=0, message=message)] response_oai = ChatCompletion( - id=response.id, + id=response_id, model=groq_params["model"], created=int(time.time()), object="chat.completion", From a181aa1ffc3b79129ec0bf145129ccf63a18a513 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Wed, 26 Jun 2024 02:32:44 +0000 Subject: [PATCH 4/5] Used Groq retries instead of loop, thanks Gal-Gilor! --- autogen/oai/groq.py | 79 ++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py index 412370ebbe0c..a97240887c8e 100644 --- a/autogen/oai/groq.py +++ b/autogen/oai/groq.py @@ -139,8 +139,8 @@ def create(self, params: Dict) -> ChatCompletion: groq_params["messages"] = groq_messages - # We use chat model by default - client = Groq(api_key=self.api_key) + # We use chat model by default, and set max_retries to 5 (in line with typical retries loop) + client = Groq(api_key=self.api_key, max_retries=5) # Token counts will be returned prompt_tokens = 0 @@ -150,48 +150,45 @@ def create(self, params: Dict) -> ChatCompletion: # Streaming tool call recommendations streaming_tool_calls = [] - max_retries = 5 - for attempt in range(max_retries): - ans = None - try: - response = client.chat.completions.create(**groq_params) - except Exception as e: - raise RuntimeError(f"Groq exception occurred: {e}") - else: + ans = None + try: + response = client.chat.completions.create(**groq_params) + except Exception as e: + raise RuntimeError(f"Groq exception occurred: {e}") + else: - if groq_params["stream"]: - # Read in the chunks as they stream, taking in tool_calls which may be across - # multiple chunks if more than one suggested - ans = "" - for chunk in response: - ans = ans + (chunk.choices[0].delta.content or "") - - if chunk.choices[0].delta.tool_calls: - # We have a tool call recommendation - for tool_call in chunk.choices[0].delta.tool_calls: - streaming_tool_calls.append( - ChatCompletionMessageToolCall( - id=tool_call.id, - function={ - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - type="function", - ) + if groq_params["stream"]: + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested + ans = "" + for chunk in response: + ans = ans + (chunk.choices[0].delta.content or "") + + if chunk.choices[0].delta.tool_calls: + # We have a tool call recommendation + for tool_call in chunk.choices[0].delta.tool_calls: + streaming_tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + type="function", ) + ) - if chunk.choices[0].finish_reason: - prompt_tokens = chunk.x_groq.usage.prompt_tokens - completion_tokens = chunk.x_groq.usage.completion_tokens - total_tokens = chunk.x_groq.usage.total_tokens - else: - # Non-streaming finished - ans: str = response.choices[0].message.content + if chunk.choices[0].finish_reason: + prompt_tokens = chunk.x_groq.usage.prompt_tokens + completion_tokens = chunk.x_groq.usage.completion_tokens + total_tokens = chunk.x_groq.usage.total_tokens + else: + # Non-streaming finished + ans: str = response.choices[0].message.content - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - total_tokens = response.usage.total_tokens - break + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + total_tokens = response.usage.total_tokens if response is not None: @@ -227,7 +224,7 @@ def create(self, params: Dict) -> ChatCompletion: response_content = response.choices[0].message.content response_id = response.id else: - raise RuntimeError(f"Failed to get response from Groq after retrying {attempt + 1} times.") + raise RuntimeError("Failed to get response from Groq after retrying 5 times.") # 3. convert output message = ChatCompletionMessage( From 6f09479838ea42b3735e18600eb8832fa3effed7 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Thu, 27 Jun 2024 19:23:29 +0000 Subject: [PATCH 5/5] Fixed bug when using logging. --- autogen/oai/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index e8bb0535b481..a7b12ce83dad 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -490,11 +490,13 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s elif api_type is not None and api_type.startswith("together"): if together_import_exception: raise ImportError("Please install `together` to use the Together.AI API.") - self._clients.append(TogetherClient(**config)) + client = TogetherClient(**openai_config) + self._clients.append(client) elif api_type is not None and api_type.startswith("groq"): if groq_import_exception: raise ImportError("Please install `groq` to use the Groq API.") - self._clients.append(GroqClient(**config)) + client = GroqClient(**openai_config) + self._clients.append(client) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client))