Skip to content

Commit df2cd36

Browse files
gagbsonichiyiranwu0
authored
Refactor GPTAssistantAgent (#632)
* Refactor GPTAssistantAgent constructor to handle instructions and overwrite_instructions flag - Ensure that `system_message` is always consistent with `instructions` - Ensure provided instructions are always used - Add option to permanently modify the instructions of the assistant * Improve default behavior * Add a test; add method to delete assistant * Add a new test for overwriting instructions * Add test case for when no instructions are given for existing assistant * Add pytest markers to test_gpt_assistant.py * add test in workflow * update * fix test_client_stream * comment out test_hierarchy_ --------- Co-authored-by: Chi Wang <[email protected]> Co-authored-by: kevin666aa <[email protected]>
1 parent ff41489 commit df2cd36

File tree

6 files changed

+246
-23
lines changed

6 files changed

+246
-23
lines changed

.github/workflows/contrib-openai.yml

+41
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,44 @@ jobs:
9797
with:
9898
file: ./coverage.xml
9999
flags: unittests
100+
GPTAssistantAgent:
101+
strategy:
102+
matrix:
103+
os: [ubuntu-latest]
104+
python-version: ["3.11"]
105+
runs-on: ${{ matrix.os }}
106+
environment: openai1
107+
steps:
108+
# checkout to pr branch
109+
- name: Checkout
110+
uses: actions/checkout@v3
111+
with:
112+
ref: ${{ github.event.pull_request.head.sha }}
113+
- name: Set up Python ${{ matrix.python-version }}
114+
uses: actions/setup-python@v4
115+
with:
116+
python-version: ${{ matrix.python-version }}
117+
- name: Install packages and dependencies
118+
run: |
119+
docker --version
120+
python -m pip install --upgrade pip wheel
121+
pip install -e .
122+
python -c "import autogen"
123+
pip install coverage pytest-asyncio
124+
- name: Install packages for test when needed
125+
run: |
126+
pip install docker
127+
- name: Coverage
128+
env:
129+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
130+
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
131+
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
132+
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
133+
run: |
134+
coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py
135+
coverage xml
136+
- name: Upload coverage to Codecov
137+
uses: codecov/codecov-action@v3
138+
with:
139+
file: ./coverage.xml
140+
flags: unittests

.github/workflows/contrib-tests.yml

+26
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ jobs:
8282
if: matrix.python-version != '3.10'
8383
run: |
8484
pytest test/agentchat/contrib/test_compressible_agent.py
85+
86+
GPTAssistantAgent:
87+
runs-on: ${{ matrix.os }}
88+
strategy:
89+
fail-fast: false
90+
matrix:
91+
os: [ubuntu-latest, macos-latest, windows-2019]
92+
python-version: ["3.8", "3.9", "3.10", "3.11"]
93+
steps:
94+
- uses: actions/checkout@v3
95+
- name: Set up Python ${{ matrix.python-version }}
96+
uses: actions/setup-python@v4
97+
with:
98+
python-version: ${{ matrix.python-version }}
99+
- name: Install packages and dependencies for all tests
100+
run: |
101+
python -m pip install --upgrade pip wheel
102+
pip install pytest
103+
- name: Install packages and dependencies for GPTAssistantAgent
104+
run: |
105+
pip install -e .
106+
pip uninstall -y openai
107+
- name: Test GPTAssistantAgent
108+
if: matrix.python-version != '3.10'
109+
run: |
110+
pytest test/agentchat/contrib/test_gpt_assistant.py

autogen/agentchat/contrib/gpt_assistant_agent.py

+59-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from autogen import OpenAIWrapper
88
from autogen.agentchat.agent import Agent
99
from autogen.agentchat.assistant_agent import ConversableAgent
10+
from autogen.agentchat.assistant_agent import AssistantAgent
1011
from typing import Dict, Optional, Union, List, Tuple, Any
1112

1213
logger = logging.getLogger(__name__)
@@ -21,45 +22,76 @@ class GPTAssistantAgent(ConversableAgent):
2122
def __init__(
2223
self,
2324
name="GPT Assistant",
24-
instructions: Optional[str] = "You are a helpful GPT Assistant.",
25+
instructions: Optional[str] = None,
2526
llm_config: Optional[Union[Dict, bool]] = None,
27+
overwrite_instructions: bool = False,
2628
):
2729
"""
2830
Args:
2931
name (str): name of the agent.
3032
instructions (str): instructions for the OpenAI assistant configuration.
33+
When instructions is not None, the system message of the agent will be
34+
set to the provided instructions and used in the assistant run, irrespective
35+
of the overwrite_instructions flag. But when instructions is None,
36+
and the assistant does not exist, the system message will be set to
37+
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
38+
system message will be set to the existing assistant instructions.
3139
llm_config (dict or False): llm inference configuration.
40+
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
3241
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
3342
- check_every_ms: check thread run status interval
3443
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
3544
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
3645
- file_ids: files used by retrieval in run
46+
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant.
3747
"""
38-
super().__init__(
39-
name=name,
40-
system_message=instructions,
41-
human_input_mode="NEVER",
42-
llm_config=llm_config,
43-
)
44-
4548
# Use AutoGen OpenAIWrapper to create a client
46-
oai_wrapper = OpenAIWrapper(**self.llm_config)
49+
oai_wrapper = OpenAIWrapper(**llm_config)
4750
if len(oai_wrapper._clients) > 1:
4851
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
4952
self._openai_client = oai_wrapper._clients[0]
50-
5153
openai_assistant_id = llm_config.get("assistant_id", None)
5254
if openai_assistant_id is None:
5355
# create a new assistant
56+
if instructions is None:
57+
logger.warning(
58+
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
59+
)
60+
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
5461
self._openai_assistant = self._openai_client.beta.assistants.create(
5562
name=name,
5663
instructions=instructions,
57-
tools=self.llm_config.get("tools", []),
58-
model=self.llm_config.get("model", "gpt-4-1106-preview"),
64+
tools=llm_config.get("tools", []),
65+
model=llm_config.get("model", "gpt-4-1106-preview"),
5966
)
6067
else:
6168
# retrieve an existing assistant
6269
self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
70+
# if no instructions are provided, set the instructions to the existing instructions
71+
if instructions is None:
72+
logger.warning(
73+
"No instructions were provided for given assistant. Using existing instructions from assistant API."
74+
)
75+
instructions = self.get_assistant_instructions()
76+
elif overwrite_instructions is True:
77+
logger.warning(
78+
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
79+
)
80+
self._openai_assistant = self._openai_client.beta.assistants.update(
81+
assistant_id=openai_assistant_id,
82+
instructions=instructions,
83+
)
84+
else:
85+
logger.warning(
86+
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
87+
)
88+
89+
super().__init__(
90+
name=name,
91+
system_message=instructions,
92+
human_input_mode="NEVER",
93+
llm_config=llm_config,
94+
)
6395

6496
# lazly create thread
6597
self._openai_threads = {}
@@ -107,6 +139,8 @@ def _invoke_assistant(
107139
run = self._openai_client.beta.threads.runs.create(
108140
thread_id=assistant_thread.id,
109141
assistant_id=self._openai_assistant.id,
142+
# pass the latest system message as instructions
143+
instructions=self.system_message,
110144
)
111145

112146
run_response_messages = self._get_run_response(assistant_thread, run)
@@ -300,3 +334,16 @@ def pretty_print_thread(self, thread):
300334
def oai_threads(self) -> Dict[Agent, Any]:
301335
"""Return the threads of the agent."""
302336
return self._openai_threads
337+
338+
@property
339+
def assistant_id(self):
340+
"""Return the assistant id"""
341+
return self._openai_assistant.id
342+
343+
def get_assistant_instructions(self):
344+
"""Return the assistant instructions from OAI assistant API"""
345+
return self._openai_assistant.instructions
346+
347+
def delete_assistant(self):
348+
"""Delete the assistant from OAI assistant API"""
349+
self._openai_client.beta.assistants.delete(self.assistant_id)

test/agentchat/contrib/test_gpt_assistant.py

+111-1
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ def test_gpt_assistant_chat():
3838
"description": "This is an API endpoint allowing users (analysts) to input question about GitHub in text format to retrieve the realted and structured data.",
3939
}
4040

41+
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
4142
analyst = GPTAssistantAgent(
4243
name="Open_Source_Project_Analyst",
43-
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}]},
44+
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list},
4445
instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
4546
)
4647
analyst.register_function(
@@ -62,5 +63,114 @@ def test_gpt_assistant_chat():
6263
assert len(analyst._openai_threads) == 0
6364

6465

66+
@pytest.mark.skipif(
67+
sys.platform in ["darwin", "win32"] or skip_test,
68+
reason="do not run on MacOS or windows or dependency is not installed",
69+
)
70+
def test_get_assistant_instructions():
71+
"""
72+
Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions,
73+
and assert that the retrieved instructions match the set instructions.
74+
"""
75+
76+
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
77+
assistant = GPTAssistantAgent(
78+
"assistant",
79+
instructions="This is a test",
80+
llm_config={
81+
"config_list": config_list,
82+
},
83+
)
84+
85+
instruction_match = assistant.get_assistant_instructions() == "This is a test"
86+
assistant.delete_assistant()
87+
88+
assert instruction_match is True
89+
90+
91+
@pytest.mark.skipif(
92+
sys.platform in ["darwin", "win32"] or skip_test,
93+
reason="do not run on MacOS or windows or dependency is not installed",
94+
)
95+
def test_gpt_assistant_instructions_overwrite():
96+
"""
97+
Test that the instructions of a GPTAssistantAgent can be overwritten or not depending on the value of the
98+
`overwrite_instructions` parameter when creating a new assistant with the same ID.
99+
100+
Steps:
101+
1. Create a new GPTAssistantAgent with some instructions.
102+
2. Get the ID of the assistant.
103+
3. Create a new GPTAssistantAgent with the same ID but different instructions and `overwrite_instructions=True`.
104+
4. Check that the instructions of the assistant have been overwritten with the new ones.
105+
"""
106+
107+
instructions1 = "This is a test #1"
108+
instructions2 = "This is a test #2"
109+
110+
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
111+
assistant = GPTAssistantAgent(
112+
"assistant",
113+
instructions=instructions1,
114+
llm_config={
115+
"config_list": config_list,
116+
},
117+
)
118+
119+
assistant_id = assistant.assistant_id
120+
assistant = GPTAssistantAgent(
121+
"assistant",
122+
instructions=instructions2,
123+
llm_config={
124+
"config_list": config_list,
125+
"assistant_id": assistant_id,
126+
},
127+
overwrite_instructions=True,
128+
)
129+
130+
instruction_match = assistant.get_assistant_instructions() == instructions2
131+
assistant.delete_assistant()
132+
133+
assert instruction_match is True
134+
135+
136+
@pytest.mark.skipif(
137+
sys.platform in ["darwin", "win32"] or skip_test,
138+
reason="do not run on MacOS or windows or dependency is not installed",
139+
)
140+
def test_gpt_assistant_existing_no_instructions():
141+
"""
142+
Test function to check if the GPTAssistantAgent can retrieve instructions for an existing assistant
143+
even if the assistant was created with no instructions initially.
144+
"""
145+
instructions = "This is a test #1"
146+
147+
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
148+
assistant = GPTAssistantAgent(
149+
"assistant",
150+
instructions=instructions,
151+
llm_config={
152+
"config_list": config_list,
153+
},
154+
)
155+
156+
assistant_id = assistant.assistant_id
157+
158+
# create a new assistant with the same ID but no instructions
159+
assistant = GPTAssistantAgent(
160+
"assistant",
161+
llm_config={
162+
"config_list": config_list,
163+
"assistant_id": assistant_id,
164+
},
165+
)
166+
167+
instruction_match = assistant.get_assistant_instructions() == instructions
168+
assistant.delete_assistant()
169+
assert instruction_match is True
170+
171+
65172
if __name__ == "__main__":
66173
test_gpt_assistant_chat()
174+
test_get_assistant_instructions()
175+
test_gpt_assistant_instructions_overwrite()
176+
test_gpt_assistant_existing_no_instructions()

test/oai/test_client_stream.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_aoai_chat_completion_stream():
1818
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
1919
)
2020
client = OpenAIWrapper(config_list=config_list)
21-
response = client.create(messages=[{"role": "user", "content": "2+2="}], seed=None, stream=True)
21+
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
2222
print(response)
2323
print(client.extract_text_or_function_call(response))
2424

@@ -31,7 +31,7 @@ def test_chat_completion_stream():
3131
filter_dict={"model": ["gpt-3.5-turbo"]},
3232
)
3333
client = OpenAIWrapper(config_list=config_list)
34-
response = client.create(messages=[{"role": "user", "content": "1+1="}], seed=None, stream=True)
34+
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
3535
print(response)
3636
print(client.extract_text_or_function_call(response))
3737

@@ -63,7 +63,6 @@ def test_chat_functions_stream():
6363
response = client.create(
6464
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
6565
functions=functions,
66-
seed=None,
6766
stream=True,
6867
)
6968
print(response)
@@ -74,7 +73,7 @@ def test_chat_functions_stream():
7473
def test_completion_stream():
7574
config_list = config_list_openai_aoai(KEY_LOC)
7675
client = OpenAIWrapper(config_list=config_list)
77-
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", seed=None, stream=True)
76+
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
7877
print(response)
7978
print(client.extract_text_or_function_call(response))
8079

test/test_notebook.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def _test_oai_chatgpt_gpt4(save=False):
8484
run_notebook("oai_chatgpt_gpt4.ipynb", save=save)
8585

8686

87-
@pytest.mark.skipif(
88-
skip or not sys.version.startswith("3.10"),
89-
reason="do not run if openai is not installed or py!=3.10",
90-
)
91-
def test_hierarchy_flow_using_select_speaker(save=False):
92-
run_notebook("agentchat_hierarchy_flow_using_select_speaker.ipynb", save=save)
87+
# @pytest.mark.skipif(
88+
# skip or not sys.version.startswith("3.10"),
89+
# reason="do not run if openai is not installed or py!=3.10",
90+
# )
91+
# def test_hierarchy_flow_using_select_speaker(save=False):
92+
# run_notebook("agentchat_hierarchy_flow_using_select_speaker.ipynb", save=save)
9393

9494

9595
if __name__ == "__main__":

0 commit comments

Comments
 (0)