Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Support image input for Anthropic Models #2849

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public sealed class AnthropicClient : IDisposable

private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new ContentBaseConverter() }
};

private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
Expand Down
11 changes: 8 additions & 3 deletions dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

using System.Text.Json.Serialization;
using System.Collections.Generic;

namespace AutoGen.Anthropic.DTO;

using System.Collections.Generic;

public class ChatCompletionRequest
{
[JsonPropertyName("model")]
Expand Down Expand Up @@ -50,9 +49,15 @@ public class ChatMessage
public string Role { get; set; }

[JsonPropertyName("content")]
public string Content { get; set; }
public List<ContentBase> Content { get; set; }

public ChatMessage(string role, string content)
{
Role = role;
Content = new List<ContentBase>() { new TextContent { Text = content } };
}

public ChatMessage(string role, List<ContentBase> content)
{
Role = role;
Content = content;
Expand Down
112 changes: 91 additions & 21 deletions dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -19,7 +20,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);
var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);

return response is IMessage<ChatCompletionResponse> chatMessage
Expand All @@ -31,7 +32,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);

await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
Expand All @@ -53,60 +54,78 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> chatMessage,
IStreamingAgent agent)
{
Delta? delta = chatMessage.Content.Delta;
var delta = chatMessage.Content.Delta;
return delta != null && !string.IsNullOrEmpty(delta.Text)
? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
: null;
}

private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
private async Task<IEnumerable<IMessage>> ProcessMessageAsync(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany<IMessage, IMessage>(m =>
var processedMessages = new List<IMessage>();

foreach (var message in messages)
{
return m switch
var processedMessage = message switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
_ => [m],

ImageMessage imageMessage =>
new MessageEnvelope<ChatMessage>(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
from: agent.Name),

MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
_ => message,
};
});

processedMessages.Add(processedMessage);
}

return processedMessages;
}

private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from)
{
if (response.Content is null)
{
throw new ArgumentNullException(nameof(response.Content));
}

if (response.Content.Count != 1)
{
throw new NotSupportedException($"{nameof(response.Content)} != 1");
}

return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
}

private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMessage, IAgent agent)
private IMessage<ChatMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
IEnumerable<ChatMessage> messages;
ChatMessage messages;

if (textMessage.From == agent.Name)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.From is null)
{
if (textMessage.Role == Role.User)
{
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}
else if (textMessage.Role == Role.Assistant)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.Role == Role.System)
{
messages = [new ChatMessage(
"system", textMessage.Content)];
messages = new ChatMessage(
"system", textMessage.Content);
}
else
{
Expand All @@ -116,10 +135,61 @@ private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMe
else
{
// if from is not null, then the message is from user
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}

return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: textMessage.From));
return new MessageEnvelope<ChatMessage>(messages, from: textMessage.From);
}

private async Task<IMessage> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List<ContentBase>();
foreach (var message in multiModalMessage.Content)
{
switch (message)
{
case TextMessage textMessage when textMessage.GetContent() is not null:
content.Add(new TextContent { Text = textMessage.GetContent() });
break;
case ImageMessage imageMessage:
content.Add(new ImageContent() { Source = await ProcessImageSourceAsync(imageMessage) });
break;
}
}

var chatMessage = new ChatMessage("user", content);
return MessageEnvelope.Create(chatMessage, agent.Name);
}

private async Task<ImageSource> ProcessImageSourceAsync(ImageMessage imageMessage)
{
if (imageMessage.Data != null)
{
return new ImageSource
{
MediaType = imageMessage.Data.MediaType,
Data = Convert.ToBase64String(imageMessage.Data.ToArray())
};
}

if (imageMessage.Url is null)
{
throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
}

var uri = new Uri(imageMessage.Url);
using var client = new HttpClient();
var response = client.GetAsync(uri).Result;
if (!response.IsSuccessStatusCode)
{
throw new HttpRequestException($"Failed to download the image from {uri}");
}

return new ImageSource
{
MediaType = "image/jpeg",
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
}
1 change: 0 additions & 1 deletion dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{

return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}

Expand Down
95 changes: 86 additions & 9 deletions dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClientAgentTest.cs

using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Core;
using AutoGen.Tests;
using Xunit.Abstractions;
using FluentAssertions;

namespace AutoGen.Anthropic;
namespace AutoGen.Anthropic.Tests;

public class AnthropicClientAgentTest
{
private readonly ITestOutputHelper _output;

public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output;

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentChatCompletionTestAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);

var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that convert user message to upper case")
.RegisterMessageConnector();

var uppCaseMessage = new TextMessage(Role.User, "abcdefg");

var reply = await agent.SendAsync(chatHistory: new[] { uppCaseMessage });

reply.GetContent().Should().Contain("ABCDEFG");
reply.From.Should().Be(agent.Name);
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku).RegisterMessageConnector();

var singleAgentTest = new SingleAgentTest(_output);
await singleAgentTest.UpperCaseTestAsync(agent);
await singleAgentTest.UpperCaseStreamingTestAsync(agent);
var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png");
var imageMessage = new ChatMessage("user",
DavidLuong98 marked this conversation as resolved.
Show resolved Hide resolved
[new ImageContent { Source = new ImageSource { MediaType = "image/png", Data = base64Image } }]);

var messages = new IMessage[] { MessageEnvelope.Create(imageMessage) };

// test streaming
foreach (var message in messages)
{
var reply = agent.GenerateStreamingReplyAsync([message]);

await foreach (var streamingMessage in reply)
{
streamingMessage.Should().BeOfType<TextMessageUpdate>();
streamingMessage.As<TextMessageUpdate>().From.Should().Be(agent.Name);
}
}
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestMultiModalAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku)
.RegisterMessageConnector();

var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);
var textMessage = new TextMessage(Role.User, "What's in this image?");
var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);

var reply = await agent.SendAsync(multiModalMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestImageMessageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that is capable of determining what an image is. Tell me a brief description of the image."
)
.RegisterMessageConnector();

var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);

var reply = await agent.SendAsync(imageMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
}
Loading
Loading