diff --git a/validator/main.py b/validator/main.py index dcbfac5..ae52a83 100644 --- a/validator/main.py +++ b/validator/main.py @@ -138,56 +138,8 @@ def __init__( # TODO api endpoint ... - self._json_schema, self._tools = self._create_json_schema( - self._valid_topics, self._invalid_topics - ) - - def _create_json_schema(self, valid_topics: list, invalid_topics: list) -> str: - """Creates a json schema that an LLM will fill out. The json schema contains - one of each of the provided topics, as well as a blank 'present' and 'confidence' - for the llm to fill in. - - Args: - valid_topics (list): The valid topics to provide as one of the json schema - invalid_topics (list): Invalid topics to provide as one of the json schema - Returns: - str: The resulting json schema with unfilled data types - """ - tools = [ - { - "type": "function", - "function": { - "name": "is_topic_relevant", - "description": "Determine if the provided text is about a topic, with a confidence score.", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Simply the repeated name of the given topic.", - }, - "present": { - "type": "boolean", - "description": "If the given topic is discussed in the given text.", - }, - "confidence": { - "type": "integer", - "description": "The confidence level of the topic being present in the text, from 1-5", - }, - }, - "required": ["name", "present", "confidence"], - }, - }, - }, - ] - - json_schema = [] - for topic in set(valid_topics + invalid_topics): - json_schema.append({"topic": topic}) - return json_schema, tools - - def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> List[str]: """Finds the topics in the input text based on if it is determined by the zero shot model or the llm. @@ -196,7 +148,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st candidate_topics (List[str]): The topics to search for in the input text Returns: - list[str]: The found topics + List[str]: The found topics """ # Find topics based on zero shot model zero_shot_topics = self.get_topics_zero_shot(text, candidate_topics) @@ -206,7 +158,7 @@ def get_topics_ensemble(self, text: str, candidate_topics: List[str]) -> list[st return list(set(zero_shot_topics + llm_topics)) - def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_llm(self, text: str, candidate_topics: List[str]) -> List[str]: """Returns a list of the topics identified in the given text using an LLM callable @@ -215,15 +167,13 @@ def get_topics_llm(self, text: str, candidate_topics: List[str]) -> list[str]: candidate_topics (List[str]): The topics to identify if present in the text. Returns: - list[str]: The topics found in the input text. + List[str]: The topics found in the input text. """ - topics = self.call_llm(text) + llm_topics = self.call_llm(text, candidate_topics) found_topics = [] - for llm_result in topics: - if llm_result["present"] and llm_result["confidence"] > self._llm_threshold: - # Verify the llm didn't hallucinate a topic. - if llm_result["name"] in candidate_topics: - found_topics.append(llm_result["name"]) + for llm_topic in llm_topics: + if llm_topic in candidate_topics: + found_topics.append(llm_topic) return found_topics def get_client_args(self) -> str: @@ -252,7 +202,7 @@ def get_client_args(self) -> str: stop=stop_after_attempt(5), reraise=True, ) - def call_llm(self, text: str) -> str: + def call_llm(self, text: str, topics: List[str]) -> str: """Call the LLM with the given prompt. Expects a function that takes a string and returns a string. @@ -262,7 +212,7 @@ def call_llm(self, text: str) -> str: Returns: response (str): String representing the LLM response. """ - return self._llm_callable(text) + return self._llm_callable(text, topics) def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: """Set the LLM callable. @@ -273,17 +223,17 @@ def set_callable(self, llm_callable: Union[str, Callable, None]) -> None: """ if llm_callable is None: - llm_callable = "gpt-3.5-turbo" + llm_callable = "gpt-4o" if isinstance(llm_callable, str): - if llm_callable not in ["gpt-3.5-turbo", "gpt-4"]: + if llm_callable not in ["gpt-3.5-turbo", "gpt-4", "gpt-4o"]: raise ValueError( - "llm_callable must be one of 'gpt-3.5-turbo' or 'gpt-4'." + "llm_callable must be one of 'gpt-3.5-turbo', 'gpt-4', or 'gpt-4o'" "If you want to use a custom LLM, please provide a callable." "Check out ProvenanceV1 documentation for an example." ) - def openai_callable(text: str) -> str: + def openai_callable(text: str, topics: List[str]) -> str: api_key = self.get_client_args() client = OpenAI(api_key=api_key) response = client.chat.completions.create( @@ -292,28 +242,27 @@ def openai_callable(text: str) -> str: messages=[ { "role": "user", - "content": f"""Given a series of topics, determine if the topic is present in the provided text. Return the result as json. - - Text - ---- - {text} + "content": f""" + Given a text and a list of topics, return a valid json list of which topics are present in the text. If none, just return an empty list. + + Output Format: + ------------- + "topics_present": [] - Schema - ------ - {self._json_schema} + Text: + ---- + "{text}" - Complete Schema - --------------- + Topics: + ------ + {topics} - """, + Result: + ------ """, }, ], - tools=self._tools, ) - tool_calls = [] - for tool_call in response.choices[0].message.tool_calls: - tool_calls.append(json.loads(tool_call.function.arguments)) - return tool_calls + return json.loads(response.choices[0].message.content)["topics_present"] self._llm_callable = openai_callable elif isinstance(llm_callable, Callable): @@ -321,7 +270,7 @@ def openai_callable(text: str) -> str: else: raise ValueError("llm_callable must be a string or a Callable") - def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[str]: + def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> List[str]: """Gets the topics found through the zero shot classifier Args: @@ -329,7 +278,7 @@ def get_topics_zero_shot(self, text: str, candidate_topics: List[str]) -> list[s candidate_topics (List[str]): The potential topics to look for Returns: - list[str]: The resulting topics found that meet the given threshold + List[str]: The resulting topics found that meet the given threshold """ result = self._classifier(text, candidate_topics) topics = result["labels"]