Skip to content

Commit

Permalink
Initial chat memory implementation (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Jun 8, 2024
1 parent 37cc6bc commit e99ad51
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 24 deletions.
3 changes: 3 additions & 0 deletions examples/chess_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.memory import BufferedChatMemory
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import TextMessage
Expand Down Expand Up @@ -175,6 +176,7 @@ def get_board_text() -> Annotated[str, "The current board state"]:
"Think about your strategy and call make_move(thinking, move) to make a move."
),
],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
tools=black_tools,
)
Expand All @@ -190,6 +192,7 @@ def get_board_text() -> Annotated[str, "The current board state"]:
"Think about your strategy and call make_move(thinking, move) to make a move."
),
],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
tools=white_tools,
)
Expand Down
5 changes: 5 additions & 0 deletions examples/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.memory import BufferedChatMemory
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.components.models import OpenAI, SystemMessage
Expand Down Expand Up @@ -83,6 +84,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
description="A developer that writes code.",
runtime=runtime,
system_messages=[SystemMessage("You are a Python developer.")],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
)

Expand All @@ -109,6 +111,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
SystemMessage("You are a product manager good at translating customer needs into software specifications."),
SystemMessage("You can use the search tool to find information on the web."),
],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
tools=[SearchTool()],
)
Expand All @@ -118,6 +121,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
description="A planner that organizes and schedules tasks.",
runtime=runtime,
system_messages=[SystemMessage("You are a planner of complex tasks.")],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
)

Expand All @@ -128,6 +132,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
system_messages=[
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
],
memory=BufferedChatMemory(buffer_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
)

Expand Down
50 changes: 26 additions & 24 deletions src/agnext/chat/agents/chat_completion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
import json
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple

from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import (
FunctionCallMessage,
Message,
Reset,
RespondNow,
ResponseFormat,
TextMessage,
)
from agnext.chat.utils import convert_messages_to_llm_messages
from agnext.components import (
from ...components import (
FunctionCall,
TypeRoutedAgent,
message_handler,
)
from agnext.components.models import (
from ...components.models import (
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
SystemMessage,
)
from agnext.components.tools import Tool
from agnext.core import AgentRuntime, CancellationToken
from ...components.tools import Tool
from ...core import AgentRuntime, CancellationToken
from ..memory import ChatMemory
from ..types import (
FunctionCallMessage,
Message,
Reset,
RespondNow,
ResponseFormat,
TextMessage,
)
from ..utils import convert_messages_to_llm_messages
from .base import BaseChatAgent


class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
Expand All @@ -34,32 +35,33 @@ def __init__(
description: str,
runtime: AgentRuntime,
system_messages: List[SystemMessage],
memory: ChatMemory,
model_client: ChatCompletionClient,
tools: Sequence[Tool] = [],
) -> None:
super().__init__(name, description, runtime)
self._system_messages = system_messages
self._client = model_client
self._chat_messages: List[Message] = []
self._memory = memory
self._tools = tools

@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Add a user message.
self._chat_messages.append(message)
self._memory.add_message(message)

@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the chat messages.
self._chat_messages = []
self._memory.clear()

@message_handler()
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
# Get a response from the model.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
self._system_messages + convert_messages_to_llm_messages(self._memory.get_messages(), self.name),
tools=self._tools,
json_output=message.response_format == ResponseFormat.json_object,
)
Expand All @@ -80,7 +82,7 @@ async def on_respond_now(
)
# Make an assistant message from the response.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
self._system_messages + convert_messages_to_llm_messages(self._memory.get_messages(), self.name),
tools=self._tools,
json_output=message.response_format == ResponseFormat.json_object,
)
Expand All @@ -96,7 +98,7 @@ async def on_respond_now(
raise ValueError(f"Unexpected response: {response.content}")

# Add the response to the chat messages.
self._chat_messages.append(final_response)
self._memory.add_message(final_response)

# Return the response.
return final_response
Expand All @@ -109,7 +111,7 @@ async def on_tool_call_message(
raise ValueError("No tools available")

# Add a tool call message.
self._chat_messages.append(message)
self._memory.add_message(message)

# Execute the tool calls.
results: List[FunctionExecutionResult] = []
Expand Down Expand Up @@ -146,7 +148,7 @@ async def on_tool_call_message(
tool_call_result_msg = FunctionExecutionResultMessage(content=results)

# Add tool call result message.
self._chat_messages.append(tool_call_result_msg)
self._memory.add_message(tool_call_result_msg)

# Return the results.
return tool_call_result_msg
Expand All @@ -172,11 +174,11 @@ async def execute_function(
def save_state(self) -> Mapping[str, Any]:
return {
"description": self.description,
"chat_messages": self._chat_messages,
"memory": self._memory.save_state(),
"system_messages": self._system_messages,
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._chat_messages = state["chat_messages"]
self._memory.load_state(state["memory"])
self._system_messages = state["system_messages"]
self._description = state["description"]
5 changes: 5 additions & 0 deletions src/agnext/chat/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._base import ChatMemory
from ._buffered import BufferedChatMemory
from ._full import FullChatMemory

__all__ = ["ChatMemory", "FullChatMemory", "BufferedChatMemory"]
15 changes: 15 additions & 0 deletions src/agnext/chat/memory/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any, List, Mapping, Protocol

from ..types import Message


class ChatMemory(Protocol):
def add_message(self, message: Message) -> None: ...

def get_messages(self) -> List[Message]: ...

def clear(self) -> None: ...

def save_state(self) -> Mapping[str, Any]: ...

def load_state(self, state: Mapping[str, Any]) -> None: ...
29 changes: 29 additions & 0 deletions src/agnext/chat/memory/_buffered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, List, Mapping

from ..types import Message
from ._base import ChatMemory


class BufferedChatMemory(ChatMemory):
def __init__(self, buffer_size: int) -> None:
self._messages: List[Message] = []
self._buffer_size = buffer_size

def add_message(self, message: Message) -> None:
self._messages.append(message)

def get_messages(self) -> List[Message]:
return self._messages[-self._buffer_size :]

def clear(self) -> None:
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]
24 changes: 24 additions & 0 deletions src/agnext/chat/memory/_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, List, Mapping

from ..types import Message
from ._base import ChatMemory


class FullChatMemory(ChatMemory):
def __init__(self) -> None:
self._messages: List[Message] = []

def add_message(self, message: Message) -> None:
self._messages.append(message)

def get_messages(self) -> List[Message]:
return self._messages

def clear(self) -> None:
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {"messages": [message for message in self._messages]}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]

0 comments on commit e99ad51

Please sign in to comment.