From 0a041c525cedba6393cc3dea308ae8ce6060f87d Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Tue, 25 Jun 2024 09:52:55 -0700 Subject: [PATCH] fix #2975 (#3012) --- .../AutoGen.Core/Message/ToolCallMessage.cs | 13 +++++++++- .../OpenAIChatRequestMessageConnector.cs | 24 ++++++++++++------- .../OpenAIMessageTests.cs | 10 +++++--- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs index 396dba3d3a17..d0f89e1ecdde 100644 --- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs @@ -36,7 +36,7 @@ public override string ToString() } } -public class ToolCallMessage : IMessage, ICanGetToolCalls +public class ToolCallMessage : IMessage, ICanGetToolCalls, ICanGetTextContent { public ToolCallMessage(IEnumerable toolCalls, string? from = null) { @@ -80,6 +80,12 @@ public void Update(ToolCallMessageUpdate update) public string? From { get; set; } + /// + /// Some LLMs might also include text content in a tool call response, like GPT. + /// This field is used to store the text content in that case. + /// + public string? Content { get; set; } + public override string ToString() { var sb = new StringBuilder(); @@ -96,6 +102,11 @@ public IEnumerable GetToolCalls() { return this.ToolCalls; } + + public string? GetContent() + { + return this.Content; + } } public class ToolCallMessageUpdate : IStreamingMessage diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 246e50cc6c59..c1dc2caa99fb 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -136,14 +136,13 @@ private IMessage PostProcessChatCompletions(IMessage message) private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from) { - if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content)) - { - return new TextMessage(Role.Assistant, content, from); - } - + var textContent = chatResponseMessage.Content; if (chatResponseMessage.FunctionCall is FunctionCall functionCall) { - return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from); + return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from) + { + Content = textContent, + }; } if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any()) @@ -154,7 +153,15 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id }); - return new ToolCallMessage(toolCalls, from); + return new ToolCallMessage(toolCalls, from) + { + Content = textContent, + }; + } + + if (textContent is string content && !string.IsNullOrEmpty(content)) + { + return new TextMessage(Role.Assistant, content, from); } throw new InvalidOperationException("Invalid ChatResponseMessage"); @@ -327,7 +334,8 @@ private IEnumerable ProcessToolCallMessage(IAgent agent, Too } var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); - var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From }; + var textContent = message.GetContent() ?? string.Empty; + var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From }; foreach (var tc in toolCall) { chatRequestMessage.ToolCalls.Add(tc); diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs index 81581d068ee7..a9b852e0d8c1 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs @@ -278,9 +278,9 @@ public async Task ItProcessToolCallMessageAsync() var innerMessage = msgs.Last(); innerMessage!.Should().BeOfType>(); var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; - chatRequestMessage.Content.Should().BeNullOrEmpty(); chatRequestMessage.Name.Should().Be("assistant"); chatRequestMessage.ToolCalls.Count().Should().Be(1); + chatRequestMessage.Content.Should().Be("textContent"); chatRequestMessage.ToolCalls.First().Should().BeOfType(); var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); functionToolCall.Name.Should().Be("test"); @@ -291,7 +291,10 @@ public async Task ItProcessToolCallMessageAsync() .RegisterMiddleware(middleware); // user message - IMessage message = new ToolCallMessage("test", "test", "assistant"); + IMessage message = new ToolCallMessage("test", "test", "assistant") + { + Content = "textContent", + }; await agent.GenerateReplyAsync([message]); } @@ -526,13 +529,14 @@ public async Task ItConvertChatResponseMessageToToolCallMessageAsync() .RegisterMiddleware(middleware); // tool call message - var toolCallMessage = CreateInstance(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary()); + var toolCallMessage = CreateInstance(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance(), new Dictionary()); var chatRequestMessage = MessageEnvelope.Create(toolCallMessage); var message = await agent.GenerateReplyAsync([chatRequestMessage]); message.Should().BeOfType(); message.GetToolCalls()!.Count().Should().Be(1); message.GetToolCalls()!.First().FunctionName.Should().Be("test"); message.GetToolCalls()!.First().FunctionArguments.Should().Be("test"); + message.GetContent().Should().Be("textContent"); } [Fact]