diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs index 50bcd8a8048e..5dc2cbe701ad 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs @@ -2,6 +2,7 @@ // TypeSafeFunctionCallCodeSnippet.cs using System.Text.Json; +using System.Text.Json.Serialization; using AutoGen.OpenAI.Extension; using Azure.AI.OpenAI; #region weather_report_using_statement @@ -11,6 +12,15 @@ #region weather_report public partial class TypeSafeFunctionCall { + private class GetWeatherSchema + { + [JsonPropertyName(@"city")] + public string city { get; set; } + + [JsonPropertyName(@"date")] + public string date { get; set; } + } + /// /// Get weather report /// @@ -21,7 +31,20 @@ public async Task WeatherReport(string city, string date) { return $"Weather report for {city} on {date} is sunny"; } + + public Task GetWeatherReportWrapper(string arguments) + { + var schema = JsonSerializer.Deserialize( + arguments, + new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + + return WeatherReport(schema.city, schema.date); + } } + #endregion weather_report public partial class TypeSafeFunctionCall diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs index e395bb4a225f..75940cc8741c 100644 --- a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs +++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -16,6 +17,7 @@ public class AnthropicClientAgent : IStreamingAgent private readonly string _systemMessage; private readonly decimal _temperature; private readonly int _maxTokens; + private readonly Tool[]? _tools; public AnthropicClientAgent( AnthropicClient anthropicClient, @@ -23,7 +25,8 @@ public AnthropicClientAgent( string modelName, string systemMessage = "You are a helpful AI assistant", decimal temperature = 0.7m, - int maxTokens = 1024) + int maxTokens = 1024, + Tool[]? tools = null) { Name = name; _anthropicClient = anthropicClient; @@ -31,6 +34,7 @@ public AnthropicClientAgent( _systemMessage = systemMessage; _temperature = temperature; _maxTokens = maxTokens; + _tools = tools; } public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, @@ -59,6 +63,7 @@ private ChatCompletionRequest CreateParameters(IEnumerable messages, G Model = _modelName, Stream = shouldStream, Temperature = (decimal?)options?.Temperature ?? _temperature, + Tools = _tools?.ToList() }; chatCompletionRequest.Messages = BuildMessages(messages); diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs index 4cb8fdbb34e0..b41a761dc4d3 100644 --- a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs +++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs @@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert, return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); case "image": return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "tool_use": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "tool_result": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); } } diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs index 0c1749eaa989..1d45a62caf80 100644 --- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs @@ -37,6 +37,9 @@ public class ChatCompletionRequest [JsonPropertyName("top_p")] public decimal? TopP { get; set; } + [JsonPropertyName("tools")] + public List? Tools { get; set; } + public ChatCompletionRequest() { Messages = new List(); @@ -62,4 +65,6 @@ public ChatMessage(string role, List content) Role = role; Content = content; } + + public void AddContent(ContentBase content) => Content.Add(content); } diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs index dd2481bd58f3..ee7a745a1416 100644 --- a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs +++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Content.cs +using System.Text.Json.Nodes; using System.Text.Json.Serialization; namespace AutoGen.Anthropic.DTO; @@ -40,3 +41,30 @@ public class ImageSource [JsonPropertyName("data")] public string? Data { get; set; } } + +public class ToolUseContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "tool_use"; + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("input")] + public JsonNode? Input { get; set; } +} + +public class ToolResultContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "tool_result"; + + [JsonPropertyName("tool_use_id")] + public string? Id { get; set; } + + [JsonPropertyName("content")] + public string? Content { get; set; } +} diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs new file mode 100644 index 000000000000..41c20dc2a42d --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Tool.cs + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace AutoGen.Anthropic.DTO; + +public class Tool +{ + [JsonPropertyName("name")] + public string? Name { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } + + [JsonPropertyName("input_schema")] + public InputSchema? InputSchema { get; set; } +} + +public class InputSchema +{ + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("properties")] + public Dictionary? Properties { get; set; } + + [JsonPropertyName("required")] + public List? Required { get; set; } +} + +public class SchemaProperty +{ + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("description")] + public string? Description { get; set; } +} diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs index bb2f5820f74c..1191f84d28ec 100644 --- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs +++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net.Http; using System.Runtime.CompilerServices; +using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; using AutoGen.Anthropic.DTO; @@ -71,16 +72,20 @@ private async Task> ProcessMessageAsync(IEnumerable ProcessTextMessage(textMessage, agent), ImageMessage imageMessage => - new MessageEnvelope(new ChatMessage("user", + (MessageEnvelope[])[new MessageEnvelope(new ChatMessage("user", new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } } .ToList()), - from: agent.Name), + from: agent.Name)], MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent), - _ => message, + + ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent), + ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage), + AggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent), + _ => [message], }; - processedMessages.Add(processedMessage); + processedMessages.AddRange(processedMessage); } return processedMessages; @@ -93,15 +98,42 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from throw new ArgumentNullException(nameof(response.Content)); } - if (response.Content.Count != 1) + // When expecting a tool call, sometimes the response will contain two messages, one chat and one tool. + // The first message is typically a TextContent, of the LLM explaining what it is trying to do. + // The second message contains the tool call. + if (response.Content.Count > 1) { - throw new NotSupportedException($"{nameof(response.Content)} != 1"); + if (response.Content.Count == 2 && response.Content[0] is TextContent && + response.Content[1] is ToolUseContent toolUseContent) + { + return new ToolCallMessage(toolUseContent.Name ?? string.Empty, + toolUseContent.Input?.ToJsonString() ?? string.Empty, + from: from.Name); + } + + throw new NotSupportedException($"Expected {nameof(response.Content)} to have one output"); } - return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name); + var content = response.Content[0]; + switch (content) + { + case TextContent textContent: + return new TextMessage(Role.Assistant, textContent.Text ?? string.Empty, from: from.Name); + + case ToolUseContent toolUseContent: + return new ToolCallMessage(toolUseContent.Name ?? string.Empty, + toolUseContent.Input?.ToJsonString() ?? string.Empty, + from: from.Name); + + case ImageContent: + throw new InvalidOperationException( + "Claude is an image understanding model only. It can interpret and analyze images, but it cannot generate, produce, edit, manipulate or create images"); + default: + throw new ArgumentOutOfRangeException(nameof(content)); + } } - private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent) + private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent) { ChatMessage messages; @@ -139,10 +171,10 @@ private IMessage ProcessTextMessage(TextMessage textMessage, IAgent "user", textMessage.Content); } - return new MessageEnvelope(messages, from: textMessage.From); + return [new MessageEnvelope(messages, from: textMessage.From)]; } - private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent) + private async Task> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent) { var content = new List(); foreach (var message in multiModalMessage.Content) @@ -158,8 +190,7 @@ private async Task ProcessMultiModalMessageAsync(MultiModalMessage mul } } - var chatMessage = new ChatMessage("user", content); - return MessageEnvelope.Create(chatMessage, agent.Name); + return [MessageEnvelope.Create(new ChatMessage("user", content), agent.Name)]; } private async Task ProcessImageSourceAsync(ImageMessage imageMessage) @@ -192,4 +223,52 @@ private async Task ProcessImageSourceAsync(ImageMessage imageMessag Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync()) }; } + + private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent) + { + var chatMessage = new ChatMessage("assistant", new List()); + foreach (var toolCall in toolCallMessage.ToolCalls) + { + chatMessage.AddContent(new ToolUseContent + { + Id = toolCall.ToolCallId, + Name = toolCall.FunctionName, + Input = JsonNode.Parse(toolCall.FunctionArguments) + }); + } + + return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)]; + } + + private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage) + { + var chatMessage = new ChatMessage("user", new List()); + foreach (var toolCall in toolCallResultMessage.ToolCalls) + { + chatMessage.AddContent(new ToolResultContent + { + Id = toolCall.ToolCallId ?? string.Empty, + Content = toolCall.Result, + }); + } + + return [MessageEnvelope.Create(chatMessage, toolCallResultMessage.From)]; + } + + private IEnumerable ProcessToolCallAggregateMessage(AggregateMessage aggregateMessage, IAgent agent) + { + if (aggregateMessage.From is { } from && from != agent.Name) + { + var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result); + var messages = contents.Select(c => + new ChatMessage("assistant", c ?? throw new ArgumentNullException(nameof(c)))); + + return messages.Select(m => new MessageEnvelope(m, from: from)); + } + + var toolCallMessage = ProcessToolCallMessage(aggregateMessage.Message1, agent); + var toolCallResult = ProcessToolCallResultMessage(aggregateMessage.Message2); + + return toolCallMessage.Concat(toolCallResult); + } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs index d29025b44aff..2c6e86d270c5 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -7,6 +7,7 @@ using AutoGen.Core; using AutoGen.Tests; using FluentAssertions; +using Xunit; namespace AutoGen.Anthropic.Tests; @@ -105,4 +106,101 @@ public async Task AnthropicAgentTestImageMessageAsync() reply.GetContent().Should().NotBeNullOrEmpty(); reply.From.Should().Be(agent.Name); } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentTestToolAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var function = new TypeSafeFunctionCall(); + var functionCallMiddleware = new FunctionCallMiddleware( + functions: new[] { function.WeatherReportFunctionContract }, + functionMap: new Dictionary>> + { + { function.WeatherReportFunctionContract.Name ?? string.Empty, function.WeatherReportWrapper }, + }); + + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are an LLM that is specialized in finding the weather !", + tools: [AnthropicTestUtils.WeatherTool] + ) + .RegisterMessageConnector() + .RegisterStreamingMiddleware(functionCallMiddleware); + + var reply = await agent.SendAsync("What is the weather in Philadelphia?"); + reply.GetContent().Should().Be("Weather report for Philadelphia on today is sunny"); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentFunctionCallMessageTest() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant.", + tools: [AnthropicTestUtils.WeatherTool] + ) + .RegisterMessageConnector(); + + var weatherFunctionArgumets = """ + { + "city": "Philadelphia", + "date": "6/14/2024" + } + """; + + var function = new TypeSafeFunctionCall(); + var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets); + var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets) + { + ToolCallId = "get_weather", + Result = functionCallResult, + }; + + IMessage[] chatHistory = [ + new TextMessage(Role.User, "what's the weather in Philadelphia?"), + new ToolCallMessage([toolCall], from: "assistant"), + new ToolCallResultMessage([toolCall], from: "user" ), + ]; + + var reply = await agent.SendAsync(chatHistory: chatHistory); + + reply.Should().BeOfType(); + reply.GetContent().Should().Be("The weather report for Philadelphia on 6/14/2024 is sunny."); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentFunctionCallMiddlewareMessageTest() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + var function = new TypeSafeFunctionCall(); + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [function.WeatherReportFunctionContract], + functionMap: new Dictionary>> + { + { function.WeatherReportFunctionContract.Name!, function.GetWeatherReportWrapper } + }); + + var functionCallAgent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant.", + tools: [AnthropicTestUtils.WeatherTool] + ) + .RegisterMessageConnector() + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = new TextMessage(Role.User, "what's the weather in Philadelphia?"); + var reply = await functionCallAgent.SendAsync(question); + + var finalReply = await functionCallAgent.SendAsync(chatHistory: [question, reply]); + finalReply.Should().BeOfType(); + finalReply.GetContent()!.ToLower().Should().Contain("sunny"); + } } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs index a0b1f60cfb95..ce07a25f7fc3 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs @@ -1,5 +1,6 @@ using System.Text; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using AutoGen.Anthropic.DTO; using AutoGen.Anthropic.Utils; @@ -108,6 +109,52 @@ public async Task AnthropicClientImageChatCompletionTestAsync() response.Usage.OutputTokens.Should().BeGreaterThan(0); } + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicClientTestToolsAsync() + { + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var stockPriceTool = new Tool + { + Name = "get_stock_price", + Description = "Get the current stock price for a given ticker symbol.", + InputSchema = new InputSchema + { + Type = "object", + Properties = new Dictionary + { + { + "ticker", new SchemaProperty + { + Type = "string", + Description = "The stock ticker symbol, e.g. AAPL for Apple Inc." + } + } + }, + Required = new List { "ticker" } + } + }; + + var request = new ChatCompletionRequest(); + request.Model = AnthropicConstants.Claude3Haiku; + request.Stream = false; + request.MaxTokens = 100; + request.Messages = new List() { new("user", "Use the stock price tool to look for MSFT. Your response should only be the tool.") }; + request.Tools = new List() { stockPriceTool }; + + ChatCompletionResponse response = + await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None); + + Assert.NotNull(response.Content); + Assert.True(response.Content.First() is ToolUseContent); + ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First()); + Assert.Equal("get_stock_price", toolUseContent.Name); + Assert.NotNull(toolUseContent.Input); + Assert.True(toolUseContent.Input is JsonNode); + JsonNode jsonNode = toolUseContent.Input; + Assert.Equal("{\"ticker\":\"MSFT\"}", jsonNode.ToJsonString()); + } + private sealed class Person { [JsonPropertyName("name")] diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs index de630da6d87c..a0f821784a5d 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicTestUtils.cs +using AutoGen.Anthropic.DTO; + namespace AutoGen.Anthropic.Tests; public static class AnthropicTestUtils @@ -13,4 +15,25 @@ public static async Task Base64FromImageAsync(string imageName) return Convert.ToBase64String( await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName))); } + + public static Tool WeatherTool + { + get + { + return new Tool + { + Name = "WeatherReport", + Description = "Get the current weather", + InputSchema = new InputSchema + { + Type = "object", + Properties = new Dictionary + { + { "city", new SchemaProperty {Type = "string", Description = "The name of the city"} }, + { "date", new SchemaProperty {Type = "string", Description = "date of the day"} } + } + } + }; + } + } }