Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass function",
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"pass",
"TERMINATE",
Expand Down Expand Up @@ -144,18 +145,21 @@ async def test_run_with_tools_and_reflection() -> None:
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
CreateResult(
finish_reason="stop",
content="Hello",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
],
model_info={
Expand Down Expand Up @@ -246,6 +250,7 @@ async def test_run_with_parallel_tools() -> None:
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass and echo functions",
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"pass",
"TERMINATE",
Expand Down Expand Up @@ -331,6 +336,7 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"pass",
"TERMINATE",
Expand Down Expand Up @@ -672,6 +678,7 @@ async def test_handoffs() -> None:
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
thought="Calling handoff function",
)
],
Expand Down Expand Up @@ -1064,6 +1071,7 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
content="Response to message 1",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
)
]
)
Expand Down Expand Up @@ -1269,6 +1277,7 @@ async def test_model_client_stream_with_tool_calls() -> None:
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"Example response 2 to task",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ async def test_self_debugging_loop() -> None:
numbers = [10, 20, 30, 40, 50]
mean = sum(numbers) / len(numbers
print("The mean is:", mean)
""".strip()
incorrect_code_result = """
mean = sum(numbers) / len(numbers
^
SyntaxError: '(' was never closed
""".strip()
correct_code_block = """
numbers = [10, 20, 30, 40, 50]
Expand Down Expand Up @@ -218,8 +213,8 @@ async def test_self_debugging_loop() -> None:
elif isinstance(message, CodeExecutionEvent) and message_id == 1:
# Step 2: First code execution
assert (
incorrect_code_result in message.to_text().strip()
), f"Expected {incorrect_code_result} in execution result, got: {message.to_text().strip()}"
"SyntaxError: '(' was never closed" in message.to_text()
), f"Expected SyntaxError in execution result, got: {message.to_text().strip()}"
incorrect_code_execution_event = message

elif isinstance(message, CodeGenerationEvent) and message_id == 2:
Expand Down
3 changes: 3 additions & 0 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
content=[FunctionCall(id="1", name="pass", arguments=json.dumps({"input": "pass"}))],
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"Hello",
"TERMINATE",
Expand Down Expand Up @@ -1267,6 +1268,7 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
content=[FunctionCall(id="1", name="handoff_to_agent2", arguments=json.dumps({}))],
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"Hello",
"TERMINATE",
Expand Down Expand Up @@ -1367,6 +1369,7 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
],
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
raw_response={"id": "mock-id", "provider": "replay"},
),
"Hello",
"TERMINATE",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -125,3 +125,6 @@ class CreateResult(BaseModel):
thought: Optional[str] = None
"""The reasoning text for the completion if available. Used for reasoning models
and additional text content besides function calls."""

raw_response: Optional[Dict[str, Any]] = None
"""Raw response from the model API, useful for custom field access."""
2 changes: 2 additions & 0 deletions python/packages/autogen-core/tests/test_tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ async def create(
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
logprobs=None,
raw_response={"id": "mock-id", "provider": "replay"},
)
return CreateResult(
content="Done",
finish_reason="stop",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
logprobs=None,
raw_response={"id": "mock-id", "provider": "replay"},
)

def create_stream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def create(
finish_reason=data.get("finish_reason", "stop"),
usage=data.get("usage", RequestUsage(prompt_tokens=0, completion_tokens=0)),
cached=True,
raw_response=data.get("raw_response", {"id": "mock-id", "provider": "replay"}),
)
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ async def create(
usage=usage,
cached=False,
thought=thought,
raw_response=result,
)

# Update usage statistics
Expand Down Expand Up @@ -863,13 +864,20 @@ async def create_stream(
# Just text content
content = "".join(text_content)

future: asyncio.Task[Message] = asyncio.ensure_future(
self._client.messages.create(**request_args) # type: ignore
)

message_result: Message = cast(Message, await future)

# Create the final result
result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
content=content,
usage=usage,
cached=False,
thought=thought,
raw_response=message_result,
)

# Emit the end event.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ async def create(
usage=usage,
cached=False,
thought=thought,
raw_response=result,
)

self.add_usage(usage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,11 @@ async def create(
if not response_tool_calls and not response_text:
logger.debug("DEBUG: No response text found. Returning empty response.")
return CreateResult(
content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False
content="",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
finish_reason="stop",
cached=False,
raw_response=response,
)

# Create a CreateResult object
Expand All @@ -373,6 +377,7 @@ async def create(
usage=cast(RequestUsage, response["usage"]),
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
cached=False,
raw_response=response,
)

# If we are running in the context of a handler we can get the agent_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ async def create(
usage=usage,
cached=False,
logprobs=None,
raw_response=result,
thought=thought,
)

Expand Down Expand Up @@ -827,6 +828,7 @@ async def create_stream(
usage=usage,
cached=False,
logprobs=None,
raw_response=None,
thought=thought,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ async def create(
cached=False,
logprobs=logprobs,
thought=thought,
raw_response=result,
)

self._total_usage = _add_usage(self._total_usage, usage)
Expand Down Expand Up @@ -956,6 +957,28 @@ async def create_stream(
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
thought, content = parse_r1_content(content)

create_params = self._process_create_args(
messages,
tools,
json_output,
extra_create_args,
)

if create_params.response_format is not None:
result = await self._client.beta.chat.completions.parse(
messages=create_params.messages,
tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
response_format=create_params.response_format,
**create_params.create_args,
)
else:
result = await self._client.chat.completions.create(
messages=create_params.messages,
stream=False,
tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
**create_params.create_args,
)

# Create the result.
result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
Expand All @@ -964,6 +987,7 @@ async def create_stream(
cached=False,
logprobs=logprobs,
thought=thought,
raw_response=result,
)

# Log the end of the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ async def create(
_, output_token_count = self._tokenize(response)
self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
response = CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
finish_reason="stop",
content=response,
usage=self._cur_usage,
cached=self._cached_bool_value,
raw_response=response,
)
else:
self._cur_usage = RequestUsage(
Expand Down Expand Up @@ -221,7 +225,11 @@ async def create_stream(
else:
yield token
yield CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
finish_reason="stop",
content=response,
usage=self._cur_usage,
cached=self._cached_bool_value,
raw_response=response,
)
self._update_total_usage()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ async def create(
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
raw_response=result,
)

@staticmethod
Expand Down Expand Up @@ -676,6 +677,7 @@ async def create_stream(
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
raw_response=None,
)
return

Expand All @@ -698,6 +700,7 @@ async def create_stream(
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
raw_response=None,
)

# Emit the end event.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async def mock_get_streaming_chat_message_contents(
created=1736674044,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="scale",
service_tier="default",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
Expand All @@ -232,7 +232,7 @@ async def mock_get_streaming_chat_message_contents(
created=1736674044,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="scale",
service_tier="default",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
Expand All @@ -253,7 +253,7 @@ async def mock_get_streaming_chat_message_contents(
created=1736674044,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="scale",
service_tier="default",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
Expand All @@ -280,7 +280,7 @@ async def mock_get_streaming_chat_message_contents(
created=1736674044,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
service_tier="scale",
service_tier="default",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
Expand Down Expand Up @@ -503,7 +503,7 @@ async def mock_get_streaming_chat_message_contents(
created=1736674044,
model="r1",
object="chat.completion.chunk",
service_tier="scale",
service_tier="default",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@
import pytest
from autogen_agentchat.messages import BaseChatMessage, TextMessage, ToolCallRequestEvent
from autogen_core import CancellationToken
from autogen_core.models import UserMessage
from autogen_core.tools._base import BaseTool, Tool
from autogen_ext.agents.openai import OpenAIAssistantAgent
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI
from pydantic import BaseModel


def fake_to_model_message(self):
return UserMessage(content=self.content, source=self.source)


TextMessage.to_model_message = fake_to_model_message


class QuestionType(str, Enum):
MULTIPLE_CHOICE = "MULTIPLE_CHOICE"
FREE_RESPONSE = "FREE_RESPONSE"
Expand Down
Loading