From 06d6b82c10d21ff7d6c82c240708d2193b00f5d3 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Thu, 8 Aug 2024 16:33:14 -0700 Subject: [PATCH] [.Net] Add AutoGen.AzureAIInference (#3332) * add AutoGen.AzureAIInference * add tests * update readme * fix format --- dotnet/AutoGen.sln | 21 + dotnet/eng/Version.props | 1 + .../Agent/ChatCompletionsClientAgent.cs | 202 +++++++ .../AutoGen.AzureAIInference.csproj | 25 + .../ChatComptionClientAgentExtension.cs | 39 ++ .../Extension/FunctionContractExtension.cs | 64 ++ ...eAIInferenceChatRequestMessageConnector.cs | 302 ++++++++++ dotnet/src/AutoGen/AutoGen.csproj | 2 + .../AutoGen.AzureAIInference.Tests.csproj | 16 + .../ChatCompletionClientAgentTests.cs | 533 ++++++++++++++++ .../ChatRequestMessageTests.cs | 568 ++++++++++++++++++ .../EnvironmentSpecificFactAttribute.cs | 31 + .../Attribute/OpenAIFact.cs | 22 + .../AutoGen.Tests.Share.csproj | 15 + dotnet/test/AutoGen.Test.Share/EchoAgent.cs | 37 ++ .../EnvironmentSpecificFactAttribute.cs | 33 - .../AutoGen.Tests/Attribute/OpenAIFact.cs | 26 - .../test/AutoGen.Tests/AutoGen.Tests.csproj | 1 + dotnet/test/AutoGen.Tests/EchoAgent.cs | 41 -- .../Orchestrator/RolePlayOrchestratorTests.cs | 19 + dotnet/website/articles/Agent-overview.md | 1 - dotnet/website/articles/Installation.md | 3 +- 22 files changed, 1900 insertions(+), 102 deletions(-) create mode 100644 dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs create mode 100644 dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj create mode 100644 dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs create mode 100644 dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs create mode 100644 dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs create mode 100644 dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj create mode 100644 dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs create mode 100644 dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs create mode 100644 dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs create mode 100644 dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs create mode 100644 dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj create mode 100644 dotnet/test/AutoGen.Test.Share/EchoAgent.cs delete mode 100644 dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs delete mode 100644 dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs delete mode 100644 dotnet/test/AutoGen.Tests/EchoAgent.cs diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 1218cf129821..0fcaf15ceb2a 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -68,6 +68,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sa EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -194,6 +200,18 @@ Global {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -229,6 +247,9 @@ Global {6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} {12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} + {5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index c78ce4b415fc..d90e8bc76c80 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -15,6 +15,7 @@ 8.0.4 3.0.0 4.3.0.2 + 1.0.0-beta.1 7.4.4 \ No newline at end of file diff --git a/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs new file mode 100644 index 000000000000..452c5b1c3079 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Agent/ChatCompletionsClientAgent.cs @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionsClientAgent.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.AzureAIInference.Extension; +using AutoGen.Core; +using Azure.AI.Inference; + +namespace AutoGen.AzureAIInference; + +/// +/// ChatCompletions client agent. This agent is a thin wrapper around to provide a simple interface for chat completions. +/// supports the following message types: +/// +/// +/// where T is : chat request message. +/// +/// +/// returns the following message types: +/// +/// +/// where T is : chat response message. +/// where T is : streaming chat completions update. +/// +/// +/// +public class ChatCompletionsClientAgent : IStreamingAgent +{ + private readonly ChatCompletionsClient chatCompletionsClient; + private readonly ChatCompletionsOptions options; + private readonly string systemMessage; + + /// + /// Create a new instance of . + /// + /// chat completions client + /// agent name + /// model name. e.g. gpt-turbo-3.5 + /// system message + /// temperature + /// max tokens to generated + /// response format, set it to to enable json mode. + /// seed to use, set it to enable deterministic output + /// functions + public ChatCompletionsClientAgent( + ChatCompletionsClient chatCompletionsClient, + string name, + string modelName, + string systemMessage = "You are a helpful AI assistant", + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + : this( + chatCompletionsClient: chatCompletionsClient, + name: name, + options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions), + systemMessage: systemMessage) + { + } + + /// + /// Create a new instance of . + /// + /// chat completions client + /// agent name + /// system message + /// chat completion option. The option can't contain messages + public ChatCompletionsClientAgent( + ChatCompletionsClient chatCompletionsClient, + string name, + ChatCompletionsOptions options, + string systemMessage = "You are a helpful AI assistant") + { + if (options.Messages is { Count: > 0 }) + { + throw new ArgumentException("Messages should not be provided in options"); + } + + this.chatCompletionsClient = chatCompletionsClient; + this.Name = name; + this.options = options; + this.systemMessage = systemMessage; + } + + public string Name { get; } + + public async Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken); + + return new MessageEnvelope(reply, from: this.Name); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken); + await foreach (var update in response.WithCancellation(cancellationToken)) + { + yield return new MessageEnvelope(update, from: this.Name); + } + } + + private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) + { + var oaiMessages = messages.Select(m => m switch + { + IMessage chatRequestMessage => chatRequestMessage.Content, + _ => throw new ArgumentException("Invalid message type") + }); + + // add system message if there's no system message in messages + if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) + { + oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages); + } + + // clone the options by serializing and deserializing + var json = JsonSerializer.Serialize(this.options); + var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options"); + + foreach (var m in oaiMessages) + { + settings.Messages.Add(m); + } + + settings.Temperature = options?.Temperature ?? settings.Temperature; + settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens; + + foreach (var functions in this.options.Tools) + { + settings.Tools.Add(functions); + } + + foreach (var stopSequence in this.options.StopSequences) + { + settings.StopSequences.Add(stopSequence); + } + + var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList(); + if (openAIFunctionDefinitions is { Count: > 0 }) + { + foreach (var f in openAIFunctionDefinitions) + { + settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) + { + foreach (var seq in sequence) + { + settings.StopSequences.Add(seq); + } + } + + return settings; + } + + private static ChatCompletionsOptions CreateChatCompletionOptions( + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + { + var options = new ChatCompletionsOptions() + { + Model = modelName, + Temperature = temperature, + MaxTokens = maxTokens, + Seed = seed, + ResponseFormat = responseFormat, + }; + + if (functions is not null) + { + foreach (var f in functions) + { + options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + return options; + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj new file mode 100644 index 000000000000..e9401bc4bc22 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/AutoGen.AzureAIInference.csproj @@ -0,0 +1,25 @@ + + + $(PackageTargetFrameworks) + AutoGen.AzureAIInference + + + + + + + AutoGen.AzureAIInference + + Azure AI Inference Intergration for AutoGen. + + + + + + + + + + + + diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs new file mode 100644 index 000000000000..8faf29604ed1 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Extension/ChatComptionClientAgentExtension.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatComptionClientAgentExtension.cs + +using AutoGen.Core; + +namespace AutoGen.AzureAIInference.Extension; + +public static class ChatComptionClientAgentExtension +{ + /// + /// Register an to the + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null) + { + if (connector == null) + { + connector = new AzureAIInferenceChatRequestMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } + + /// + /// Register an to the where T is + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this MiddlewareStreamingAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null) + { + if (connector == null) + { + connector = new AzureAIInferenceChatRequestMessageConnector(); + } + + return agent.RegisterStreamingMiddleware(connector); + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs new file mode 100644 index 000000000000..4cd7b3864f95 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Extension/FunctionContractExtension.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// FunctionContractExtension.cs + +using System; +using System.Collections.Generic; +using AutoGen.Core; +using Azure.AI.Inference; +using Json.Schema; +using Json.Schema.Generation; + +namespace AutoGen.AzureAIInference.Extension; + +public static class FunctionContractExtension +{ + /// + /// Convert a to a that can be used in gpt funciton call. + /// + /// function contract + /// + public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract) + { + var functionDefinition = new FunctionDefinition + { + Name = functionContract.Name, + Description = functionContract.Description, + }; + var requiredParameterNames = new List(); + var propertiesSchemas = new Dictionary(); + var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object); + foreach (var param in functionContract.Parameters ?? []) + { + if (param.Name is null) + { + throw new InvalidOperationException("Parameter name cannot be null"); + } + + var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType))); + if (param.Description != null) + { + schemaBuilder = schemaBuilder.Description(param.Description); + } + + if (param.IsRequired) + { + requiredParameterNames.Add(param.Name); + } + + var schema = schemaBuilder.Build(); + propertiesSchemas[param.Name] = schema; + + } + propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas); + propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames); + + var option = new System.Text.Json.JsonSerializerOptions() + { + PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase + }; + + functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option); + + return functionDefinition; + } +} diff --git a/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs new file mode 100644 index 000000000000..9c5d22e2e7e7 --- /dev/null +++ b/dotnet/src/AutoGen.AzureAIInference/Middleware/AzureAIInferenceChatRequestMessageConnector.cs @@ -0,0 +1,302 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AzureAIInferenceChatRequestMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; +using Azure.AI.Inference; + +namespace AutoGen.AzureAIInference; + +/// +/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent. +/// Supported are +/// - +/// - +/// - +/// - +/// - +/// - where T is +/// - where TMessage1 is and TMessage2 is +/// +public class AzureAIInferenceChatRequestMessageConnector : IStreamingMiddleware +{ + private bool strictMode = false; + + /// + /// Create a new instance of . + /// + /// If true, will throw an + /// When the message type is not supported. If false, it will ignore the unsupported message type. + public AzureAIInferenceChatRequestMessageConnector(bool strictMode = false) + { + this.strictMode = strictMode; + } + + public string? Name => nameof(AzureAIInferenceChatRequestMessageConnector); + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages); + + var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + + return PostProcessMessage(reply); + } + + public async IAsyncEnumerable InvokeAsync( + MiddlewareContext context, + IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages); + var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); + string? currentToolName = null; + await foreach (var reply in streamingReply) + { + if (reply is IMessage update) + { + if (update.Content.FunctionName is string functionName) + { + currentToolName = functionName; + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName) + { + currentToolName = toolCallName; + } + var postProcessMessage = PostProcessStreamingMessage(update, currentToolName); + if (postProcessMessage != null) + { + yield return postProcessMessage; + } + } + else + { + if (this.strictMode) + { + throw new InvalidOperationException($"Invalid streaming message type {reply.GetType().Name}"); + } + else + { + yield return reply; + } + } + } + } + + public IMessage PostProcessMessage(IMessage message) + { + return message switch + { + IMessage m => PostProcessChatResponseMessage(m.Content, m.From), + IMessage m => PostProcessChatCompletions(m), + _ when strictMode is false => message, + _ => throw new InvalidOperationException($"Invalid return message type {message.GetType().Name}"), + }; + } + + public IMessage? PostProcessStreamingMessage(IMessage update, string? currentToolName) + { + if (update.Content.ContentUpdate is string contentUpdate && string.IsNullOrEmpty(contentUpdate) == false) + { + // text message + return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From); + } + else if (update.Content.FunctionName is string functionName) + { + return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From); + } + else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From); + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From); + } + else + { + return null; + } + } + + private IMessage PostProcessChatCompletions(IMessage message) + { + // throw exception if prompt filter results is not null + if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) + { + throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); + } + + return PostProcessChatResponseMessage(message.Content.Choices[0].Message, message.From); + } + + private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from) + { + var textContent = chatResponseMessage.Content; + if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any()) + { + var functionToolCalls = chatResponseMessage.ToolCalls + .Where(tc => tc is ChatCompletionsFunctionToolCall) + .Select(tc => (ChatCompletionsFunctionToolCall)tc); + + var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id }); + + return new ToolCallMessage(toolCalls, from) + { + Content = textContent, + }; + } + + if (textContent is string content && !string.IsNullOrEmpty(content)) + { + return new TextMessage(Role.Assistant, content, from); + } + + throw new InvalidOperationException("Invalid ChatResponseMessage"); + } + + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + { + return messages.SelectMany(m => + { + if (m is IMessage crm) + { + return [crm]; + } + else + { + var chatRequestMessages = m switch + { + TextMessage textMessage => ProcessTextMessage(agent, textMessage), + ImageMessage imageMessage when (imageMessage.From is null || imageMessage.From != agent.Name) => ProcessImageMessage(agent, imageMessage), + MultiModalMessage multiModalMessage when (multiModalMessage.From is null || multiModalMessage.From != agent.Name) => ProcessMultiModalMessage(agent, multiModalMessage), + ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(agent, toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage aggregateMessage => ProcessFunctionCallMiddlewareMessage(agent, aggregateMessage), + _ when strictMode is false => [], + _ => throw new InvalidOperationException($"Invalid message type: {m.GetType().Name}"), + }; + + if (chatRequestMessages.Any()) + { + return chatRequestMessages.Select(cm => MessageEnvelope.Create(cm, m.From)); + } + else + { + return [m]; + } + } + }); + } + + private IEnumerable ProcessTextMessage(IAgent agent, TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatRequestSystemMessage(message.Content)]; + } + + if (agent.Name == message.From) + { + return [new ChatRequestAssistantMessage { Content = message.Content }]; + } + else + { + return message.From switch + { + null when message.Role == Role.User => [new ChatRequestUserMessage(message.Content)], + null when message.Role == Role.Assistant => [new ChatRequestAssistantMessage() { Content = message.Content }], + null => throw new InvalidOperationException("Invalid Role"), + _ => [new ChatRequestUserMessage(message.Content)] + }; + } + } + + private IEnumerable ProcessImageMessage(IAgent agent, ImageMessage message) + { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("ImageMessage is not supported when message.From is the same with agent"); + } + + var imageContentItem = this.CreateChatMessageImageContentItemFromImageMessage(message); + return [new ChatRequestUserMessage([imageContentItem])]; + } + + private IEnumerable ProcessMultiModalMessage(IAgent agent, MultiModalMessage message) + { + if (agent.Name == message.From) + { + // image message from assistant is not supported + throw new ArgumentException("MultiModalMessage is not supported when message.From is the same with agent"); + } + + IEnumerable items = message.Content.Select(ci => ci switch + { + TextMessage text => new ChatMessageTextContentItem(text.Content), + ImageMessage image => this.CreateChatMessageImageContentItemFromImageMessage(image), + _ => throw new NotImplementedException(), + }); + + return [new ChatRequestUserMessage(items)]; + } + + private ChatMessageImageContentItem CreateChatMessageImageContentItemFromImageMessage(ImageMessage message) + { + return message.Data is null && message.Url is not null + ? new ChatMessageImageContentItem(new Uri(message.Url)) + : new ChatMessageImageContentItem(message.Data, message.Data?.MediaType); + } + + private IEnumerable ProcessToolCallMessage(IAgent agent, ToolCallMessage message) + { + if (message.From is not null && message.From != agent.Name) + { + throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); + } + + var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); + var textContent = message.GetContent() ?? string.Empty; + var chatRequestMessage = new ChatRequestAssistantMessage() { Content = textContent }; + foreach (var tc in toolCall) + { + chatRequestMessage.ToolCalls.Add(tc); + } + + return [chatRequestMessage]; + } + + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage message) + { + return message.ToolCalls + .Where(tc => tc.Result is not null) + .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}")); + } + + private IEnumerable ProcessFunctionCallMiddlewareMessage(IAgent agent, AggregateMessage aggregateMessage) + { + if (aggregateMessage.From is not null && aggregateMessage.From != agent.Name) + { + // convert as user message + var resultMessage = aggregateMessage.Message2; + + return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + } + else + { + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; + + var assistantMessage = this.ProcessToolCallMessage(agent, toolCallMessage1); + var toolCallResults = this.ProcessToolCallResultMessage(toolCallResultMessage); + + return assistantMessage.Concat(toolCallResults); + } + } +} diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index 3cb5a23da14c..88d9fca19ca2 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -15,6 +15,8 @@ + + diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj new file mode 100644 index 000000000000..0eaebd1da0cb --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/AutoGen.AzureAIInference.Tests.csproj @@ -0,0 +1,16 @@ + + + + $(TestTargetFrameworks) + false + True + True + + + + + + + + + diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs new file mode 100644 index 000000000000..d81b8881ac55 --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatCompletionClientAgentTests.cs @@ -0,0 +1,533 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatCompletionClientAgentTests.cs + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using AutoGen.AzureAIInference.Extension; +using AutoGen.Core; +using AutoGen.Tests; +using Azure.AI.Inference; +using FluentAssertions; +using Xunit; + +namespace AutoGen.AzureAIInference.Tests; + +public partial class ChatCompletionClientAgentTests +{ + /// + /// Get the weather for a location. + /// + /// location + /// + [Function] + public async Task GetWeatherAsync(string location) + { + return $"The weather in {location} is sunny."; + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionAgent_LLaMA3_1() + { + var client = CreateChatCompletionClient(); + var model = "meta-llama-3-8b-instruct"; + + var agent = new ChatCompletionsClientAgent(client, "assistant", model) + .RegisterMessageConnector(); + + var reply = await this.BasicChatAsync(agent); + reply.Should().BeOfType(); + + reply = await this.BasicChatWithContinuousMessageFromSameSenderAsync(agent); + reply.Should().BeOfType(); + } + + [ApiKeyFact("GH_API_KEY")] + public async Task BasicConversation_Mistra_Small() + { + var deployName = "Mistral-small"; + var client = CreateChatCompletionClient(); + var openAIChatAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + modelName: deployName); + + // By default, ChatCompletionClientAgent supports the following message types + // - IMessage + var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello")); + var reply = await openAIChatAgent.SendAsync(chatMessageContent); + + reply.Should().BeOfType>(); + reply.As>().From.Should().Be("assistant"); + reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant); + reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0); + + // test streaming + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + + await foreach (var streamingMessage in streamingReply) + { + streamingMessage.Should().BeOfType>(); + streamingMessage.As>().From.Should().Be("assistant"); + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionsMessageContentConnector_Phi3_Mini() + { + var deployName = "Phi-3-mini-4k-instruct"; + var openaiClient = CreateChatCompletionClient(); + var chatCompletionAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: openaiClient, + name: "assistant", + modelName: deployName); + + MiddlewareStreamingAgent assistant = chatCompletionAgent + .RegisterMessageConnector(); + + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage("Hello")), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await assistant.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + } + + // test streaming + foreach (var message in messages) + { + var reply = assistant.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionClientAgentToolCall_Mistral_Nemo() + { + var deployName = "Mistral-nemo"; + var chatCompletionClient = CreateChatCompletionClient(); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionClient, + name: "assistant", + modelName: deployName); + + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract]); + MiddlewareStreamingAgent assistant = agent + .RegisterMessageConnector(); + + assistant.StreamingMiddlewares.Count().Should().Be(1); + var functionCallAgent = assistant + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + reply.As().ToolCalls.Count().Should().Be(1); + reply.As().ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + + // test streaming + foreach (var message in messages) + { + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); + ToolCallMessage? toolCallMessage = null; + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + if (toolCallMessage is null) + { + toolCallMessage = new ToolCallMessage(streamingMessage.As()); + } + else + { + toolCallMessage.Update(streamingMessage.As()); + } + } + + toolCallMessage.Should().NotBeNull(); + toolCallMessage!.From.Should().Be("assistant"); + toolCallMessage.ToolCalls.Count().Should().Be(1); + toolCallMessage.ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ChatCompletionClientAgentToolCallInvoking_gpt_4o_mini() + { + var deployName = "gpt-4o-mini"; + var client = CreateChatCompletionClient(); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + modelName: deployName); + + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract], + functionMap: new Dictionary>> { { this.GetWeatherAsyncFunctionContract.Name!, this.GetWeatherAsyncWrapper } }); + MiddlewareStreamingAgent assistant = agent + .RegisterMessageConnector(); + + var functionCallAgent = assistant + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.From.Should().Be("assistant"); + reply.GetToolCalls()!.Count().Should().Be(1); + reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + reply.GetContent()!.ToLower().Should().Contain("seattle"); + } + + // test streaming + foreach (var message in messages) + { + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); + await foreach (var streamingMessage in reply) + { + if (streamingMessage is not IMessage) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + else + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle"); + } + } + } + } + + [ApiKeyFact("GH_API_KEY")] + public async Task ItCreateChatCompletionClientAgentWithChatCompletionOption_AI21_Jamba_Instruct() + { + var deployName = "AI21-Jamba-Instruct"; + var chatCompletionsClient = CreateChatCompletionClient(); + var options = new ChatCompletionsOptions() + { + Model = deployName, + Temperature = 0.7f, + MaxTokens = 1, + }; + + var openAIChatAgent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionsClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var respond = await openAIChatAgent.SendAsync("hello"); + respond.GetContent()?.Should().NotBeNullOrEmpty(); + } + + [Fact] + public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages() + { + var client = new ChatCompletionsClient(new Uri("https://dummy.com"), new Azure.AzureKeyCredential("dummy")); + var options = new ChatCompletionsOptions([new ChatRequestUserMessage("hi")]) + { + Model = "dummy", + Temperature = 0.7f, + MaxTokens = 1, + }; + + var action = () => new ChatCompletionsClientAgent( + chatCompletionsClient: client, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + action.Should().ThrowExactly().WithMessage("Messages should not be provided in options"); + } + + private ChatCompletionsClient CreateChatCompletionClient() + { + var apiKey = Environment.GetEnvironmentVariable("GH_API_KEY") ?? throw new Exception("Please set GH_API_KEY environment variable."); + var endpoint = "https://models.inference.ai.azure.com"; + return new ChatCompletionsClient(new Uri(endpoint), new Azure.AzureKeyCredential(apiKey)); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task BasicChatEndWithSelfMessageAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user2"), + new TextMessage(Role.Assistant, "Hello", from: "user3"), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task BasicChatAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user1"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. This test the generate reply with continuous message from the same sender. + /// + private async Task BasicChatWithContinuousMessageFromSameSenderAsync(IAgent agent) + { + IMessage[] chatHistory = [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + new TextMessage(Role.Assistant, "Hello", from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task ImageChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user"); + + IMessage[] chatHistory = [ + imageMessage, + new TextMessage(Role.Assistant, "What's in the picture", from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. This test the generate reply with continuous image messages. + /// + /// + /// + private async Task MultipleImageChatAsync(IAgent agent) + { + var image1 = Path.Join("testData", "images", "square.png"); + var image2 = Path.Join("testData", "images", "background.png"); + var binaryData1 = File.ReadAllBytes(image1); + var binaryData2 = File.ReadAllBytes(image2); + var imageMessage1 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData1, "image/png"), from: "user"); + var imageMessage2 = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData2, "image/png"), from: "user"); + + IMessage[] chatHistory = [ + imageMessage1, + imageMessage2, + new TextMessage(Role.Assistant, "What's in the picture", from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a text message based on the chat history. + /// + /// + /// + private async Task MultiModalChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var question = "What's in the picture"; + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: "user"); + var textMessage = new TextMessage(Role.Assistant, question, from: "user"); + + IMessage[] chatHistory = [ + new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: "user"), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } + + /// + /// The agent should return a tool call message based on the chat history. + /// + /// + /// + private async Task ToolCallChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because tool call result is not available. + /// + private async Task ToolCallFromSelfChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// mimic the further chat after tool call. The agent should return a text message based on the tool call result. + /// + private async Task ToolCallWithResultChatAsync(IAgent agent) + { + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + new TextMessage(Role.Assistant, question, from: "user"), + new ToolCallMessage("GetWeatherAsync", "Seattle", from: "user"), + new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name), + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// the agent should return a text message based on the tool call result. + /// + /// + /// + private async Task AggregateToolCallFromSelfChatAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: agent.Name); + var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: agent.Name); + var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, from: agent.Name); + + var messages = new IMessage[] + { + textMessage, + aggregateToolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// the agent should return a text message based on the tool call result. Because the aggregate tool call message is from other, the message would be treated as an ordinary text message. + /// + private async Task AggregateToolCallFromOtherChatWithContinuousMessageAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other"); + var toolCallResultMessage = new ToolCallResultMessage("sunny", "GetWeatherAsync", "Seattle", from: "other"); + var aggregateToolCallMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "other"); + + var messages = new IMessage[] + { + textMessage, + aggregateToolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because tool call message from other is not allowed. + /// + private async Task ToolCallMessaageFromOtherChatAsync(IAgent agent) + { + var textMessage = new TextMessage(Role.Assistant, "What's the weather in Seattle", from: "user"); + var toolCallMessage = new ToolCallMessage("GetWeatherAsync", "Seattle", from: "other"); + + var messages = new IMessage[] + { + textMessage, + toolCallMessage, + }; + + return await agent.GenerateReplyAsync(messages); + } + + /// + /// The agent should throw an exception because multi-modal message from self is not allowed. + /// + /// + /// + private async Task MultiModalMessageFromSelfChatAsync(IAgent agent) + { + var image = Path.Join("testData", "images", "square.png"); + var binaryData = File.ReadAllBytes(image); + var question = "What's in the picture"; + var imageMessage = new ImageMessage(Role.Assistant, BinaryData.FromBytes(binaryData, "image/png"), from: agent.Name); + var textMessage = new TextMessage(Role.Assistant, question, from: agent.Name); + + IMessage[] chatHistory = [ + new MultiModalMessage(Role.Assistant, [imageMessage, textMessage], from: agent.Name), + ]; + + return await agent.GenerateReplyAsync(chatHistory); + } +} diff --git a/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs new file mode 100644 index 000000000000..d6e5c5283932 --- /dev/null +++ b/dotnet/test/AutoGen.AzureAIInference.Tests/ChatRequestMessageTests.cs @@ -0,0 +1,568 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatRequestMessageTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Threading.Tasks; +using AutoGen.Core; +using AutoGen.Tests; +using Azure.AI.Inference; +using FluentAssertions; +using Xunit; + +namespace AutoGen.AzureAIInference.Tests; + +public class ChatRequestMessageTests +{ + private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { + WriteIndented = true, + IgnoreReadOnlyProperties = false, + }; + + [Fact] + public async Task ItProcessUserTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("Hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new TextMessage(Role.User, "Hello", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItShortcutChatRequestMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = new ChatRequestUserMessage("hello"); + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItShortcutMessageWhenStrictModelIsFalseAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + + var chatRequestMessage = ((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Should().Be("hello"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + await agent.GenerateReplyAsync([chatRequestMessage]); + } + + [Fact] + public async Task ItThrowExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // user message + var userMessage = "hello"; + var chatRequestMessage = MessageEnvelope.Create(userMessage); + Func action = async () => await agent.GenerateReplyAsync([chatRequestMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MessageEnvelope`1"); + } + + [Fact] + public async Task ItProcessAssistantTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("How can I help you?"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // assistant message + IMessage message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessSystemTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestSystemMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("You are a helpful AI assistant"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // system message + IMessage message = new TextMessage(Role.System, "You are a helpful AI assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessImageMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(1); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingImageMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var imageMessage = new ImageMessage(Role.Assistant, "https://example.com/image.png", "assistant"); + Func action = async () => await agent.GenerateReplyAsync([imageMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: ImageMessage"); + } + + [Fact] + public async Task ItProcessMultiModalMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.MultimodalContentItems.Count().Should().Be(2); + chatRequestMessage.MultimodalContentItems.First().Should().BeOfType(); + chatRequestMessage.MultimodalContentItems.Last().Should().BeOfType(); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new MultiModalMessage( + Role.User, + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingMultiModalMessageFromSelfAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var multiModalMessage = new MultiModalMessage( + Role.Assistant, + [ + new TextMessage(Role.User, "Hello", "assistant"), + new ImageMessage(Role.User, "https://example.com/image.png", "assistant"), + ], "assistant"); + + Func action = async () => await agent.GenerateReplyAsync([multiModalMessage]); + + await action.Should().ThrowAsync().WithMessage("Invalid message type: MultiModalMessage"); + } + + [Fact] + public async Task ItProcessToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.ToolCalls.Count().Should().Be(1); + chatRequestMessage.Content.Should().Be("textContent"); + chatRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallMessage("test", "test", "assistant") + { + Content = "textContent", + }; + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessParallelToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.ToolCalls.Count().Should().Be(2); + for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) + { + chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test"), + new ToolCall("test", "test"), + }; + IMessage message = new ToolCallMessage(toolCalls, "assistant"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(strictMode: true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + Func action = async () => await agent.GenerateReplyAsync([toolCallMessage]); + await action.Should().ThrowAsync().WithMessage("Invalid message type: ToolCallMessage"); + } + + [Fact] + public async Task ItProcessToolCallResultMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + IMessage message = new ToolCallResultMessage("result", "test", "test", "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessParallelToolCallResultMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + + for (int i = 0; i < msgs.Count(); i++) + { + var innerMessage = msgs.ElementAt(i); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be($"test_{i}"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + IMessage message = new ToolCallResultMessage(toolCalls, "user"); + await agent.GenerateReplyAsync([message]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(1); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestUserMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "user"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "user"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "user"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test"); + + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(1); + toolCallRequestMessage.ToolCalls.First().Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First(); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test"); + functionToolCall.Arguments.Should().Be("test"); + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCallMessage = new ToolCallMessage("test", "test", "assistant"); + var toolCallResultMessage = new ToolCallResultMessage("result", "test", "test", "assistant"); + var aggregateMessage = new ToolCallAggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(3); + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(2); + + for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++) + { + toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + + for (int i = 1; i < msgs.Count(); i++) + { + var toolCallResultMessage = msgs.ElementAt(i); + toolCallResultMessage!.Should().BeOfType>(); + var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content; + toolCallResultRequestMessage.Content.Should().Be("result"); + toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}"); + } + + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + var toolCallMessage = new ToolCallMessage(toolCalls, "assistant"); + var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + + [Fact] + public async Task ItConvertChatResponseMessageToTextMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = CreateInstance(ChatRole.Assistant, "hello"); + var chatRequestMessage = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetContent().Should().Be("hello"); + message.GetRole().Should().Be(Role.Assistant); + } + + [Fact] + public async Task ItConvertChatResponseMessageToToolCallMessageAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // tool call message + var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new Dictionary()); + var chatRequestMessage = MessageEnvelope.Create(toolCallMessage); + var message = await agent.GenerateReplyAsync([chatRequestMessage]); + message.Should().BeOfType(); + message.GetToolCalls()!.Count().Should().Be(1); + message.GetToolCalls()!.First().FunctionName.Should().Be("test"); + message.GetToolCalls()!.First().FunctionArguments.Should().Be("test"); + message.GetContent().Should().Be("textContent"); + } + + [Fact] + public async Task ItReturnOriginalMessageWhenStrictModeIsFalseAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = "hello"; + var messageToSend = MessageEnvelope.Create(textMessage); + + var message = await agent.GenerateReplyAsync([messageToSend]); + message.Should().BeOfType>(); + } + + [Fact] + public async Task ItThrowInvalidOperationExceptionWhenStrictModeIsTrueAsync() + { + var middleware = new AzureAIInferenceChatRequestMessageConnector(true); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(middleware); + + // text message + var textMessage = new ChatRequestUserMessage("hello"); + var messageToSend = MessageEnvelope.Create(textMessage); + Func action = async () => await agent.GenerateReplyAsync([messageToSend]); + + await action.Should().ThrowAsync().WithMessage("Invalid return message type MessageEnvelope`1"); + } + + [Fact] + public void ToOpenAIChatRequestMessageShortCircuitTest() + { + var agent = new EchoAgent("assistant"); + var middleware = new AzureAIInferenceChatRequestMessageConnector(); + ChatRequestMessage[] messages = + [ + new ChatRequestUserMessage("Hello"), + new ChatRequestAssistantMessage() + { + Content = "How can I help you?", + }, + new ChatRequestSystemMessage("You are a helpful AI assistant"), + new ChatRequestToolMessage("test", "test"), + ]; + + foreach (var oaiMessage in messages) + { + IMessage message = new MessageEnvelope(oaiMessage); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); + oaiMessages.Count().Should().Be(1); + //oaiMessages.First().Should().BeOfType>(); + if (oaiMessages.First() is IMessage chatRequestMessage) + { + chatRequestMessage.Content.Should().Be(oaiMessage); + } + else + { + // fail the test + Assert.True(false); + } + } + } + + private static T CreateInstance(params object[] args) + { + var type = typeof(T); + var instance = type.Assembly.CreateInstance( + type.FullName!, false, + BindingFlags.Instance | BindingFlags.NonPublic, + null, args, null, null); + return (T)instance!; + } +} diff --git a/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs b/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs new file mode 100644 index 000000000000..1361531cc9ed --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/Attribute/EnvironmentSpecificFactAttribute.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// EnvironmentSpecificFactAttribute.cs + +using Xunit; + +namespace AutoGen.Tests; + +/// +/// A base class for environment-specific fact attributes. +/// +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public abstract class EnvironmentSpecificFactAttribute : FactAttribute +{ + private readonly string _skipMessage; + + /// + /// Creates a new instance of the class. + /// + /// The message to be used when skipping the test marked with this attribute. + protected EnvironmentSpecificFactAttribute(string skipMessage) + { + _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); + } + + public sealed override string Skip => IsEnvironmentSupported() ? string.Empty : _skipMessage; + + /// + /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. + /// + protected abstract bool IsEnvironmentSupported(); +} diff --git a/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs b/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs new file mode 100644 index 000000000000..54d72cd61ab7 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/Attribute/OpenAIFact.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIFact.cs + +namespace AutoGen.Tests; + +/// +/// A fact for tests requiring OPENAI_API_KEY env. +/// +public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute +{ + private readonly string[] _envVariableNames; + public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{envVariableNames} is not found in env") + { + _envVariableNames = envVariableNames; + } + + /// + protected override bool IsEnvironmentSupported() + { + return _envVariableNames.All(Environment.GetEnvironmentVariables().Contains); + } +} diff --git a/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj b/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj new file mode 100644 index 000000000000..21c71896ddc7 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/AutoGen.Tests.Share.csproj @@ -0,0 +1,15 @@ + + + + $(TestTargetFrameworks) + enable + false + True + enable + + + + + + + diff --git a/dotnet/test/AutoGen.Test.Share/EchoAgent.cs b/dotnet/test/AutoGen.Test.Share/EchoAgent.cs new file mode 100644 index 000000000000..010b72d2add0 --- /dev/null +++ b/dotnet/test/AutoGen.Test.Share/EchoAgent.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// EchoAgent.cs + +using System.Runtime.CompilerServices; +using AutoGen.Core; + +namespace AutoGen.Tests; + +public class EchoAgent : IStreamingAgent +{ + public EchoAgent(string name) + { + Name = name; + } + public string Name { get; } + + public Task GenerateReplyAsync( + IEnumerable conversation, + GenerateReplyOptions? options = null, + CancellationToken ct = default) + { + // return the most recent message + var lastMessage = conversation.Last(); + lastMessage.From = this.Name; + + return Task.FromResult(lastMessage); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var message in messages) + { + message.From = this.Name; + yield return message; + } + } +} diff --git a/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs b/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs deleted file mode 100644 index 1042dec6f271..000000000000 --- a/dotnet/test/AutoGen.Tests/Attribute/EnvironmentSpecificFactAttribute.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// EnvironmentSpecificFactAttribute.cs - -using System; -using Xunit; - -namespace AutoGen.Tests -{ - /// - /// A base class for environment-specific fact attributes. - /// - [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = true)] - public abstract class EnvironmentSpecificFactAttribute : FactAttribute - { - private readonly string _skipMessage; - - /// - /// Creates a new instance of the class. - /// - /// The message to be used when skipping the test marked with this attribute. - protected EnvironmentSpecificFactAttribute(string skipMessage) - { - _skipMessage = skipMessage ?? throw new ArgumentNullException(nameof(skipMessage)); - } - - public sealed override string Skip => IsEnvironmentSupported() ? string.Empty : _skipMessage; - - /// - /// A method used to evaluate whether to skip a test marked with this attribute. Skips iff this method evaluates to false. - /// - protected abstract bool IsEnvironmentSupported(); - } -} diff --git a/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs b/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs deleted file mode 100644 index 44457d8f571c..000000000000 --- a/dotnet/test/AutoGen.Tests/Attribute/OpenAIFact.cs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// OpenAIFact.cs - -using System; -using System.Linq; - -namespace AutoGen.Tests -{ - /// - /// A fact for tests requiring OPENAI_API_KEY env. - /// - public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute - { - private readonly string[] _envVariableNames; - public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{envVariableNames} is not found in env") - { - _envVariableNames = envVariableNames; - } - - /// - protected override bool IsEnvironmentSupported() - { - return _envVariableNames.All(Environment.GetEnvironmentVariables().Contains); - } - } -} diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj index ce968b91f556..a0c3b815f22b 100644 --- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj +++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs deleted file mode 100644 index af5490218e8d..000000000000 --- a/dotnet/test/AutoGen.Tests/EchoAgent.cs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// EchoAgent.cs - -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace AutoGen.Tests -{ - public class EchoAgent : IStreamingAgent - { - public EchoAgent(string name) - { - Name = name; - } - public string Name { get; } - - public Task GenerateReplyAsync( - IEnumerable conversation, - GenerateReplyOptions? options = null, - CancellationToken ct = default) - { - // return the most recent message - var lastMessage = conversation.Last(); - lastMessage.From = this.Name; - - return Task.FromResult(lastMessage); - } - - public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var message in messages) - { - message.From = this.Name; - yield return message; - } - } - } -} diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs index 5a2cebb66cff..f9ab09716b94 100644 --- a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs +++ b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs @@ -10,11 +10,14 @@ using AutoGen.Anthropic; using AutoGen.Anthropic.Extensions; using AutoGen.Anthropic.Utils; +using AutoGen.AzureAIInference; +using AutoGen.AzureAIInference.Extension; using AutoGen.Gemini; using AutoGen.Mistral; using AutoGen.Mistral.Extension; using AutoGen.OpenAI; using AutoGen.OpenAI.Extension; +using Azure.AI.Inference; using Azure.AI.OpenAI; using FluentAssertions; using Moq; @@ -304,6 +307,22 @@ public async Task Mistra_7b_CoderReviewerRunnerTestAsync() await CoderReviewerRunnerTestAsync(agent); } + [ApiKeyFact("GH_API_KEY")] + public async Task LLaMA_3_1_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GH_API_KEY") ?? throw new InvalidOperationException("GH_API_KEY is not set."); + var endPoint = "https://models.inference.ai.azure.com"; + + var chatCompletionClient = new ChatCompletionsClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey)); + var agent = new ChatCompletionsClientAgent( + chatCompletionsClient: chatCompletionClient, + name: "assistant", + modelName: "Meta-Llama-3.1-70B-Instruct") + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(agent); + } + /// /// This test is to mimic the conversation among coder, reviewer and runner. /// The coder will write the code, the reviewer will review the code, and the runner will run the code. diff --git a/dotnet/website/articles/Agent-overview.md b/dotnet/website/articles/Agent-overview.md index 0b84cdc49ac7..586d231a6e7d 100644 --- a/dotnet/website/articles/Agent-overview.md +++ b/dotnet/website/articles/Agent-overview.md @@ -8,7 +8,6 @@ - Create an @AutoGen.OpenAI.OpenAIChatAgent: [Create an OpenAI chat agent](./OpenAIChatAgent-simple-chat.md) - Create a @AutoGen.SemanticKernel.SemanticKernelAgent: [Create a semantic kernel agent](./AutoGen.SemanticKernel/SemanticKernelAgent-simple-chat.md) - Create a @AutoGen.LMStudio.LMStudioAgent: [Connect to LM Studio](./Consume-LLM-server-from-LM-Studio.md) -- Create your own agent: [Create your own agent](./Create-your-own-agent.md) ## Chat with an agent To chat with an agent, typically you can invoke @AutoGen.Core.IAgent.GenerateReplyAsync*. On top of that, you can also use one of the extension methods like @AutoGen.Core.AgentExtension.SendAsync* as shortcuts. diff --git a/dotnet/website/articles/Installation.md b/dotnet/website/articles/Installation.md index 3ec5d3a470f4..30b55442d246 100644 --- a/dotnet/website/articles/Installation.md +++ b/dotnet/website/articles/Installation.md @@ -13,8 +13,9 @@ AutoGen.Net provides the following packages, you can choose to install one or mo - `AutoGen.LMStudio`: This package provides the integration agents from LM Studio. - `AutoGen.SemanticKernel`: This package provides the integration agents over semantic kernel. - `AutoGen.Gemini`: This package provides the integration agents from [Google Gemini](https://gemini.google.com/). +- `AutoGen.AzureAIInference`: This package provides the integration agents for [Azure AI Inference](https://www.nuget.org/packages/Azure.AI.Inference). - `AutoGen.SourceGenerator`: This package carries a source generator that adds support for type-safe function definition generation. -- `AutoGen.DotnetInteractive`: This packages carries dotnet interactive support to execute dotnet code snippet. +- `AutoGen.DotnetInteractive`: This packages carries dotnet interactive support to execute code snippets. The current supported language is C#, F#, powershell and python. >[!Note] > Help me choose