Skip to content

process message before send #1783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def add_to_agent(self, agent: ConversableAgent):
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages)
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/capabilities/teachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def add_to_agent(self, agent: ConversableAgent):
self.teachable_agent = agent

# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message)
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)

# Was an llm_config passed to the constructor?
if self.llm_config is None:
Expand All @@ -82,7 +82,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()

def process_last_message(self, text):
def process_last_received_message(self, text):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
Expand Down
33 changes: 24 additions & 9 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def __init__(

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {"process_last_message": [], "process_all_messages": []}
self.hook_lists = {
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
}

@property
def name(self) -> str:
Expand Down Expand Up @@ -467,6 +471,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
self._oai_messages[conversation_id].append(oai_message)
return True

def _process_message_before_send(
self, message: Union[Dict, str], recipient: Agent, silent: bool
) -> Union[Dict, str]:
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
message = hook(message, recipient, silent)
return message

def send(
self,
message: Union[Dict, str],
Expand Down Expand Up @@ -509,6 +522,7 @@ def send(
Returns:
ChatResult: a ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
Expand Down Expand Up @@ -561,6 +575,7 @@ async def a_send(
Returns:
ChatResult: an ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
Expand Down Expand Up @@ -1634,11 +1649,11 @@ def generate_reply(

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
Expand Down Expand Up @@ -1695,11 +1710,11 @@ async def a_generate_reply(

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
Expand Down Expand Up @@ -2333,11 +2348,11 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists["process_all_messages"]
hook_list = self.hook_lists["process_all_messages_before_reply"]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages
Expand All @@ -2348,14 +2363,14 @@ def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
processed_messages = hook(processed_messages)
return processed_messages

def process_last_message(self, messages):
def process_last_received_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""

# If any required condition is not met, return the original message list.
hook_list = self.hook_lists["process_last_message"]
hook_list = self.hook_lists["process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
Expand Down
21 changes: 20 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,11 +1074,30 @@ def test_max_turn():
assert len(res.chat_history) <= 6


def test_process_before_send():
print_mock = unittest.mock.MagicMock()

def send_to_frontend(message, recipient, silent):
if not silent:
print(f"Message sent to {recipient.name}: {message}")
print_mock(message=message)
return message

dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("process_message_before_send", send_to_frontend)
dummy_agent_1.send("hello", dummy_agent_2)
print_mock.assert_called_once_with(message="hello")
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
print_mock.assert_called_once_with(message="hello")


if __name__ == "__main__":
# test_trigger()
# test_context()
# test_max_consecutive_auto_reply()
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
test_max_turn()
# test_max_turn()
test_process_before_send()
Loading