Skip to content

Commit

Permalink
Add missing model context attribute (#4848)
Browse files Browse the repository at this point in the history
* Add missing model context attribute

* fix type

* Add test

* imports

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
Leon0402 and ekzhu authored Dec 29, 2024
1 parent 6e0f65b commit 23dbb6a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
if not model_context:
if model_context is not None:
self._model_context = model_context
else:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
Expand Down
44 changes: 44 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
Expand All @@ -39,10 +41,12 @@ class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[List[LLMMessage]] = []

async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs["messages"]) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
Expand Down Expand Up @@ -468,3 +472,43 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1


@pytest.mark.asyncio
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
)

messages = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
]
await agent.run(task=messages)

# Check if the mock client is called with only the last two messages.
assert len(mock.calls) == 1
assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message

0 comments on commit 23dbb6a

Please sign in to comment.