Skip to content

Commit

Permalink
Update termination logic (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkall authored Oct 9, 2023
1 parent 46ab5b8 commit 95e4c58
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,10 @@ def colored(x, *args, **kwargs):
"""


def _is_termination_msg_retrievechat(message):
"""Check if a message is a termination message."""
if isinstance(message, dict):
message = message.get("content")
if message is None:
return False
cb = extract_code(message)
contain_code = False
for c in cb:
if c[0] == "python":
contain_code = True
break
return not contain_code


class RetrieveUserProxyAgent(UserProxyAgent):
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
is_termination_msg: Optional[Callable[[Dict], bool]] = _is_termination_msg_retrievechat,
human_input_mode: Optional[str] = "ALWAYS",
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
Expand Down Expand Up @@ -135,7 +119,6 @@ def __init__(
"""
super().__init__(
name=name,
is_termination_msg=is_termination_msg,
human_input_mode=human_input_mode,
**kwargs,
)
Expand Down Expand Up @@ -164,7 +147,27 @@ def __init__(
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply)
self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1)

def _is_termination_msg_retrievechat(self, message):
"""Check if a message is a termination message.
For code generation, terminate when no code block is detected. Currently only detect python code blocks.
For question answering, terminate when don't update context, i.e., answer is given.
"""
if isinstance(message, dict):
message = message.get("content")
if message is None:
return False
cb = extract_code(message)
contain_code = False
for c in cb:
# todo: support more languages
if c[0] == "python":
contain_code = True
break
update_context_case1, update_context_case2 = self._check_update_context(message)
return not (contain_code or update_context_case1 or update_context_case2)

@staticmethod
def get_max_tokens(model="gpt-3.5-turbo"):
Expand Down Expand Up @@ -231,6 +234,13 @@ def _generate_message(self, doc_contents, task="default"):
raise NotImplementedError(f"task {task} is not implemented.")
return message

def _check_update_context(self, message):
if isinstance(message, dict):
message = message.get("content", "")
update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper()
update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper()
return update_context_case1, update_context_case2

def _generate_retrieve_user_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -247,13 +257,7 @@ def _generate_retrieve_user_reply(
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
update_context_case1 = (
"UPDATE CONTEXT" in message.get("content", "")[-20:].upper()
or "UPDATE CONTEXT" in message.get("content", "")[:20].upper()
)
update_context_case2 = (
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
)
update_context_case1, update_context_case2 = self._check_update_context(message)
if (update_context_case1 or update_context_case2) and self.update_context:
print(colored("Updating context and resetting conversation.", "green"), flush=True)
# extract the first sentence in the response as the intermediate answer
Expand Down

0 comments on commit 95e4c58

Please sign in to comment.