Skip to content

Commit

Permalink
Support tools for Anthropic models
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLuong98 committed Jun 14, 2024
1 parent 6d4cf40 commit 39c0d9a
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// TypeSafeFunctionCallCodeSnippet.cs

using System.Text.Json;
using System.Text.Json.Serialization;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
#region weather_report_using_statement
Expand All @@ -11,6 +12,15 @@
#region weather_report
public partial class TypeSafeFunctionCall
{
private class GetWeatherSchema
{
[JsonPropertyName(@"city")]
public string city { get; set; }

[JsonPropertyName(@"date")]
public string date { get; set; }
}

/// <summary>
/// Get weather report
/// </summary>
Expand All @@ -21,7 +31,20 @@ public async Task<string> WeatherReport(string city, string date)
{
return $"Weather report for {city} on {date} is sunny";
}

public Task<string> GetWeatherReportWrapper(string arguments)
{
var schema = JsonSerializer.Deserialize<GetWeatherSchema>(
arguments,
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});

return WeatherReport(schema.city, schema.date);
}
}

#endregion weather_report

public partial class TypeSafeFunctionCall
Expand Down
7 changes: 6 additions & 1 deletion dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -16,21 +17,24 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
private readonly Tool[]? _tools;

public AnthropicClientAgent(
AnthropicClient anthropicClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
int maxTokens = 1024)
int maxTokens = 1024,
Tool[]? tools = null)
{
Name = name;
_anthropicClient = anthropicClient;
_modelName = modelName;
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
_tools = tools;
}

public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null,
Expand Down Expand Up @@ -59,6 +63,7 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
Tools = _tools?.ToList()
};

chatCompletionRequest.Messages = BuildMessages(messages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize<TextContent>(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize<ImageContent>(text, options) ?? throw new InvalidOperationException();
case "tool_use":
return JsonSerializer.Deserialize<ToolUseContent>(text, options) ?? throw new InvalidOperationException();
case "tool_result":
return JsonSerializer.Deserialize<ToolResultContent>(text, options) ?? throw new InvalidOperationException();
}
}

Expand Down
5 changes: 5 additions & 0 deletions dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }

[JsonPropertyName("tools")]
public List<Tool>? Tools { get; set; }

public ChatCompletionRequest()
{
Messages = new List<ChatMessage>();
Expand All @@ -62,4 +65,6 @@ public ChatMessage(string role, List<ContentBase> content)
Role = role;
Content = content;
}

public void AddContent(ContentBase content) => Content.Add(content);
}
28 changes: 28 additions & 0 deletions dotnet/src/AutoGen.Anthropic/DTO/Content.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Content.cs

using System.Text.Json.Nodes;
using System.Text.Json.Serialization;

namespace AutoGen.Anthropic.DTO;
Expand Down Expand Up @@ -40,3 +41,30 @@ public class ImageSource
[JsonPropertyName("data")]
public string? Data { get; set; }
}

public class ToolUseContent : ContentBase
{
[JsonPropertyName("type")]
public override string Type => "tool_use";

[JsonPropertyName("id")]
public string? Id { get; set; }

[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("input")]
public JsonNode? Input { get; set; }
}

public class ToolResultContent : ContentBase
{
[JsonPropertyName("type")]
public override string Type => "tool_result";

[JsonPropertyName("tool_use_id")]
public string? Id { get; set; }

[JsonPropertyName("content")]
public string? Content { get; set; }
}
40 changes: 40 additions & 0 deletions dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Tool.cs

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace AutoGen.Anthropic.DTO;

public class Tool
{
[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("description")]
public string? Description { get; set; }

[JsonPropertyName("input_schema")]
public InputSchema? InputSchema { get; set; }
}

public class InputSchema
{
[JsonPropertyName("type")]
public string? Type { get; set; }

[JsonPropertyName("properties")]
public Dictionary<string, SchemaProperty>? Properties { get; set; }

[JsonPropertyName("required")]
public List<string>? Required { get; set; }
}

public class SchemaProperty
{
[JsonPropertyName("type")]
public string? Type { get; set; }

[JsonPropertyName("description")]
public string? Description { get; set; }
}
103 changes: 91 additions & 12 deletions dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Anthropic.DTO;
Expand Down Expand Up @@ -71,16 +72,20 @@ private async Task<IEnumerable<IMessage>> ProcessMessageAsync(IEnumerable<IMessa
TextMessage textMessage => ProcessTextMessage(textMessage, agent),

ImageMessage imageMessage =>
new MessageEnvelope<ChatMessage>(new ChatMessage("user",
(MessageEnvelope<ChatMessage>[])[new MessageEnvelope<ChatMessage>(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
from: agent.Name),
from: agent.Name)],

MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
_ => message,

ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
_ => [message],
};

processedMessages.Add(processedMessage);
processedMessages.AddRange(processedMessage);
}

return processedMessages;
Expand All @@ -93,15 +98,42 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
throw new ArgumentNullException(nameof(response.Content));
}

if (response.Content.Count != 1)
// When expecting a tool call, sometimes the response will contain two messages, one chat and one tool.
// The first message is typically a TextContent, of the LLM explaining what it is trying to do.
// The second message contains the tool call.
if (response.Content.Count > 1)
{
throw new NotSupportedException($"{nameof(response.Content)} != 1");
if (response.Content.Count == 2 && response.Content[0] is TextContent &&
response.Content[1] is ToolUseContent toolUseContent)
{
return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
toolUseContent.Input?.ToJsonString() ?? string.Empty,
from: from.Name);
}

throw new NotSupportedException($"Expected {nameof(response.Content)} to have one output");
}

return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
var content = response.Content[0];
switch (content)
{
case TextContent textContent:
return new TextMessage(Role.Assistant, textContent.Text ?? string.Empty, from: from.Name);

case ToolUseContent toolUseContent:
return new ToolCallMessage(toolUseContent.Name ?? string.Empty,
toolUseContent.Input?.ToJsonString() ?? string.Empty,
from: from.Name);

case ImageContent:
throw new InvalidOperationException(
"Claude is an image understanding model only. It can interpret and analyze images, but it cannot generate, produce, edit, manipulate or create images");
default:
throw new ArgumentOutOfRangeException(nameof(content));
}
}

private IMessage<ChatMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
ChatMessage messages;

Expand Down Expand Up @@ -139,10 +171,10 @@ private IMessage<ChatMessage> ProcessTextMessage(TextMessage textMessage, IAgent
"user", textMessage.Content);
}

return new MessageEnvelope<ChatMessage>(messages, from: textMessage.From);
return [new MessageEnvelope<ChatMessage>(messages, from: textMessage.From)];
}

private async Task<IMessage> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
private async Task<IEnumerable<IMessage>> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List<ContentBase>();
foreach (var message in multiModalMessage.Content)
Expand All @@ -158,8 +190,7 @@ private async Task<IMessage> ProcessMultiModalMessageAsync(MultiModalMessage mul
}
}

var chatMessage = new ChatMessage("user", content);
return MessageEnvelope.Create(chatMessage, agent.Name);
return [MessageEnvelope.Create(new ChatMessage("user", content), agent.Name)];
}

private async Task<ImageSource> ProcessImageSourceAsync(ImageMessage imageMessage)
Expand Down Expand Up @@ -192,4 +223,52 @@ private async Task<ImageSource> ProcessImageSourceAsync(ImageMessage imageMessag
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}

private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
{
var chatMessage = new ChatMessage("assistant", new List<ContentBase>());
foreach (var toolCall in toolCallMessage.ToolCalls)
{
chatMessage.AddContent(new ToolUseContent
{
Id = toolCall.ToolCallId,
Name = toolCall.FunctionName,
Input = JsonNode.Parse(toolCall.FunctionArguments)
});
}

return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
}

private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
{
var chatMessage = new ChatMessage("user", new List<ContentBase>());
foreach (var toolCall in toolCallResultMessage.ToolCalls)
{
chatMessage.AddContent(new ToolResultContent
{
Id = toolCall.ToolCallId ?? string.Empty,
Content = toolCall.Result,
});
}

return [MessageEnvelope.Create(chatMessage, toolCallResultMessage.From)];
}

private IEnumerable<IMessage> ProcessToolCallAggregateMessage(AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage, IAgent agent)
{
if (aggregateMessage.From is { } from && from != agent.Name)
{
var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result);
var messages = contents.Select(c =>
new ChatMessage("assistant", c ?? throw new ArgumentNullException(nameof(c))));

return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: from));
}

var toolCallMessage = ProcessToolCallMessage(aggregateMessage.Message1, agent);
var toolCallResult = ProcessToolCallResultMessage(aggregateMessage.Message2);

return toolCallMessage.Concat(toolCallResult);
}
}
Loading

0 comments on commit 39c0d9a

Please sign in to comment.