Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Interface in AgentChat #4438

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
48d7ecb
initial base memroy impl
victordibia Nov 30, 2024
f70f61e
update, add example with chromadb
victordibia Dec 1, 2024
24fa684
include mimetype consideration
victordibia Dec 1, 2024
9e94ec8
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Dec 19, 2024
0b7469e
add transform method
victordibia Dec 20, 2024
138ee05
update to address feedback, will update after 4681 is merged
victordibia Dec 20, 2024
a94634b
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Dec 20, 2024
675924c
update memory impl,
victordibia Dec 25, 2024
b1da7e2
remove chroma db, typing fixes
victordibia Jan 3, 2025
f0812a3
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 3, 2025
32701db
format, add test
victordibia Jan 4, 2025
d7bf4d2
update uv lock
victordibia Jan 4, 2025
afbef4d
update docs
victordibia Jan 4, 2025
003bb2e
format updates
victordibia Jan 4, 2025
7b15c2e
update notebook
victordibia Jan 4, 2025
b353110
add memoryqueryevent message, yield message for observability.
victordibia Jan 4, 2025
e1a9be2
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 4, 2025
c797f6a
minor fixes, make score optional/none
victordibia Jan 4, 2025
dfb1da6
Merge branch 'agentchat_memory_vd' of github.com:microsoft/autogen in…
victordibia Jan 4, 2025
97ed7f5
Update python/packages/autogen-agentchat/src/autogen_agentchat/agents…
victordibia Jan 6, 2025
5a74fdf
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 6, 2025
24bd81e
update tests to improve cov
victordibia Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ToolCallResultMessage,
)
from ._base_chat_agent import BaseChatAgent
from ..memory._base_memory import Memory, MemoryQueryResult

event_logger = logging.getLogger(EVENT_LOGGER_NAME)

Expand Down Expand Up @@ -60,10 +61,12 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
else:
name = values["name"]
if not isinstance(name, str):
raise ValueError(f"Handoff name must be a string: {values['name']}")
raise ValueError(
f"Handoff name must be a string: {values['name']}")
# Check if name is a valid identifier.
if not name.isidentifier():
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
raise ValueError(
f"Handoff name must be a valid identifier: {values['name']}")
if values.get("message") is None:
values["message"] = (
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
Expand Down Expand Up @@ -203,22 +206,29 @@ def __init__(
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
tools: List[Tool | Callable[..., Any] |
Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[Handoff | str] | None = None,
memory: Memory | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._memory = memory

self._system_messages: List[SystemMessage | UserMessage |
AssistantMessage | FunctionExecutionResultMessage] = []
if system_message is None:
self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
raise ValueError(
"The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
Expand All @@ -227,7 +237,8 @@ def __init__(
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(
tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
Expand All @@ -239,26 +250,42 @@ def __init__(
self._handoffs: Dict[str, Handoff] = {}
if handoffs is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling, which is needed for handoffs.")
raise ValueError(
"The model does not support function calling, which is needed for handoffs.")
for handoff in handoffs:
if isinstance(handoff, str):
handoff = Handoff(target=handoff)
if isinstance(handoff, Handoff):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
raise ValueError(
f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
raise ValueError(
f"Handoff names must be unique: {handoff_tool_names}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []

def _format_memory_context(self, results: List[MemoryQueryResult]) -> str:
if not results or not self._memory: # Guard against no memory
return ""

context_lines = []
for i, result in enumerate(results, 1):
context_lines.append(
self._memory.config.context_format.format(
i=i, content=result.entry.content, score=result.score)
)

return "".join(context_lines)

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
Expand All @@ -270,44 +297,70 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")
raise AssertionError(
"The stream should have returned the final result.")

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Query memory if available with the last message
memory_context = ""
if self._memory is not None and messages:
try:
last_message = messages[-1]
# ensure the last message is a text message or multimodal message
if not isinstance(last_message, TextMessage) and not isinstance(last_message, MultiModalMessage):
raise ValueError(
"Memory query failed: Last message must be a text message or multimodal message.")
results: List[MemoryQueryResult] = await self._memory.query(messages[-1].content, cancellation_token=cancellation_token)
victordibia marked this conversation as resolved.
Show resolved Hide resolved
memory_context = self._format_memory_context(results)
except Exception as e:
event_logger.warning(f"Memory query failed: {e}")

# Add messages to the model context.
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
self._model_context.append(UserMessage(
content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
# Prepare messages for model with memory context if available
llm_messages = self._system_messages
if memory_context:
llm_messages = llm_messages + \
[SystemMessage(content=memory_context)]
llm_messages = llm_messages + self._model_context

# Generate inference result
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
tool_call_msg = ToolCallMessage(
content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
execution_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
tool_call_result_msg = ToolCallResultMessage(
content=execution_results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
self._model_context.append(
FunctionExecutionResultMessage(content=execution_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -318,7 +371,8 @@ async def on_messages_stream(
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
raise ValueError(
f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(
Expand All @@ -329,15 +383,22 @@ async def on_messages_stream(
return

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = (
self._system_messages
+ ([SystemMessage(content=memory_context)]
if memory_context else [])
+ self._model_context
)
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
self._model_context.append(AssistantMessage(
content=result.content, source=self.name))

assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)

Expand All @@ -348,9 +409,11 @@ async def _execute_tool_call(
try:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
tool = next((t for t in self._tools +
self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
raise ValueError(
f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from datetime import datetime
from typing import Any, Dict, List, Protocol, Union, runtime_checkable

from autogen_core.base import CancellationToken
from autogen_core.components import Image
from pydantic import BaseModel, ConfigDict, Field


class BaseMemoryConfig(BaseModel):
"""Base configuration for memory implementations."""

k: int = Field(default=5, description="Number of results to return")
score_threshold: float | None = Field(default=None, description="Minimum relevance score")
context_format: str = Field(
default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.",
description="Format string for memory results in prompt",
)

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryEntry(BaseModel):
"""A memory entry containing content and metadata."""

content: Union[str, List[Union[str, Image]]]
colombod marked this conversation as resolved.
Show resolved Hide resolved
"""The content of the memory entry - can be text or multimodal."""

metadata: Dict[str, Any] = Field(default_factory=dict)
"""Optional metadata associated with the memory entry."""

timestamp: datetime = Field(default_factory=datetime.now)
"""When the memory was created."""

source: str | None = None
"""Optional source identifier for the memory."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryQueryResult(BaseModel):
"""Result from a memory query including the entry and its relevance score."""

entry: MemoryEntry
"""The memory entry."""

score: float
"""Relevance score for this result. Higher means more relevant."""

model_config = ConfigDict(arbitrary_types_allowed=True)


@runtime_checkable
class Memory(Protocol):
"""Protocol defining the interface for memory implementations."""

@property
def name(self) -> str | None:
"""The name of this memory implementation."""
...

@property
def config(self) -> BaseMemoryConfig:
"""The configuration for this memory implementation."""
...

async def query(
self,
query: Union[str, Image, List[Union[str, Image]]],
victordibia marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> List[MemoryQueryResult]:
"""
Query the memory store and return relevant entries.

Args:
query: Text, image or multimodal query
cancellation_token: Optional token to cancel operation
**kwargs: Additional implementation-specific parameters

Returns:
List of memory entries with relevance scores
"""
...

async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None:
"""
Add a new entry to memory.

Args:
entry: The memory entry to add
cancellation_token: Optional token to cancel operation
"""
...

async def clear(self) -> None:
"""Clear all entries from memory."""
...

async def cleanup(self) -> None:
"""Clean up any resources used by the memory implementation."""
...
Loading
Loading