-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2bd067
commit 8e03bc0
Showing
12 changed files
with
261 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,104 @@ | ||
import openai | ||
openai.api_key = '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' | ||
import json | ||
import yaml | ||
|
||
# Load the configuration | ||
with open('config.yml', 'r') as f: | ||
config = yaml.safe_load(f) | ||
|
||
# Set the wandb_enabled flag | ||
wandb_enabled = config['wandb_enabled'] | ||
openai.api_key = config['openai_api_key'] | ||
|
||
# Check if the API key works | ||
try: | ||
openai.Model.list() | ||
except openai.error.AuthenticationError: | ||
raise ValueError("Invalid OpenAI API key") | ||
|
||
from wandb.sdk.data_types.trace_tree import Trace | ||
import datetime | ||
import helpers as h | ||
import globals | ||
import yaml | ||
|
||
class AI: | ||
def __init__(self, system="", model = 'gpt-4', openai=openai): | ||
self.system = system | ||
def __init__(self, module_name, model = 'gpt-4', temperature=0.7, openai=openai): | ||
self.model = model | ||
self.openai = openai | ||
self.messages = [{"role": "system", "content": system}] | ||
self.temperature = temperature | ||
self.module_name = module_name | ||
self.system = h.load_system_prompt(module_name) | ||
self.messages = [{"role": "system", "content": self.system}] | ||
|
||
|
||
def generate_response(self, prompt): | ||
self.messages.append({"role": "user", "content": prompt}) | ||
response = self.openai.ChatCompletion.create( | ||
model=self.model, | ||
stream=True, | ||
messages=self.messages, | ||
) | ||
|
||
chat = [] | ||
for chunk in response: | ||
delta = chunk["choices"][0]["delta"] | ||
msg = delta.get("content", "") | ||
print(msg, end="") | ||
chat.append(msg) | ||
try: | ||
response = self.openai.ChatCompletion.create( | ||
model=self.model, | ||
stream=True, | ||
messages=self.messages, | ||
temperature=self.temperature, | ||
) | ||
|
||
chat = [] | ||
token_count = 0 | ||
for chunk in response: | ||
delta = chunk["choices"][0]["delta"] | ||
msg = delta.get("content", "") | ||
print(msg, end="") | ||
chat.append(msg) | ||
token_count += len(msg.split()) # estimate token usage | ||
|
||
print() | ||
|
||
response_text = "".join(chat) | ||
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds | ||
status_code="success" | ||
status_message=None, | ||
token_usage = {"total_tokens": token_count} | ||
|
||
except Exception as e: | ||
llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds | ||
status_code="error" | ||
status_message=str(e) | ||
response_text = "" | ||
token_usage = {} | ||
|
||
if wandb_enabled: | ||
# calculate the runtime of the LLM | ||
runtime = llm_end_time_ms - globals.agent_start_time_ms | ||
# create a child span in wandb | ||
llm_span = Trace( | ||
name="child_span", | ||
kind="llm", # kind can be "llm", "chain", "agent" or "tool" | ||
status_code=status_code, | ||
status_message=status_message, | ||
metadata={"temperature": self.temperature, | ||
"token_usage": token_usage, | ||
"runtime_ms": runtime, | ||
"module_name": self.module_name, | ||
"model_name": self.model}, | ||
start_time_ms=globals.agent_start_time_ms, | ||
end_time_ms=llm_end_time_ms, | ||
inputs={"system_prompt": self.system, "query": prompt}, | ||
outputs={"response": response_text} | ||
) | ||
|
||
# add the child span to the root span | ||
globals.chain_span.add_child(llm_span) | ||
|
||
print() | ||
# update the end time of the Chain span | ||
globals.chain_span.add_inputs_and_outputs( | ||
inputs={"query": prompt}, | ||
outputs={"response": response_text}) | ||
|
||
response_text = "".join(chat) | ||
# update the Chain span's end time | ||
globals.chain_span._span.end_time_ms = llm_end_time_ms | ||
|
||
# log the child span to wandb | ||
llm_span.log(name="pipeline_trace") | ||
|
||
self.messages.append({"role": "assistant", "content": response_text}) | ||
return response_text, self.messages | ||
|
||
def generate_image(self, prompt, n=1, size="1024x1024", response_format="url"): | ||
"""Generate an image using DALL·E given a prompt. | ||
Arguments: | ||
prompt (str): A text description of the desired image(s). | ||
n (int, optional): The number of images to generate. Defaults to 1. | ||
size (str, optional): The size of the generated images. Defaults to "1024x1024". | ||
response_format (str, optional): The format in which the generated images are returned. Defaults to "url". | ||
Returns: | ||
dict: The response from the OpenAI API. | ||
""" | ||
return openai.Image.create(prompt=prompt, n=n, size=size, response_format=response_format) | ||
return response_text, self.messages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
openai_api_key: 'OPENAI_API_KEY' | ||
wandb_enabled: false | ||
pipeline: "engineering_pipeline" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.