Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud authored and luxzoli committed Jun 27, 2024
1 parent 1ea5537 commit 0a041c5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
13 changes: 12 additions & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public override string ToString()
}
}

public class ToolCallMessage : IMessage, ICanGetToolCalls
public class ToolCallMessage : IMessage, ICanGetToolCalls, ICanGetTextContent
{
public ToolCallMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
{
Expand Down Expand Up @@ -80,6 +80,12 @@ public void Update(ToolCallMessageUpdate update)

public string? From { get; set; }

/// <summary>
/// Some LLMs might also include text content in a tool call response, like GPT.
/// This field is used to store the text content in that case.
/// </summary>
public string? Content { get; set; }

public override string ToString()
{
var sb = new StringBuilder();
Expand All @@ -96,6 +102,11 @@ public IEnumerable<ToolCall> GetToolCalls()
{
return this.ToolCalls;
}

public string? GetContent()
{
return this.Content;
}
}

public class ToolCallMessageUpdate : IStreamingMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ private IMessage PostProcessChatCompletions(IMessage<ChatCompletions> message)

private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponseMessage, string? from)
{
if (chatResponseMessage.Content is string content && !string.IsNullOrEmpty(content))
{
return new TextMessage(Role.Assistant, content, from);
}

var textContent = chatResponseMessage.Content;
if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
{
return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from);
return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from)
{
Content = textContent,
};
}

if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
Expand All @@ -154,7 +153,15 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse

var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });

return new ToolCallMessage(toolCalls, from);
return new ToolCallMessage(toolCalls, from)
{
Content = textContent,
};
}

if (textContent is string content && !string.IsNullOrEmpty(content))
{
return new TextMessage(Role.Assistant, content, from);
}

throw new InvalidOperationException("Invalid ChatResponseMessage");
Expand Down Expand Up @@ -327,7 +334,8 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too
}

var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
Expand Down
10 changes: 7 additions & 3 deletions dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ public async Task ItProcessToolCallMessageAsync()
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
Expand All @@ -291,7 +291,10 @@ public async Task ItProcessToolCallMessageAsync()
.RegisterMiddleware(middleware);

// user message
IMessage message = new ToolCallMessage("test", "test", "assistant");
IMessage message = new ToolCallMessage("test", "test", "assistant")
{
Content = "textContent",
};
await agent.GenerateReplyAsync([message]);
}

Expand Down Expand Up @@ -526,13 +529,14 @@ public async Task ItConvertChatResponseMessageToToolCallMessageAsync()
.RegisterMiddleware(middleware);

// tool call message
var toolCallMessage = CreateInstance<ChatResponseMessage>(ChatRole.Assistant, "", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance<AzureChatExtensionsMessageContext>(), new Dictionary<string, BinaryData>());
var toolCallMessage = CreateInstance<ChatResponseMessage>(ChatRole.Assistant, "textContent", new[] { new ChatCompletionsFunctionToolCall("test", "test", "test") }, new FunctionCall("test", "test"), CreateInstance<AzureChatExtensionsMessageContext>(), new Dictionary<string, BinaryData>());
var chatRequestMessage = MessageEnvelope.Create(toolCallMessage);
var message = await agent.GenerateReplyAsync([chatRequestMessage]);
message.Should().BeOfType<ToolCallMessage>();
message.GetToolCalls()!.Count().Should().Be(1);
message.GetToolCalls()!.First().FunctionName.Should().Be("test");
message.GetToolCalls()!.First().FunctionArguments.Should().Be("test");
message.GetContent().Should().Be("textContent");
}

[Fact]
Expand Down

0 comments on commit 0a041c5

Please sign in to comment.