|
| 1 | +"""Groq class for text generation.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import cast |
| 6 | + |
| 7 | +import instructor |
| 8 | +import openai |
| 9 | + |
| 10 | +from pydantic import BaseModel |
| 11 | +from typeguard import typechecked |
| 12 | + |
| 13 | +from rago.generation.base import GenerationBase |
| 14 | + |
| 15 | + |
| 16 | +@typechecked |
| 17 | +class GroqGen(GenerationBase): |
| 18 | + """Groq generation model for text generation.""" |
| 19 | + |
| 20 | + default_model_name = 'gemma2-9b-it' |
| 21 | + default_api_params = { # noqa: RUF012 |
| 22 | + 'top_p': 1.0, |
| 23 | + } |
| 24 | + |
| 25 | + def _setup(self) -> None: |
| 26 | + """Set up the Groq client.""" |
| 27 | + groq_api_key = self.api_key |
| 28 | + if not groq_api_key: |
| 29 | + raise Exception('GROQ_API_KEY environment variable is not set') |
| 30 | + |
| 31 | + # Can use Groq client as well. |
| 32 | + groq_client = openai.OpenAI( |
| 33 | + base_url='https://api.groq.com/openai/v1', api_key=groq_api_key |
| 34 | + ) |
| 35 | + |
| 36 | + # Optionally use instructor if structured output is needed |
| 37 | + self.model = ( |
| 38 | + instructor.from_openai(groq_client) |
| 39 | + if self.structured_output |
| 40 | + else groq_client |
| 41 | + ) |
| 42 | + |
| 43 | + def generate( |
| 44 | + self, |
| 45 | + query: str, |
| 46 | + context: list[str], |
| 47 | + ) -> str | BaseModel: |
| 48 | + """Generate text using the Groq AP.""" |
| 49 | + input_text = self.prompt_template.format( |
| 50 | + query=query, context=' '.join(context) |
| 51 | + ) |
| 52 | + |
| 53 | + if not self.model: |
| 54 | + raise Exception('The model was not created.') |
| 55 | + |
| 56 | + api_params = ( |
| 57 | + self.api_params if self.api_params else self.default_api_params |
| 58 | + ) |
| 59 | + |
| 60 | + messages = [] |
| 61 | + if self.system_message: |
| 62 | + messages.append({'role': 'system', 'content': self.system_message}) |
| 63 | + messages.append({'role': 'user', 'content': input_text}) |
| 64 | + |
| 65 | + model_params = dict( |
| 66 | + model=self.model_name, |
| 67 | + messages=messages, |
| 68 | + max_completion_tokens=self.output_max_length, |
| 69 | + temperature=self.temperature, |
| 70 | + **api_params, |
| 71 | + ) |
| 72 | + |
| 73 | + if self.structured_output: |
| 74 | + model_params['response_model'] = self.structured_output |
| 75 | + |
| 76 | + response = self.model.chat.completions.create(**model_params) |
| 77 | + self.logs['model_params'] = model_params |
| 78 | + |
| 79 | + if hasattr(response, 'choices') and isinstance(response.choices, list): |
| 80 | + return cast(str, response.choices[0].message.content.strip()) |
| 81 | + |
| 82 | + return cast(BaseModel, response) |
0 commit comments