Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing model context attribute #4848

Merged
merged 4 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
if not model_context:
if model_context is not None:
self._model_context = model_context
else:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
Expand Down
44 changes: 44 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
Expand All @@ -39,10 +41,12 @@ class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[List[LLMMessage]] = []

async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs["messages"]) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
Expand Down Expand Up @@ -468,3 +472,43 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1


@pytest.mark.asyncio
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
)

messages = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
]
await agent.run(task=messages)

# Check if the mock client is called with only the last two messages.
assert len(mock.calls) == 1
assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message
Loading