From 6b67d173f034333b4e34acc9068fd124545c7981 Mon Sep 17 00:00:00 2001
From: Ian
Date: Sat, 11 Nov 2023 21:54:18 +0800
Subject: [PATCH] Introducing Experimental GPT Assistant Agent in AutoGen
(#616)
* add gpt assistant agent
* complete code
* Inherit class ConversableAgent
* format code
* add code comments
* add test case
* format code
* fix test
* format code
* Improve GPTAssistant
* Use OpenAIWrapper to create client
* Implement clear_history()
* Reply message formatting improvements
* Handle the case when content contains image files
* README update
* Fix doc string of methods
* add multiple conversations support
* Add GPT Assistant Agent into README
* fix test
---------
Co-authored-by: gagb
Co-authored-by: Beibin Li
---
README.md | 4 +-
.../agentchat/contrib/gpt_assistant_agent.py | 302 ++++++++++++++++++
test/agentchat/contrib/test_gpt_assistant.py | 66 ++++
3 files changed, 371 insertions(+), 1 deletion(-)
create mode 100644 autogen/agentchat/contrib/gpt_assistant_agent.py
create mode 100644 test/agentchat/contrib/test_gpt_assistant.py
diff --git a/README.md b/README.md
index 8bb66f7ce8f5..089f3f7ab872 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,9 @@ This project is a spinoff from [FLAML](https://github.com/microsoft/FLAML).
-->
-:fire: Nov 8: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff.
+:fire: Nov 11: AutoGen experimentally supports OpenAI's Assistants! Checkout the [GPT Assistant Agent](autogen/agentchat/contrib/gpt_assistant_agent.py).
+
+:fire: Nov 8: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 40 days after release.
:fire: Nov 6: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U) around 13:20.
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
new file mode 100644
index 000000000000..50f7d4799a46
--- /dev/null
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -0,0 +1,302 @@
+from collections import defaultdict
+import openai
+import json
+import time
+import logging
+
+from autogen import OpenAIWrapper
+from autogen.agentchat.agent import Agent
+from autogen.agentchat.assistant_agent import ConversableAgent
+from typing import Dict, Optional, Union, List, Tuple, Any
+
+logger = logging.getLogger(__name__)
+
+
+class GPTAssistantAgent(ConversableAgent):
+ """
+ An experimental AutoGen agent class that leverages the OpenAI Assistant API for conversational capabilities.
+ This agent is unique in its reliance on the OpenAI Assistant for state management, differing from other agents like ConversableAgent.
+ """
+
+ def __init__(
+ self,
+ name="GPT Assistant",
+ instructions: Optional[str] = "You are a helpful GPT Assistant.",
+ llm_config: Optional[Union[Dict, bool]] = None,
+ ):
+ """
+ Args:
+ name (str): name of the agent.
+ instructions (str): instructions for the OpenAI assistant configuration.
+ llm_config (dict or False): llm inference configuration.
+ - model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
+ - check_every_ms: check thread run status interval
+ - tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
+ or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
+ - file_ids: files used by retrieval in run
+ """
+ super().__init__(
+ name=name,
+ system_message=instructions,
+ human_input_mode="NEVER",
+ llm_config=llm_config,
+ )
+
+ # Use AutoGen OpenAIWrapper to create a client
+ oai_wrapper = OpenAIWrapper(**self.llm_config)
+ if len(oai_wrapper._clients) > 1:
+ logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
+ self._openai_client = oai_wrapper._clients[0]
+
+ openai_assistant_id = llm_config.get("assistant_id", None)
+ if openai_assistant_id is None:
+ # create a new assistant
+ self._openai_assistant = self._openai_client.beta.assistants.create(
+ name=name,
+ instructions=instructions,
+ tools=self.llm_config.get("tools", []),
+ model=self.llm_config.get("model", "gpt-4-1106-preview"),
+ )
+ else:
+ # retrieve an existing assistant
+ self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
+
+ # lazly create thread
+ self._openai_threads = {}
+ self._unread_index = defaultdict(int)
+ self.register_reply(Agent, GPTAssistantAgent._invoke_assistant)
+
+ def _invoke_assistant(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """
+ Invokes the OpenAI assistant to generate a reply based on the given messages.
+
+ Args:
+ messages: A list of messages in the conversation history with the sender.
+ sender: The agent instance that sent the message.
+ config: Optional configuration for message processing.
+
+ Returns:
+ A tuple containing a boolean indicating success and the assistant's reply.
+ """
+
+ if messages is None:
+ messages = self._oai_messages[sender]
+ unread_index = self._unread_index[sender] or 0
+ pending_messages = messages[unread_index:]
+
+ # Check and initiate a new thread if necessary
+ if self._openai_threads.get(sender, None) is None:
+ self._openai_threads[sender] = self._openai_client.beta.threads.create(
+ messages=[],
+ )
+ assistant_thread = self._openai_threads[sender]
+ # Process each unread message
+ for message in pending_messages:
+ self._openai_client.beta.threads.messages.create(
+ thread_id=assistant_thread.id,
+ content=message["content"],
+ role=message["role"],
+ )
+
+ # Create a new run to get responses from the assistant
+ run = self._openai_client.beta.threads.runs.create(
+ thread_id=assistant_thread.id,
+ assistant_id=self._openai_assistant.id,
+ )
+
+ run_response_messages = self._get_run_response(assistant_thread, run)
+ assert len(run_response_messages) > 0, "No response from the assistant."
+
+ response = {
+ "role": run_response_messages[-1]["role"],
+ "content": "",
+ }
+ for message in run_response_messages:
+ # just logging or do something with the intermediate messages?
+ # if current response is not empty and there is more, append new lines
+ if len(response["content"]) > 0:
+ response["content"] += "\n\n"
+ response["content"] += message["content"]
+
+ self._unread_index[sender] = len(self._oai_messages[sender]) + 1
+ return True, response
+
+ def _get_run_response(self, thread, run):
+ """
+ Waits for and processes the response of a run from the OpenAI assistant.
+
+ Args:
+ run: The run object initiated with the OpenAI assistant.
+
+ Returns:
+ Updated run object, status of the run, and response messages.
+ """
+ while True:
+ run = self._wait_for_run(run.id, thread.id)
+ if run.status == "completed":
+ response_messages = self._openai_client.beta.threads.messages.list(thread.id, order="asc")
+
+ new_messages = []
+ for msg in response_messages:
+ if msg.run_id == run.id:
+ for content in msg.content:
+ if content.type == "text":
+ new_messages.append(
+ {"role": msg.role, "content": self._format_assistant_message(content.text)}
+ )
+ elif content.type == "image_file":
+ new_messages.append(
+ {
+ "role": msg.role,
+ "content": f"Recieved file id={content.image_file.file_id}",
+ }
+ )
+ return new_messages
+ elif run.status == "requires_action":
+ actions = []
+ for tool_call in run.required_action.submit_tool_outputs.tool_calls:
+ function = tool_call.function
+ is_exec_success, tool_response = self.execute_function(function.dict())
+ tool_response["metadata"] = {
+ "tool_call_id": tool_call.id,
+ "run_id": run.id,
+ "thread_id": thread.id,
+ }
+
+ logger.info(
+ "Intermediate executing(%s, Sucess: %s) : %s",
+ tool_response["name"],
+ is_exec_success,
+ tool_response["content"],
+ )
+ actions.append(tool_response)
+
+ submit_tool_outputs = {
+ "tool_outputs": [
+ {"output": action["content"], "tool_call_id": action["metadata"]["tool_call_id"]}
+ for action in actions
+ ],
+ "run_id": run.id,
+ "thread_id": thread.id,
+ }
+
+ run = self._openai_client.beta.threads.runs.submit_tool_outputs(**submit_tool_outputs)
+ else:
+ run_info = json.dumps(run.dict(), indent=2)
+ raise ValueError(f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})")
+
+ def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
+ """
+ Waits for a run to complete or reach a final state.
+
+ Args:
+ run_id: The ID of the run.
+ thread_id: The ID of the thread associated with the run.
+
+ Returns:
+ The updated run object after completion or reaching a final state.
+ """
+ in_progress = True
+ while in_progress:
+ run = self._openai_client.beta.threads.runs.retrieve(run_id, thread_id=thread_id)
+ in_progress = run.status in ("in_progress", "queued")
+ if in_progress:
+ time.sleep(self.llm_config.get("check_every_ms", 1000) / 1000)
+ return run
+
+ def _format_assistant_message(self, message_content):
+ """
+ Formats the assistant's message to include annotations and citations.
+ """
+
+ annotations = message_content.annotations
+ citations = []
+
+ # Iterate over the annotations and add footnotes
+ for index, annotation in enumerate(annotations):
+ # Replace the text with a footnote
+ message_content.value = message_content.value.replace(annotation.text, f" [{index}]")
+
+ # Gather citations based on annotation attributes
+ if file_citation := getattr(annotation, "file_citation", None):
+ try:
+ cited_file = self._openai_client.files.retrieve(file_citation.file_id)
+ citations.append(f"[{index}] {cited_file.filename}: {file_citation.quote}")
+ except Exception as e:
+ logger.error(f"Error retrieving file citation: {e}")
+ elif file_path := getattr(annotation, "file_path", None):
+ try:
+ cited_file = self._openai_client.files.retrieve(file_path.file_id)
+ citations.append(f"[{index}] Click to download {cited_file.filename}")
+ except Exception as e:
+ logger.error(f"Error retrieving file citation: {e}")
+ # Note: File download functionality not implemented above for brevity
+
+ # Add footnotes to the end of the message before displaying to user
+ message_content.value += "\n" + "\n".join(citations)
+ return message_content.value
+
+ def can_execute_function(self, name: str) -> bool:
+ """Whether the agent can execute the function."""
+ return False
+
+ def reset(self):
+ """
+ Resets the agent, clearing any existing conversation thread and unread message indices.
+ """
+ super().reset()
+ for thread in self._openai_threads.values():
+ # Delete the existing thread to start fresh in the next conversation
+ self._openai_client.beta.threads.delete(thread.id)
+ self._openai_threads = {}
+ # Clear the record of unread messages
+ self._unread_index.clear()
+
+ def clear_history(self, agent: Optional[Agent] = None):
+ """Clear the chat history of the agent.
+
+ Args:
+ agent: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
+ """
+ super().clear_history(agent)
+ if self._openai_threads.get(agent, None) is not None:
+ # Delete the existing thread to start fresh in the next conversation
+ thread = self._openai_threads[agent]
+ logger.info("Clearing thread %s", thread.id)
+ self._openai_client.beta.threads.delete(thread.id)
+ self._openai_threads.pop(agent)
+ self._unread_index[agent] = 0
+
+ def pretty_print_thread(self, thread):
+ """Pretty print the thread."""
+ if thread is None:
+ print("No thread to print")
+ return
+ # NOTE: that list may not be in order, sorting by created_at is important
+ messages = self._openai_client.beta.threads.messages.list(
+ thread_id=thread.id,
+ )
+ messages = sorted(messages.data, key=lambda x: x.created_at)
+ print("~~~~~~~THREAD CONTENTS~~~~~~~")
+ for message in messages:
+ content_types = [content.type for content in message.content]
+ print(f"[{message.created_at}]", message.role, ": [", ", ".join(content_types), "]")
+ for content in message.content:
+ content_type = content.type
+ if content_type == "text":
+ print(content.type, ": ", content.text.value)
+ elif content_type == "image_file":
+ print(content.type, ": ", content.image_file.file_id)
+ else:
+ print(content.type, ": ", content)
+ print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
+
+ @property
+ def oai_threads(self) -> Dict[Agent, Any]:
+ """Return the threads of the agent."""
+ return self._openai_threads
diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py
new file mode 100644
index 000000000000..bf8d2e227b7f
--- /dev/null
+++ b/test/agentchat/contrib/test_gpt_assistant.py
@@ -0,0 +1,66 @@
+import pytest
+import os
+import sys
+import autogen
+
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
+
+try:
+ from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
+
+ skip_test = False
+except ImportError:
+ skip_test = True
+
+
+def ask_ossinsight(question):
+ return f"That is a good question, but I don't know the answer yet. Please ask your human developer friend to help you. \n\n{question}"
+
+
+@pytest.mark.skipif(
+ sys.platform in ["darwin", "win32"] or skip_test,
+ reason="do not run on MacOS or windows or dependency is not installed",
+)
+def test_gpt_assistant_chat():
+ ossinsight_api_schema = {
+ "name": "ossinsight_data_api",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "question": {
+ "type": "string",
+ "description": "Enter your GitHub data question in the form of a clear and specific question to ensure the returned data is accurate and valuable. For optimal results, specify the desired format for the data table in your request.",
+ }
+ },
+ "required": ["question"],
+ },
+ "description": "This is an API endpoint allowing users (analysts) to input question about GitHub in text format to retrieve the realted and structured data.",
+ }
+
+ analyst = GPTAssistantAgent(
+ name="Open_Source_Project_Analyst",
+ llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}]},
+ instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
+ )
+ analyst.register_function(
+ function_map={
+ "ossinsight_data_api": ask_ossinsight,
+ }
+ )
+
+ ok, response = analyst._invoke_assistant(
+ [{"role": "user", "content": "What is the most popular open source project on GitHub?"}]
+ )
+ assert ok is True
+ assert response.get("role", "") == "assistant"
+ assert len(response.get("content", "")) > 0
+
+ assert analyst.can_execute_function("ossinsight_data_api") is False
+
+ analyst.reset()
+ assert len(analyst._openai_threads) == 0
+
+
+if __name__ == "__main__":
+ test_gpt_assistant_chat()