Skip to content

Commit

Permalink
stop setting name field when assistant message contains tool call (#3481
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LittleLittleCloud authored Sep 5, 2024
1 parent 40cfe07 commit a44b86f
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too

var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };

// don't include the name field when it's tool call message.
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ private IEnumerable<ChatMessage> ProcessToolCallMessage(IAgent agent, ToolCallMe

var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? null;
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From };

// Don't set participant name for assistant when it is tool call
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);

return [chatRequestMessage];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down Expand Up @@ -184,7 +184,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand All @@ -210,7 +210,7 @@
{
"Role": "assistant",
"Content": [],
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down
65 changes: 65 additions & 0 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public async Task<string> GetWeatherAsync(string location)
return $"The weather in {location} is sunny.";
}

[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync()
{
Expand Down Expand Up @@ -246,6 +252,65 @@ public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync()
respond.GetContent()?.Should().NotBeNullOrEmpty();
}


[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionOptions()
{
Temperature = 0.7f,
MaxTokens = 1,
};

var agentName = "assistant";

var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);

var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);

var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};

var agent = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "assistant",
options: options)
.RegisterMessageConnector();

var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}

private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
Expand Down
10 changes: 8 additions & 2 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.First().Text.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>();
Expand Down Expand Up @@ -307,7 +310,10 @@ public async Task ItProcessParallelToolCallMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.ParticipantName.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down Expand Up @@ -126,7 +126,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand All @@ -152,7 +152,7 @@
{
"Role": "assistant",
"Content": "",
"Name": "assistant",
"Name": null,
"TooCall": [
{
"Type": "Function",
Expand Down
13 changes: 9 additions & 4 deletions dotnet/test/AutoGen.OpenAI.V1.Tests/MathClassTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.V1.Extension;
Expand Down Expand Up @@ -45,7 +46,11 @@ private Task<IMessage> Print(IEnumerable<IMessage> messages, GenerateReplyOption
_output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages)
{
_output.WriteLine(message.FormatMessage());
if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
}

throw;
Expand Down Expand Up @@ -149,9 +154,9 @@ You create math question and ask student to answer it.
Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it",
modelName: model)
.RegisterMessageConnector()
.RegisterStreamingMiddleware(functionCallMiddleware)
.RegisterMiddleware(Print);
.RegisterMiddleware(Print)
.RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(functionCallMiddleware);

return teacher;
}
Expand Down
66 changes: 65 additions & 1 deletion dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIChatAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest
[Function]
public async Task<string> GetWeatherAsync(string location)
{
return $"The weather in {location} is sunny.";
return $"[GetWeather] The weather in {location} is sunny.";
}

[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
Expand Down Expand Up @@ -270,6 +276,64 @@ public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages()
action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [])
{
Temperature = 0.7f,
MaxTokens = 1,
};

var agentName = "assistant";

var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);

var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);

var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};

var agent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();

var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}

private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
Expand Down
11 changes: 9 additions & 2 deletions dotnet/test/AutoGen.OpenAI.V1.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Name.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
Expand Down Expand Up @@ -309,7 +312,11 @@ public async Task ItProcessParallelToolCallMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");

// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
Expand Down

0 comments on commit a44b86f

Please sign in to comment.