Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] In AutoGen.OpenAI and AutoGen.OpenAI.V1, stop setting name field when assistant message contains tool call #3481

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too

var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };

// don't include the name field when it's tool call message.
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ private IEnumerable<ChatMessage> ProcessToolCallMessage(IAgent agent, ToolCallMe

var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? null;
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From };

// Don't set participant name for assistant when it is tool call
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);

return [chatRequestMessage];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down Expand Up @@ -184,7 +184,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand All @@ -210,7 +210,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down
65 changes: 65 additions & 0 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public async Task<string> GetWeatherAsync(string location)
return $"The weather in {location} is sunny.";
}

[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync()
{
Expand Down Expand Up @@ -246,6 +252,65 @@ public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync()
respond.GetContent()?.Should().NotBeNullOrEmpty();
}


[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionOptions()
{
Temperature = 0.7f,
MaxTokens = 1,
};

var agentName = "assistant";

var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);

var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);

var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};

var agent = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "assistant",
options: options)
.RegisterMessageConnector();

var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}

private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
Expand Down
10 changes: 8 additions & 2 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.First().Text.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>();
Expand Down Expand Up @@ -307,7 +310,10 @@ public async Task ItProcessParallelToolCallMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down Expand Up @@ -126,7 +126,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand All @@ -152,7 +152,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down
13 changes: 9 additions & 4 deletions dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.V1.Extension;
Expand Down Expand Up @@ -45,7 +46,11 @@ private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOption
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
}

throw;
Expand Down Expand Up @@ -149,9 +154,9 @@ You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);

return teacher;
}
Expand Down
66 changes: 65 additions & 1 deletion dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest
[Function]
public async Task<string> GetWeatherAsync(string location)
{
return $"The weather in {location} is sunny.";
return $"[GetWeather] The weather in {location} is sunny.";
}

[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
Expand Down Expand Up @@ -270,6 +276,64 @@ public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages()
action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [])
{
Temperature = 0.7f,
MaxTokens = 1,
};

var agentName = "assistant";

var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);

var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);

var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};

var agent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();

var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}

private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
Expand Down
11 changes: 9 additions & 2 deletions dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Name.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
Expand Down Expand Up @@ -309,7 +312,11 @@ public async Task ItProcessParallelToolCallMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");

// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
Expand Down
Loading