Skip to content

Commit 192f3eb

Browse files
committed
Squash commits : support anthropic tools
1 parent 03259b2 commit 192f3eb

16 files changed

+573
-29
lines changed

dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
<ProjectReference Include="..\..\src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj" />
1414
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
1515
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
16+
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
1617
</ItemGroup>
1718

1819
</Project>

dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs renamed to dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace AutoGen.Anthropic.Samples;
99

10-
public static class AnthropicSamples
10+
public static class Create_Anthropic_Agent
1111
{
1212
public static async Task RunAsync()
1313
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Single_Anthropic_Tool.cs
3+
4+
using AutoGen.Anthropic.DTO;
5+
using AutoGen.Anthropic.Extensions;
6+
using AutoGen.Anthropic.Utils;
7+
using AutoGen.Core;
8+
using FluentAssertions;
9+
10+
namespace AutoGen.Anthropic.Samples;
11+
12+
#region WeatherFunction
13+
14+
public partial class WeatherFunction
15+
{
16+
/// <summary>
17+
/// Gets the weather based on the location and the unit
18+
/// </summary>
19+
/// <param name="location"></param>
20+
/// <param name="unit"></param>
21+
/// <returns></returns>
22+
[Function]
23+
public async Task<string> GetWeather(string location, string unit)
24+
{
25+
// dummy implementation
26+
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
27+
}
28+
}
29+
#endregion
30+
public class Create_Anthropic_Agent_With_Tool
31+
{
32+
public static async Task RunAsync()
33+
{
34+
#region define_tool
35+
var tool = new Tool
36+
{
37+
Name = "GetWeather",
38+
Description = "Get the current weather in a given location",
39+
InputSchema = new InputSchema
40+
{
41+
Type = "object",
42+
Properties = new Dictionary<string, SchemaProperty>
43+
{
44+
{ "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
45+
{ "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
46+
},
47+
Required = new List<string> { "location" }
48+
}
49+
};
50+
51+
var weatherFunction = new WeatherFunction();
52+
var functionMiddleware = new FunctionCallMiddleware(
53+
functions: [
54+
weatherFunction.GetWeatherFunctionContract,
55+
],
56+
functionMap: new Dictionary<string, Func<string, Task<string>>>
57+
{
58+
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
59+
});
60+
61+
#endregion
62+
63+
#region create_anthropic_agent
64+
65+
var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
66+
throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");
67+
68+
var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
69+
var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
70+
tools: [tool]); // Define tools for AnthropicClientAgent
71+
#endregion
72+
73+
#region register_middleware
74+
75+
var agentWithConnector = agent
76+
.RegisterMessageConnector()
77+
.RegisterPrintMessage()
78+
.RegisterStreamingMiddleware(functionMiddleware);
79+
#endregion register_middleware
80+
81+
#region single_turn
82+
var question = new TextMessage(Role.Assistant,
83+
"What is the weather like in San Francisco?",
84+
from: "user");
85+
var functionCallReply = await agentWithConnector.SendAsync(question);
86+
#endregion
87+
88+
#region Single_turn_verify_reply
89+
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
90+
#endregion Single_turn_verify_reply
91+
92+
#region Multi_turn
93+
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
94+
#endregion Multi_turn
95+
96+
#region Multi_turn_verify_reply
97+
finalReply.Should().BeOfType<TextMessage>();
98+
#endregion Multi_turn_verify_reply
99+
}
100+
}

dotnet/sample/AutoGen.Anthropic.Samples/Program.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ internal static class Program
77
{
88
public static async Task Main(string[] args)
99
{
10-
await AnthropicSamples.RunAsync();
10+
await Create_Anthropic_Agent_With_Tool.RunAsync();
1111
}
1212
}

dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs

+23
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// TypeSafeFunctionCallCodeSnippet.cs
33

44
using System.Text.Json;
5+
using System.Text.Json.Serialization;
56
using AutoGen.OpenAI.Extension;
67
using Azure.AI.OpenAI;
78
#region weather_report_using_statement
@@ -11,6 +12,15 @@
1112
#region weather_report
1213
public partial class TypeSafeFunctionCall
1314
{
15+
private class GetWeatherSchema
16+
{
17+
[JsonPropertyName(@"city")]
18+
public string city { get; set; }
19+
20+
[JsonPropertyName(@"date")]
21+
public string date { get; set; }
22+
}
23+
1424
/// <summary>
1525
/// Get weather report
1626
/// </summary>
@@ -21,7 +31,20 @@ public async Task<string> WeatherReport(string city, string date)
2131
{
2232
return $"Weather report for {city} on {date} is sunny";
2333
}
34+
35+
public Task<string> GetWeatherReportWrapper(string arguments)
36+
{
37+
var schema = JsonSerializer.Deserialize<GetWeatherSchema>(
38+
arguments,
39+
new JsonSerializerOptions
40+
{
41+
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
42+
});
43+
44+
return WeatherReport(schema.city, schema.date);
45+
}
2446
}
47+
2548
#endregion weather_report
2649

2750
public partial class TypeSafeFunctionCall

dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Runtime.CompilerServices;
45
using System.Threading;
56
using System.Threading.Tasks;
@@ -16,21 +17,24 @@ public class AnthropicClientAgent : IStreamingAgent
1617
private readonly string _systemMessage;
1718
private readonly decimal _temperature;
1819
private readonly int _maxTokens;
20+
private readonly Tool[]? _tools;
1921

2022
public AnthropicClientAgent(
2123
AnthropicClient anthropicClient,
2224
string name,
2325
string modelName,
2426
string systemMessage = "You are a helpful AI assistant",
2527
decimal temperature = 0.7m,
26-
int maxTokens = 1024)
28+
int maxTokens = 1024,
29+
Tool[]? tools = null)
2730
{
2831
Name = name;
2932
_anthropicClient = anthropicClient;
3033
_modelName = modelName;
3134
_systemMessage = systemMessage;
3235
_temperature = temperature;
3336
_maxTokens = maxTokens;
37+
_tools = tools;
3438
}
3539

3640
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null,
@@ -59,6 +63,7 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
5963
Model = _modelName,
6064
Stream = shouldStream,
6165
Temperature = (decimal?)options?.Temperature ?? _temperature,
66+
Tools = _tools?.ToList()
6267
};
6368

6469
chatCompletionRequest.Messages = BuildMessages(messages);

dotnet/src/AutoGen.Anthropic/AnthropicClient.cs

+92-11
Original file line numberDiff line numberDiff line change
@@ -61,33 +61,75 @@ public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAs
6161
using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());
6262

6363
var currentEvent = new SseEvent();
64+
6465
while (await reader.ReadLineAsync() is { } line)
6566
{
6667
if (!string.IsNullOrEmpty(line))
6768
{
68-
currentEvent.Data = line.Substring("data:".Length).Trim();
69+
if (line.StartsWith("event:"))
70+
{
71+
currentEvent.EventType = line.Substring("event:".Length).Trim();
72+
}
73+
else if (line.StartsWith("data:"))
74+
{
75+
currentEvent.Data = line.Substring("data:".Length).Trim();
76+
}
6977
}
70-
else
78+
else // an empty line indicates the end of an event
7179
{
72-
if (currentEvent.Data == "[DONE]")
73-
continue;
80+
if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
81+
{
82+
var dataBlock = JsonSerializer.Deserialize<DataBlock>(currentEvent.Data!);
83+
if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
84+
{
85+
currentEvent.ContentBlock = dataBlock.ContentBlock;
86+
}
87+
}
7488

75-
if (currentEvent.Data != null)
89+
if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
7690
{
77-
yield return await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
91+
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
7892
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
79-
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
93+
cancellationToken: cancellationToken);
94+
95+
if (res == null)
96+
{
97+
throw new Exception("Failed to deserialize response");
98+
}
99+
100+
if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
101+
currentEvent.ContentBlock != null)
102+
{
103+
currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
104+
}
105+
else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
106+
{
107+
if (res.Content == null)
108+
{
109+
res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
110+
}
111+
else
112+
{
113+
res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
114+
}
115+
116+
currentEvent = new SseEvent();
117+
}
118+
119+
yield return res;
80120
}
81-
else if (currentEvent.Data != null)
121+
else if (currentEvent.EventType == "error" && currentEvent.Data != null)
82122
{
83123
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
84124
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);
85125

86126
throw new Exception(res?.Error?.Message);
87127
}
88128

89-
// Reset the current event for the next one
90-
currentEvent = new SseEvent();
129+
if (currentEvent.ContentBlock == null)
130+
{
131+
currentEvent = new SseEvent();
132+
}
91133
}
92134
}
93135
}
@@ -113,11 +155,50 @@ public void Dispose()
113155

114156
private struct SseEvent
115157
{
158+
public string EventType { get; set; }
116159
public string? Data { get; set; }
160+
public ContentBlock? ContentBlock { get; set; }
117161

118-
public SseEvent(string? data = null)
162+
public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
119163
{
164+
EventType = eventType;
120165
Data = data;
166+
ContentBlock = contentBlock;
121167
}
122168
}
169+
170+
private class ContentBlock
171+
{
172+
[JsonPropertyName("type")]
173+
public string? Type { get; set; }
174+
175+
[JsonPropertyName("id")]
176+
public string? Id { get; set; }
177+
178+
[JsonPropertyName("name")]
179+
public string? Name { get; set; }
180+
181+
[JsonPropertyName("input")]
182+
public object? Input { get; set; }
183+
184+
public string? parameters { get; set; }
185+
186+
public void AppendDeltaParameters(string deltaParams)
187+
{
188+
StringBuilder sb = new StringBuilder(parameters);
189+
sb.Append(deltaParams);
190+
parameters = sb.ToString();
191+
}
192+
193+
public ToolUseContent CreateToolUseContent()
194+
{
195+
return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
196+
}
197+
}
198+
199+
private class DataBlock
200+
{
201+
[JsonPropertyName("content_block")]
202+
public ContentBlock? ContentBlock { get; set; }
203+
}
123204
}

dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
2424
return JsonSerializer.Deserialize<TextContent>(text, options) ?? throw new InvalidOperationException();
2525
case "image":
2626
return JsonSerializer.Deserialize<ImageContent>(text, options) ?? throw new InvalidOperationException();
27+
case "tool_use":
28+
return JsonSerializer.Deserialize<ToolUseContent>(text, options) ?? throw new InvalidOperationException();
29+
case "tool_result":
30+
return JsonSerializer.Deserialize<ToolResultContent>(text, options) ?? throw new InvalidOperationException();
2731
}
2832
}
2933

dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ public class ChatCompletionRequest
3737
[JsonPropertyName("top_p")]
3838
public decimal? TopP { get; set; }
3939

40+
[JsonPropertyName("tools")]
41+
public List<Tool>? Tools { get; set; }
42+
4043
public ChatCompletionRequest()
4144
{
4245
Messages = new List<ChatMessage>();
@@ -62,4 +65,6 @@ public ChatMessage(string role, List<ContentBase> content)
6265
Role = role;
6366
Content = content;
6467
}
68+
69+
public void AddContent(ContentBase content) => Content.Add(content);
6570
}

dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ public class StreamingMessage
4949
[JsonPropertyName("role")]
5050
public string? Role { get; set; }
5151

52-
[JsonPropertyName("content")]
53-
public List<object>? Content { get; set; }
54-
5552
[JsonPropertyName("model")]
5653
public string? Model { get; set; }
5754

@@ -85,6 +82,9 @@ public class Delta
8582
[JsonPropertyName("text")]
8683
public string? Text { get; set; }
8784

85+
[JsonPropertyName("partial_json")]
86+
public string? PartialJson { get; set; }
87+
8888
[JsonPropertyName("usage")]
8989
public Usage? Usage { get; set; }
9090
}

0 commit comments

Comments
 (0)