Skip to content

[.Net] fix #2975 #3012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public override string ToString()
}
}

public class ToolCallMessage : IMessage, ICanGetToolCalls
public class ToolCallMessage : IMessage, ICanGetToolCalls, ICanGetTextContent
{
public ToolCallMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
{
Expand Down Expand Up @@ -80,6 +80,12 @@ public void Update(ToolCallMessageUpdate update)

public string? From { get; set; }

/// <summary>
/// Some LLMs might also include text content in a tool call response, like GPT.
/// This field is used to store the text content in that case.
/// </summary>
public string? Content { get; set; }

public override string ToString()
{
var sb = new StringBuilder();
Expand All @@ -96,6 +102,11 @@ public IEnumerable<ToolCall> GetToolCalls()
{
return this.ToolCalls;
}

public string? GetContent()
{
return this.Content;
}
}

public class ToolCallMessageUpdate : IStreamingMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ private IMessage PostProcessChatCompletions(IMessage<ChatCompletions> message)

private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
{
if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content))
{
return new TextMessage(Role.Assistant, content, from);
}

var textContent = chatResponseMessage.Content;
if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
{
return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from);
return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from)
{
Content = textContent,
};
}

if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
Expand All @@ -154,7 +153,15 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse

var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });

return new ToolCallMessage(toolCalls, from);
return new ToolCallMessage(toolCalls, from)
{
Content = textContent,
};
}

if (textContent is string content && !string.IsNullOrEmpty(content))
{
return new TextMessage(Role.Assistant, content, from);
}

throw new InvalidOperationException("Invalid ChatResponseMessage");
Expand Down Expand Up @@ -327,7 +334,8 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too
}

var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
Expand Down
10 changes: 7 additions & 3 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
Expand All @@ -291,7 +291,10 @@ public async Task ItProcessToolCallMessageAsync()
.RegisterMiddleware(middleware);

// user message
IMessage message = new ToolCallMessage("test", "test", "assistant");
IMessage message = new ToolCallMessage("test", "test", "assistant")
{
Content = "textContent",
};
await agent.GenerateReplyAsync([message]);
}

Expand Down Expand Up @@ -526,13 +529,14 @@ public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
.RegisterMiddleware(middleware);

// tool call message
var toolCallMessage = CreateInstance<ChatResponseMessage>(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance<AzureChatExtensionsMessageContext>(), new Dictionary<string, BinaryData>());
var toolCallMessage = CreateInstance<ChatResponseMessage>(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance<AzureChatExtensionsMessageContext>(), new Dictionary<string, BinaryData>());
var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
var message = await agent.GenerateReplyAsync([chatRequestMessage]);
message.Should().BeOfType<ToolCallMessage>();
message.GetToolCalls()!.Count().Should().Be(1);
message.GetToolCalls()!.First().FunctionName.Should().Be("test");
message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
message.GetContent().Should().Be("textContent");
}

[Fact]
Expand Down
Loading