Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Support tools for AnthropicClient and AnthropicAgent #2944

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<ProjectReference Include="..\..\src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace AutoGen.Anthropic.Samples;

public static class AnthropicSamples
public static class Create_Anthropic_Agent
{
public static async Task RunAsync()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Single_Anthropic_Tool.cs

using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Core;
using FluentAssertions;

namespace AutoGen.Anthropic.Samples;

#region WeatherFunction

public partial class WeatherFunction
{
/// <summary>
/// Gets the weather based on the location and the unit
/// </summary>
/// <param name="location"></param>
/// <param name="unit"></param>
/// <returns></returns>
[Function]
public async Task<string> GetWeather(string location, string unit)
{
// dummy implementation
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
}
}
#endregion
public class Create_Anthropic_Agent_With_Tool
{
public static async Task RunAsync()
{
#region define_tool
var tool = new Tool
{
Name = "GetWeather",
Description = "Get the current weather in a given location",
InputSchema = new InputSchema
{
Type = "object",
Properties = new Dictionary<string, SchemaProperty>
{
{ "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
{ "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
},
Required = new List<string> { "location" }
}
};

var weatherFunction = new WeatherFunction();
var functionMiddleware = new FunctionCallMiddleware(
functions: [
weatherFunction.GetWeatherFunctionContract,
],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
});

#endregion

#region create_anthropic_agent

var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");

var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
tools: [tool]); // Define tools for AnthropicClientAgent
#endregion

#region register_middleware

var agentWithConnector = agent
.RegisterMessageConnector()
.RegisterPrintMessage()
.RegisterStreamingMiddleware(functionMiddleware);
#endregion register_middleware

#region single_turn
var question = new TextMessage(Role.Assistant,
"What is the weather like in San Francisco?",
from: "user");
var functionCallReply = await agentWithConnector.SendAsync(question);
#endregion

#region Single_turn_verify_reply
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
#endregion Single_turn_verify_reply

#region Multi_turn
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
#endregion Multi_turn

#region Multi_turn_verify_reply
finalReply.Should().BeOfType<TextMessage>();
#endregion Multi_turn_verify_reply
}
}
2 changes: 1 addition & 1 deletion dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ internal static class Program
{
public static async Task Main(string[] args)
{
await AnthropicSamples.RunAsync();
await Create_Anthropic_Agent_With_Tool.RunAsync();
}
}
11 changes: 10 additions & 1 deletion dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -16,21 +17,27 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
private readonly Tool[]? _tools;
private readonly ToolChoice? _toolChoice;

public AnthropicClientAgent(
AnthropicClient anthropicClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
int maxTokens = 1024)
int maxTokens = 1024,
Tool[]? tools = null,
ToolChoice? toolChoice = null)
{
Name = name;
_anthropicClient = anthropicClient;
_modelName = modelName;
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
_tools = tools;
_toolChoice = toolChoice;
}

public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null,
Expand Down Expand Up @@ -59,6 +66,8 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
Tools = _tools?.ToList(),
ToolChoice = _toolChoice ?? ToolChoice.Auto
};

chatCompletionRequest.Messages = BuildMessages(messages);
Expand Down
107 changes: 94 additions & 13 deletions dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ public sealed class AnthropicClient : IDisposable
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new ContentBaseConverter() }
Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter<ToolChoiceType>() }
};

private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
{
Converters = { new ContentBaseConverter() }
Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter<ToolChoiceType>() }
};

public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey)
Expand Down Expand Up @@ -61,33 +61,75 @@ public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAs
using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());

var currentEvent = new SseEvent();

while (await reader.ReadLineAsync() is { } line)
{
if (!string.IsNullOrEmpty(line))
{
currentEvent.Data = line.Substring("data:".Length).Trim();
if (line.StartsWith("event:"))
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved
{
currentEvent.EventType = line.Substring("event:".Length).Trim();
}
else if (line.StartsWith("data:"))
{
currentEvent.Data = line.Substring("data:".Length).Trim();
}
}
else
else // an empty line indicates the end of an event
{
if (currentEvent.Data == "[DONE]")
continue;
if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
{
var dataBlock = JsonSerializer.Deserialize<DataBlock>(currentEvent.Data!);
if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
{
currentEvent.ContentBlock = dataBlock.ContentBlock;
}
}

if (currentEvent.Data != null)
if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
{
yield return await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
cancellationToken: cancellationToken);

if (res == null)
{
throw new Exception("Failed to deserialize response");
}

if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
currentEvent.ContentBlock != null)
{
currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
}
else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
{
if (res.Content == null)
{
res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
}
else
{
res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
}

currentEvent = new SseEvent();
}

yield return res;
}
else if (currentEvent.Data != null)
else if (currentEvent.EventType == "error" && currentEvent.Data != null)
{
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);

throw new Exception(res?.Error?.Message);
}

// Reset the current event for the next one
currentEvent = new SseEvent();
if (currentEvent.ContentBlock == null)
{
currentEvent = new SseEvent();
}
}
}
}
Expand All @@ -113,11 +155,50 @@ public void Dispose()

private struct SseEvent
{
public string EventType { get; set; }
public string? Data { get; set; }
public ContentBlock? ContentBlock { get; set; }

public SseEvent(string? data = null)
public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
{
EventType = eventType;
Data = data;
ContentBlock = contentBlock;
}
}

private class ContentBlock
{
[JsonPropertyName("type")]
public string? Type { get; set; }

[JsonPropertyName("id")]
public string? Id { get; set; }

[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("input")]
public object? Input { get; set; }

public string? parameters { get; set; }

public void AppendDeltaParameters(string deltaParams)
{
StringBuilder sb = new StringBuilder(parameters);
sb.Append(deltaParams);
parameters = sb.ToString();
}

public ToolUseContent CreateToolUseContent()
{
return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
}
}

private class DataBlock
{
[JsonPropertyName("content_block")]
public ContentBlock? ContentBlock { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize<TextContent>(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize<ImageContent>(text, options) ?? throw new InvalidOperationException();
case "tool_use":
return JsonSerializer.Deserialize<ToolUseContent>(text, options) ?? throw new InvalidOperationException();
case "tool_result":
return JsonSerializer.Deserialize<ToolResultContent>(text, options) ?? throw new InvalidOperationException();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// JsonPropertyNameEnumCoverter.cs

using System;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace AutoGen.Anthropic.Converters;

internal class JsonPropertyNameEnumConverter<T> : JsonConverter<T> where T : struct, Enum
{
public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
string value = reader.GetString() ?? throw new JsonException("Value was null.");

foreach (var field in typeToConvert.GetFields())
{
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();
if (attribute?.Name == value)
{
return (T)Enum.Parse(typeToConvert, field.Name);
}
}

throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
}

public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
var field = value.GetType().GetField(value.ToString());
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();

if (attribute != null)
{
writer.WriteStringValue(attribute.Name);
}
else
{
writer.WriteStringValue(value.ToString());
}
}
}

8 changes: 8 additions & 0 deletions dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }

[JsonPropertyName("tools")]
public List<Tool>? Tools { get; set; }

[JsonPropertyName("tool_choice")]
public ToolChoice? ToolChoice { get; set; }

public ChatCompletionRequest()
{
Messages = new List<ChatMessage>();
Expand All @@ -62,4 +68,6 @@ public ChatMessage(string role, List<ContentBase> content)
Role = role;
Content = content;
}

public void AddContent(ContentBase content) => Content.Add(content);
}
Loading
Loading