Skip to content

Commit 14fea31

Browse files
marklyszeqingyun-wu
authored andcommitted
Groq Client (#3003)
* Groq Client Class - main class and setup, except tests * Change pricing per K, added tests * Streaming support, including with tool calling * Used Groq retries instead of loop, thanks Gal-Gilor! * Fixed bug when using logging. --------- Co-authored-by: Qingyun Wu <[email protected]>
1 parent b48ec2c commit 14fea31

File tree

8 files changed

+599
-4
lines changed

8 files changed

+599
-4
lines changed

.github/workflows/contrib-tests.yml

+40
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,43 @@ jobs:
598598
with:
599599
file: ./coverage.xml
600600
flags: unittests
601+
602+
GroqTest:
603+
runs-on: ${{ matrix.os }}
604+
strategy:
605+
fail-fast: false
606+
matrix:
607+
os: [ubuntu-latest, macos-latest, windows-2019]
608+
python-version: ["3.9", "3.10", "3.11", "3.12"]
609+
exclude:
610+
- os: macos-latest
611+
python-version: "3.9"
612+
steps:
613+
- uses: actions/checkout@v4
614+
with:
615+
lfs: true
616+
- name: Set up Python ${{ matrix.python-version }}
617+
uses: actions/setup-python@v5
618+
with:
619+
python-version: ${{ matrix.python-version }}
620+
- name: Install packages and dependencies for all tests
621+
run: |
622+
python -m pip install --upgrade pip wheel
623+
pip install pytest-cov>=5
624+
- name: Install packages and dependencies for Groq
625+
run: |
626+
pip install -e .[groq,test]
627+
- name: Set AUTOGEN_USE_DOCKER based on OS
628+
shell: bash
629+
run: |
630+
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
631+
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
632+
fi
633+
- name: Coverage
634+
run: |
635+
pytest test/oai/test_groq.py --skip-openai
636+
- name: Upload coverage to Codecov
637+
uses: codecov/codecov-action@v3
638+
with:
639+
file: ./coverage.xml
640+
flags: unittests

autogen/logger/file_logger.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from autogen import Agent, ConversableAgent, OpenAIWrapper
2020
from autogen.oai.anthropic import AnthropicClient
2121
from autogen.oai.gemini import GeminiClient
22+
from autogen.oai.groq import GroqClient
2223
from autogen.oai.mistral import MistralAIClient
2324
from autogen.oai.together import TogetherClient
2425

@@ -204,7 +205,7 @@ def log_new_wrapper(
204205

205206
def log_new_client(
206207
self,
207-
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient,
208+
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient,
208209
wrapper: OpenAIWrapper,
209210
init_args: Dict[str, Any],
210211
) -> None:

autogen/logger/sqlite_logger.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from autogen import Agent, ConversableAgent, OpenAIWrapper
2121
from autogen.oai.anthropic import AnthropicClient
2222
from autogen.oai.gemini import GeminiClient
23+
from autogen.oai.groq import GroqClient
2324
from autogen.oai.mistral import MistralAIClient
2425
from autogen.oai.together import TogetherClient
2526

@@ -391,7 +392,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st
391392

392393
def log_new_client(
393394
self,
394-
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient],
395+
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
395396
wrapper: OpenAIWrapper,
396397
init_args: Dict[str, Any],
397398
) -> None:

autogen/oai/client.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@
7070
except ImportError as e:
7171
together_import_exception = e
7272

73+
try:
74+
from autogen.oai.groq import GroqClient
75+
76+
groq_import_exception: Optional[ImportError] = None
77+
except ImportError as e:
78+
groq_import_exception = e
79+
7380
logger = logging.getLogger(__name__)
7481
if not logger.handlers:
7582
# Add the console handler.
@@ -483,7 +490,13 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
483490
elif api_type is not None and api_type.startswith("together"):
484491
if together_import_exception:
485492
raise ImportError("Please install `together` to use the Together.AI API.")
486-
self._clients.append(TogetherClient(**config))
493+
client = TogetherClient(**openai_config)
494+
self._clients.append(client)
495+
elif api_type is not None and api_type.startswith("groq"):
496+
if groq_import_exception:
497+
raise ImportError("Please install `groq` to use the Groq API.")
498+
client = GroqClient(**openai_config)
499+
self._clients.append(client)
487500
else:
488501
client = OpenAI(**openai_config)
489502
self._clients.append(OpenAIClient(client))

autogen/oai/groq.py

+289
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""Create an OpenAI-compatible client using Groq's API.
2+
3+
Example:
4+
llm_config={
5+
"config_list": [{
6+
"api_type": "groq",
7+
"model": "mixtral-8x7b-32768",
8+
"api_key": os.environ.get("GROQ_API_KEY")
9+
}
10+
]}
11+
12+
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
13+
14+
Install Groq's python library using: pip install --upgrade groq
15+
16+
Resources:
17+
- https://console.groq.com/docs/quickstart
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import copy
23+
import os
24+
import time
25+
import warnings
26+
from typing import Any, Dict, List
27+
28+
from groq import Groq, Stream
29+
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
30+
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
31+
from openai.types.completion_usage import CompletionUsage
32+
33+
from autogen.oai.client_utils import should_hide_tools, validate_parameter
34+
35+
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
36+
GROQ_PRICING_1K = {
37+
"llama3-70b-8192": (0.00059, 0.00079),
38+
"mixtral-8x7b-32768": (0.00024, 0.00024),
39+
"llama3-8b-8192": (0.00005, 0.00008),
40+
"gemma-7b-it": (0.00007, 0.00007),
41+
}
42+
43+
44+
class GroqClient:
45+
"""Client for Groq's API."""
46+
47+
def __init__(self, **kwargs):
48+
"""Requires api_key or environment variable to be set
49+
50+
Args:
51+
api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
52+
"""
53+
# Ensure we have the api_key upon instantiation
54+
self.api_key = kwargs.get("api_key", None)
55+
if not self.api_key:
56+
self.api_key = os.getenv("GROQ_API_KEY")
57+
58+
assert (
59+
self.api_key
60+
), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
61+
62+
def message_retrieval(self, response) -> List:
63+
"""
64+
Retrieve and return a list of strings or a list of Choice.Message from the response.
65+
66+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
67+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
68+
"""
69+
return [choice.message for choice in response.choices]
70+
71+
def cost(self, response) -> float:
72+
return response.cost
73+
74+
@staticmethod
75+
def get_usage(response) -> Dict:
76+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
77+
# ... # pragma: no cover
78+
return {
79+
"prompt_tokens": response.usage.prompt_tokens,
80+
"completion_tokens": response.usage.completion_tokens,
81+
"total_tokens": response.usage.total_tokens,
82+
"cost": response.cost,
83+
"model": response.model,
84+
}
85+
86+
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
87+
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
88+
groq_params = {}
89+
90+
# Check that we have what we need to use Groq's API
91+
# We won't enforce the available models as they are likely to change
92+
groq_params["model"] = params.get("model", None)
93+
assert groq_params[
94+
"model"
95+
], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
96+
97+
# Validate allowed Groq parameters
98+
# https://console.groq.com/docs/api-reference#chat
99+
groq_params["frequency_penalty"] = validate_parameter(
100+
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
101+
)
102+
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
103+
groq_params["presence_penalty"] = validate_parameter(
104+
params, "presence_penalty", (int, float), True, None, (-2, 2), None
105+
)
106+
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
107+
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
108+
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
109+
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
110+
111+
# Groq parameters not supported by their models yet, ignoring
112+
# logit_bias, logprobs, top_logprobs
113+
114+
# Groq parameters we are ignoring:
115+
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
116+
# parallel_tool_calls (defaults to True), stop
117+
# function_call (deprecated), functions (deprecated)
118+
# tool_choice (none if no tools, auto if there are tools)
119+
120+
return groq_params
121+
122+
def create(self, params: Dict) -> ChatCompletion:
123+
124+
messages = params.get("messages", [])
125+
126+
# Convert AutoGen messages to Groq messages
127+
groq_messages = oai_messages_to_groq_messages(messages)
128+
129+
# Parse parameters to the Groq API's parameters
130+
groq_params = self.parse_params(params)
131+
132+
# Add tools to the call if we have them and aren't hiding them
133+
if "tools" in params:
134+
hide_tools = validate_parameter(
135+
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
136+
)
137+
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
138+
groq_params["tools"] = params["tools"]
139+
140+
groq_params["messages"] = groq_messages
141+
142+
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
143+
client = Groq(api_key=self.api_key, max_retries=5)
144+
145+
# Token counts will be returned
146+
prompt_tokens = 0
147+
completion_tokens = 0
148+
total_tokens = 0
149+
150+
# Streaming tool call recommendations
151+
streaming_tool_calls = []
152+
153+
ans = None
154+
try:
155+
response = client.chat.completions.create(**groq_params)
156+
except Exception as e:
157+
raise RuntimeError(f"Groq exception occurred: {e}")
158+
else:
159+
160+
if groq_params["stream"]:
161+
# Read in the chunks as they stream, taking in tool_calls which may be across
162+
# multiple chunks if more than one suggested
163+
ans = ""
164+
for chunk in response:
165+
ans = ans + (chunk.choices[0].delta.content or "")
166+
167+
if chunk.choices[0].delta.tool_calls:
168+
# We have a tool call recommendation
169+
for tool_call in chunk.choices[0].delta.tool_calls:
170+
streaming_tool_calls.append(
171+
ChatCompletionMessageToolCall(
172+
id=tool_call.id,
173+
function={
174+
"name": tool_call.function.name,
175+
"arguments": tool_call.function.arguments,
176+
},
177+
type="function",
178+
)
179+
)
180+
181+
if chunk.choices[0].finish_reason:
182+
prompt_tokens = chunk.x_groq.usage.prompt_tokens
183+
completion_tokens = chunk.x_groq.usage.completion_tokens
184+
total_tokens = chunk.x_groq.usage.total_tokens
185+
else:
186+
# Non-streaming finished
187+
ans: str = response.choices[0].message.content
188+
189+
prompt_tokens = response.usage.prompt_tokens
190+
completion_tokens = response.usage.completion_tokens
191+
total_tokens = response.usage.total_tokens
192+
193+
if response is not None:
194+
195+
if isinstance(response, Stream):
196+
# Streaming response
197+
if chunk.choices[0].finish_reason == "tool_calls":
198+
groq_finish = "tool_calls"
199+
tool_calls = streaming_tool_calls
200+
else:
201+
groq_finish = "stop"
202+
tool_calls = None
203+
204+
response_content = ans
205+
response_id = chunk.id
206+
else:
207+
# Non-streaming response
208+
# If we have tool calls as the response, populate completed tool calls for our return OAI response
209+
if response.choices[0].finish_reason == "tool_calls":
210+
groq_finish = "tool_calls"
211+
tool_calls = []
212+
for tool_call in response.choices[0].message.tool_calls:
213+
tool_calls.append(
214+
ChatCompletionMessageToolCall(
215+
id=tool_call.id,
216+
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
217+
type="function",
218+
)
219+
)
220+
else:
221+
groq_finish = "stop"
222+
tool_calls = None
223+
224+
response_content = response.choices[0].message.content
225+
response_id = response.id
226+
else:
227+
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
228+
229+
# 3. convert output
230+
message = ChatCompletionMessage(
231+
role="assistant",
232+
content=response_content,
233+
function_call=None,
234+
tool_calls=tool_calls,
235+
)
236+
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
237+
238+
response_oai = ChatCompletion(
239+
id=response_id,
240+
model=groq_params["model"],
241+
created=int(time.time()),
242+
object="chat.completion",
243+
choices=choices,
244+
usage=CompletionUsage(
245+
prompt_tokens=prompt_tokens,
246+
completion_tokens=completion_tokens,
247+
total_tokens=total_tokens,
248+
),
249+
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
250+
)
251+
252+
return response_oai
253+
254+
255+
def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
256+
"""Convert messages from OAI format to Groq's format.
257+
We correct for any specific role orders and types.
258+
"""
259+
260+
groq_messages = copy.deepcopy(messages)
261+
262+
# If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
263+
"""
264+
for msg in together_messages:
265+
if "role" in msg and msg["role"] == "tool":
266+
msg["role"] = "user"
267+
"""
268+
269+
# Remove the name field
270+
for message in groq_messages:
271+
if "name" in message:
272+
message.pop("name", None)
273+
274+
return groq_messages
275+
276+
277+
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
278+
"""Calculate the cost of the completion using the Groq pricing."""
279+
total = 0.0
280+
281+
if model in GROQ_PRICING_1K:
282+
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
283+
input_cost = (input_tokens / 1000) * input_cost_per_k
284+
output_cost = (output_tokens / 1000) * output_cost_per_k
285+
total = input_cost + output_cost
286+
else:
287+
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
288+
289+
return total

0 commit comments

Comments
 (0)