Skip to content

Commit

Permalink
Merge pull request #5 from guardrails-ai/fix_gpt4o_call
Browse files Browse the repository at this point in the history
Fix gpt4o call
  • Loading branch information
wylansford authored May 31, 2024
2 parents 8f49489 + e19f39c commit 5049bd3
Showing 1 changed file with 31 additions and 82 deletions.
113 changes: 31 additions & 82 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,56 +138,8 @@ def __init__(
# TODO api endpoint
...

self._json_schema, self._tools = self._create_json_schema(
self._valid_topics, self._invalid_topics
)

def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str:
"""Creates a json schema that an LLM will fill out. The json schema contains
one of each of the provided topics, as well as a blank 'present' and 'confidence'
for the llm to fill in.
Args:
valid_topics (list): The valid topics to provide as one of the json schema
invalid_topics (list): Invalid topics to provide as one of the json schema

Returns:
str: The resulting json schema with unfilled data types
"""
tools = [
{
"type": "function",
"function": {
"name": "is_topic_relevant",
"description": "Determine if the provided text is about a topic, with a confidence score.",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Simply the repeated name of the given topic.",
},
"present": {
"type": "boolean",
"description": "If the given topic is discussed in the given text.",
},
"confidence": {
"type": "integer",
"description": "The confidence level of the topic being present in the text, from 1-5",
},
},
"required": ["name", "present", "confidence"],
},
},
},
]

json_schema = []
for topic in set(valid_topics + invalid_topics):
json_schema.append({"topic": topic})
return json_schema, tools

def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]:
def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]:
"""Finds the topics in the input text based on if it is determined by the zero
shot model or the llm.
Expand All @@ -196,7 +148,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st
candidate_topics (List[str]): The topics to search for in the input text
Returns:
list[str]: The found topics
List[str]: The found topics
"""
# Find topics based on zero shot model
zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics)
Expand All @@ -206,7 +158,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st

return list(set(zero_shot_topics + llm_topics))

def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]:
def get_topics_llm(self, text: str, candidate_topics: List[str]) -> List[str]:
"""Returns a list of the topics identified in the given text using an LLM
callable
Expand All @@ -215,15 +167,13 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]:
candidate_topics (List[str]): The topics to identify if present in the text.
Returns:
list[str]: The topics found in the input text.
List[str]: The topics found in the input text.
"""
topics = self.call_llm(text)
llm_topics = self.call_llm(text, candidate_topics)
found_topics = []
for llm_result in topics:
if llm_result["present"] and llm_result["confidence"] > self._llm_threshold:
# Verify the llm didn't hallucinate a topic.
if llm_result["name"] in candidate_topics:
found_topics.append(llm_result["name"])
for llm_topic in llm_topics:
if llm_topic in candidate_topics:
found_topics.append(llm_topic)
return found_topics

def get_client_args(self) -> str:
Expand Down Expand Up @@ -252,7 +202,7 @@ def get_client_args(self) -> str:
stop=stop_after_attempt(5),
reraise=True,
)
def call_llm(self, text: str) -> str:
def call_llm(self, text: str, topics: List[str]) -> str:
"""Call the LLM with the given prompt.
Expects a function that takes a string and returns a string.
Expand All @@ -262,7 +212,7 @@ def call_llm(self, text: str) -> str:
Returns:
response (str): String representing the LLM response.
"""
return self._llm_callable(text)
return self._llm_callable(text, topics)

def set_callable(self, llm_callable: Union[str, Callable, None]) -> None:
"""Set the LLM callable.
Expand All @@ -273,17 +223,17 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None:
"""

if llm_callable is None:
llm_callable = "gpt-3.5-turbo"
llm_callable = "gpt-4o"

if isinstance(llm_callable, str):
if llm_callable not in ["gpt-3.5-turbo", "gpt-4"]:
if llm_callable not in ["gpt-3.5-turbo", "gpt-4", "gpt-4o"]:
raise ValueError(
"llm_callable must be one of 'gpt-3.5-turbo' or 'gpt-4'."
"llm_callable must be one of 'gpt-3.5-turbo', 'gpt-4', or 'gpt-4o'"
"If you want to use a custom LLM, please provide a callable."
"Check out ProvenanceV1 documentation for an example."
)

def openai_callable(text: str) -> str:
def openai_callable(text: str, topics: List[str]) -> str:
api_key = self.get_client_args()
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
Expand All @@ -292,44 +242,43 @@ def openai_callable(text: str) -> str:
messages=[
{
"role": "user",
"content": f"""Given a series of topics, determine if the topic is present in the provided text. Return the result as json.
Text
----
{text}
"content": f"""
Given a text and a list of topics, return a valid json list of which topics are present in the text. If none, just return an empty list.
Output Format:
-------------
"topics_present": []
Schema
------
{self._json_schema}
Text:
----
"{text}"
Complete Schema
---------------
Topics:
------
{topics}
""",
Result:
------ """,
},
],
tools=self._tools,
)
tool_calls = []
for tool_call in response.choices[0].message.tool_calls:
tool_calls.append(json.loads(tool_call.function.arguments))
return tool_calls
return json.loads(response.choices[0].message.content)["topics_present"]

self._llm_callable = openai_callable
elif isinstance(llm_callable, Callable):
self._llm_callable = llm_callable
else:
raise ValueError("llm_callable must be a string or a Callable")

def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]:
def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> List[str]:
"""Gets the topics found through the zero shot classifier
Args:
text (str): The text to classify
candidate_topics (List[str]): The potential topics to look for
Returns:
list[str]: The resulting topics found that meet the given threshold
List[str]: The resulting topics found that meet the given threshold
"""
result = self._classifier(text, candidate_topics)
topics = result["labels"]
Expand Down

0 comments on commit 5049bd3

Please sign in to comment.