Skip to content

Commit 32701db

Browse files
committed
format, add test
1 parent f0812a3 commit 32701db

File tree

7 files changed

+315
-184
lines changed

7 files changed

+315
-184
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+21-37
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .. import EVENT_LOGGER_NAME
3333
from ..base import Handoff as HandoffBase
3434
from ..base import Response
35+
from ..memory._base_memory import Memory
3536
from ..messages import (
3637
AgentEvent,
3738
ChatMessage,
@@ -44,7 +45,6 @@
4445
)
4546
from ..state import AssistantAgentState
4647
from ._base_chat_agent import BaseChatAgent
47-
from ..memory._base_memory import Memory
4848

4949
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
5050

@@ -245,8 +245,7 @@ def __init__(
245245
name: str,
246246
model_client: ChatCompletionClient,
247247
*,
248-
tools: List[Tool | Callable[..., Any] |
249-
Callable[..., Awaitable[Any]]] | None = None,
248+
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
250249
handoffs: List[HandoffBase | str] | None = None,
251250
model_context: ChatCompletionContext | None = None,
252251
description: str = "An agent that provides assistance with ability to use tools.",
@@ -266,20 +265,19 @@ def __init__(
266265
elif isinstance(memory, list):
267266
self._memory = memory
268267
else:
269-
raise TypeError(
270-
f"Expected Memory, List[Memory], or None, got {type(memory)}")
268+
raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}")
271269

272-
self._system_messages: List[SystemMessage | UserMessage |
273-
AssistantMessage | FunctionExecutionResultMessage] = []
270+
self._system_messages: List[
271+
SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage
272+
] = []
274273
if system_message is None:
275274
self._system_messages = []
276275
else:
277276
self._system_messages = [SystemMessage(content=system_message)]
278277
self._tools: List[Tool] = []
279278
if tools is not None:
280279
if model_client.model_info["function_calling"] is False:
281-
raise ValueError(
282-
"The model does not support function calling.")
280+
raise ValueError("The model does not support function calling.")
283281
for tool in tools:
284282
if isinstance(tool, Tool):
285283
self._tools.append(tool)
@@ -288,8 +286,7 @@ def __init__(
288286
description = tool.__doc__
289287
else:
290288
description = ""
291-
self._tools.append(FunctionTool(
292-
tool, description=description))
289+
self._tools.append(FunctionTool(tool, description=description))
293290
else:
294291
raise ValueError(f"Unsupported tool type: {type(tool)}")
295292
# Check if tool names are unique.
@@ -301,22 +298,19 @@ def __init__(
301298
self._handoffs: Dict[str, HandoffBase] = {}
302299
if handoffs is not None:
303300
if model_client.model_info["function_calling"] is False:
304-
raise ValueError(
305-
"The model does not support function calling, which is needed for handoffs.")
301+
raise ValueError("The model does not support function calling, which is needed for handoffs.")
306302
for handoff in handoffs:
307303
if isinstance(handoff, str):
308304
handoff = HandoffBase(target=handoff)
309305
if isinstance(handoff, HandoffBase):
310306
self._handoff_tools.append(handoff.handoff_tool)
311307
self._handoffs[handoff.name] = handoff
312308
else:
313-
raise ValueError(
314-
f"Unsupported handoff type: {type(handoff)}")
309+
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
315310
# Check if handoff tool names are unique.
316311
handoff_tool_names = [tool.name for tool in self._handoff_tools]
317312
if len(handoff_tool_names) != len(set(handoff_tool_names)):
318-
raise ValueError(
319-
f"Handoff names must be unique: {handoff_tool_names}")
313+
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
320314
# Check if handoff tool names not in tool names.
321315
if any(name in tool_names for name in handoff_tool_names):
322316
raise ValueError(
@@ -344,8 +338,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
344338
async for message in self.on_messages_stream(messages, cancellation_token):
345339
if isinstance(message, Response):
346340
return message
347-
raise AssertionError(
348-
"The stream should have returned the final result.")
341+
raise AssertionError("The stream should have returned the final result.")
349342

350343
async def on_messages_stream(
351344
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
@@ -377,26 +370,22 @@ async def on_messages_stream(
377370
# Check if the response is a string and return it.
378371
if isinstance(result.content, str):
379372
yield Response(
380-
chat_message=TextMessage(
381-
content=result.content, source=self.name, models_usage=result.usage),
373+
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
382374
inner_messages=inner_messages,
383375
)
384376
return
385377

386378
# Process tool calls.
387-
assert isinstance(result.content, list) and all(
388-
isinstance(item, FunctionCall) for item in result.content)
389-
tool_call_msg = ToolCallRequestEvent(
390-
content=result.content, source=self.name, models_usage=result.usage)
379+
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
380+
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
391381
event_logger.debug(tool_call_msg)
392382
# Add the tool call message to the output.
393383
inner_messages.append(tool_call_msg)
394384
yield tool_call_msg
395385

396386
# Execute the tool calls.
397387
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
398-
tool_call_result_msg = ToolCallExecutionEvent(
399-
content=results, source=self.name)
388+
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
400389
event_logger.debug(tool_call_result_msg)
401390
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
402391
inner_messages.append(tool_call_result_msg)
@@ -416,8 +405,7 @@ async def on_messages_stream(
416405
)
417406
# Return the output messages to signal the handoff.
418407
yield Response(
419-
chat_message=HandoffMessage(
420-
content=handoffs[0].message, target=handoffs[0].target, source=self.name),
408+
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
421409
inner_messages=inner_messages,
422410
)
423411
return
@@ -431,8 +419,7 @@ async def on_messages_stream(
431419
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
432420
# Yield the response.
433421
yield Response(
434-
chat_message=TextMessage(
435-
content=result.content, source=self.name, models_usage=result.usage),
422+
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
436423
inner_messages=inner_messages,
437424
)
438425
else:
@@ -448,8 +435,7 @@ async def on_messages_stream(
448435
)
449436
tool_call_summary = "\n".join(tool_call_summaries)
450437
yield Response(
451-
chat_message=ToolCallSummaryMessage(
452-
content=tool_call_summary, source=self.name),
438+
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
453439
inner_messages=inner_messages,
454440
)
455441

@@ -460,11 +446,9 @@ async def _execute_tool_call(
460446
try:
461447
if not self._tools + self._handoff_tools:
462448
raise ValueError("No tools are available.")
463-
tool = next((t for t in self._tools +
464-
self._handoff_tools if t.name == tool_call.name), None)
449+
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
465450
if tool is None:
466-
raise ValueError(
467-
f"The tool '{tool_call.name}' is not available.")
451+
raise ValueError(f"The tool '{tool_call.name}' is not available.")
468452
arguments = json.loads(tool_call.arguments)
469453
result = await tool.run_json(arguments, cancellation_token)
470454
result_as_str = tool.return_value_as_string(result)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from ._base_memory import Memory, MemoryContent, MemoryMimeType
2+
from ._list_memory import ListMemory, ListMemoryConfig
3+
4+
__all__ = [
5+
"Memory",
6+
"MemoryContent",
7+
"MemoryMimeType",
8+
"ListMemory",
9+
"ListMemoryConfig",
10+
]

python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, Dict, List, Protocol, Union, runtime_checkable
44

55
from autogen_core import CancellationToken, Image
6-
from pydantic import BaseModel, ConfigDict, Field
76
from autogen_core.model_context import ChatCompletionContext
7+
from pydantic import BaseModel, ConfigDict, Field
88

99

1010
class MemoryMimeType(Enum):
@@ -22,7 +22,7 @@ class MemoryMimeType(Enum):
2222

2323
class MemoryContent(BaseModel):
2424
content: ContentType
25-
mime_type: MemoryMimeType
25+
mime_type: MemoryMimeType | str
2626
metadata: Dict[str, Any] | None = None
2727
timestamp: datetime | None = None
2828
source: str | None = None
@@ -35,8 +35,7 @@ class BaseMemoryConfig(BaseModel):
3535
"""Base configuration for memory implementations."""
3636

3737
k: int = Field(default=5, description="Number of results to return")
38-
score_threshold: float | None = Field(
39-
default=None, description="Minimum relevance score")
38+
score_threshold: float | None = Field(default=None, description="Minimum relevance score")
4039

4140
model_config = ConfigDict(arbitrary_types_allowed=True)
4241

python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py

+25-38
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from difflib import SequenceMatcher
21
import logging
2+
from difflib import SequenceMatcher
33
from typing import Any, List
44

55
from autogen_core import CancellationToken, Image
6-
from pydantic import Field
7-
8-
from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryMimeType
96
from autogen_core.model_context import ChatCompletionContext
107
from autogen_core.models import (
118
SystemMessage,
129
)
10+
from pydantic import Field
1311

1412
from .. import EVENT_LOGGER_NAME
13+
from ._base_memory import BaseMemoryConfig, Memory, MemoryContent, MemoryMimeType
1514

1615
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
1716

@@ -34,19 +33,10 @@ class ListMemory(Memory):
3433
Example:
3534
```python
3635
# Initialize memory with custom config
37-
memory = ListMemory(
38-
name="chat_history",
39-
config=ListMemoryConfig(
40-
similarity_threshold=0.7,
41-
k=3
42-
)
43-
)
36+
memory = ListMemory(name="chat_history", config=ListMemoryConfig(similarity_threshold=0.7, k=3))
4437
4538
# Add memory content
46-
content = MemoryContent(
47-
content="User prefers formal language",
48-
mime_type=MemoryMimeType.TEXT
49-
)
39+
content = MemoryContent(content="User prefers formal language", mime_type=MemoryMimeType.TEXT)
5040
await memory.add(content)
5141
5242
# Transform a model context with memory
@@ -106,26 +96,23 @@ async def transform(
10696

10797
# Extract query from last message
10898
last_message = messages[-1]
109-
query_text = last_message.content if isinstance(
110-
last_message.content, str) else str(last_message)
111-
query = MemoryContent(content=query_text,
112-
mime_type=MemoryMimeType.TEXT)
99+
query_text = last_message.content if isinstance(last_message.content, str) else str(last_message)
100+
query = MemoryContent(content=query_text, mime_type=MemoryMimeType.TEXT)
113101

114102
# Query memory and format results
115103
results: List[str] = []
116104
query_results = await self.query(query)
117105
for i, result in enumerate(query_results, 1):
118106
if isinstance(result.content, str):
119107
results.append(f"{i}. {result.content}")
120-
event_logger.debug(
121-
f"Retrieved memory {i}. {result.content}, score: {result.score}"
122-
)
108+
event_logger.debug(f"Retrieved memory {i}. {result.content}, score: {result.score}")
123109

124110
# Add memory results to context
125111
if results:
126112
memory_context = (
127-
"\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" +
128-
"\n".join(results) + "\n"
113+
"\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n"
114+
+ "\n".join(results)
115+
+ "\n"
129116
)
130117
await model_context.add_message(SystemMessage(content=memory_context))
131118

@@ -159,10 +146,7 @@ async def query(
159146
Example:
160147
```python
161148
# Query memories similar to some text
162-
query = MemoryContent(
163-
content="What's the weather?",
164-
mime_type=MemoryMimeType.TEXT
165-
)
149+
query = MemoryContent(content="What's the weather?", mime_type=MemoryMimeType.TEXT)
166150
results = await memory.query(query)
167151
168152
# Check similarity scores
@@ -172,8 +156,8 @@ async def query(
172156
"""
173157
try:
174158
query_text = self._extract_text(query)
175-
except ValueError:
176-
raise ValueError("Query must contain text content")
159+
except ValueError as e:
160+
raise ValueError("Query must contain text content") from e
177161

178162
results: List[MemoryContent] = []
179163

@@ -207,7 +191,7 @@ def _calculate_similarity(self, text1: str, text2: str) -> float:
207191
208192
Note:
209193
Uses difflib's SequenceMatcher for basic text similarity.
210-
For production use cases, consider using more sophisticated
194+
For production use cases, consider using more sophisticated
211195
similarity metrics or embeddings.
212196
"""
213197
return SequenceMatcher(None, text1.lower(), text2.lower()).ratio()
@@ -242,14 +226,9 @@ def _extract_text(self, content_item: MemoryContent) -> str:
242226
elif isinstance(content, Image):
243227
raise ValueError("Image content cannot be converted to text")
244228
else:
245-
raise ValueError(
246-
f"Unsupported content type: {content_item.mime_type}")
229+
raise ValueError(f"Unsupported content type: {content_item.mime_type}")
247230

248-
async def add(
249-
self,
250-
content: MemoryContent,
251-
cancellation_token: CancellationToken | None = None
252-
) -> None:
231+
async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
253232
"""Add new content to memory.
254233
255234
Args:
@@ -262,3 +241,11 @@ async def add(
262241
deduplication or content-based filtering.
263242
"""
264243
self._contents.append(content)
244+
245+
async def clear(self) -> None:
246+
"""Clear all memory content."""
247+
self._contents = []
248+
249+
async def cleanup(self) -> None:
250+
"""Cleanup resources if needed."""
251+
pass

0 commit comments

Comments
 (0)