Skip to content

Commit

Permalink
fix some bugs (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored May 24, 2024
1 parent 52f6f79 commit d941a0a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
23 changes: 13 additions & 10 deletions examples/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import openai
from agnext.agent_components.models_clients.openai_client import OpenAI
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.application_components.single_threaded_agent_runtime import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.messages import ChatMessage
from agnext.chat.patterns.group_chat import GroupChat
from agnext.chat.patterns.orchestrator import Orchestrator


async def group_chat() -> None:
async def group_chat(message: str) -> None:
runtime = SingleThreadedAgentRuntime()

joe_oai_assistant = openai.beta.assistants.create(
Expand Down Expand Up @@ -44,22 +46,22 @@ async def group_chat() -> None:
)

chat = GroupChat(
"chat_room",
"Host",
"A round-robin chat room.",
runtime,
[joe, cathy],
num_rounds=5,
)

response = runtime.send_message(ChatMessage(body="Run a show!", sender="external"), chat)
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)

while not response.done():
await runtime.process_next()

print((await response).body) # type: ignore


async def orchestrator() -> None:
async def orchestrator(message: str) -> None:
runtime = SingleThreadedAgentRuntime()

developer_oai_assistant = openai.beta.assistants.create(
Expand Down Expand Up @@ -93,16 +95,16 @@ async def orchestrator() -> None:
)

chat = Orchestrator(
"Team",
"A software development team.",
"Manager",
"A software development team manager.",
runtime,
[developer, product_manager],
model_client=OpenAI(model="gpt-3.5-turbo"),
)

response = runtime.send_message(
ChatMessage(
body="Write a simple FastAPI webapp for showing the current time.",
body=message,
sender="customer",
),
chat,
Expand All @@ -122,11 +124,12 @@ async def orchestrator() -> None:
choices=chocies,
help="The pattern to demo.",
)
parser.add_argument("--message", help="The message to send.")
args = parser.parse_args()

if args.pattern == "group_chat":
asyncio.run(group_chat())
asyncio.run(group_chat(args.message))
elif args.pattern == "orchestrator":
asyncio.run(orchestrator())
asyncio.run(orchestrator(args.message))
else:
raise ValueError(f"Invalid pattern: {args.pattern}")
10 changes: 9 additions & 1 deletion src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List, Sequence

from ...agent_components.type_routed_agent import message_handler
from ...core.agent_runtime import AgentRuntime
from ...core.cancellation_token import CancellationToken
from ..agents.base import BaseChatAgent
from ..messages import ChatMessage

Expand All @@ -19,7 +21,13 @@ def __init__(
self._num_rounds = num_rounds
self._history: List[ChatMessage] = []

async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
@message_handler(ChatMessage)
async def on_chat_message(
self,
message: ChatMessage,
require_response: bool,
cancellation_token: CancellationToken,
) -> ChatMessage | None:
if message.reset:
# Reset the history.
self._history = []
Expand Down
12 changes: 7 additions & 5 deletions src/agnext/chat/patterns/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from typing import Any, List, Sequence, Tuple

from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken

from ...agent_components.model_client import ModelClient
from ...agent_components.type_routed_agent import message_handler
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
from ...core.agent_runtime import AgentRuntime
from ...core.cancellation_token import CancellationToken
from ..agents.base import BaseChatAgent
from ..messages import ChatMessage

Expand All @@ -33,8 +32,11 @@ def __init__(

@message_handler(ChatMessage)
async def on_chat_message(
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
) -> ChatMessage:
self,
message: ChatMessage,
require_response: bool,
cancellation_token: CancellationToken,
) -> ChatMessage | None:
# A task is received.
task = message.body

Expand Down
3 changes: 2 additions & 1 deletion src/agnext/chat/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Optional, Union

from typing_extensions import Literal

from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
from typing_extensions import Literal


def convert_content_message_to_assistant_message(
Expand Down

0 comments on commit d941a0a

Please sign in to comment.