Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ToChatCompletion{Async} methods for combining StreamingChatCompletionUpdates #5605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
#if NET
using System.Runtime.InteropServices;
#endif
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable S127 // "for" loop stop conditions should be invariant

namespace Microsoft.Extensions.AI;

/// <summary>
/// Provides extension methods for working with <see cref="StreamingChatCompletionUpdate"/> instances.
/// </summary>
public static class StreamingChatCompletionUpdateExtensions
{
/// <summary>Combines <see cref="StreamingChatCompletionUpdate"/> instances into a single <see cref="ChatCompletion"/>.</summary>
/// <param name="updates">The updates to be combined.</param>
/// <param name="coalesceContent">
/// <see langword="true"/> to attempt to coalesce contiguous <see cref="AIContent"/> items, where applicable,
/// into a single <see cref="AIContent"/>, in order to reduce the number of individual content items that are included in
/// the manufactured <see cref="ChatMessage"/> instances. When <see langword="false"/>, the original content items are used.
/// The default is <see langword="true"/>.
/// </param>
/// <returns>The combined <see cref="ChatCompletion"/>.</returns>
public static ChatCompletion ToChatCompletion(
this IEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent = true)
{
_ = Throw.IfNull(updates);

ChatCompletion completion = new([]);
Dictionary<int, ChatMessage> messages = [];

foreach (var update in updates)
{
ProcessUpdate(update, messages, completion);
}

AddMessagesToCompletion(messages, completion, coalesceContent);

return completion;
}

/// <summary>Combines <see cref="StreamingChatCompletionUpdate"/> instances into a single <see cref="ChatCompletion"/>.</summary>
/// <param name="updates">The updates to be combined.</param>
/// <param name="coalesceContent">
/// <see langword="true"/> to attempt to coalesce contiguous <see cref="AIContent"/> items, where applicable,
/// into a single <see cref="AIContent"/>, in order to reduce the number of individual content items that are included in
/// the manufactured <see cref="ChatMessage"/> instances. When <see langword="false"/>, the original content items are used.
/// The default is <see langword="true"/>.
/// </param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The combined <see cref="ChatCompletion"/>.</returns>
public static Task<ChatCompletion> ToChatCompletionAsync(
this IAsyncEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent = true, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(updates);

return ToChatCompletionAsync(updates, coalesceContent, cancellationToken);

static async Task<ChatCompletion> ToChatCompletionAsync(
IAsyncEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent, CancellationToken cancellationToken)
{
ChatCompletion completion = new([]);
Dictionary<int, ChatMessage> messages = [];

await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false))
{
ProcessUpdate(update, messages, completion);
}

AddMessagesToCompletion(messages, completion, coalesceContent);

return completion;
}
}

/// <summary>Processes the <see cref="StreamingChatCompletionUpdate"/>, incorporating its contents into <paramref name="messages"/> and <paramref name="completion"/>.</summary>
/// <param name="update">The update to process.</param>
/// <param name="messages">The dictionary mapping <see cref="StreamingChatCompletionUpdate.ChoiceIndex"/> to the <see cref="ChatMessage"/> being built for that choice.</param>
/// <param name="completion">The <see cref="ChatCompletion"/> object whose properties should be updated based on <paramref name="update"/>.</param>
private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictionary<int, ChatMessage> messages, ChatCompletion completion)
{
completion.CompletionId ??= update.CompletionId;
completion.CreatedAt ??= update.CreatedAt;
completion.FinishReason ??= update.FinishReason;
completion.ModelId ??= update.ModelId;

#if NET
ChatMessage message = CollectionsMarshal.GetValueRefOrAddDefault(messages, update.ChoiceIndex, out _) ??=
new(default, new List<AIContent>());
#else
if (!messages.TryGetValue(update.ChoiceIndex, out ChatMessage? message))
{
messages[update.ChoiceIndex] = message = new(default, new List<AIContent>());
}
#endif

((List<AIContent>)message.Contents).AddRange(update.Contents);

message.AuthorName ??= update.AuthorName;
if (update.Role is ChatRole role && message.Role == default)
{
message.Role = role;
}

if (update.AdditionalProperties is not null)
{
if (message.AdditionalProperties is null)
{
message.AdditionalProperties = new(update.AdditionalProperties);
}
else
{
foreach (var entry in update.AdditionalProperties)
{
// Use first-wins behavior to match the behavior of the other properties.
_ = message.AdditionalProperties.TryAdd(entry.Key, entry.Value);
}
}
}
}

/// <summary>Finalizes the <paramref name="completion"/> object by transferring the <paramref name="messages"/> into it.</summary>
/// <param name="messages">The messages to process further and transfer into <paramref name="completion"/>.</param>
/// <param name="completion">The result <see cref="ChatCompletion"/> being built.</param>
/// <param name="coalesceContent">The corresponding option value provided to <see cref="ToChatCompletion"/> or <see cref="ToChatCompletionAsync"/>.</param>
private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> messages, ChatCompletion completion, bool coalesceContent)
{
foreach (var entry in messages)
{
if (entry.Value.Role == default)
{
entry.Value.Role = ChatRole.Assistant;
}

if (coalesceContent)
{
CoalesceTextContent((List<AIContent>)entry.Value.Contents);
}

completion.Choices.Add(entry.Value);

if (completion.Usage is null)
{
foreach (var content in entry.Value.Contents)
{
if (content is UsageContent c)
{
completion.Usage = c.Details;
break;
}
}
}
}
}

/// <summary>Coalesces sequential <see cref="TextContent"/> content elements.</summary>
private static void CoalesceTextContent(List<AIContent> contents)
{
StringBuilder? coalescedText = null;

// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
int start = 0;
while (start < contents.Count - 1)
{
// We need at least two TextContents in a row to be able to coalesce.
if (contents[start] is not TextContent firstText)
{
start++;
continue;
}

if (contents[start + 1] is not TextContent secondText)
{
start += 2;
continue;
}

// Append the text from those nodes and continue appending subsequent TextContents until we run out.
// We null out nodes as their text is appended so that we can later remove them all in one O(N) operation.
coalescedText ??= new();
_ = coalescedText.Clear().Append(firstText.Text).Append(secondText.Text);
contents[start + 1] = null!;
int i = start + 2;
for (; i < contents.Count && contents[i] is TextContent next; i++)
{
_ = coalescedText.Append(next.Text);
contents[i] = null!;
}

// Store the replacement node.
contents[start] = new TextContent(coalescedText.ToString())
{
// We inherit the properties of the first text node. We don't currently propagate additional
// properties from the subsequent nodes. If we ever need to, we can add that here.
AdditionalProperties = firstText.AdditionalProperties?.Clone(),
};

start = i;
}

// Remove all of the null slots left over from the coalescing process.
_ = contents.RemoveAll(u => u is null);
}
}
Loading
Loading