From 101f482213712227d4b5ded2a9aca4eee44c743e Mon Sep 17 00:00:00 2001 From: David Luong Date: Sun, 9 Jun 2024 10:43:38 -0400 Subject: [PATCH] Squash changes --- .../src/AutoGen.Anthropic/AnthropicClient.cs | 3 +- .../DTO/ChatCompletionRequest.cs | 11 +- .../Middleware/AnthropicMessageConnector.cs | 112 ++++++++++++++---- .../Agent/MiddlewareStreamingAgent.cs | 1 - .../AnthropicClientAgentTest.cs | 95 +++++++++++++-- .../AnthropicClientTest.cs | 37 +++++- .../AnthropicTestUtils.cs | 8 +- .../AutoGen.Anthropic.Tests.csproj | 6 + .../images/.gitattributes | 1 + .../AutoGen.Anthropic.Tests/images/square.png | 3 + 10 files changed, 240 insertions(+), 37 deletions(-) create mode 100644 dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes create mode 100644 dotnet/test/AutoGen.Anthropic.Tests/images/square.png diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs index 8ea0bef86e2c..90bd33683f20 100644 --- a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs +++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs @@ -23,7 +23,8 @@ public sealed class AnthropicClient : IDisposable private static readonly JsonSerializerOptions JsonSerializerOptions = new() { - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + Converters = { new ContentBaseConverter() } }; private static readonly JsonSerializerOptions JsonDeserializerOptions = new() diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs index fa1654bc11d0..36cc1bb8e3e3 100644 --- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. using System.Text.Json.Serialization; +using System.Collections.Generic; namespace AutoGen.Anthropic.DTO; -using System.Collections.Generic; - public class ChatCompletionRequest { [JsonPropertyName("model")] @@ -50,9 +49,15 @@ public class ChatMessage public string Role { get; set; } [JsonPropertyName("content")] - public string Content { get; set; } + public List Content { get; set; } public ChatMessage(string role, string content) + { + Role = role; + Content = new List() { new TextContent { Text = content } }; + } + + public ChatMessage(string role, List content) { Role = role; Content = content; diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs index bfe79190925f..bb2f5820f74c 100644 --- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs +++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.Http; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -19,7 +20,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { var messages = context.Messages; - var chatMessages = ProcessMessage(messages, agent); + var chatMessages = await ProcessMessageAsync(messages, agent); var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); return response is IMessage chatMessage @@ -31,7 +32,7 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c [EnumeratorCancellation] CancellationToken cancellationToken = default) { var messages = context.Messages; - var chatMessages = ProcessMessage(messages, agent); + var chatMessages = await ProcessMessageAsync(messages, agent); await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) { @@ -53,60 +54,78 @@ public async IAsyncEnumerable InvokeAsync(MiddlewareContext c private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage chatMessage, IStreamingAgent agent) { - Delta? delta = chatMessage.Content.Delta; + var delta = chatMessage.Content.Delta; return delta != null && !string.IsNullOrEmpty(delta.Text) ? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name) : null; } - private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent) + private async Task> ProcessMessageAsync(IEnumerable messages, IAgent agent) { - return messages.SelectMany(m => + var processedMessages = new List(); + + foreach (var message in messages) { - return m switch + var processedMessage = message switch { TextMessage textMessage => ProcessTextMessage(textMessage, agent), - _ => [m], + + ImageMessage imageMessage => + new MessageEnvelope(new ChatMessage("user", + new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } } + .ToList()), + from: agent.Name), + + MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent), + _ => message, }; - }); + + processedMessages.Add(processedMessage); + } + + return processedMessages; } private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from) { if (response.Content is null) + { throw new ArgumentNullException(nameof(response.Content)); + } if (response.Content.Count != 1) + { throw new NotSupportedException($"{nameof(response.Content)} != 1"); + } return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name); } - private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent) + private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent) { - IEnumerable messages; + ChatMessage messages; if (textMessage.From == agent.Name) { - messages = [new ChatMessage( - "assistant", textMessage.Content)]; + messages = new ChatMessage( + "assistant", textMessage.Content); } else if (textMessage.From is null) { if (textMessage.Role == Role.User) { - messages = [new ChatMessage( - "user", textMessage.Content)]; + messages = new ChatMessage( + "user", textMessage.Content); } else if (textMessage.Role == Role.Assistant) { - messages = [new ChatMessage( - "assistant", textMessage.Content)]; + messages = new ChatMessage( + "assistant", textMessage.Content); } else if (textMessage.Role == Role.System) { - messages = [new ChatMessage( - "system", textMessage.Content)]; + messages = new ChatMessage( + "system", textMessage.Content); } else { @@ -116,10 +135,61 @@ private IEnumerable> ProcessTextMessage(TextMessage textMe else { // if from is not null, then the message is from user - messages = [new ChatMessage( - "user", textMessage.Content)]; + messages = new ChatMessage( + "user", textMessage.Content); } - return messages.Select(m => new MessageEnvelope(m, from: textMessage.From)); + return new MessageEnvelope(messages, from: textMessage.From); + } + + private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent) + { + var content = new List(); + foreach (var message in multiModalMessage.Content) + { + switch (message) + { + case TextMessage textMessage when textMessage.GetContent() is not null: + content.Add(new TextContent { Text = textMessage.GetContent() }); + break; + case ImageMessage imageMessage: + content.Add(new ImageContent() { Source = await ProcessImageSourceAsync(imageMessage) }); + break; + } + } + + var chatMessage = new ChatMessage("user", content); + return MessageEnvelope.Create(chatMessage, agent.Name); + } + + private async Task ProcessImageSourceAsync(ImageMessage imageMessage) + { + if (imageMessage.Data != null) + { + return new ImageSource + { + MediaType = imageMessage.Data.MediaType, + Data = Convert.ToBase64String(imageMessage.Data.ToArray()) + }; + } + + if (imageMessage.Url is null) + { + throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided"); + } + + var uri = new Uri(imageMessage.Url); + using var client = new HttpClient(); + var response = client.GetAsync(uri).Result; + if (!response.IsSuccessStatusCode) + { + throw new HttpRequestException($"Failed to download the image from {uri}"); + } + + return new ImageSource + { + MediaType = "image/jpeg", + Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync()) + }; } } diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 251d3c110f98..52967d6ff1ce 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -49,7 +49,6 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs index ba31f2297ba8..d29025b44aff 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -1,31 +1,108 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicClientAgentTest.cs +using AutoGen.Anthropic.DTO; using AutoGen.Anthropic.Extensions; using AutoGen.Anthropic.Utils; +using AutoGen.Core; using AutoGen.Tests; -using Xunit.Abstractions; +using FluentAssertions; -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public class AnthropicClientAgentTest { - private readonly ITestOutputHelper _output; - - public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output; - [ApiKeyFact("ANTHROPIC_API_KEY")] public async Task AnthropicAgentChatCompletionTestAsync() { var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that convert user message to upper case") + .RegisterMessageConnector(); + + var uppCaseMessage = new TextMessage(Role.User, "abcdefg"); + + var reply = await agent.SendAsync(chatHistory: new[] { uppCaseMessage }); + + reply.GetContent().Should().Contain("ABCDEFG"); + reply.From.Should().Be(agent.Name); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestProcessImageAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); var agent = new AnthropicClientAgent( client, name: "AnthropicAgent", AnthropicConstants.Claude3Haiku).RegisterMessageConnector(); - var singleAgentTest = new SingleAgentTest(_output); - await singleAgentTest.UpperCaseTestAsync(agent); - await singleAgentTest.UpperCaseStreamingTestAsync(agent); + var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png"); + var imageMessage = new ChatMessage("user", + [new ImageContent { Source = new ImageSource { MediaType = "image/png", Data = base64Image } }]); + + var messages = new IMessage[] { MessageEnvelope.Create(imageMessage) }; + + // test streaming + foreach (var message in messages) + { + var reply = agent.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be(agent.Name); + } + } + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestMultiModalAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku) + .RegisterMessageConnector(); + + var image = Path.Combine("images", "square.png"); + var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png"); + var imageMessage = new ImageMessage(Role.User, binaryData); + var textMessage = new TextMessage(Role.User, "What's in this image?"); + var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]); + + var reply = await agent.SendAsync(multiModalMessage); + reply.Should().BeOfType(); + reply.GetRole().Should().Be(Role.Assistant); + reply.GetContent().Should().NotBeNullOrEmpty(); + reply.From.Should().Be(agent.Name); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestImageMessageAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that is capable of determining what an image is. Tell me a brief description of the image." + ) + .RegisterMessageConnector(); + + var image = Path.Combine("images", "square.png"); + var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png"); + var imageMessage = new ImageMessage(Role.User, binaryData); + + var reply = await agent.SendAsync(imageMessage); + reply.Should().BeOfType(); + reply.GetRole().Should().Be(Role.Assistant); + reply.GetContent().Should().NotBeNullOrEmpty(); + reply.From.Should().Be(agent.Name); } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs index 0b64c9e4e3c2..f62c062f3066 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs @@ -7,7 +7,7 @@ using FluentAssertions; using Xunit; -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public class AnthropicClientTests { @@ -73,6 +73,41 @@ public async Task AnthropicClientStreamingChatCompletionTestAsync() results.First().streamingMessage!.Role.Should().Be("assistant"); } + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicClientImageChatCompletionTestAsync() + { + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var request = new ChatCompletionRequest(); + request.Model = AnthropicConstants.Claude3Haiku; + request.Stream = false; + request.MaxTokens = 100; + request.SystemMessage = "You are a LLM that is suppose to describe the content of the image. Give me a description of the provided image."; + + var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png"); + var messages = new List + { + new("user", + [ + new ImageContent { Source = new ImageSource {MediaType = "image/png", Data = base64Image} } + ]) + }; + + request.Messages = messages; + + var response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None); + + Assert.NotNull(response); + Assert.NotNull(response.Content); + Assert.NotEmpty(response.Content); + response.Content.Count.Should().Be(1); + response.Content.First().Should().BeOfType(); + var textContent = (TextContent)response.Content.First(); + Assert.Equal("text", textContent.Type); + Assert.NotNull(response.Usage); + response.Usage.OutputTokens.Should().BeGreaterThan(0); + } + private sealed class Person { [JsonPropertyName("name")] diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs index a5b80eee3bdf..de630da6d87c 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs @@ -1,10 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicTestUtils.cs -namespace AutoGen.Anthropic; +namespace AutoGen.Anthropic.Tests; public static class AnthropicTestUtils { public static string ApiKey => Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Please set ANTHROPIC_API_KEY environment variable."); + + public static async Task Base64FromImageAsync(string imageName) + { + return Convert.ToBase64String( + await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName))); + } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj index 8cd1e3003b0e..9f30d7357d07 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj +++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj @@ -20,4 +20,10 @@ + + + + PreserveNewest + + diff --git a/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes b/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes new file mode 100644 index 000000000000..56e7c34d4989 --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/images/.gitattributes @@ -0,0 +1 @@ +square.png filter=lfs diff=lfs merge=lfs -text diff --git a/dotnet/test/AutoGen.Anthropic.Tests/images/square.png b/dotnet/test/AutoGen.Anthropic.Tests/images/square.png new file mode 100644 index 000000000000..5c2b3ed820b1 --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/images/square.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8341030e5b93aab2c55dcd40ffa26ced8e42cc15736a8348176ffd155ad2d937 +size 8167