diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
index 50bcd8a8048e..5dc2cbe701ad 100644
--- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
@@ -2,6 +2,7 @@
// TypeSafeFunctionCallCodeSnippet.cs
using System.Text.Json;
+using System.Text.Json.Serialization;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
#region weather_report_using_statement
@@ -11,6 +12,15 @@
#region weather_report
public partial class TypeSafeFunctionCall
{
+ private class GetWeatherSchema
+ {
+ [JsonPropertyName(@"city")]
+ public string city { get; set; }
+
+ [JsonPropertyName(@"date")]
+ public string date { get; set; }
+ }
+
///
/// Get weather report
///
@@ -21,7 +31,20 @@ public async Task WeatherReport(string city, string date)
{
return $"Weather report for {city} on {date} is sunny";
}
+
+ public Task GetWeatherReportWrapper(string arguments)
+ {
+ var schema = JsonSerializer.Deserialize(
+ arguments,
+ new JsonSerializerOptions
+ {
+ PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+ });
+
+ return WeatherReport(schema.city, schema.date);
+ }
}
+
#endregion weather_report
public partial class TypeSafeFunctionCall
diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
index e395bb4a225f..75940cc8741c 100644
--- a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
+++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -16,6 +17,7 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
+ private readonly Tool[]? _tools;
public AnthropicClientAgent(
AnthropicClient anthropicClient,
@@ -23,7 +25,8 @@ public AnthropicClientAgent(
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
- int maxTokens = 1024)
+ int maxTokens = 1024,
+ Tool[]? tools = null)
{
Name = name;
_anthropicClient = anthropicClient;
@@ -31,6 +34,7 @@ public AnthropicClientAgent(
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
+ _tools = tools;
}
public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null,
@@ -59,6 +63,7 @@ private ChatCompletionRequest CreateParameters(IEnumerable messages, G
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
+ Tools = _tools?.ToList()
};
chatCompletionRequest.Messages = BuildMessages(messages);
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
index 4cb8fdbb34e0..b41a761dc4d3 100644
--- a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
+++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
@@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_use":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_result":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
}
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
index 0c1749eaa989..1d45a62caf80 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
@@ -37,6 +37,9 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }
+ [JsonPropertyName("tools")]
+ public List? Tools { get; set; }
+
public ChatCompletionRequest()
{
Messages = new List();
@@ -62,4 +65,6 @@ public ChatMessage(string role, List content)
Role = role;
Content = content;
}
+
+ public void AddContent(ContentBase content) => Content.Add(content);
}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
index dd2481bd58f3..ee7a745a1416 100644
--- a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Content.cs
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
namespace AutoGen.Anthropic.DTO;
@@ -40,3 +41,30 @@ public class ImageSource
[JsonPropertyName("data")]
public string? Data { get; set; }
}
+
+public class ToolUseContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_use";
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public JsonNode? Input { get; set; }
+}
+
+public class ToolResultContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_result";
+
+ [JsonPropertyName("tool_use_id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
new file mode 100644
index 000000000000..41c20dc2a42d
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public class Tool
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+
+ [JsonPropertyName("input_schema")]
+ public InputSchema? InputSchema { get; set; }
+}
+
+public class InputSchema
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("properties")]
+ public Dictionary? Properties { get; set; }
+
+ [JsonPropertyName("required")]
+ public List? Required { get; set; }
+}
+
+public class SchemaProperty
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
index bb2f5820f74c..1191f84d28ec 100644
--- a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
+++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
@@ -6,6 +6,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
+using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Anthropic.DTO;
@@ -71,16 +72,20 @@ private async Task> ProcessMessageAsync(IEnumerable ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage =>
- new MessageEnvelope(new ChatMessage("user",
+ (MessageEnvelope[])[new MessageEnvelope(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
- from: agent.Name),
+ from: agent.Name)],
MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
- _ => message,
+
+ ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
+ ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
+ AggregateMessage toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
+ _ => [message],
};
- processedMessages.Add(processedMessage);
+ processedMessages.AddRange(processedMessage);
}
return processedMessages;
@@ -93,15 +98,42 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
throw new ArgumentNullException(nameof(response.Content));
}
- if (response.Content.Count != 1)
+ // When expecting a tool call, sometimes the response will contain two messages, one chat and one tool.
+ // The first message is typically a TextContent, of the LLM explaining what it is trying to do.
+ // The second message contains the tool call.
+ if (response.Content.Count > 1)
{
- throw new NotSupportedException($"{nameof(response.Content)} != 1");
+ if (response.Content.Count == 2 && response.Content[0] is TextContent &&
+ response.Content[1] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+ }
+
+ throw new NotSupportedException($"Expected {nameof(response.Content)} to have one output");
}
- return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
+ var content = response.Content[0];
+ switch (content)
+ {
+ case TextContent textContent:
+ return new TextMessage(Role.Assistant, textContent.Text ?? string.Empty, from: from.Name);
+
+ case ToolUseContent toolUseContent:
+ return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
+ toolUseContent.Input?.ToJsonString() ?? string.Empty,
+ from: from.Name);
+
+ case ImageContent:
+ throw new InvalidOperationException(
+ "Claude is an image understanding model only. It can interpret and analyze images, but it cannot generate, produce, edit, manipulate or create images");
+ default:
+ throw new ArgumentOutOfRangeException(nameof(content));
+ }
}
- private IMessage ProcessTextMessage(TextMessage textMessage, IAgent agent)
+ private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
ChatMessage messages;
@@ -139,10 +171,10 @@ private IMessage ProcessTextMessage(TextMessage textMessage, IAgent
"user", textMessage.Content);
}
- return new MessageEnvelope(messages, from: textMessage.From);
+ return [new MessageEnvelope(messages, from: textMessage.From)];
}
- private async Task ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
+ private async Task> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List();
foreach (var message in multiModalMessage.Content)
@@ -158,8 +190,7 @@ private async Task ProcessMultiModalMessageAsync(MultiModalMessage mul
}
}
- var chatMessage = new ChatMessage("user", content);
- return MessageEnvelope.Create(chatMessage, agent.Name);
+ return [MessageEnvelope.Create(new ChatMessage("user", content), agent.Name)];
}
private async Task ProcessImageSourceAsync(ImageMessage imageMessage)
@@ -192,4 +223,52 @@ private async Task ProcessImageSourceAsync(ImageMessage imageMessag
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
+
+ private IEnumerable ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
+ {
+ var chatMessage = new ChatMessage("assistant", new List());
+ foreach (var toolCall in toolCallMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolUseContent
+ {
+ Id = toolCall.ToolCallId,
+ Name = toolCall.FunctionName,
+ Input = JsonNode.Parse(toolCall.FunctionArguments)
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
+ {
+ var chatMessage = new ChatMessage("user", new List());
+ foreach (var toolCall in toolCallResultMessage.ToolCalls)
+ {
+ chatMessage.AddContent(new ToolResultContent
+ {
+ Id = toolCall.ToolCallId ?? string.Empty,
+ Content = toolCall.Result,
+ });
+ }
+
+ return [MessageEnvelope.Create(chatMessage, toolCallResultMessage.From)];
+ }
+
+ private IEnumerable ProcessToolCallAggregateMessage(AggregateMessage aggregateMessage, IAgent agent)
+ {
+ if (aggregateMessage.From is { } from && from != agent.Name)
+ {
+ var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result);
+ var messages = contents.Select(c =>
+ new ChatMessage("assistant", c ?? throw new ArgumentNullException(nameof(c))));
+
+ return messages.Select(m => new MessageEnvelope(m, from: from));
+ }
+
+ var toolCallMessage = ProcessToolCallMessage(aggregateMessage.Message1, agent);
+ var toolCallResult = ProcessToolCallResultMessage(aggregateMessage.Message2);
+
+ return toolCallMessage.Concat(toolCallResult);
+ }
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
index d29025b44aff..2c6e86d270c5 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
@@ -7,6 +7,7 @@
using AutoGen.Core;
using AutoGen.Tests;
using FluentAssertions;
+using Xunit;
namespace AutoGen.Anthropic.Tests;
@@ -105,4 +106,101 @@ public async Task AnthropicAgentTestImageMessageAsync()
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentTestToolAsync()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: new[] { function.WeatherReportFunctionContract },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name ?? string.Empty, function.WeatherReportWrapper },
+ });
+
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are an LLM that is specialized in finding the weather !",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var reply = await agent.SendAsync("What is the weather in Philadelphia?");
+ reply.GetContent().Should().Be("Weather report for Philadelphia on today is sunny");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var agent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector();
+
+ var weatherFunctionArgumets = """
+ {
+ "city": "Philadelphia",
+ "date": "6/14/2024"
+ }
+ """;
+
+ var function = new TypeSafeFunctionCall();
+ var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
+ var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
+ {
+ ToolCallId = "get_weather",
+ Result = functionCallResult,
+ };
+
+ IMessage[] chatHistory = [
+ new TextMessage(Role.User, "what's the weather in Philadelphia?"),
+ new ToolCallMessage([toolCall], from: "assistant"),
+ new ToolCallResultMessage([toolCall], from: "user" ),
+ ];
+
+ var reply = await agent.SendAsync(chatHistory: chatHistory);
+
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("The weather report for Philadelphia on 6/14/2024 is sunny.");
+ }
+
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicAgentFunctionCallMiddlewareMessageTest()
+ {
+ var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [function.WeatherReportFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name!, function.GetWeatherReportWrapper }
+ });
+
+ var functionCallAgent = new AnthropicClientAgent(
+ client,
+ name: "AnthropicAgent",
+ AnthropicConstants.Claude3Haiku,
+ systemMessage: "You are a helpful AI assistant.",
+ tools: [AnthropicTestUtils.WeatherTool]
+ )
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var question = new TextMessage(Role.User, "what's the weather in Philadelphia?");
+ var reply = await functionCallAgent.SendAsync(question);
+
+ var finalReply = await functionCallAgent.SendAsync(chatHistory: [question, reply]);
+ finalReply.Should().BeOfType();
+ finalReply.GetContent()!.ToLower().Should().Contain("sunny");
+ }
}
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
index a0b1f60cfb95..ce07a25f7fc3 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs
@@ -1,5 +1,6 @@
using System.Text;
using System.Text.Json;
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Utils;
@@ -108,6 +109,52 @@ public async Task AnthropicClientImageChatCompletionTestAsync()
response.Usage.OutputTokens.Should().BeGreaterThan(0);
}
+ [ApiKeyFact("ANTHROPIC_API_KEY")]
+ public async Task AnthropicClientTestToolsAsync()
+ {
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
+
+ var stockPriceTool = new Tool
+ {
+ Name = "get_stock_price",
+ Description = "Get the current stock price for a given ticker symbol.",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ {
+ "ticker", new SchemaProperty
+ {
+ Type = "string",
+ Description = "The stock ticker symbol, e.g. AAPL for Apple Inc."
+ }
+ }
+ },
+ Required = new List { "ticker" }
+ }
+ };
+
+ var request = new ChatCompletionRequest();
+ request.Model = AnthropicConstants.Claude3Haiku;
+ request.Stream = false;
+ request.MaxTokens = 100;
+ request.Messages = new List() { new("user", "Use the stock price tool to look for MSFT. Your response should only be the tool.") };
+ request.Tools = new List() { stockPriceTool };
+
+ ChatCompletionResponse response =
+ await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None);
+
+ Assert.NotNull(response.Content);
+ Assert.True(response.Content.First() is ToolUseContent);
+ ToolUseContent toolUseContent = ((ToolUseContent)response.Content.First());
+ Assert.Equal("get_stock_price", toolUseContent.Name);
+ Assert.NotNull(toolUseContent.Input);
+ Assert.True(toolUseContent.Input is JsonNode);
+ JsonNode jsonNode = toolUseContent.Input;
+ Assert.Equal("{\"ticker\":\"MSFT\"}", jsonNode.ToJsonString());
+ }
+
private sealed class Person
{
[JsonPropertyName("name")]
diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
index de630da6d87c..a0f821784a5d 100644
--- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
+++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicTestUtils.cs
+using AutoGen.Anthropic.DTO;
+
namespace AutoGen.Anthropic.Tests;
public static class AnthropicTestUtils
@@ -13,4 +15,25 @@ public static async Task Base64FromImageAsync(string imageName)
return Convert.ToBase64String(
await File.ReadAllBytesAsync(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "images", imageName)));
}
+
+ public static Tool WeatherTool
+ {
+ get
+ {
+ return new Tool
+ {
+ Name = "WeatherReport",
+ Description = "Get the current weather",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "city", new SchemaProperty {Type = "string", Description = "The name of the city"} },
+ { "date", new SchemaProperty {Type = "string", Description = "date of the day"} }
+ }
+ }
+ };
+ }
+ }
}