diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index a165ac592afe..8c4809343495 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -107,11 +107,11 @@ def custom_speaker_selection_func( admin_name: Optional[str] = "Admin" func_call_filter: Optional[bool] = True speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto" - max_retries_for_selecting_speaker: Optional[int] = 2 + max_retries_for_selecting_speaker: int = 2 allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None allowed_or_disallowed_speaker_transitions: Optional[Dict] = None speaker_transitions_type: Literal["allowed", "disallowed", None] = None - enable_clear_history: Optional[bool] = False + enable_clear_history: bool = False send_introductions: bool = False select_speaker_message_template: str = """You are in a role play game. The following roles are available: {roles}. diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py new file mode 100644 index 000000000000..f85784749586 --- /dev/null +++ b/autogen/logger/file_logger.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import json +import logging +import os +import threading +import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from openai import AzureOpenAI, OpenAI +from openai.types.chat import ChatCompletion + +from autogen.logger.base_logger import BaseLogger +from autogen.logger.logger_utils import get_current_ts, to_dict + +from .base_logger import LLMConfig + +if TYPE_CHECKING: + from autogen import Agent, ConversableAgent, OpenAIWrapper + +logger = logging.getLogger(__name__) + + +class FileLogger(BaseLogger): + def __init__(self, config: Dict[str, Any]): + self.config = config + self.session_id = str(uuid.uuid4()) + + curr_dir = os.getcwd() + self.log_dir = os.path.join(curr_dir, "autogen_logs") + os.makedirs(self.log_dir, exist_ok=True) + + self.log_file = os.path.join(self.log_dir, self.config.get("filename", "runtime.log")) + try: + with open(self.log_file, "a"): + pass + except Exception as e: + logger.error(f"[file_logger] Failed to create logging file: {e}") + + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + file_handler = logging.FileHandler(self.log_file) + self.logger.addHandler(file_handler) + + def start(self) -> str: + """Start the logger and return the session_id.""" + try: + self.logger.info(f"Started new session with Session ID: {self.session_id}") + except Exception as e: + logger.error(f"[file_logger] Failed to create logging file: {e}") + finally: + return self.session_id + + def log_chat_completion( + self, + invocation_id: uuid.UUID, + client_id: int, + wrapper_id: int, + request: Dict[str, Union[float, str, List[Dict[str, str]]]], + response: Union[str, ChatCompletion], + is_cached: int, + cost: float, + start_time: str, + ) -> None: + """ + Log a chat completion. + """ + thread_id = threading.get_ident() + try: + log_data = json.dumps( + { + "invocation_id": str(invocation_id), + "client_id": client_id, + "wrapper_id": wrapper_id, + "request": to_dict(request), + "response": str(response), + "is_cached": is_cached, + "cost": cost, + "start_time": start_time, + "end_time": get_current_ts(), + "thread_id": thread_id, + } + ) + + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log chat completion: {e}") + + def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) -> None: + """ + Log a new agent instance. + """ + thread_id = threading.get_ident() + + try: + log_data = json.dumps( + { + "id": id(agent), + "agent_name": agent.name if hasattr(agent, "name") and agent.name is not None else "", + "wrapper_id": to_dict( + agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else "" + ), + "session_id": self.session_id, + "current_time": get_current_ts(), + "agent_type": type(agent).__name__, + "args": to_dict(init_args), + "thread_id": thread_id, + } + ) + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log new agent: {e}") + + def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None: + """ + Log an event from an agent or a string source. + """ + from autogen import Agent + + # This takes an object o as input and returns a string. If the object o cannot be serialized, instead of raising an error, + # it returns a string indicating that the object is non-serializable, along with its type's qualified name obtained using __qualname__. + json_args = json.dumps(kwargs, default=lambda o: f"<>") + thread_id = threading.get_ident() + + if isinstance(source, Agent): + try: + log_data = json.dumps( + { + "source_id": id(source), + "source_name": str(source.name) if hasattr(source, "name") else source, + "event_name": name, + "agent_module": source.__module__, + "agent_class": source.__class__.__name__, + "json_state": json_args, + "timestamp": get_current_ts(), + "thread_id": thread_id, + } + ) + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log event {e}") + else: + try: + log_data = json.dumps( + { + "source_id": id(source), + "source_name": str(source.name) if hasattr(source, "name") else source, + "event_name": name, + "json_state": json_args, + "timestamp": get_current_ts(), + "thread_id": thread_id, + } + ) + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log event {e}") + + def log_new_wrapper( + self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {} + ) -> None: + """ + Log a new wrapper instance. + """ + thread_id = threading.get_ident() + + try: + log_data = json.dumps( + { + "wrapper_id": id(wrapper), + "session_id": self.session_id, + "json_state": json.dumps(init_args), + "timestamp": get_current_ts(), + "thread_id": thread_id, + } + ) + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log event {e}") + + def log_new_client(self, client: AzureOpenAI | OpenAI, wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None: + """ + Log a new client instance. + """ + thread_id = threading.get_ident() + + try: + log_data = json.dumps( + { + "client_id": id(client), + "wrapper_id": id(wrapper), + "session_id": self.session_id, + "class": type(client).__name__, + "json_state": json.dumps(init_args), + "timestamp": get_current_ts(), + "thread_id": thread_id, + } + ) + self.logger.info(log_data) + except Exception as e: + self.logger.error(f"[file_logger] Failed to log event {e}") + + def get_connection(self) -> None: + """Method is intentionally left blank because there is no specific connection needed for the FileLogger.""" + pass + + def stop(self) -> None: + """Close the file handler and remove it from the logger.""" + for handler in self.logger.handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + self.logger.removeHandler(handler) diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py index 8073c0c07d3e..ed9567977bb0 100644 --- a/autogen/logger/logger_factory.py +++ b/autogen/logger/logger_factory.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional from autogen.logger.base_logger import BaseLogger +from autogen.logger.file_logger import FileLogger from autogen.logger.sqlite_logger import SqliteLogger __all__ = ("LoggerFactory",) @@ -8,11 +9,15 @@ class LoggerFactory: @staticmethod - def get_logger(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> BaseLogger: + def get_logger( + logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[Dict[str, Any]] = None + ) -> BaseLogger: if config is None: config = {} if logger_type == "sqlite": return SqliteLogger(config) + elif logger_type == "file": + return FileLogger(config) else: raise ValueError(f"[logger_factory] Unknown logger type: {logger_type}") diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 1b9835eaa4b0..d848ca3645e1 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -3,12 +3,12 @@ import logging import sqlite3 import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion -from autogen.logger.base_logger import LLMConfig +from autogen.logger.base_logger import BaseLogger, LLMConfig from autogen.logger.logger_factory import LoggerFactory if TYPE_CHECKING: @@ -20,11 +20,27 @@ is_logging = False -def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> str: +def start( + logger: Optional[BaseLogger] = None, + logger_type: Literal["sqlite", "file"] = "sqlite", + config: Optional[Dict[str, Any]] = None, +) -> str: + """ + Start logging for the runtime. + Args: + logger (BaseLogger): A logger instance + logger_type (str): The type of logger to use (default: sqlite) + config (dict): Configuration for the logger + Returns: + session_id (str(uuid.uuid4)): a unique id for the logging session + """ global autogen_logger global is_logging - autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config) + if logger: + autogen_logger = logger + else: + autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config) try: session_id = autogen_logger.start() diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index 589d7b404a7d..b71dbc428a13 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -34,6 +34,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int: "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, "gpt-4-vision-preview": 128000, + "gpt-4o": 128000, + "gpt-4o-2024-05-13": 128000, } return max_token_limit[model] diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 7648cd1196a8..b29e5e21e950 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -35,6 +35,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "te EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}" @@ -107,6 +108,10 @@ Global {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.Build.0 = Release|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.Build.0 = Release|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.Build.0 = Debug|Any CPU {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -131,6 +136,7 @@ Global {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 2bd9470ffa72..2925a43e16f2 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -19,7 +19,6 @@ namespace AutoGen.OpenAI; /// - /// - /// - -/// - /// - where T is /// - where TMessage1 is and TMessage2 is /// @@ -27,6 +26,11 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa { private bool strictMode = false; + /// + /// Create a new instance of . + /// + /// If true, will throw an + /// When the message type is not supported. If false, it will ignore the unsupported message type. public OpenAIChatRequestMessageConnector(bool strictMode = false) { this.strictMode = strictMode; @@ -36,8 +40,7 @@ public OpenAIChatRequestMessageConnector(bool strictMode = false) public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { - var chatMessages = ProcessIncomingMessages(agent, context.Messages) - .Select(m => new MessageEnvelope(m)); + var chatMessages = ProcessIncomingMessages(agent, context.Messages); var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); @@ -49,8 +52,7 @@ public async IAsyncEnumerable InvokeAsync( IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var chatMessages = ProcessIncomingMessages(agent, context.Messages) - .Select(m => new MessageEnvelope(m)); + var chatMessages = ProcessIncomingMessages(agent, context.Messages); var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); string? currentToolName = null; await foreach (var reply in streamingReply) @@ -73,7 +75,14 @@ public async IAsyncEnumerable InvokeAsync( } else { - yield return reply; + if (this.strictMode) + { + throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}"); + } + else + { + yield return reply; + } } } } @@ -82,16 +91,10 @@ public IMessage PostProcessMessage(IMessage message) { return message switch { - TextMessage => message, - ImageMessage => message, - MultiModalMessage => message, - ToolCallMessage => message, - ToolCallResultMessage => message, - Message => message, - AggregateMessage => message, - IMessage m => PostProcessMessage(m), - IMessage m => PostProcessMessage(m), - _ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage, AggregateMessage"), + IMessage m => PostProcessChatResponseMessage(m.Content, m.From), + IMessage m => PostProcessChatCompletions(m), + _ when strictMode is false => message, + _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"), }; } @@ -120,12 +123,7 @@ public IMessage PostProcessMessage(IMessage message) } } - private IMessage PostProcessMessage(IMessage message) - { - return PostProcessMessage(message.Content, message.From); - } - - private IMessage PostProcessMessage(IMessage message) + private IMessage PostProcessChatCompletions(IMessage message) { // throw exception if prompt filter results is not null if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) @@ -133,12 +131,12 @@ private IMessage PostProcessMessage(IMessage message) throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); } - return PostProcessMessage(message.Content.Choices[0].Message, message.From); + return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From); } - private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, string? from) + private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from) { - if (chatResponseMessage.Content is string content) + if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content)) { return new TextMessage(Role.Assistant, content, from); } @@ -162,112 +160,41 @@ private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, str throw new InvalidOperationException("Invalid ChatResponseMessage"); } - public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) { - return messages.SelectMany(m => + return messages.SelectMany(m => { - if (m.From == null) + if (m is IMessage crm) { - return ProcessIncomingMessagesWithEmptyFrom(m); - } - else if (m.From == agent.Name) - { - return ProcessIncomingMessagesForSelf(m); + return [crm]; } else { - return ProcessIncomingMessagesForOther(m); + var chatRequestMessages = m switch + { + TextMessage textMessage => ProcessTextMessage(agent, textMessage), + ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage), + MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage), + ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage), + Message msg => ProcessMessage(agent, msg), + _ when strictMode is false => [], + _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"), + }; + + if (chatRequestMessages.Any()) + { + return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From)); + } + else + { + return [m]; + } } }); } - private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesForSelf(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesForSelf(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesForSelf(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesForSelf(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForSelf(toolCallResultMessage), - Message msg => ProcessIncomingMessagesForSelf(msg), - IMessage crm => ProcessIncomingMessagesForSelf(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesForSelf(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesWithEmptyFrom(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesWithEmptyFrom(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesWithEmptyFrom(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallResultMessage), - Message msg => ProcessIncomingMessagesWithEmptyFrom(msg), - IMessage crm => ProcessIncomingMessagesWithEmptyFrom(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesWithEmptyFrom(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesForOther(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessIncomingMessagesForOther(textMessage), - ImageMessage imageMessage => ProcessIncomingMessagesForOther(imageMessage), - MultiModalMessage multiModalMessage => ProcessIncomingMessagesForOther(multiModalMessage), - ToolCallMessage toolCallMessage => ProcessIncomingMessagesForOther(toolCallMessage), - ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForOther(toolCallResultMessage), - Message msg => ProcessIncomingMessagesForOther(msg), - IMessage crm => ProcessIncomingMessagesForOther(crm), - AggregateMessage aggregateMessage => ProcessIncomingMessagesForOther(aggregateMessage), - _ => throw new NotImplementedException(), - }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(TextMessage message) - { - if (message.Role == Role.System) - { - return new[] { new ChatRequestSystemMessage(message.Content) }; - } - else - { - return new[] { new ChatRequestAssistantMessage(message.Content) }; - } - } - - private IEnumerable ProcessIncomingMessagesForSelf(ImageMessage _) - { - return [new ChatRequestAssistantMessage("// Image Message is not supported")]; - } - - private IEnumerable ProcessIncomingMessagesForSelf(MultiModalMessage _) - { - return [new ChatRequestAssistantMessage("// MultiModal Message is not supported")]; - } - - private IEnumerable ProcessIncomingMessagesForSelf(ToolCallMessage message) - { - var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); - var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty); - foreach (var tc in toolCall) - { - chatRequestMessage.ToolCalls.Add(tc); - } - - return new[] { chatRequestMessage }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(ToolCallResultMessage message) - { - return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - } - private IEnumerable ProcessIncomingMessagesForSelf(Message message) { if (message.Role == Role.System) @@ -303,151 +230,145 @@ private IEnumerable ProcessIncomingMessagesForSelf(Message m } } - private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) - { - return new[] { message.Content }; - } - - private IEnumerable ProcessIncomingMessagesForSelf(AggregateMessage aggregateMessage) + private IEnumerable ProcessIncomingMessagesForOther(Message message) { - var toolCallMessage1 = aggregateMessage.Message1; - var toolCallResultMessage = aggregateMessage.Message2; - - var assistantMessage = new ChatRequestAssistantMessage(string.Empty); - var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); - foreach (var tc in toolCalls) + if (message.Role == Role.System) { - assistantMessage.ToolCalls.Add(tc); + return [new ChatRequestSystemMessage(message.Content) { Name = message.From }]; } + else if (message.Content is string content && content is { Length: > 0 }) + { + if (message.FunctionName is not null) + { + return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; + } - var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - - // return assistantMessage and tool call result messages - var messages = new List { assistantMessage }; - messages.AddRange(toolCallResults); - - return messages; + return [new ChatRequestUserMessage(message.Content) { Name = message.From }]; + } + else if (message.FunctionName is string _) + { + return [new ChatRequestUserMessage("// Message type is not supported") { Name = message.From }]; + } + else + { + throw new InvalidOperationException("Invalid Message as message from other."); + } } - private IEnumerable ProcessIncomingMessagesForOther(TextMessage message) + private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message) { if (message.Role == Role.System) { - return new[] { new ChatRequestSystemMessage(message.Content) }; + return [new ChatRequestSystemMessage(message.Content) { Name = message.From }]; + } + + if (agent.Name == message.From) + { + return [new ChatRequestAssistantMessage(message.Content) { Name = agent.Name }]; } else { - return new[] { new ChatRequestUserMessage(message.Content) }; + return message.From switch + { + null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)], + null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage(message.Content)], + null => throw new InvalidOperationException("Invalid Role"), + _ => [new ChatRequestUserMessage(message.Content) { Name = message.From }] + }; } } - private IEnumerable ProcessIncomingMessagesForOther(ImageMessage message) + private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message) { - return new[] { new ChatRequestUserMessage([ - new ChatMessageImageContentItem(new Uri(message.Url ?? message.BuildDataUri())), - ])}; + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent"); + } + + var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message); + return [new ChatRequestUserMessage([imageContentItem]) { Name = message.From }]; } - private IEnumerable ProcessIncomingMessagesForOther(MultiModalMessage message) + private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message) { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent"); + } + IEnumerable items = message.Content.Select(ci => ci switch { TextMessage text => new ChatMessageTextContentItem(text.Content), - ImageMessage image => new ChatMessageImageContentItem(new Uri(image.Url ?? image.BuildDataUri())), + ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image), _ => throw new NotImplementedException(), }); - return new[] { new ChatRequestUserMessage(items) }; + return [new ChatRequestUserMessage(items) { Name = message.From }]; } - private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage msg) + private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message) { - throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); + return message.Data is null + ? new ChatMessageImageContentItem(new Uri(message.Url)) + : new ChatMessageImageContentItem(message.Data, message.Data.MediaType); } - private IEnumerable ProcessIncomingMessagesForOther(ToolCallResultMessage message) + private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message) { - return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); - } - - private IEnumerable ProcessIncomingMessagesForOther(Message message) - { - if (message.Role == Role.System) + if (message.From is not null && message.From != agent.Name) { - return new[] { new ChatRequestSystemMessage(message.Content) }; + throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); } - else if (message.Content is string content && content is { Length: > 0 }) - { - if (message.FunctionName is not null) - { - return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; - } - return new[] { new ChatRequestUserMessage(message.Content) }; - } - else if (message.FunctionName is string _) - { - return new[] - { - new ChatRequestUserMessage("// Message type is not supported"), - }; - } - else + var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From }; + foreach (var tc in toolCall) { - throw new InvalidOperationException("Invalid Message as message from other."); + chatRequestMessage.ToolCalls.Add(tc); } - } - private IEnumerable ProcessIncomingMessagesForOther(IMessage message) - { - return new[] { message.Content }; - } - - private IEnumerable ProcessIncomingMessagesForOther(AggregateMessage aggregateMessage) - { - // convert as user message - var resultMessage = aggregateMessage.Message2; - - return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + return [chatRequestMessage]; } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(TextMessage message) + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message) { - return ProcessIncomingMessagesForOther(message); + return message.ToolCalls + .Where(tc => tc.Result is not null) + .Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ImageMessage message) + private IEnumerable ProcessMessage(IAgent agent, Message message) { - return ProcessIncomingMessagesForOther(message); - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(MultiModalMessage message) - { - return ProcessIncomingMessagesForOther(message); - } - - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallMessage message) - { - return ProcessIncomingMessagesForSelf(message); + if (message.From is not null && message.From != agent.Name) + { + return ProcessIncomingMessagesForOther(message); + } + else + { + return ProcessIncomingMessagesForSelf(message); + } } - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallResultMessage message) + private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage) { - return ProcessIncomingMessagesForOther(message); - } + if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name) + { + // convert as user message + var resultMessage = aggregateMessage.Message2; - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(Message message) - { - return ProcessIncomingMessagesForOther(message); - } + return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result) { Name = aggregateMessage.From }); + } + else + { + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) - { - return new[] { message.Content }; - } + var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1); + var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage); - private IEnumerable ProcessIncomingMessagesWithEmptyFrom(AggregateMessage aggregateMessage) - { - return ProcessIncomingMessagesForOther(aggregateMessage); + return assistantMessage.Concat(toolCallResults); + } } } diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt similarity index 94% rename from dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt rename to dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index 2cb58f4d88c0..d17de56e1295 100644 --- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -3,6 +3,7 @@ "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )", "ConvertedMessages": [ { + "Name": null, "Role": "system", "Content": "You are a helpful AI assistant" } @@ -14,6 +15,7 @@ { "Role": "user", "Content": "Hello", + "Name": "user", "MultiModaItem": null } ] @@ -24,6 +26,7 @@ { "Role": "assistant", "Content": "How can I help you?", + "Name": "assistant", "TooCall": [], "FunctionCallName": null, "FunctionCallArguments": null @@ -34,6 +37,7 @@ "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )", "ConvertedMessages": [ { + "Name": null, "Role": "system", "Content": "You are a helpful AI assistant" } @@ -45,6 +49,7 @@ { "Role": "user", "Content": "Hello", + "Name": "user", "MultiModaItem": null } ] @@ -55,6 +60,7 @@ { "Role": "assistant", "Content": "How can I help you?", + "Name": null, "TooCall": [], "FunctionCallName": null, "FunctionCallArguments": null @@ -67,6 +73,7 @@ { "Role": "user", "Content": "result", + "Name": "user", "MultiModaItem": null } ] @@ -77,6 +84,7 @@ { "Role": "assistant", "Content": null, + "Name": null, "TooCall": [], "FunctionCallName": "functionName", "FunctionCallArguments": "functionArguments" @@ -89,6 +97,7 @@ { "Role": "user", "Content": null, + "Name": "user", "MultiModaItem": [ { "Type": "Image", @@ -107,6 +116,7 @@ { "Role": "user", "Content": null, + "Name": "user", "MultiModaItem": [ { "Type": "Text", @@ -129,6 +139,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", @@ -173,6 +184,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", @@ -198,6 +210,7 @@ { "Role": "assistant", "Content": "", + "Name": "assistant", "TooCall": [ { "Type": "Function", diff --git a/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj new file mode 100644 index 000000000000..044975354b80 --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/AutoGen.OpenAI.Tests.csproj @@ -0,0 +1,32 @@ + + + + $(TestTargetFramework) + false + True + + + + + + + + + + + + + + + + + + + + $([System.String]::Copy('%(FileName)').Split('.')[0]) + $(ProjectExt.Replace('proj', '')) + %(ParentFile)%(ParentExtension) + + + + diff --git a/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs new file mode 100644 index 000000000000..d66bf001ed5e --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GlobalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs similarity index 100% rename from dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs rename to dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs new file mode 100644 index 000000000000..a8c1d3f7860d --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs @@ -0,0 +1,612 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIMessageTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using AutoGen.OpenAI; +using Azure.AI.OpenAI; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Tests; + +public class OpenAIMessageTests +{ + private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { + WriteIndented = true, + IgnoreReadOnlyProperties = false, + }; + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("ApprovalTests")] + public void BasicMessageTest() + { + IMessage[] messages = [ + new TextMessage(Role.System, "You are a helpful AI assistant"), + new TextMessage(Role.User, "Hello", "user"), + new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.System, "You are a helpful AI assistant"), + new Message(Role.User, "Hello", "user"), + new Message(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.Function, "result", "user"), + new Message(Role.Assistant, null, "assistant") + { + FunctionName = "functionName", + FunctionArguments = "functionArguments", + }, + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"), + new ToolCallMessage("test", "test", "assistant"), + new ToolCallResultMessage("result", "test", "test", "user"), + new ToolCallResultMessage( + [ + new ToolCall("result", "test", "test"), + new ToolCall("result", "test", "test"), + ], "user"), + new ToolCallMessage( + [ + new ToolCall("test", "test"), + new ToolCall("test", "test"), + ], "assistant"), + new AggregateMessage( + message1: new ToolCallMessage("test", "test", "assistant"), + message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), + ]; + var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant"); + + var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); + VerifyOAIMessages(oaiMessages); + } + + [Fact] + public async Task ItProcessUserTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("Hello"); + chatRequestMessage.Name.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new TextMessage(Role.User, "Hello", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItShortcutChatRequestMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = new ChatRequestUserMessage("hello"); + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItShortcutMessageWhenStrictModelIsFalseAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItThrowExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1"); + } + + [Fact] + public async Task ItProcessAssistantTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("How can I help you?"); + chatRequestMessage.Name.Should().Be("assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // assistant message + IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessSystemTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("You are a helpful AI assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // system message + IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessImageMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("user"); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(1); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant"); + Func action = async () => await agent.GenerateReplyAsync([imageMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage"); + } + + [Fact] + public async Task ItProcessMultiModalMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("user"); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(2); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new MultiModalMessage( + Role.User, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var multiModalMessage = new MultiModalMessage( + Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "assistant"), + new ImageMessage(Role.User, "https://example.com/image.png", "assistant"), + ], "assistant"); + + Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage"); + } + + [Fact] + public async Task ItProcessToolCallMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("assistant"); + chatRequestMessage.ToolCalls.Count().Should().Be(1); + chatRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallMessage("test", "test", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(strictMode: true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]); + await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage"); + } + + [Fact] + public async Task ItProcessToolCallResultMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallResultMessage("result", "test", "test", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.Name.Should().Be("user"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(1); + toolCallRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "assistant"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItConvertChatResponseMessageToTextMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = CreateInstance(ChatRole.Assistant, "hello"); + var chatRequestMessage = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetContent().Should().Be("hello"); + message.GetRole().Should().Be(Role.Assistant); + } + + [Fact] + public async Task ItConvertChatResponseMessageToToolCallMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // tool call message + var toolCallMessage = CreateInstance(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary()); + var chatRequestMessage = MessageEnvelope.Create(toolCallMessage); + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetToolCalls()!.Count().Should().Be(1); + message.GetToolCalls()!.First().FunctionName.Should().Be("test"); + message.GetToolCalls()!.First().FunctionArguments.Should().Be("test"); + } + + [Fact] + public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = "hello"; + var messageToSend = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([messageToSend]); + message.Should().BeOfType>(); + } + + [Fact] + public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = new ChatRequestUserMessage("hello"); + var messageToSend = MessageEnvelope.Create(textMessage); + Func action = async () => await agent.GenerateReplyAsync([messageToSend]); + + await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1"); + } + + [Fact] + public void ToOpenAIChatRequestMessageShortCircuitTest() + { + var agent = new EchoAgent("assistant"); + var middleware = new OpenAIChatRequestMessageConnector(); + ChatRequestMessage[] messages = + [ + new ChatRequestUserMessage("Hello"), + new ChatRequestAssistantMessage("How can I help you?"), + new ChatRequestSystemMessage("You are a helpful AI assistant"), + new ChatRequestFunctionMessage("result", "functionName"), + new ChatRequestToolMessage("test", "test"), + ]; + + foreach (var oaiMessage in messages) + { + IMessage message = new MessageEnvelope(oaiMessage); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); + oaiMessages.Count().Should().Be(1); + //oaiMessages.First().Should().BeOfType>(); + if (oaiMessages.First() is IMessage chatRequestMessage) + { + chatRequestMessage.Content.Should().Be(oaiMessage); + } + else + { + // fail the test + Assert.True(false); + } + } + } + private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) + { + var jsonObjects = messages.Select(pair => + { + var (originalMessage, ms) = pair; + var objs = new List(); + foreach (var m in ms) + { + object? obj = null; + var chatRequestMessage = (m as IMessage)?.Content; + if (chatRequestMessage is ChatRequestUserMessage userMessage) + { + obj = new + { + Role = userMessage.Role.ToString(), + Content = userMessage.Content, + Name = userMessage.Name, + MultiModaItem = userMessage.MultimodalContentItems?.Select(item => + { + return item switch + { + ChatMessageImageContentItem imageContentItem => new + { + Type = "Image", + ImageUrl = GetImageUrlFromContent(imageContentItem), + } as object, + ChatMessageTextContentItem textContentItem => new + { + Type = "Text", + Text = textContentItem.Text, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + }; + } + + if (chatRequestMessage is ChatRequestAssistantMessage assistantMessage) + { + obj = new + { + Role = assistantMessage.Role.ToString(), + Content = assistantMessage.Content, + Name = assistantMessage.Name, + TooCall = assistantMessage.ToolCalls.Select(tc => + { + return tc switch + { + ChatCompletionsFunctionToolCall functionToolCall => new + { + Type = "Function", + Name = functionToolCall.Name, + Arguments = functionToolCall.Arguments, + Id = functionToolCall.Id, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + FunctionCallName = assistantMessage.FunctionCall?.Name, + FunctionCallArguments = assistantMessage.FunctionCall?.Arguments, + }; + } + + if (chatRequestMessage is ChatRequestSystemMessage systemMessage) + { + obj = new + { + Name = systemMessage.Name, + Role = systemMessage.Role.ToString(), + Content = systemMessage.Content, + }; + } + + if (chatRequestMessage is ChatRequestFunctionMessage functionMessage) + { + obj = new + { + Role = functionMessage.Role.ToString(), + Content = functionMessage.Content, + Name = functionMessage.Name, + }; + } + + if (chatRequestMessage is ChatRequestToolMessage toolCallMessage) + { + obj = new + { + Role = toolCallMessage.Role.ToString(), + Content = toolCallMessage.Content, + ToolCallId = toolCallMessage.ToolCallId, + }; + } + + objs.Add(obj ?? throw new System.NotImplementedException()); + } + + return new + { + OriginalMessage = originalMessage.ToString(), + ConvertedMessages = objs, + }; + }); + + var json = JsonSerializer.Serialize(jsonObjects, this.jsonSerializerOptions); + Approvals.Verify(json); + } + + private object? GetImageUrlFromContent(ChatMessageImageContentItem content) + { + return content.GetType().GetProperty("ImageUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.GetValue(content); + } + + private static T CreateInstance(params object[] args) + { + var type = typeof(T); + var instance = type.Assembly.CreateInstance( + type.FullName!, false, + BindingFlags.Instance | BindingFlags.NonPublic, + null, args, null, null); + return (T)instance!; + } +} diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs index 28a7b91bad58..9cead5ad2516 100644 --- a/dotnet/test/AutoGen.Tests/EchoAgent.cs +++ b/dotnet/test/AutoGen.Tests/EchoAgent.cs @@ -3,12 +3,13 @@ using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace AutoGen.Tests { - internal class EchoAgent : IAgent + public class EchoAgent : IStreamingAgent { public EchoAgent(string name) { @@ -27,5 +28,14 @@ public Task GenerateReplyAsync( return Task.FromResult(lastMessage); } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var message in messages) + { + message.From = this.Name; + yield return message; + } + } } } diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs deleted file mode 100644 index 6e9cd28c4cbd..000000000000 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// OpenAIMessageTests.cs - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using ApprovalTests; -using ApprovalTests.Namers; -using ApprovalTests.Reporters; -using AutoGen.OpenAI; -using Azure.AI.OpenAI; -using FluentAssertions; -using Xunit; - -namespace AutoGen.Tests; - -public class OpenAIMessageTests -{ - private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions - { - WriteIndented = true, - IgnoreReadOnlyProperties = false, - }; - - [Fact] - [UseReporter(typeof(DiffReporter))] - [UseApprovalSubdirectory("ApprovalTests")] - public void BasicMessageTest() - { - IMessage[] messages = [ - new TextMessage(Role.System, "You are a helpful AI assistant"), - new TextMessage(Role.User, "Hello", "user"), - new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"), - new Message(Role.System, "You are a helpful AI assistant"), - new Message(Role.User, "Hello", "user"), - new Message(Role.Assistant, "How can I help you?", from: "assistant"), - new Message(Role.Function, "result", "user"), - new Message(Role.Assistant, null, "assistant") - { - FunctionName = "functionName", - FunctionArguments = "functionArguments", - }, - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - new MultiModalMessage(Role.Assistant, - [ - new TextMessage(Role.User, "Hello", "user"), - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - ], "user"), - new ToolCallMessage("test", "test", "assistant"), - new ToolCallResultMessage("result", "test", "test", "user"), - new ToolCallResultMessage( - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], "user"), - new ToolCallMessage( - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], "assistant"), - new AggregateMessage( - message1: new ToolCallMessage("test", "test", "assistant"), - message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), - ]; - var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector(); - var agent = new EchoAgent("assistant"); - - var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); - VerifyOAIMessages(oaiMessages); - } - - [Fact] - public void ToOpenAIChatRequestMessageTest() - { - var agent = new EchoAgent("assistant"); - var middleware = new OpenAIChatRequestMessageConnector(); - - // user message - IMessage message = new TextMessage(Role.User, "Hello", "user"); - var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("Hello"); - - // user message test 2 - // even if Role is assistant, it should be converted to user message because it is from the user - message = new TextMessage(Role.Assistant, "Hello", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("Hello"); - - // user message with multimodal content - // image - message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().BeNullOrEmpty(); - userMessage.MultimodalContentItems.Count().Should().Be(1); - userMessage.MultimodalContentItems.First().Should().BeOfType(); - - // text and image - message = new MultiModalMessage( - Role.User, - [ - new TextMessage(Role.User, "Hello", "user"), - new ImageMessage(Role.User, "https://example.com/image.png", "user"), - ], "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().BeNullOrEmpty(); - userMessage.MultimodalContentItems.Count().Should().Be(2); - userMessage.MultimodalContentItems.First().Should().BeOfType(); - - // assistant text message - message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().Be("How can I help you?"); - - // assistant text message with single tool call - message = new ToolCallMessage("test", "test", "assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(1); - assistantMessage.ToolCalls.First().Should().BeOfType(); - - // user should not suppose to send tool call message - message = new ToolCallMessage("test", "test", "user"); - Func action = () => middleware.ProcessIncomingMessages(agent, [message]).First(); - action.Should().Throw().WithMessage("ToolCallMessage is not supported when message.From is not the same with agent"); - - // assistant text message with multiple tool calls - message = new ToolCallMessage( - toolCalls: - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], "assistant"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(2); - - // tool call result message - message = new ToolCallResultMessage("result", "test", "test", "user"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - var toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); - toolCallMessage.Content.Should().Be("result"); - - // tool call result message with multiple tool calls - message = new ToolCallResultMessage( - toolCalls: - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); - toolCallMessage.Content.Should().Be("test"); - oaiMessages.Last().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); - toolCallMessage.Content.Should().Be("test"); - - // aggregate message test - // aggregate message with tool call and tool call result will be returned by GPT agent if the tool call is automatically invoked inside agent - message = new AggregateMessage( - message1: new ToolCallMessage("test", "test", "assistant"), - message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); - assistantMessage.Content.Should().BeNullOrEmpty(); - assistantMessage.ToolCalls.Count().Should().Be(1); - - oaiMessages.Last().Should().BeOfType(); - toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); - toolCallMessage.Content.Should().Be("result"); - - // aggregate message test 2 - // if the aggregate message is from user, it should be converted to user message - message = new AggregateMessage( - message1: new ToolCallMessage("test", "test", "user"), - message2: new ToolCallResultMessage("result", "test", "test", "user"), "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - userMessage = (ChatRequestUserMessage)oaiMessages.First(); - userMessage.Content.Should().Be("result"); - - // aggregate message test 3 - // if the aggregate message is from user and contains multiple tool call results, it should be converted to user message - message = new AggregateMessage( - message1: new ToolCallMessage( - toolCalls: - [ - new ToolCall("test", "test"), - new ToolCall("test", "test"), - ], from: "user"), - message2: new ToolCallResultMessage( - toolCalls: - [ - new ToolCall("result", "test", "test"), - new ToolCall("result", "test", "test"), - ], from: "user"), "user"); - - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(2); - oaiMessages.First().Should().BeOfType(); - oaiMessages.Last().Should().BeOfType(); - - // system message - message = new TextMessage(Role.System, "You are a helpful AI assistant"); - oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().BeOfType(); - } - - [Fact] - public void ToOpenAIChatRequestMessageShortCircuitTest() - { - var agent = new EchoAgent("assistant"); - var middleware = new OpenAIChatRequestMessageConnector(); - ChatRequestMessage[] messages = - [ - new ChatRequestUserMessage("Hello"), - new ChatRequestAssistantMessage("How can I help you?"), - new ChatRequestSystemMessage("You are a helpful AI assistant"), - new ChatRequestFunctionMessage("result", "functionName"), - new ChatRequestToolMessage("test", "test"), - ]; - - foreach (var oaiMessage in messages) - { - IMessage message = new MessageEnvelope(oaiMessage); - var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); - oaiMessages.Count().Should().Be(1); - oaiMessages.First().Should().Be(oaiMessage); - } - } - private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) - { - var jsonObjects = messages.Select(pair => - { - var (originalMessage, ms) = pair; - var objs = new List(); - foreach (var m in ms) - { - object? obj = null; - if (m is ChatRequestUserMessage userMessage) - { - obj = new - { - Role = userMessage.Role.ToString(), - Content = userMessage.Content, - MultiModaItem = userMessage.MultimodalContentItems?.Select(item => - { - return item switch - { - ChatMessageImageContentItem imageContentItem => new - { - Type = "Image", - ImageUrl = GetImageUrlFromContent(imageContentItem), - } as object, - ChatMessageTextContentItem textContentItem => new - { - Type = "Text", - Text = textContentItem.Text, - } as object, - _ => throw new System.NotImplementedException(), - }; - }), - }; - } - - if (m is ChatRequestAssistantMessage assistantMessage) - { - obj = new - { - Role = assistantMessage.Role.ToString(), - Content = assistantMessage.Content, - TooCall = assistantMessage.ToolCalls.Select(tc => - { - return tc switch - { - ChatCompletionsFunctionToolCall functionToolCall => new - { - Type = "Function", - Name = functionToolCall.Name, - Arguments = functionToolCall.Arguments, - Id = functionToolCall.Id, - } as object, - _ => throw new System.NotImplementedException(), - }; - }), - FunctionCallName = assistantMessage.FunctionCall?.Name, - FunctionCallArguments = assistantMessage.FunctionCall?.Arguments, - }; - } - - if (m is ChatRequestSystemMessage systemMessage) - { - obj = new - { - Role = systemMessage.Role.ToString(), - Content = systemMessage.Content, - }; - } - - if (m is ChatRequestFunctionMessage functionMessage) - { - obj = new - { - Role = functionMessage.Role.ToString(), - Content = functionMessage.Content, - Name = functionMessage.Name, - }; - } - - if (m is ChatRequestToolMessage toolCallMessage) - { - obj = new - { - Role = toolCallMessage.Role.ToString(), - Content = toolCallMessage.Content, - ToolCallId = toolCallMessage.ToolCallId, - }; - } - - objs.Add(obj ?? throw new System.NotImplementedException()); - } - - return new - { - OriginalMessage = originalMessage.ToString(), - ConvertedMessages = objs, - }; - }); - - var json = JsonSerializer.Serialize(jsonObjects, this.jsonSerializerOptions); - Approvals.Verify(json); - } - - private object? GetImageUrlFromContent(ChatMessageImageContentItem content) - { - return content.GetType().GetProperty("ImageUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.GetValue(content); - } -} diff --git a/notebook/agentchat_logging.ipynb b/notebook/agentchat_logging.ipynb index 2ad19e7995a5..7eb4138b4cc1 100644 --- a/notebook/agentchat_logging.ipynb +++ b/notebook/agentchat_logging.ipynb @@ -8,6 +8,10 @@ "\n", "AutoGen offers utilities to log data for debugging and performance analysis. This notebook demonstrates how to use them. \n", "\n", + "we log data in different modes:\n", + "- SQlite Database\n", + "- File \n", + "\n", "In general, users can initiate logging by calling `autogen.runtime_logging.start()` and stop logging by calling `autogen.runtime_logging.stop()`" ] }, @@ -287,6 +291,82 @@ " + str(round(session_cost, 4))\n", ")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Log data in File mode\n", + "\n", + "By default, the log type is set to `sqlite` as shown above, but we introduced a new parameter for the `autogen.runtime_logging.start()`\n", + "\n", + "the `logger_type = \"file\"` will start to log data in the File mode." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logging session ID: ed493ebf-d78e-49f0-b832-69557276d557\n", + "\u001b[33muser_proxy\u001b[0m (to assistant):\n", + "\n", + "What is the height of the Eiffel Tower? Only respond with the answer and terminate\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to user_proxy):\n", + "\n", + "The height of the Eiffel Tower is 330 meters.\n", + "TERMINATE\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "\n", + "import pandas as pd\n", + "\n", + "import autogen\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "\n", + "# Setup API key. Add your own API key to config file or environment variable\n", + "llm_config = {\n", + " \"config_list\": autogen.config_list_from_json(\n", + " env_or_file=\"OAI_CONFIG_LIST\",\n", + " ),\n", + " \"temperature\": 0.9,\n", + "}\n", + "\n", + "# Start logging with logger_type and the filename to log to\n", + "logging_session_id = autogen.runtime_logging.start(logger_type=\"file\", config={\"filename\": \"runtime.log\"})\n", + "print(\"Logging session ID: \" + str(logging_session_id))\n", + "\n", + "# Create an agent workflow and run it\n", + "assistant = AssistantAgent(name=\"assistant\", llm_config=llm_config)\n", + "user_proxy = UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " code_execution_config=False,\n", + " human_input_mode=\"NEVER\",\n", + " is_termination_msg=lambda msg: \"TERMINATE\" in msg[\"content\"],\n", + ")\n", + "\n", + "user_proxy.initiate_chat(\n", + " assistant, message=\"What is the height of the Eiffel Tower? Only respond with the answer and terminate\"\n", + ")\n", + "autogen.runtime_logging.stop()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This should create a `runtime.log` file in your current directory. " + ] } ], "metadata": { @@ -312,7 +392,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.13" } }, "nbformat": 4, diff --git a/samples/apps/cap/README.md b/samples/apps/cap/README.md index 7fe7a6c7469b..c8d1945eae2e 100644 --- a/samples/apps/cap/README.md +++ b/samples/apps/cap/README.md @@ -1,11 +1,13 @@ # Composable Actor Platform (CAP) for AutoGen -## I just want to run the demo! +## I just want to run the remote AutoGen agents! *Python Instructions (Windows, Linux, MacOS):* 0) cd py 1) pip install -r autogencap/requirements.txt 2) python ./demo/App.py +3) Choose (5) and follow instructions to run standalone Agents +4) Choose other options for other demos *Demo Notes:* 1) Options involving AutoGen require OAI_CONFIG_LIST. @@ -15,14 +17,15 @@ *Demo Reference:* ``` - Select the Composable Actor Platform (CAP) demo app to run: - (enter anything else to quit) - 1. Hello World Actor - 2. Complex Actor Graph - 3. AutoGen Pair - 4. AutoGen GroupChat - 5. AutoGen Agents in different processes - Enter your choice (1-5): +Select the Composable Actor Platform (CAP) demo app to run: +(enter anything else to quit) +1. Hello World +2. Complex Agent (e.g. Name or Quit) +3. AutoGen Pair +4. AutoGen GroupChat +5. AutoGen Agents in different processes +6. List Actors in CAP (Registry) +Enter your choice (1-6): ``` ## What is Composable Actor Platform (CAP)? diff --git a/samples/apps/cap/py/README.md b/samples/apps/cap/py/README.md new file mode 100644 index 000000000000..e11fa3d048fc --- /dev/null +++ b/samples/apps/cap/py/README.md @@ -0,0 +1,39 @@ +# Composable Actor Platform (CAP) for AutoGen + +## I just want to run the remote AutoGen agents! +*Python Instructions (Windows, Linux, MacOS):* + +pip install autogencap + +1) AutoGen require OAI_CONFIG_LIST. + AutoGen python requirements: 3.8 <= python <= 3.11 + +``` + +## What is Composable Actor Platform (CAP)? +AutoGen is about Agents and Agent Orchestration. CAP extends AutoGen to allows Agents to communicate via a message bus. CAP, therefore, deals with the space between these components. CAP is a message based, actor platform that allows actors to be composed into arbitrary graphs. + +Actors can register themselves with CAP, find other agents, construct arbitrary graphs, send and receive messages independently and many, many, many other things. +```python + # CAP Platform + network = LocalActorNetwork() + # Register an agent + network.register(GreeterAgent()) + # Tell agents to connect to other agents + network.connect() + # Get a channel to the agent + greeter_link = network.lookup_agent("Greeter") + # Send a message to the agent + greeter_link.send_txt_msg("Hello World!") + # Cleanup + greeter_link.close() + network.disconnect() +``` +### Check out other demos in the `py/demo` directory. We show the following: ### +1) Hello World shown above +2) Many CAP Actors interacting with each other +3) A pair of interacting AutoGen Agents wrapped in CAP Actors +4) CAP wrapped AutoGen Agents in a group chat +5) Two AutoGen Agents running in different processes and communicating through CAP +6) List all registered agents in CAP +7) AutoGen integration to list all registered agents diff --git a/samples/apps/cap/py/autogencap/ActorConnector.py b/samples/apps/cap/py/autogencap/ActorConnector.py index c7b16157dc6a..e2ddbfa4fb69 100644 --- a/samples/apps/cap/py/autogencap/ActorConnector.py +++ b/samples/apps/cap/py/autogencap/ActorConnector.py @@ -29,8 +29,11 @@ def _connect_pub_socket(self): evt: Dict[str, Any] = {} mon_evt = recv_monitor_message(monitor) evt.update(mon_evt) - if evt["event"] == zmq.EVENT_MONITOR_STOPPED or evt["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED: - Debug("ActorSender", "Handshake received (Or Monitor stopped)") + if evt["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED: + Debug("ActorSender", "Handshake received") + break + elif evt["event"] == zmq.EVENT_MONITOR_STOPPED: + Debug("ActorSender", "Monitor stopped") break self._pub_socket.disable_monitor() monitor.close() @@ -117,32 +120,33 @@ def send_txt_msg(self, msg): def send_bin_msg(self, msg_type: str, msg): self._sender.send_bin_msg(msg_type, msg) - def binary_request(self, msg_type: str, msg, retry=5): + def binary_request(self, msg_type: str, msg, num_attempts=5): original_timeout: int = 0 - if retry == -1: + if num_attempts == -1: original_timeout = self._resp_socket.getsockopt(zmq.RCVTIMEO) self._resp_socket.setsockopt(zmq.RCVTIMEO, 1000) try: self._sender.send_bin_request_msg(msg_type, msg, self._resp_topic) - while retry == -1 or retry > 0: + while num_attempts == -1 or num_attempts > 0: try: topic, resp_msg_type, _, resp = self._resp_socket.recv_multipart() return topic, resp_msg_type, resp except zmq.Again: Debug( - "ActorConnector", f"{self._topic}: No response received. retry_count={retry}, max_retry={retry}" + "ActorConnector", + f"{self._topic}: No response received. retry_count={num_attempts}, max_retry={num_attempts}", ) time.sleep(0.01) - if retry != -1: - retry -= 1 + if num_attempts != -1: + num_attempts -= 1 finally: - if retry == -1: + if num_attempts == -1: self._resp_socket.setsockopt(zmq.RCVTIMEO, original_timeout) Error("ActorConnector", f"{self._topic}: No response received. Giving up.") return None, None, None def close(self): - self._sender.close() + self._pub_socket.close() self._resp_socket.close() diff --git a/samples/apps/cap/py/autogencap/DebugLog.py b/samples/apps/cap/py/autogencap/DebugLog.py index e03712355853..d3be81fe24e6 100644 --- a/samples/apps/cap/py/autogencap/DebugLog.py +++ b/samples/apps/cap/py/autogencap/DebugLog.py @@ -15,42 +15,58 @@ LEVEL_NAMES = ["DBG", "INF", "WRN", "ERR"] LEVEL_COLOR = ["dark_grey", "green", "yellow", "red"] -console_lock = threading.Lock() - - -def Log(level, context, msg): - # Check if the current level meets the threshold - if level >= Config.LOG_LEVEL: # Use the LOG_LEVEL from the Config module - # Check if the context is in the list of ignored contexts - if context in Config.IGNORED_LOG_CONTEXTS: - return - with console_lock: - timestamp = colored(datetime.datetime.now().strftime("%m/%d/%y %H:%M:%S"), "dark_grey") - # Translate level number to name and color - level_name = colored(LEVEL_NAMES[level], LEVEL_COLOR[level]) - # Left justify the context and color it blue - context = colored(context.ljust(14), "blue") - # Left justify the threadid and color it blue - thread_id = colored(str(threading.get_ident()).ljust(5), "blue") - # color the msg based on the level - msg = colored(msg, LEVEL_COLOR[level]) - print(f"{thread_id} {timestamp} {level_name}: [{context}] {msg}") + +class BaseLogger: + def __init__(self): + self._lock = threading.Lock() + + def Log(self, level, context, msg): + # Check if the current level meets the threshold + if level >= Config.LOG_LEVEL: # Use the LOG_LEVEL from the Config module + # Check if the context is in the list of ignored contexts + if context in Config.IGNORED_LOG_CONTEXTS: + return + with self._lock: + self.WriteLog(level, context, msg) + + def WriteLog(self, level, context, msg): + raise NotImplementedError("Subclasses must implement this method") + + +class ConsoleLogger(BaseLogger): + def __init__(self): + super().__init__() + + def WriteLog(self, level, context, msg): + timestamp = colored(datetime.datetime.now().strftime("%m/%d/%y %H:%M:%S"), "pink") + # Translate level number to name and color + level_name = colored(LEVEL_NAMES[level], LEVEL_COLOR[level]) + # Left justify the context and color it blue + context = colored(context.ljust(14), "blue") + # Left justify the threadid and color it blue + thread_id = colored(str(threading.get_ident()).ljust(5), "blue") + # color the msg based on the level + msg = colored(msg, LEVEL_COLOR[level]) + print(f"{thread_id} {timestamp} {level_name}: [{context}] {msg}") + + +LOGGER = ConsoleLogger() def Debug(context, message): - Log(DEBUG, context, message) + LOGGER.Log(DEBUG, context, message) def Info(context, message): - Log(INFO, context, message) + LOGGER.Log(INFO, context, message) def Warn(context, message): - Log(WARN, context, message) + LOGGER.Log(WARN, context, message) def Error(context, message): - Log(ERROR, context, message) + LOGGER.Log(ERROR, context, message) def shorten(msg, num_parts=5, max_len=100): diff --git a/samples/apps/cap/py/autogencap/ag_adapter/AutoGenConnector.py b/samples/apps/cap/py/autogencap/ag_adapter/AutoGenConnector.py index 3fbb0db64fdd..ce81e7e945d3 100644 --- a/samples/apps/cap/py/autogencap/ag_adapter/AutoGenConnector.py +++ b/samples/apps/cap/py/autogencap/ag_adapter/AutoGenConnector.py @@ -1,3 +1,4 @@ +import json from typing import Dict, Optional, Union from autogen import Agent @@ -37,7 +38,7 @@ def send_gen_reply_req(self): # Setting retry to -1 to keep trying until a response is received # This normal AutoGen behavior but does not handle the case when an AutoGen agent # is not running. In that case, the connector will keep trying indefinitely. - _, _, resp = self._can_channel.binary_request(type(msg).__name__, serialized_msg, retry=-1) + _, _, resp = self._can_channel.binary_request(type(msg).__name__, serialized_msg, num_attempts=-1) gen_reply_resp = GenReplyResp() gen_reply_resp.ParseFromString(resp) return gen_reply_resp.data @@ -55,7 +56,8 @@ def send_receive_req( msg = ReceiveReq() if isinstance(message, dict): for key, value in message.items(): - msg.data_map.data[key] = value + json_serialized_value = json.dumps(value) + msg.data_map.data[key] = json_serialized_value elif isinstance(message, str): msg.data = message msg.sender = sender.name diff --git a/samples/apps/cap/py/autogencap/ag_adapter/CAP2AG.py b/samples/apps/cap/py/autogencap/ag_adapter/CAP2AG.py index 50a0a4751ea4..25cd7093ba79 100644 --- a/samples/apps/cap/py/autogencap/ag_adapter/CAP2AG.py +++ b/samples/apps/cap/py/autogencap/ag_adapter/CAP2AG.py @@ -1,3 +1,4 @@ +import json from enum import Enum from typing import Optional @@ -72,7 +73,11 @@ def _call_agent_receive(self, receive_params: ReceiveReq): save_name = self._ag2can_other_agent.name self._ag2can_other_agent.set_name(receive_params.sender) if receive_params.HasField("data_map"): - data = dict(receive_params.data_map.data) + json_data = dict(receive_params.data_map.data) + data = {} + for key, json_value in json_data.items(): + value = json.loads(json_value) + data[key] = value else: data = receive_params.data self._the_ag_agent.receive(data, self._ag2can_other_agent, request_reply, silent) diff --git a/samples/apps/cap/py/autogencap/ag_adapter/agent.py b/samples/apps/cap/py/autogencap/ag_adapter/agent.py new file mode 100644 index 000000000000..219bb7297c18 --- /dev/null +++ b/samples/apps/cap/py/autogencap/ag_adapter/agent.py @@ -0,0 +1,22 @@ +import time + +from autogen import ConversableAgent + +from ..DebugLog import Info, Warn +from .CAP2AG import CAP2AG + + +class Agent: + def __init__(self, agent: ConversableAgent, counter_party_name="user_proxy", init_chat=False): + self._agent = agent + self._the_other_name = counter_party_name + self._agent_adptr = CAP2AG( + ag_agent=self._agent, the_other_name=self._the_other_name, init_chat=init_chat, self_recursive=True + ) + + def register(self, network): + Info("Agent", f"Running Standalone {self._agent.name}") + network.register(self._agent_adptr) + + def running(self): + return self._agent_adptr.run diff --git a/samples/apps/cap/py/demo/App.py b/samples/apps/cap/py/demo/App.py index 19411a9b315c..8af8c97b0e5a 100644 --- a/samples/apps/cap/py/demo/App.py +++ b/samples/apps/cap/py/demo/App.py @@ -45,7 +45,7 @@ def main(): print("3. AutoGen Pair") print("4. AutoGen GroupChat") print("5. AutoGen Agents in different processes") - print("6. List Actors in CAP") + print("6. List Actors in CAP (Registry)") choice = input("Enter your choice (1-6): ") if choice == "1": diff --git a/samples/apps/cap/py/demo/RemoteAGDemo.py b/samples/apps/cap/py/demo/RemoteAGDemo.py index 0c2a946c0a42..5e7f2f0f1efe 100644 --- a/samples/apps/cap/py/demo/RemoteAGDemo.py +++ b/samples/apps/cap/py/demo/RemoteAGDemo.py @@ -5,13 +5,12 @@ def remote_ag_demo(): print("Remote Agent Demo") instructions = """ - In this demo, Broker, Assistant, and UserProxy are running in separate processes. - demo/standalone/UserProxy.py will initiate a conversation by sending UserProxy a message. + In this demo, Assistant, and UserProxy are running in separate processes. + demo/standalone/user_proxy.py will initiate a conversation by sending UserProxy Agent a message. Please do the following: - 1) Start Broker (python demo/standalone/Broker.py) - 2) Start Assistant (python demo/standalone/Assistant.py) - 3) Start UserProxy (python demo/standalone/UserProxy.py) + 1) Start Assistant (python demo/standalone/assistant.py) + 2) Start UserProxy (python demo/standalone/user_proxy.py) """ print(instructions) input("Press Enter to return to demo menu...") diff --git a/samples/apps/cap/py/demo/standalone/DirectorySvc.py b/samples/apps/cap/py/demo/standalone/directory_svc.py similarity index 100% rename from samples/apps/cap/py/demo/standalone/DirectorySvc.py rename to samples/apps/cap/py/demo/standalone/directory_svc.py diff --git a/samples/apps/cap/py/demo/standalone/user_proxy.py b/samples/apps/cap/py/demo/standalone/user_proxy.py new file mode 100644 index 000000000000..3ce4dac79276 --- /dev/null +++ b/samples/apps/cap/py/demo/standalone/user_proxy.py @@ -0,0 +1,57 @@ +import time + +import _paths +from autogencap.ag_adapter.agent import Agent +from autogencap.Config import IGNORED_LOG_CONTEXTS +from autogencap.LocalActorNetwork import LocalActorNetwork + +from autogen import UserProxyAgent + +# Filter out some Log message contexts +IGNORED_LOG_CONTEXTS.extend(["BROKER"]) + + +def main(): + # Standard AutoGen + user_proxy = UserProxyAgent( + "user_proxy", + code_execution_config={"work_dir": "coding"}, + is_termination_msg=lambda x: "TERMINATE" in x.get("content"), + ) + + # Wrap AutoGen Agent in CAP + cap_user_proxy = Agent(user_proxy, counter_party_name="assistant", init_chat=True) + # Create the message bus + network = LocalActorNetwork() + # Add the user_proxy to the message bus + cap_user_proxy.register(network) + # Start message processing + network.connect() + + # Wait for the user_proxy to finish + interact_with_user(network, cap_user_proxy) + # Cleanup + network.disconnect() + + +# Starts the Broker and the Assistant. The UserProxy is started separately. +def interact_with_user(network, cap_assistant): + user_proxy_conn = network.lookup_actor("user_proxy") + example = "Plot a chart of MSFT daily closing prices for last 1 Month." + print(f"Example: {example}") + try: + user_input = input("Please enter your command: ") + if user_input == "": + user_input = example + print(f"Sending: {user_input}") + user_proxy_conn.send_txt_msg(user_input) + + # Hang around for a while + while cap_assistant.running(): + time.sleep(0.5) + except KeyboardInterrupt: + print("Interrupted by user, shutting down.") + + +if __name__ == "__main__": + main() diff --git a/samples/apps/cap/py/pyproject.toml b/samples/apps/cap/py/pyproject.toml new file mode 100644 index 000000000000..51024bb8b279 --- /dev/null +++ b/samples/apps/cap/py/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "autogencap_rajan.jedi" +version = "0.0.7" +authors = [ + { name="Rajan Chari", email="rajan.jedi@gmail.com" }, +] +dependencies = [ + "pyzmq >= 25.1.2", + "protobuf >= 4.25.3", + "termcolor >= 2.4.0", + "pyautogen >= 0.2.23", +] +description = "CAP w/ autogen bindings" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.urls] +"Homepage" = "https://github.com/microsoft/autogen" +"Bug Tracker" = "https://github.com/microsoft/autogen/issues" + +[tool.hatch.build.targets.sdist] +packages = ["autogencap"] +only-packages = true + +[tool.hatch.build.targets.wheel] +packages = ["autogencap"] +only-packages = true diff --git a/test/agentchat/test_agent_file_logging.py b/test/agentchat/test_agent_file_logging.py new file mode 100644 index 000000000000..9b930ba4ac91 --- /dev/null +++ b/test/agentchat/test_agent_file_logging.py @@ -0,0 +1,127 @@ +import json +import os +import sys +import tempfile +import uuid + +import pytest + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) + +from conftest import skip_openai # noqa: E402 +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 + +import autogen +import autogen.runtime_logging +from autogen.logger.file_logger import FileLogger + +is_windows = sys.platform.startswith("win") + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +@pytest.fixture +def logger() -> FileLogger: + current_dir = os.path.dirname(os.path.abspath(__file__)) + with tempfile.TemporaryDirectory(dir=current_dir) as temp_dir: + log_file = os.path.join(temp_dir, "test_log.log") + config = {"filename": log_file} + logger = FileLogger(config) + yield logger + + logger.stop() + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_start(logger: FileLogger): + session_id = logger.start() + assert isinstance(session_id, str) + assert len(session_id) == 36 + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_log_chat_completion(logger: FileLogger): + invocation_id = uuid.uuid4() + client_id = 123456789 + wrapper_id = 987654321 + request = {"messages": [{"content": "Test message", "role": "user"}]} + response = "Test response" + is_cached = 0 + cost = 0.5 + start_time = "2024-05-06 15:20:21.263231" + + logger.log_chat_completion(invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time) + + with open(logger.log_file, "r") as f: + lines = f.readlines() + assert len(lines) == 1 + log_data = json.loads(lines[0]) + assert log_data["invocation_id"] == str(invocation_id) + assert log_data["client_id"] == client_id + assert log_data["wrapper_id"] == wrapper_id + assert log_data["response"] == response + assert log_data["is_cached"] == is_cached + assert log_data["cost"] == cost + assert log_data["start_time"] == start_time + assert isinstance(log_data["thread_id"], int) + + +class TestWrapper: + def __init__(self, init_args): + self.init_args = init_args + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_log_new_agent(logger: FileLogger): + agent = autogen.UserProxyAgent(name="user_proxy", code_execution_config=False) + logger.log_new_agent(agent) + + with open(logger.log_file, "r") as f: + lines = f.readlines() + log_data = json.loads(lines[0]) # the first line is the session id + assert log_data["agent_name"] == "user_proxy" + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_log_event(logger: FileLogger): + source = autogen.AssistantAgent(name="TestAgent", code_execution_config=False) + name = "TestEvent" + kwargs = {"key": "value"} + logger.log_event(source, name, **kwargs) + + with open(logger.log_file, "r") as f: + lines = f.readlines() + log_data = json.loads(lines[0]) + assert log_data["source_name"] == "TestAgent" + assert log_data["event_name"] == name + assert log_data["json_state"] == json.dumps(kwargs) + assert isinstance(log_data["thread_id"], int) + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_log_new_wrapper(logger: FileLogger): + wrapper = TestWrapper(init_args={"foo": "bar"}) + logger.log_new_wrapper(wrapper, wrapper.init_args) + + with open(logger.log_file, "r") as f: + lines = f.readlines() + log_data = json.loads(lines[0]) + assert log_data["wrapper_id"] == id(wrapper) + assert log_data["json_state"] == json.dumps(wrapper.init_args) + assert isinstance(log_data["thread_id"], int) + + +@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows") +def test_log_new_client(logger: FileLogger): + client = autogen.UserProxyAgent(name="user_proxy", code_execution_config=False) + wrapper = TestWrapper(init_args={"foo": "bar"}) + init_args = {"foo": "bar"} + logger.log_new_client(client, wrapper, init_args) + + with open(logger.log_file, "r") as f: + lines = f.readlines() + log_data = json.loads(lines[0]) + assert log_data["client_id"] == id(client) + assert log_data["wrapper_id"] == id(wrapper) + assert log_data["json_state"] == json.dumps(init_args) + assert isinstance(log_data["thread_id"], int) diff --git a/test/agentchat/test_agent_logging.py b/test/agentchat/test_agent_logging.py index 47798fbe0f6d..8375a08444bf 100644 --- a/test/agentchat/test_agent_logging.py +++ b/test/agentchat/test_agent_logging.py @@ -6,14 +6,16 @@ import pytest -import autogen -import autogen.runtime_logging - sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) + + from conftest import skip_openai # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 +import autogen +import autogen.runtime_logging + TEACHER_MESSAGE = """ You are roleplaying a math teacher, and your job is to help your students with linear algebra. Keep your explanations short.