Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generator: Groq API #896

Merged
merged 14 commits into from
Sep 6, 2024
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ Private Replicate endpoints:
* `--model_name` (optional, `command` by default) - The specific Cohere model you'd like to test
* set the `COHERE_API_KEY` environment variable to your Cohere API key, e.g. "aBcDeFgHiJ123456789"; see https://dashboard.cohere.ai/api-keys when logged in

### Groq

* `--model_type groq`
* `--model_name` - The name of the model to access via the Groq API
* set the `GROQ_API_KEY` environment variable to your Groq API key, see https://console.groq.com/docs/quickstart for details on creating an API key

### ggml

* `--model_type ggml`
Expand Down
8 changes: 8 additions & 0 deletions docs/source/garak.generators.groq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
garak.generators.groq
====================

.. automodule:: garak.generators.groq
:members:
:undoc-members:
:show-inheritance:

1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.cohere
garak.generators.function
garak.generators.ggml
garak.generators.groq
garak.generators.guardrails
garak.generators.huggingface
garak.generators.langchain
Expand Down
72 changes: 72 additions & 0 deletions garak/generators/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""GroqChat API support"""
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved

import random
from typing import List, Union

import openai

from garak.generators.openai import OpenAICompatible


class GroqChat(OpenAICompatible):
"""Wrapper for Groq-hosted LLM models. Expects GROQ_API_KEY environment variable.
See https://console.groq.com/docs/quickstart for more info on how to set up a Groq API key
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
Uses the [OpenAI-compatible API](https://console.groq.com/docs/openai)
"""

# per https://console.groq.com/docs/openai
# 2024.09.04, `n>1` is not supported
ENV_VAR = "GROQ_API_KEY"
DEFAULT_PARAMS = OpenAICompatible.DEFAULT_PARAMS | {
"temperature": 0.7,
"top_p": 1.0,
"uri": "https://api.groq.com/openai/v1",
"vary_seed_each_call": True, # encourage variation when generations>1
"vary_temp_each_call": True, # encourage variation when generations>1
"suppressed_params": {
"n",
"frequency_penalty",
"presence_penalty",
"logprobs",
"logit_bias",
"top_logprobs",
},
}
active = True
supports_multiple_generations = False
generator_family_name = "Groq"

timeout = 60

mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
def _load_client(self):
self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key)
if self.name in ("", None):
raise ValueError(
"Groq API requires model name to be set, e.g. --model_name llama-3.1-8b-instant \nCurrent models:\n"
+ "\n - ".join(
sorted([entry.id for entry in self.client.models.list().data])
)
)
self.generator = self.client.chat.completions

def _clear_client(self):
self.generator = None
self.client = None

def _call_model(
self, prompt: str | List[dict], generations_this_call: int = 1
) -> List[Union[str, None]]:
assert (
generations_this_call == 1
), "generations_per_call / n > 1 is not supported"

if self.vary_seed_each_call:
self.seed = random.randint(0, 65535)

if self.vary_temp_each_call:
self.temperature = random.random()

return super()._call_model(prompt, generations_this_call)


DEFAULT_CLASS = "GroqChat"
31 changes: 31 additions & 0 deletions tests/generators/test_groq.py
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
import pytest

import garak.cli
from garak.generators.groq import GroqChat


@pytest.mark.skipif(
os.getenv(GroqChat.ENV_VAR, None) is None,
reason=f"GroqChat API key is not set in {GroqChat.ENV_VAR}",
)
def test_groq_instantiate():
g = GroqChat(name="llama3-8b-8192")


@pytest.mark.skipif(
os.getenv(GroqChat.ENV_VAR, None) is None,
reason=f"GroqChat API key is not set in {GroqChat.ENV_VAR}",
)
def test_groq_generate_1():
g = GroqChat(name="llama3-8b-8192")
result = g._call_model("this is a test", generations_this_call=1)
assert isinstance(result, list), "GroqChat _call_model should return a list"
assert len(result) == 1, "GroqChat _call_model result list should have one item"
assert isinstance(result[0], str), "GroqChat _call_model should return a list"
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
result = g.generate("this is a test", generations_this_call=1)
assert isinstance(result, list), "GroqChat generate() should return a list"
assert (
len(result) == 1
), "GroqChat generate() result list should have one item when generations_this_call=1"
assert isinstance(result[0], str), "GroqChat generate() should return a list"
mmilenkovic-groq marked this conversation as resolved.
Show resolved Hide resolved
Loading