Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net]: Introduce ChatCompletionAgent to AutoGen.SemanticKernel package #2584

Merged
merged 16 commits into from
May 9, 2024
Merged
4 changes: 2 additions & 2 deletions dotnet/eng/Version.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup>
<AzureOpenAIVersion>1.0.0-beta.17</AzureOpenAIVersion>
<SemanticKernelVersion>1.7.1</SemanticKernelVersion>
<SemanticKernelExperimentalVersion>1.7.1-alpha</SemanticKernelExperimentalVersion>
<SemanticKernelVersion>1.10.0</SemanticKernelVersion>
<SemanticKernelExperimentalVersion>1.10.0-alpha</SemanticKernelExperimentalVersion>
<SystemCodeDomVersion>5.0.0</SystemCodeDomVersion>
<MicrosoftCodeAnalysisVersion>4.3.0</MicrosoftCodeAnalysisVersion>
<ApprovalTestVersion>6.0.0</ApprovalTestVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<NoWarn>$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050</NoWarn>
<NoWarn>$(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,4 @@ public async Task SemanticKernelChatMessageContentConnector()
}
#endregion register_semantic_kernel_chat_message_content_connector
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
<RootNamespace>AutoGen.SemanticKernel</RootNamespace>
</PropertyGroup>

<PropertyGroup>
<NoWarn>$(NoWarn);SKEXP0110</NoWarn>
</PropertyGroup>

<Import Project="$(RepoRoot)/nuget/nuget-package.props" />

<PropertyGroup>
Expand All @@ -18,6 +22,7 @@
<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="$(AzureOpenAIVersion)" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Agents.Core" Version="$(SemanticKernelExperimentalVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SemanticKernelChatCompletionAgent.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;

namespace AutoGen.SemanticKernel;

public class SemanticKernelChatCompletionAgent : IAgent
{
public string Name { get; }
private readonly ChatCompletionAgent _chatCompletionAgent;

public SemanticKernelChatCompletionAgent(ChatCompletionAgent chatCompletionAgent)
{
this.Name = chatCompletionAgent.Name ?? throw new ArgumentNullException(nameof(chatCompletionAgent.Name));
this._chatCompletionAgent = chatCompletionAgent;
}

public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
ChatMessageContent[] reply = await _chatCompletionAgent
.InvokeAsync(BuildChatHistory(messages), cancellationToken)
.ToArrayAsync(cancellationToken: cancellationToken);

return reply.Length > 1
? throw new InvalidOperationException("ResultsPerPrompt greater than 1 is not supported in this semantic kernel agent")
: new MessageEnvelope<ChatMessageContent>(reply[0], from: this.Name);
}

private ChatHistory BuildChatHistory(IEnumerable<IMessage> messages)
{
return new ChatHistory(ProcessMessage(messages));
}

private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages)
{
return messages.Select(m => m switch
{
IMessage<ChatMessageContent> cmc => cmc.Content,
_ => throw new ArgumentException("Invalid message type")
});
}
}
2 changes: 1 addition & 1 deletion dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>$(TestTargetFramework)</TargetFramework>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<NoWarn>$(NoWarn);xUnit1013</NoWarn>
<NoWarn>$(NoWarn);xUnit1013;SKEXP0110</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand Down
111 changes: 109 additions & 2 deletions dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
using AutoGen.SemanticKernel.Extension;
using FluentAssertions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;

namespace AutoGen.Tests;

Expand Down Expand Up @@ -69,8 +71,7 @@ public async Task SemanticKernelChatMessageContentConnectorTestAsync()
var messages = new IMessage[]
{
MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello")),
new TextMessage(Role.Assistant, "Hello", from: "user"),
new MultiModalMessage(Role.Assistant,
new TextMessage(Role.Assistant, "Hello", from: "user"), new MultiModalMessage(Role.Assistant,
[
new TextMessage(Role.Assistant, "Hello", from: "user"),
],
Expand Down Expand Up @@ -128,4 +129,110 @@ public async Task SemanticKernelPluginTestAsync()
reply.GetContent()!.ToLower().Should().Contain("seattle");
reply.GetContent()!.ToLower().Should().Contain("sunny");
}


[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")]
public async Task BasicSkChatCompletionAgentConversationTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var builder = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key);

var kernel = builder.Build();
var agent = new ChatCompletionAgent()
{
Kernel = kernel,
Name = "assistant",
Instructions = "You are a helpful AI assistant"
};

var skAgent = new SemanticKernelChatCompletionAgent(agent);

var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello"));
var reply = await skAgent.SendAsync(chatMessageContent);

reply.Should().BeOfType<MessageEnvelope<ChatMessageContent>>();
reply.As<MessageEnvelope<ChatMessageContent>>().From.Should().Be("assistant");
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")]
public async Task SkChatCompletionAgentChatMessageContentConnectorTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var builder = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key);

var kernel = builder.Build();

var connector = new SemanticKernelChatMessageContentConnector();
var agent = new ChatCompletionAgent()
{
Kernel = kernel,
Name = "assistant",
Instructions = "You are a helpful AI assistant"
};
var skAgent = new SemanticKernelChatCompletionAgent(agent)
.RegisterMiddleware(connector);

var messages = new IMessage[]
{
MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello")),
new TextMessage(Role.Assistant, "Hello", from: "user"), new MultiModalMessage(Role.Assistant,
[
new TextMessage(Role.Assistant, "Hello", from: "user"),
],
from: "user"),
};

foreach (var message in messages)
{
var reply = await skAgent.SendAsync(message);

reply.Should().BeOfType<TextMessage>();
reply.As<TextMessage>().From.Should().Be("assistant");
}
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")]
public async Task SkChatCompletionAgentPluginTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var builder = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key);

var parameters = this.GetWeatherAsyncFunctionContract.Parameters!.Select(p => new KernelParameterMetadata(p.Name!)
{
Description = p.Description,
DefaultValue = p.DefaultValue,
IsRequired = p.IsRequired,
ParameterType = p.ParameterType,
});
var function = KernelFunctionFactory.CreateFromMethod(this.GetWeatherAsync, this.GetWeatherAsyncFunctionContract.Name, this.GetWeatherAsyncFunctionContract.Description, parameters);
builder.Plugins.AddFromFunctions("plugins", [function]);
var kernel = builder.Build();

var agent = new ChatCompletionAgent()
{
Kernel = kernel,
Name = "assistant",
Instructions = "You are a helpful AI assistant",
ExecutionSettings =
new OpenAIPromptExecutionSettings()
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
}
};
var skAgent =
new SemanticKernelChatCompletionAgent(agent).RegisterMiddleware(
new SemanticKernelChatMessageContentConnector());

var question = "What is the weather in Seattle?";
var reply = await skAgent.SendAsync(question);

reply.GetContent()!.ToLower().Should().Contain("seattle");
reply.GetContent()!.ToLower().Should().Contain("sunny");
}
}
Loading