Skip to content

Commit

Permalink
Memory Interface in AgentChat (#4438)
Browse files Browse the repository at this point in the history
* initial base memroy impl

* update, add example with chromadb

* include mimetype consideration

* add transform method

* update to address feedback, will update after 4681 is merged

* update memory impl,

* remove chroma db, typing fixes

* format, add test

* update uv lock

* update docs

* format updates

* update notebook

* add memoryqueryevent message, yield message for observability.

* minor fixes, make score optional/none

* Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

Co-authored-by: Eric Zhu <[email protected]>

* update tests to improve cov

* refactor, move memory to core.

* format fixxes

* format updates

* format updates

* fix azure notebook import, other fixes

* update notebook, support str query in Memory protocol

* update test

* update cells

* add specific extensible return types to memory query and update_context

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
victordibia and ekzhu authored Jan 14, 2025
1 parent d883e3d commit abbdbb2
Show file tree
Hide file tree
Showing 11 changed files with 995 additions and 279 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
Expand All @@ -35,6 +36,7 @@
AgentEvent,
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
Expand Down Expand Up @@ -120,6 +122,7 @@ class AssistantAgent(BaseChatAgent):
will be returned as the response.
Available variables: `{tool_name}`, `{arguments}`, `{result}`.
For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`.
memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`.
Raises:
ValueError: If tool names are not unique.
Expand Down Expand Up @@ -240,9 +243,20 @@ def __init__(
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._memory = None
if memory is not None:
if isinstance(memory, list):
self._memory = memory
else:
raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}")

self._system_messages: List[
SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage
] = []
if system_message is None:
self._system_messages = []
else:
Expand Down Expand Up @@ -325,6 +339,17 @@ async def on_messages_stream(
# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []

# Update the model context with memory content.
if self._memory:
for memory in self._memory:
update_context_result = await memory.update_context(self._model_context)
if update_context_result and len(update_context_result.memories.results) > 0:
memory_query_event_msg = MemoryQueryEvent(
content=update_context_result.memories.results, source=self.name
)
inner_messages.append(memory_query_event_msg)
yield memory_query_event_msg

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class and includes specific fields relevant to the type of message being sent.
from typing import List, Literal

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -115,14 +116,24 @@ class UserInputRequestedEvent(BaseAgentEvent):
type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"


class MemoryQueryEvent(BaseAgentEvent):
"""An event signaling the results of memory queries."""

content: List[MemoryContent]
"""The memory query results."""

type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
Field(discriminator="type"),
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""

Expand All @@ -138,5 +149,6 @@ class UserInputRequestedEvent(BaseAgentEvent):
"ToolCallExecutionEvent",
"ToolCallRequestEvent",
"ToolCallSummaryMessage",
"MemoryQueryEvent",
"UserInputRequestedEvent",
]
85 changes: 84 additions & 1 deletion python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models._model_client import ModelFamily
Expand Down Expand Up @@ -508,4 +510,85 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:

# Check if the mock client is called with only the last two messages.
assert len(mock.calls) == 1
assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message
# 2 message from the context + 1 system message
assert len(mock.calls[0]) == 3


@pytest.mark.asyncio
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Hello", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)

# Test basic memory properties and empty context
memory = ListMemory(name="test_memory")
assert memory.name == "test_memory"

empty_context = BufferedChatCompletionContext(buffer_size=2)
empty_results = await memory.update_context(empty_context)
assert len(empty_results.memories.results) == 0

# Test various content types
memory = ListMemory()
await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT))
await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON))
await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE))

# Test query functionality
query_result = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert isinstance(query_result, MemoryQueryResult)
# Should have all three memories we added
assert len(query_result.results) == 3

# Test clear and cleanup
await memory.clear()
empty_query = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert len(empty_query.results) == 0
await memory.close() # Should not raise

# Test invalid memory type
with pytest.raises(TypeError):
AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
memory="invalid", # type: ignore
)

# Test with agent
memory2 = ListMemory()
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))

agent = AssistantAgent(
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
)

result = await agent.run(task="test task")
assert len(result.messages) > 0
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
assert memory_event is not None
assert len(memory_event.content) > 0
assert isinstance(memory_event.content[0], MemoryContent)

# Test memory protocol
class BadMemory:
pass

assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ tutorial/termination
tutorial/custom-agents
tutorial/state
tutorial/declarative
tutorial/memory
```

```{toctree}
Expand Down
Loading

0 comments on commit abbdbb2

Please sign in to comment.