-
Notifications
You must be signed in to change notification settings - Fork 0
/
gpt_call.py
49 lines (42 loc) · 1.11 KB
/
gpt_call.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from groq import Groq
from openai import OpenAI
from dotenv import load_dotenv
import os
import json
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
openai_api_key = os.getenv('OPENAI_API_KEY')
client_groq = Groq(
api_key= groq_api_key,
)
client_openai = OpenAI(
api_key=openai_api_key,
)
def groq_call(prompt):
chat_completion = client_groq.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model="llama3-8b-8192",
)
return chat_completion.choices[0].message.content
def gpt_call(prompt):
chat_completion = client_openai.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model="gpt-40",
)
return chat_completion.choices[0].message.content
def generate_variation(prompt, model='groq'):
"""Generates a variation of the given prompt using groq/GPT-4."""
if model == 'groq':
return groq_call(prompt)
elif model == 'gpt':
return gpt_call(prompt)