Skip to content

Commit 3cda11b

Browse files
authored
feat(generation): add backend support for Groq (osl-incubator#86)
1 parent 22105c2 commit 3cda11b

File tree

5 files changed

+104
-0
lines changed

5 files changed

+104
-0
lines changed

src/rago/generation/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from rago.generation.deepseek import DeepSeekGen
88
from rago.generation.fireworks import FireworksGen
99
from rago.generation.gemini import GeminiGen
10+
from rago.generation.groq import GroqGen
1011
from rago.generation.hugging_face import HuggingFaceGen
1112
from rago.generation.hugging_face_inf import HuggingFaceInfGen
1213
from rago.generation.llama import LlamaGen
@@ -19,6 +20,7 @@
1920
'FireworksGen',
2021
'GeminiGen',
2122
'GenerationBase',
23+
'GroqGen',
2224
'HuggingFaceGen',
2325
'HuggingFaceInfGen',
2426
'LlamaGen',

src/rago/generation/groq.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Groq class for text generation."""
2+
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import instructor
8+
import openai
9+
10+
from pydantic import BaseModel
11+
from typeguard import typechecked
12+
13+
from rago.generation.base import GenerationBase
14+
15+
16+
@typechecked
17+
class GroqGen(GenerationBase):
18+
"""Groq generation model for text generation."""
19+
20+
default_model_name = 'gemma2-9b-it'
21+
default_api_params = { # noqa: RUF012
22+
'top_p': 1.0,
23+
}
24+
25+
def _setup(self) -> None:
26+
"""Set up the Groq client."""
27+
groq_api_key = self.api_key
28+
if not groq_api_key:
29+
raise Exception('GROQ_API_KEY environment variable is not set')
30+
31+
# Can use Groq client as well.
32+
groq_client = openai.OpenAI(
33+
base_url='https://api.groq.com/openai/v1', api_key=groq_api_key
34+
)
35+
36+
# Optionally use instructor if structured output is needed
37+
self.model = (
38+
instructor.from_openai(groq_client)
39+
if self.structured_output
40+
else groq_client
41+
)
42+
43+
def generate(
44+
self,
45+
query: str,
46+
context: list[str],
47+
) -> str | BaseModel:
48+
"""Generate text using the Groq AP."""
49+
input_text = self.prompt_template.format(
50+
query=query, context=' '.join(context)
51+
)
52+
53+
if not self.model:
54+
raise Exception('The model was not created.')
55+
56+
api_params = (
57+
self.api_params if self.api_params else self.default_api_params
58+
)
59+
60+
messages = []
61+
if self.system_message:
62+
messages.append({'role': 'system', 'content': self.system_message})
63+
messages.append({'role': 'user', 'content': input_text})
64+
65+
model_params = dict(
66+
model=self.model_name,
67+
messages=messages,
68+
max_completion_tokens=self.output_max_length,
69+
temperature=self.temperature,
70+
**api_params,
71+
)
72+
73+
if self.structured_output:
74+
model_params['response_model'] = self.structured_output
75+
76+
response = self.model.chat.completions.create(**model_params)
77+
self.logs['model_params'] = model_params
78+
79+
if hasattr(response, 'choices') and isinstance(response.choices, list):
80+
return cast(str, response.choices[0].message.content.strip())
81+
82+
return cast(BaseModel, response)

tests/.env.tpl

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ TOKENIZERS_PARALLELISM=false
55
COHERE_API_KEY=${COHERE_API_KEY}
66
FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
77
TOGETHER_API_KEY=${TOGETHER_API_KEY}
8+
GROQ_API_KEY=${GROQ_API_KEY}

tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,14 @@ def api_key_together(env) -> str:
9494
'Please set the TOGETHER_API_KEY environment variable.'
9595
)
9696
return key
97+
98+
99+
@pytest.fixture
100+
def api_key_groq(env) -> str:
101+
"""Fixture for GROQ API key from environment."""
102+
key = os.getenv('GROQ_API_KEY')
103+
if not key:
104+
raise EnvironmentError(
105+
'Please set the GROQ_API_KEY environment variable.'
106+
)
107+
return key

tests/test_generation.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DeepSeekGen,
1111
FireworksGen,
1212
GeminiGen,
13+
GroqGen,
1314
HuggingFaceGen,
1415
HuggingFaceInfGen,
1516
LlamaGen,
@@ -32,6 +33,7 @@
3233
CohereGen: 'api_key_cohere',
3334
FireworksGen: 'api_key_fireworks',
3435
TogetherGen: 'api_key_together',
36+
GroqGen: 'api_key_groq',
3537
}
3638

3739
gen_models = [
@@ -85,6 +87,10 @@
8587
partial(
8688
HuggingFaceInfGen,
8789
),
90+
# model 9
91+
partial(
92+
GroqGen,
93+
),
8894
]
8995

9096

@@ -98,6 +104,7 @@ def test_generation_simple_output(
98104
api_key_gemini: str,
99105
api_key_together: str,
100106
api_key_hugging_face: str,
107+
api_key_groq: str,
101108
partial_model: partial,
102109
) -> None:
103110
"""Test RAG pipeline with model generation."""
@@ -150,6 +157,7 @@ def test_generation_structure_output(
150157
api_key_gemini: str,
151158
api_key_together: str,
152159
api_key_hugging_face: str,
160+
api_key_groq: str,
153161
animals_data: list[str],
154162
question: str,
155163
partial_model: partial,

0 commit comments

Comments
 (0)