Skip to content

Commit

Permalink
Handoff termination and show how to use it for asking user input (#4128)
Browse files Browse the repository at this point in the history
* Handoff termination and show how to use it for asking user input

* lint
  • Loading branch information
ekzhu authored Nov 11, 2024
1 parent 9f17508 commit 4786f18
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from ._console import Console
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination, TokenUsageTermination
from ._terminations import (
HandoffTermination,
MaxMessageTermination,
StopMessageTermination,
TextMentionTermination,
TokenUsageTermination,
)

__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
"TokenUsageTermination",
"HandoffTermination",
"Console",
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Sequence

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import AgentMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage


class StopMessageTermination(TerminationCondition):
Expand Down Expand Up @@ -144,3 +144,34 @@ async def reset(self) -> None:
self._total_token_count = 0
self._prompt_token_count = 0
self._completion_token_count = 0


class HandoffTermination(TerminationCondition):
"""Terminate the conversation if a :class:`~autogen_agentchat.messages.HandoffMessage`
with the given target is received.
Args:
target (str): The target of the handoff message.
"""

def __init__(self, target: str) -> None:
self._terminated = False
self._target = target

@property
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
if isinstance(message, HandoffMessage) and message.target == self._target:
self._terminated = True
return StopMessage(
content=f"Handoff to {self._target} from {message.source} detected.", source="HandoffTermination"
)
return None

async def reset(self) -> None:
self._terminated = False
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.task import (
HandoffTermination,
MaxMessageTermination,
StopMessageTermination,
TextMentionTermination,
Expand All @@ -9,6 +10,32 @@
from autogen_core.components.models import RequestUsage


@pytest.mark.asyncio
async def test_handoff_termination() -> None:
termination = HandoffTermination("target")
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert await termination([HandoffMessage(target="target", source="user", content="Hello")]) is not None
assert termination.terminated
await termination.reset()
assert await termination([HandoffMessage(target="another", source="user", content="Hello")]) is None
assert not termination.terminated
await termination.reset()
assert (
await termination(
[
TextMessage(content="Hello", source="user"),
HandoffMessage(target="target", source="user", content="Hello"),
]
)
is not None
)
assert termination.terminated
await termination.reset()


@pytest.mark.asyncio
async def test_stop_message_termination() -> None:
termination = StopMessageTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -404,8 +404,6 @@
}
],
"source": [
"from autogen_agentchat.task import Console\n",
"\n",
"# Use `asyncio.run(Console(reflection_team.run_stream(task=\"Write a short poem about fall season.\")))` when running in a script.\n",
"await Console(\n",
" reflection_team.run_stream(task=\"Write a short poem about fall season.\")\n",
Expand Down Expand Up @@ -593,6 +591,141 @@
"# Use the `asyncio.run(Console(reflection_team.run_stream()))` when running in a script.\n",
"await Console(reflection_team.run_stream())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pause for User Input\n",
"\n",
"Often times, team needs additional input from the application (i.e., user)\n",
"to continue processing the task.\n",
"You can use the {py:class}`~autogen_agentchat.task.HandoffTermination` termination condition\n",
"to stop the team when an agent sends a {py:class}`~autogen_agentchat.messages.HandoffMessage` message.\n",
"\n",
"Let's create a team with a single {py:class}`~autogen_agentchat.agents.AssistantAgent` agent\n",
"with a handoff setting.\n",
"\n",
"```{note}\n",
"The model used with {py:class}`~autogen_agentchat.agents.AssistantAgent`must support tool call\n",
"to use the handoff feature.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import AssistantAgent, Handoff\n",
"from autogen_agentchat.task import HandoffTermination, TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"# Create an OpenAI model client.\n",
"model_client = OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-2024-08-06\",\n",
" # api_key=\"sk-...\", # Optional if you have an OPENAI_API_KEY env variable set.\n",
")\n",
"\n",
"# Create a lazy assistant agent that always hands off to the user.\n",
"lazy_agent = AssistantAgent(\n",
" \"lazy_assistant\",\n",
" model_client=model_client,\n",
" handoffs=[Handoff(target=\"user\", message=\"Transfer to user.\")],\n",
" system_message=\"Always transfer to user when you don't know the answer. Respond 'TERMINATE' when task is complete.\",\n",
")\n",
"\n",
"# Define a termination condition that checks for handoff message targetting helper and text \"TERMINATE\".\n",
"handoff_termination = HandoffTermination(target=\"user\")\n",
"text_termination = TextMentionTermination(\"TERMINATE\")\n",
"termination = handoff_termination | text_termination\n",
"\n",
"# Create a single-agent team.\n",
"lazy_agent_team = RoundRobinGroupChat([lazy_agent], termination_condition=termination)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run the team with a task that requires additional input from the user\n",
"because the agent doesn't have relevant tools to continue processing the task."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"What is the weather in New York?\n",
"---------- lazy_assistant ----------\n",
"[FunctionCall(id='call_YHm4KPjFIWZE95YrJWlJwcv4', arguments='{}', name='transfer_to_user')]\n",
"[Prompt tokens: 68, Completion tokens: 11]\n",
"---------- lazy_assistant ----------\n",
"[FunctionExecutionResult(content='Transfer to user.', call_id='call_YHm4KPjFIWZE95YrJWlJwcv4')]\n",
"---------- lazy_assistant ----------\n",
"Transfer to user.\n",
"---------- Summary ----------\n",
"Number of messages: 4\n",
"Finish reason: Handoff to user from lazy_assistant detected.\n",
"Total prompt tokens: 68\n",
"Total completion tokens: 11\n",
"Duration: 0.73 seconds\n"
]
}
],
"source": [
"from autogen_agentchat.task import Console\n",
"\n",
"# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"What is the weather in New York?\")))` when running in a script.\n",
"await Console(lazy_agent_team.run_stream(task=\"What is the weather in New York?\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see the team stopped due to the handoff message was detected.\n",
"Let's continue the team by providing the information the agent needs."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"It is raining in New York.\n",
"---------- lazy_assistant ----------\n",
"I hope you stay dry! Is there anything else you would like to know or do?\n",
"[Prompt tokens: 108, Completion tokens: 19]\n",
"---------- lazy_assistant ----------\n",
"TERMINATE\n",
"[Prompt tokens: 134, Completion tokens: 4]\n",
"---------- Summary ----------\n",
"Number of messages: 3\n",
"Finish reason: Text 'TERMINATE' mentioned\n",
"Total prompt tokens: 242\n",
"Total completion tokens: 23\n",
"Duration: 6.77 seconds\n"
]
}
],
"source": [
"# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\")))` when running in a script.\n",
"await Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\"))"
]
}
],
"metadata": {
Expand Down

0 comments on commit 4786f18

Please sign in to comment.