From fe1d68fe5487e7e4e69b7d67c75a9fafea090775 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 27 Feb 2025 14:52:47 -0500 Subject: [PATCH 1/9] Remove KeepFunctionCallContentRemove Choices --- .../ChatCompletion/ChatResponse.cs | 100 +++------- .../ChatCompletion/ChatResponseUpdate.cs | 9 +- .../ChatResponseUpdateExtensions.cs | 100 +++------- .../AzureAIInferenceChatClient.cs | 7 +- .../OllamaChatClient.cs | 2 +- .../OpenAIModelMapper.ChatCompletion.cs | 7 +- .../ChatCompletion/ChatResponse{T}.cs | 5 +- .../FunctionInvokingChatClient.cs | 98 +-------- .../ChatCompletion/OpenTelemetryChatClient.cs | 14 +- .../ChatClientExtensionsTests.cs | 2 +- .../ChatCompletion/ChatResponseTests.cs | 187 ++---------------- .../ChatResponseUpdateExtensionsTests.cs | 64 ++---- .../ChatCompletion/ChatResponseUpdateTests.cs | 7 - .../DelegatingChatClientTests.cs | 2 +- .../AzureAIInferenceChatClientTests.cs | 1 - .../ChatClientIntegrationTests.cs | 10 +- .../PromptBasedFunctionCallingChatClient.cs | 4 +- .../ReducingChatClientTests.cs | 2 +- .../OllamaChatClientIntegrationTests.cs | 2 - .../OllamaChatClientTests.cs | 1 - .../OpenAIChatClientTests.cs | 1 - .../OpenAISerializationTests.cs | 24 +-- .../ConfigureOptionsChatClientTests.cs | 2 +- .../DistributedCachingChatClientTest.cs | 79 +++----- .../FunctionInvokingChatClientTests.cs | 158 ++------------- .../ChatCompletion/LoggingChatClientTests.cs | 2 +- .../UseDelegateChatClientTests.cs | 4 +- 27 files changed, 155 insertions(+), 739 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index f789fc7f974..b6344c6e8fe 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Text; using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; @@ -12,50 +11,28 @@ namespace Microsoft.Extensions.AI; /// Represents the response to a chat request. public class ChatResponse { - /// The list of choices in the response. - private IList _choices; + /// The response message. + private ChatMessage _message; /// Initializes a new instance of the class. - /// The list of choices in the response, one message per choice. - [JsonConstructor] - public ChatResponse(IList choices) + public ChatResponse() { - _choices = Throw.IfNull(choices); + _message = new(ChatRole.Assistant, []); } /// Initializes a new instance of the class. - /// The chat message representing the singular choice in the response. + /// The response message. public ChatResponse(ChatMessage message) { _ = Throw.IfNull(message); - _choices = [message]; + _message = message; } - /// Gets or sets the list of chat response choices. - public IList Choices - { - get => _choices; - set => _choices = Throw.IfNull(value); - } - - /// Gets the chat response message. - /// - /// If there are multiple choices, this property returns the first choice. - /// If is empty, this property will throw. Use to access all choices directly. - /// - [JsonIgnore] + /// Gets or sets the chat response message. public ChatMessage Message { - get - { - var choices = Choices; - if (choices.Count == 0) - { - throw new InvalidOperationException($"The {nameof(ChatResponse)} instance does not contain any {nameof(ChatMessage)} choices."); - } - - return choices[0]; - } + get => _message; + set => _message = Throw.IfNull(value); } /// Gets or sets the ID of the chat response. @@ -96,26 +73,7 @@ public ChatMessage Message public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() - { - if (Choices.Count == 1) - { - return Choices[0].ToString(); - } - - StringBuilder sb = new(); - for (int i = 0; i < Choices.Count; i++) - { - if (i > 0) - { - _ = sb.AppendLine().AppendLine(); - } - - _ = sb.Append("Choice ").Append(i).AppendLine(":").Append(Choices[i]); - } - - return sb.ToString(); - } + public override string ToString() => _message.ToString(); /// Creates an array of instances that represent this . /// An array of instances that may be used to represent this . @@ -135,33 +93,27 @@ public ChatResponseUpdate[] ToChatResponseUpdates() } } - int choicesCount = Choices.Count; - var updates = new ChatResponseUpdate[choicesCount + (extra is null ? 0 : 1)]; + var updates = new ChatResponseUpdate[extra is null ? 1 : 2]; - for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++) + updates[0] = new ChatResponseUpdate { - ChatMessage choice = Choices[choiceIndex]; - updates[choiceIndex] = new ChatResponseUpdate - { - ChatThreadId = ChatThreadId, - ChoiceIndex = choiceIndex, - - AdditionalProperties = choice.AdditionalProperties, - AuthorName = choice.AuthorName, - Contents = choice.Contents, - RawRepresentation = choice.RawRepresentation, - Role = choice.Role, - - ResponseId = ResponseId, - CreatedAt = CreatedAt, - FinishReason = FinishReason, - ModelId = ModelId - }; - } + ChatThreadId = ChatThreadId, + + AdditionalProperties = _message.AdditionalProperties, + AuthorName = _message.AuthorName, + Contents = _message.Contents, + RawRepresentation = _message.RawRepresentation, + Role = _message.Role, + + ResponseId = ResponseId, + CreatedAt = CreatedAt, + FinishReason = FinishReason, + ModelId = ModelId + }; if (extra is not null) { - updates[choicesCount] = extra; + updates[1] = extra; } return updates; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 8bf9e57ece2..696dc91cbe1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -15,9 +15,7 @@ namespace Microsoft.Extensions.AI; /// /// is so named because it represents updates /// that layer on each other to form a single chat response. Conceptually, this combines the roles of -/// and in streaming output. For ease of consumption, -/// it also flattens the nested structure you see on streaming chunks in some AI services, so instead of a -/// dictionary of choices, each update is part of a single choice (and hence has its own role, choice ID, etc.). +/// and in streaming output. /// /// /// The relationship between and is @@ -26,7 +24,7 @@ namespace Microsoft.Extensions.AI; /// between the two. Note, however, that the provided conversions may be lossy, for example if multiple /// updates all have different objects whereas there's only one slot for /// such an object available in . Similarly, if different -/// updates that are part of the same choice provide different values for properties like , +/// updates provide different values for properties like , /// only one of the values will be used to populate . /// /// @@ -108,9 +106,6 @@ public IList Contents /// Gets or sets a timestamp for the response update. public DateTimeOffset? CreatedAt { get; set; } - /// Gets or sets the zero-based index of the choice with which this update is associated in the streaming sequence. - public int ChoiceIndex { get; set; } - /// Gets or sets the finish reason for the operation. public ChatFinishReason? FinishReason { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs index 25104461cd9..3fb2eb139b5 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; -#if NET -using System.Runtime.InteropServices; -#endif using System.Text; using System.Threading; using System.Threading.Tasks; @@ -36,15 +32,14 @@ public static ChatResponse ToChatResponse( { _ = Throw.IfNull(updates); - ChatResponse response = new([]); - Dictionary messages = []; + ChatResponse response = new(new(default, [])); foreach (var update in updates) { - ProcessUpdate(update, messages, response); + ProcessUpdate(update, response); } - AddMessagesToResponse(messages, response, coalesceContent); + FinalizeResponse(response, coalesceContent); return response; } @@ -69,25 +64,23 @@ public static Task ToChatResponseAsync( static async Task ToChatResponseAsync( IAsyncEnumerable updates, bool coalesceContent, CancellationToken cancellationToken) { - ChatResponse response = new([]); - Dictionary messages = []; + ChatResponse response = new(new(default, [])); await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) { - ProcessUpdate(update, messages, response); + ProcessUpdate(update, response); } - AddMessagesToResponse(messages, response, coalesceContent); + FinalizeResponse(response, coalesceContent); return response; } } - /// Processes the , incorporating its contents into and . + /// Processes the , incorporating its contents into . /// The update to process. - /// The dictionary mapping to the being built for that choice. - /// The object whose properties should be updated based on . - private static void ProcessUpdate(ChatResponseUpdate update, Dictionary messages, ChatResponse response) + /// The object that should be updated based on . + private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse response) { response.ChatThreadId ??= update.ChatThreadId; response.CreatedAt ??= update.CreatedAt; @@ -95,16 +88,6 @@ private static void ProcessUpdate(ChatResponseUpdate update, Dictionary()); -#else - if (!messages.TryGetValue(update.ChoiceIndex, out ChatMessage? message)) - { - messages[update.ChoiceIndex] = message = new(default, new List()); - } -#endif - // Incorporate all content from the update into the response. foreach (var content in update.Contents) { @@ -116,84 +99,45 @@ private static void ProcessUpdate(ChatResponseUpdate update, DictionaryFinalizes the object by transferring the into it. - /// The messages to process further and transfer into . - /// The result being built. - /// The corresponding option value provided to or . - private static void AddMessagesToResponse(Dictionary messages, ChatResponse response, bool coalesceContent) + /// Finalizes the object. + private static void FinalizeResponse(ChatResponse response, bool coalesceContent) { - if (messages.Count <= 1) + if (response.Message.Role == default) { - // Add the single message if there is one. - foreach (var entry in messages) - { - AddMessage(response, coalesceContent, entry); - } - - // In the vast majority case where there's only one choice, promote any additional properties - // from the single message to the chat response, making them more discoverable and more similar - // to how they're typically surfaced from non-streaming services. - if (response.Choices.Count == 1 && - response.Choices[0].AdditionalProperties is { } messageProps) - { - response.Choices[0].AdditionalProperties = null; - response.AdditionalProperties = messageProps; - } - } - else - { - // Add all of the messages, sorted by choice index. - foreach (var entry in messages.OrderBy(entry => entry.Key)) - { - AddMessage(response, coalesceContent, entry); - } - - // If there are multiple choices, we don't promote additional properties from the individual messages. - // At a minimum, we'd want to know which choice the additional properties applied to, and if there were - // conflicting values across the choices, it would be unclear which one should be used. + response.Message.Role = ChatRole.Assistant; } - static void AddMessage(ChatResponse response, bool coalesceContent, KeyValuePair entry) + if (coalesceContent) { - if (entry.Value.Role == default) - { - entry.Value.Role = ChatRole.Assistant; - } - - if (coalesceContent) - { - CoalesceTextContent((List)entry.Value.Contents); - } - - response.Choices.Add(entry.Value); + CoalesceTextContent((List)response.Message.Contents); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 2f527612fab..7a44dbf3e57 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -91,9 +91,6 @@ public async Task GetResponseAsync( cancellationToken: cancellationToken).ConfigureAwait(false)).Value; // Create the return message. - List returnMessages = []; - - // Populate its content from those in the response content. ChatMessage message = new() { RawRepresentation = response, @@ -119,8 +116,6 @@ public async Task GetResponseAsync( } } - returnMessages.Add(message); - UsageDetails? usage = null; if (response.Usage is CompletionsUsage completionsUsage) { @@ -133,7 +128,7 @@ public async Task GetResponseAsync( } // Wrap the content in a ChatResponse to return. - return new ChatResponse(returnMessages) + return new ChatResponse(message) { CreatedAt = response.Created, ModelId = response.Model, diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index d3f45358d10..ae18a430c45 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -102,7 +102,7 @@ public async Task GetResponseAsync(IList chatMessages throw new InvalidOperationException($"Ollama error: {response.Error}"); } - return new([FromOllamaMessage(response.Message!)]) + return new(FromOllamaMessage(response.Message!)) { CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(response), diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index f5c21be3678..e67fa627f3f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -30,11 +30,6 @@ public static ChatCompletion ToOpenAIChatCompletion(ChatResponse response, JsonS { _ = Throw.IfNull(response); - if (response.Choices.Count > 1) - { - throw new NotSupportedException("Creating OpenAI ChatCompletion models with multiple choices is currently not supported."); - } - List? toolCalls = null; foreach (AIContent content in response.Message.Contents) { @@ -138,7 +133,7 @@ public static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComplet } // Wrap the content in a ChatResponse to return. - var response = new ChatResponse([returnMessage]) + var response = new ChatResponse(returnMessage) { CreatedAt = openAICompletion.CreatedAt, FinishReason = FromOpenAIFinishReason(openAICompletion.FinishReason), diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs index e78a0acf1f5..a02792fbcf3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs @@ -31,7 +31,7 @@ public class ChatResponse : ChatResponse /// The unstructured that is being wrapped. /// The to use when deserializing the result. public ChatResponse(ChatResponse response, JsonSerializerOptions serializerOptions) - : base(Throw.IfNull(response).Choices) + : base(Throw.IfNull(response).Message) { _serializerOptions = Throw.IfNull(serializerOptions); AdditionalProperties = response.AdditionalProperties; @@ -116,8 +116,7 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) private string? GetResultAsJson() { - var choice = Choices.Count == 1 ? Choices[0] : null; - var content = choice?.Contents.Count == 1 ? choice.Contents[0] : null; + var content = Message.Contents.Count == 1 ? Message.Contents[0] : null; return (content as TextContent)?.Text; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index a64ebf7d61d..af93f59485c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -139,40 +139,6 @@ public static FunctionInvocationContext? CurrentContext /// public bool AllowConcurrentInvocation { get; set; } - /// - /// Gets or sets a value indicating whether to keep intermediate function calling request - /// and response messages in the chat history. - /// - /// - /// if intermediate messages persist in the list provided - /// to and by the caller. - /// if intermediate messages are removed prior to completing the operation. - /// The default value is . - /// - /// - /// - /// When the inner returns to the - /// , the adds - /// those messages to the list of messages, along with instances - /// it creates with the results of invoking the requested functions. The resulting augmented - /// list of messages is then passed to the inner client in order to send the results back. - /// By default, those messages persist in the list provided to - /// and by the caller, such that those - /// messages are available to the caller. Set to avoid including - /// those messages in the caller-provided . - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether function calling messages are kept during an in-flight request. - /// - /// - /// If the underlying responds with - /// set to a non- value, this property may be ignored and behave as if it is - /// , with any such intermediate messages not stored in the messages list. - /// - /// - public bool KeepFunctionCallingContent { get; set; } = true; - /// /// Gets or sets the maximum number of iterations per request. /// @@ -237,24 +203,12 @@ public override async Task GetResponseAsync(IList cha // If there are no tools to call, or for any other reason we should stop, return the response. if (options is null || options.Tools is not { Count: > 0 } - || response.Choices.Count == 0 || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) { break; } - // If there's more than one choice, we don't know which one to add to chat history, or which - // of their function calls to process. This should not happen except if the developer has - // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer - // doesn't realize this and is wasting their budget requesting extra choices we'd never use. - if (response.Choices.Count > 1) - { - ThrowForMultipleChoices(); - } - - // Extract any function call contents on the first choice. If there are none, we're done. - // We don't have any way to express a preference to use a different choice, since this - // is a niche case especially with function calling. + // Extract any function call contents. If there are none, we're done. FunctionCallContent[] functionCallContents = response.Message.Contents.OfType().ToArray(); if (functionCallContents.Length == 0) { @@ -276,27 +230,6 @@ public override async Task GetResponseAsync(IList cha } else { - // Otherwise, we need to add the response message to the history we're sending back. However, if the caller - // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. - if (!KeepFunctionCallingContent) - { - // Create a new list that will include the message with the function call contents. - if (chatMessages == originalChatMessages) - { - chatMessages = [.. chatMessages]; - } - - // We want to include any non-functional calling content, if there is any, - // in the caller's list so that they don't lose out on actual content. - // This can happen but is relatively rare. - if (response.Message.Contents.Any(c => c is not FunctionCallContent)) - { - var clone = response.Message.Clone(); - clone.Contents = clone.Contents.Where(c => c is not FunctionCallContent).ToList(); - originalChatMessages.Add(clone); - } - } - // Add the original response message into the history. chatMessages.Add(response.Message); } @@ -332,11 +265,9 @@ public override async IAsyncEnumerable GetStreamingResponseA using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); List functionCallContents = []; - int? choice; IList originalChatMessages = chatMessages; for (int iteration = 0; ; iteration++) { - choice = null; string? chatThreadId = null; functionCallContents.Clear(); await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) @@ -363,16 +294,6 @@ public override async IAsyncEnumerable GetStreamingResponseA [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); } - // Only one choice is allowed with automatic function calling. - if (choice is null) - { - choice = update.ChoiceIndex; - } - else if (choice != update.ChoiceIndex) - { - ThrowForMultipleChoices(); - } - chatThreadId ??= update.ChatThreadId; yield return update; @@ -403,13 +324,6 @@ public override async IAsyncEnumerable GetStreamingResponseA } else { - // Otherwise, we need to add the response message to the history we're sending back. However, if the caller - // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. - if (chatMessages == originalChatMessages && !KeepFunctionCallingContent) - { - chatMessages = [.. chatMessages]; - } - // Add a manufactured response message containing the function call contents to the chat history. chatMessages.Add(new(ChatRole.Assistant, [.. functionCallContents])); } @@ -424,16 +338,6 @@ public override async IAsyncEnumerable GetStreamingResponseA } } - /// Throws an exception when multiple choices are received. - private static void ThrowForMultipleChoices() - { - // If there's more than one choice, we don't know which one to add to chat history, or which - // of their function calls to process. This should not happen except if the developer has - // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer - // doesn't realize this and is wasting their budget requesting extra choices we'd never use. - throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received."); - } - /// Updates for the response. /// true if the function calling loop should terminate; otherwise, false. private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index 1ae5f83b4b2..9809872f1d0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -482,16 +482,12 @@ private void LogChatResponse(ChatResponse response) } EventId id = new(1, OpenTelemetryConsts.GenAI.Choice); - int choiceCount = response.Choices.Count; - for (int choiceIndex = 0; choiceIndex < choiceCount; choiceIndex++) + Log(id, JsonSerializer.Serialize(new() { - Log(id, JsonSerializer.Serialize(new() - { - FinishReason = response.FinishReason?.Value ?? "error", - Index = choiceIndex, - Message = CreateAssistantEvent(response.Choices[choiceIndex]), - }, OtelContext.Default.ChoiceEvent)); - } + FinishReason = response.FinishReason?.Value ?? "error", + Index = 0, + Message = CreateAssistantEvent(response.Message), + }, OtelContext.Default.ChoiceEvent)); } private void Log(EventId id, [StringSyntax(StringSyntaxAttribute.Json)] string eventBodyJson) diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 5a95f2b3fd0..9671d2bc602 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -100,7 +100,7 @@ public void GetStreamingResponseAsync_InvalidArgs_Throws() [Fact] public async Task GetResponseAsync_CreatesTextMessageAsync() { - var expectedResponse = new ChatResponse([new ChatMessage()]); + var expectedResponse = new ChatResponse(); var expectedOptions = new ChatOptions(); using var cts = new CancellationTokenSource(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index e222b6d5215..f3536bd116f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Text.Json; using Xunit; @@ -13,78 +12,29 @@ public class ChatResponseTests [Fact] public void Constructor_InvalidArgs_Throws() { - Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); - Assert.Throws("choices", () => new ChatResponse((IList)null!)); + Assert.Throws("message", () => new ChatResponse(null!)); } [Fact] public void Constructor_Message_Roundtrips() { - ChatMessage message = new(); - - ChatResponse response = new(message); - Assert.Same(message, response.Message); - Assert.Same(message, Assert.Single(response.Choices)); - } - - [Fact] - public void Constructor_Choices_Roundtrips() - { - List messages = - [ - new ChatMessage(), - new ChatMessage(), - new ChatMessage(), - ]; - - ChatResponse response = new(messages); - Assert.Same(messages, response.Choices); - Assert.Equal(3, messages.Count); - } - - [Fact] - public void Message_EmptyChoices_Throws() - { - ChatResponse response = new([]); + ChatResponse response = new(); + Assert.NotNull(response.Message); + Assert.Same(response.Message, response.Message); - Assert.Empty(response.Choices); - Assert.Throws(() => response.Message); - } - - [Fact] - public void Message_SingleChoice_Returned() - { ChatMessage message = new(); - ChatResponse response = new([message]); - + response = new(message); Assert.Same(message, response.Message); - Assert.Same(message, response.Choices[0]); - } - - [Fact] - public void Message_MultipleChoices_ReturnsFirst() - { - ChatMessage first = new(); - ChatResponse response = new([ - first, - new ChatMessage(), - ]); - - Assert.Same(first, response.Message); - Assert.Same(first, response.Choices[0]); - } - [Fact] - public void Choices_SetNull_Throws() - { - ChatResponse response = new([]); - Assert.Throws("value", () => response.Choices = null!); + message = new(); + response.Message = message; + Assert.Same(message, response.Message); } [Fact] public void Properties_Roundtrip() { - ChatResponse response = new([]); + ChatResponse response = new(); Assert.Null(response.ResponseId); response.ResponseId = "id"; @@ -116,22 +66,12 @@ public void Properties_Roundtrip() AdditionalPropertiesDictionary additionalProps = []; response.AdditionalProperties = additionalProps; Assert.Same(additionalProps, response.AdditionalProperties); - - List newChoices = [new ChatMessage(), new ChatMessage()]; - response.Choices = newChoices; - Assert.Same(newChoices, response.Choices); } [Fact] public void JsonSerialization_Roundtrips() { - ChatResponse original = new( - [ - new ChatMessage(ChatRole.Assistant, "Choice1"), - new ChatMessage(ChatRole.Assistant, "Choice2"), - new ChatMessage(ChatRole.Assistant, "Choice3"), - new ChatMessage(ChatRole.Assistant, "Choice4"), - ]) + ChatResponse original = new(new(ChatRole.Assistant, "the message")) { ResponseId = "id", ModelId = "modelId", @@ -147,13 +87,8 @@ public void JsonSerialization_Roundtrips() ChatResponse? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponse); Assert.NotNull(result); - Assert.Equal(4, result.Choices.Count); - - for (int i = 0; i < original.Choices.Count; i++) - { - Assert.Equal(ChatRole.Assistant, result.Choices[i].Role); - Assert.Equal($"Choice{i + 1}", result.Choices[i].Text); - } + Assert.Equal(ChatRole.Assistant, result.Message.Role); + Assert.Equal("the message", result.Message.Text); Assert.Equal("id", result.ResponseId); Assert.Equal("modelId", result.ModelId); @@ -169,41 +104,15 @@ public void JsonSerialization_Roundtrips() } [Fact] - public void ToString_OneChoice_OutputsChatMessageToString() + public void ToString_OutputsChatMessageToString() { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, "This is a test." + Environment.NewLine + "It's multiple lines.") - ]); + ChatResponse response = new(new(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); - Assert.Equal(response.Choices[0].Text, response.ToString()); + Assert.Equal(response.Message.ToString(), response.ToString()); } [Fact] - public void ToString_MultipleChoices_OutputsAllChoicesWithPrefix() - { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, "This is a test." + Environment.NewLine + "It's multiple lines."), - new ChatMessage(ChatRole.Assistant, "So is" + Environment.NewLine + " this."), - new ChatMessage(ChatRole.Assistant, "And this."), - ]); - - Assert.Equal( - "Choice 0:" + Environment.NewLine + - response.Choices[0] + Environment.NewLine + Environment.NewLine + - - "Choice 1:" + Environment.NewLine + - response.Choices[1] + Environment.NewLine + Environment.NewLine + - - "Choice 2:" + Environment.NewLine + - response.Choices[2], - - response.ToString()); - } - - [Fact] - public void ToChatResponseUpdates_SingleChoice() + public void ToChatResponseUpdates() { ChatResponse response = new(new ChatMessage(new ChatRole("customRole"), "Text")) { @@ -230,68 +139,4 @@ public void ToChatResponseUpdates_SingleChoice() Assert.Equal("value1", update1.AdditionalProperties?["key1"]); Assert.Equal(42, update1.AdditionalProperties?["key2"]); } - - [Fact] - public void ToChatResponseUpdates_MultiChoice() - { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, - [ - new TextContent("Hello, "), - new DataContent("http://localhost/image.png", mediaType: "image/png"), - new TextContent("world!"), - ]) - { - AdditionalProperties = new() { ["choice1Key"] = "choice1Value" }, - }, - - new ChatMessage(ChatRole.System, - [ - new FunctionCallContent("call123", "name"), - new FunctionResultContent("call123", 42), - ]) - { - AdditionalProperties = new() { ["choice2Key"] = "choice2Value" }, - }, - ]) - { - ResponseId = "12345", - ModelId = "someModel", - FinishReason = ChatFinishReason.ContentFilter, - CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), - AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 }, - Usage = new UsageDetails { TotalTokenCount = 123 }, - }; - - ChatResponseUpdate[] updates = response.ToChatResponseUpdates(); - Assert.NotNull(updates); - Assert.Equal(3, updates.Length); - - ChatResponseUpdate update0 = updates[0]; - Assert.Equal("12345", update0.ResponseId); - Assert.Equal("someModel", update0.ModelId); - Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason); - Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt); - Assert.Equal("assistant", update0.Role?.Value); - Assert.Equal("Hello, ", Assert.IsType(update0.Contents[0]).Text); - Assert.Equal("image/png", Assert.IsType(update0.Contents[1]).MediaType); - Assert.Equal("world!", Assert.IsType(update0.Contents[2]).Text); - Assert.Equal("choice1Value", update0.AdditionalProperties?["choice1Key"]); - - ChatResponseUpdate update1 = updates[1]; - Assert.Equal("12345", update1.ResponseId); - Assert.Equal("someModel", update1.ModelId); - Assert.Equal(ChatFinishReason.ContentFilter, update1.FinishReason); - Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update1.CreatedAt); - Assert.Equal("system", update1.Role?.Value); - Assert.IsType(update1.Contents[0]); - Assert.IsType(update1.Contents[1]); - Assert.Equal("choice2Value", update1.AdditionalProperties?["choice2Key"]); - - ChatResponseUpdate update2 = updates[2]; - Assert.Equal("value1", update2.AdditionalProperties?["key1"]); - Assert.Equal(42, update2.AdditionalProperties?["key2"]); - Assert.Equal(123, Assert.IsType(Assert.Single(update2.Contents)).Details.TotalTokenCount); - } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index fea25191aff..d818420359a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -38,17 +38,12 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool { ChatResponseUpdate[] updates = [ - new() { ChoiceIndex = 0, Text = "Hello", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, - new() { ChoiceIndex = 1, Text = "Hey", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model124" }, + new() { Text = "Hello", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, + new() { Text = ", ", AuthorName = "Someone", Role = new ChatRole("human"), AdditionalProperties = new() { ["a"] = "b" } }, + new() { Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), ChatThreadId = "123", AdditionalProperties = new() { ["c"] = "d" } }, - new() { ChoiceIndex = 0, Text = ", ", AuthorName = "Someone", Role = ChatRole.User, AdditionalProperties = new() { ["a"] = "b" } }, - new() { ChoiceIndex = 1, Text = ", ", AuthorName = "Else", Role = ChatRole.System, ChatThreadId = "123", AdditionalProperties = new() { ["g"] = "h" } }, - - new() { ChoiceIndex = 0, Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["c"] = "d" } }, - new() { ChoiceIndex = 1, Text = "you!", Role = ChatRole.Tool, CreatedAt = new DateTimeOffset(3, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["e"] = "f", ["i"] = 42 } }, - - new() { ChoiceIndex = 0, Contents = new[] { new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 }) } }, - new() { ChoiceIndex = 3, Contents = new[] { new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 }) } }, + new() { Contents = new[] { new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 }) } }, + new() { Contents = new[] { new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 }) } }, ]; ChatResponse response = (coalesceContent is bool, useAsync) switch @@ -71,48 +66,25 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool Assert.Equal("123", response.ChatThreadId); - Assert.Equal(3, response.Choices.Count); - - ChatMessage message = response.Choices[0]; - Assert.Equal(ChatRole.User, message.Role); + ChatMessage message = response.Message; + Assert.Equal(new ChatRole("human"), message.Role); Assert.Equal("Someone", message.AuthorName); - Assert.NotNull(message.AdditionalProperties); - Assert.Equal(2, message.AdditionalProperties.Count); - Assert.Equal("b", message.AdditionalProperties["a"]); - Assert.Equal("d", message.AdditionalProperties["c"]); - - message = response.Choices[1]; - Assert.Equal(ChatRole.System, message.Role); - Assert.Equal("Else", message.AuthorName); - Assert.NotNull(message.AdditionalProperties); - Assert.Equal(3, message.AdditionalProperties.Count); - Assert.Equal("h", message.AdditionalProperties["g"]); - Assert.Equal("f", message.AdditionalProperties["e"]); - Assert.Equal(42, message.AdditionalProperties["i"]); - - message = response.Choices[2]; - Assert.Equal(ChatRole.Assistant, message.Role); - Assert.Null(message.AuthorName); Assert.Null(message.AdditionalProperties); - Assert.Empty(message.Contents); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal(2, response.AdditionalProperties.Count); + Assert.Equal("b", response.AdditionalProperties["a"]); + Assert.Equal("d", response.AdditionalProperties["c"]); if (coalesceContent is null or true) { - Assert.Equal("Hello, world!", response.Choices[0].Text); - Assert.Equal("Hey, you!", response.Choices[1].Text); - Assert.Null(response.Choices[2].Text); + Assert.Equal("Hello, world!", response.Message.Text); } else { - Assert.Equal("Hello", response.Choices[0].Contents[0].ToString()); - Assert.Equal(", ", response.Choices[0].Contents[1].ToString()); - Assert.Equal("world!", response.Choices[0].Contents[2].ToString()); - - Assert.Equal("Hey", response.Choices[1].Contents[0].ToString()); - Assert.Equal(", ", response.Choices[1].Contents[1].ToString()); - Assert.Equal("you!", response.Choices[1].Contents[2].ToString()); - - Assert.Null(response.Choices[2].Text); + Assert.Equal("Hello", response.Message.Contents[0].ToString()); + Assert.Equal(", ", response.Message.Contents[1].ToString()); + Assert.Equal("world!", response.Message.Contents[2].ToString()); } } @@ -181,9 +153,11 @@ void AddGap() } ChatResponse response = useAsync ? await YieldAsync(updates).ToChatResponseAsync() : updates.ToChatResponse(); - Assert.Single(response.Choices); + Assert.NotNull(response); ChatMessage message = response.Message; + Assert.NotNull(message); + Assert.Equal(expected.Count + (gapLength * ((numSequences - 1) + (gapBeginningEnd ? 2 : 0))), message.Contents.Count); TextContent[] contents = message.Contents.OfType().ToArray(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs index be4108f8148..4bb9e5ae0b3 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs @@ -23,7 +23,6 @@ public void Constructor_PropsDefaulted() Assert.Null(update.ResponseId); Assert.Null(update.CreatedAt); Assert.Null(update.FinishReason); - Assert.Equal(0, update.ChoiceIndex); Assert.Equal(string.Empty, update.ToString()); } @@ -74,10 +73,6 @@ public void Properties_Roundtrip() update.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), update.CreatedAt); - Assert.Equal(0, update.ChoiceIndex); - update.ChoiceIndex = 42; - Assert.Equal(42, update.ChoiceIndex); - Assert.Null(update.FinishReason); update.FinishReason = ChatFinishReason.ContentFilter; Assert.Equal(ChatFinishReason.ContentFilter, update.FinishReason); @@ -179,7 +174,6 @@ public void JsonSerialization_Roundtrips() CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), FinishReason = ChatFinishReason.ContentFilter, AdditionalProperties = new() { ["key"] = "value" }, - ChoiceIndex = 42, }; string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.ChatResponseUpdate); @@ -209,7 +203,6 @@ public void JsonSerialization_Roundtrips() Assert.Equal("id", result.ResponseId); Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); - Assert.Equal(42, result.ChoiceIndex); Assert.NotNull(result.AdditionalProperties); Assert.Single(result.AdditionalProperties); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs index d7d265018b0..5d9170c77e3 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -25,7 +25,7 @@ public async Task ChatAsyncDefaultsToInnerClientAsync() var expectedChatOptions = new ChatOptions(); var expectedCancellationToken = CancellationToken.None; var expectedResult = new TaskCompletionSource(); - var expectedResponse = new ChatResponse([]); + var expectedResponse = new ChatResponse(); using var inner = new TestChatClient { GetResponseAsyncCallback = (chatContents, options, cancellationToken) => diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index c86dbd756b5..7d2d0a6c9ab 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -811,7 +811,6 @@ public async Task FunctionCallContent_NonStreaming(ChatToolMode mode) Assert.Equal(16, response.Usage.OutputTokenCount); Assert.Equal(77, response.Usage.TotalTokenCount); - Assert.Single(response.Choices); Assert.Single(response.Message.Contents); FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index eaf4834e60d..581e21daaea 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -65,13 +65,12 @@ public virtual async Task GetResponseAsync_MultipleRequestMessages() new(ChatRole.User, "What continent are they each in?"), ]); - Assert.Single(response.Choices); Assert.Contains("America", response.Message.Text); Assert.Contains("Asia", response.Message.Text); } [ConditionalFact] - public virtual async Task GetStreamingResponseAsync_SingleStreamingResponseChoice() + public virtual async Task GetStreamingResponseAsync() { SkipIfNotEnabled(); @@ -101,7 +100,6 @@ public virtual async Task GetResponseAsync_UsageDataAvailable() var response = await _chatClient.GetResponseAsync("Explain in 10 words how AI works"); - Assert.Single(response.Choices); Assert.True(response.Usage?.InputTokenCount > 1); Assert.True(response.Usage?.OutputTokenCount > 1); Assert.Equal(response.Usage?.InputTokenCount + response.Usage?.OutputTokenCount, response.Usage?.TotalTokenCount); @@ -151,7 +149,6 @@ public virtual async Task MultiModal_DescribeImage() ], new() { ModelId = GetModel_MultiModal_DescribeImage() }); - Assert.Single(response.Choices); Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text); } @@ -182,7 +179,6 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Paramet Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] }); - Assert.Single(response.Choices); Assert.Contains(secretNumber.ToString(), response.Message.Text); // If the underlying IChatClient provides usage data, function invocation should aggregate the @@ -208,7 +204,6 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithPar Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] }); - Assert.Single(response.Choices); Assert.Contains("3528", response.Message.Text); } @@ -285,7 +280,6 @@ public virtual async Task FunctionInvocation_RequireAny() ToolMode = ChatToolMode.RequireAny, }); - Assert.Single(response.Choices); Assert.True(callCount >= 1); } @@ -317,7 +311,6 @@ public virtual async Task Caching_OutputVariesWithoutCaching() var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); var firstResponse = await _chatClient.GetResponseAsync([message]); - Assert.Single(firstResponse.Choices); var secondResponse = await _chatClient.GetResponseAsync([message]); Assert.NotEqual(firstResponse.Message.Text, secondResponse.Message.Text); @@ -334,7 +327,6 @@ public virtual async Task Caching_SamePromptResultsInCacheHit_NonStreaming() var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); var firstResponse = await chatClient.GetResponseAsync([message]); - Assert.Single(firstResponse.Choices); // No matter what it said before, we should see identical output due to caching for (int i = 0; i < 3; i++) diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs index 1cf786bb288..96fab4d244c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -82,10 +82,10 @@ public override async Task GetResponseAsync(IList cha var result = await base.GetResponseAsync(chatMessages, options, cancellationToken); - if (result.Choices.FirstOrDefault()?.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos + if (result.Message.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos && startPos >= 0) { - var message = result.Choices.First(); + var message = result.Message; var contentItem = message.Contents.SingleOrDefault(); content = content.Substring(startPos); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs index eeccd609a93..ec7ca3c2cf0 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs @@ -33,7 +33,7 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit() Assert.Collection(messages, m => Assert.StartsWith("Golden retrievers are quite active", m.Text, StringComparison.Ordinal), m => Assert.StartsWith("Are they good with kids?", m.Text, StringComparison.Ordinal)); - return Task.FromResult(new ChatResponse([])); + return Task.FromResult(new ChatResponse()); } }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs index 09328dd8ce6..43fad43438c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -47,7 +47,6 @@ public async Task PromptBasedFunctionCalling_NoArgs() Seed = 0, }); - Assert.Single(response.Choices); Assert.Contains(secretNumber.ToString(), response.Message.Text); } @@ -82,7 +81,6 @@ public async Task PromptBasedFunctionCalling_WithArgs() Seed = 0, }); - Assert.Single(response.Choices); Assert.Contains("999", response.Message.Text); Assert.False(didCallIrrelevantTool); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index 439ca29a3ec..dc173307921 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -380,7 +380,6 @@ public async Task FunctionCallContent_NonStreaming() Assert.Equal(19, response.Usage.OutputTokenCount); Assert.Equal(189, response.Usage.TotalTokenCount); - Assert.Single(response.Choices); Assert.Single(response.Message.Contents); FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs index 789f0abeb63..6495d5f957e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -791,7 +791,6 @@ public async Task FunctionCallContent_NonStreaming() { "OutputTokenDetails.RejectedPredictionTokenCount", 0 }, }, response.Usage.AdditionalCounts); - Assert.Single(response.Choices); Assert.Single(response.Message.Contents); FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs index 51947ae0c8e..752e44dc388 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs @@ -460,7 +460,7 @@ public static async Task RequestDeserialization_ToolChatMessage() } [Fact] - public static async Task SerializeResponse_SingleChoice() + public static async Task SerializeResponse() { ChatMessage message = new() { @@ -558,28 +558,6 @@ public static async Task SerializeResponse_SingleChoice() """, result); } - [Fact] - public static async Task SerializeResponse_ManyChoices_ThrowsNotSupportedException() - { - ChatMessage message1 = new() - { - Role = ChatRole.Assistant, - Text = "Hello! How can I assist you today?", - }; - - ChatMessage message2 = new() - { - Role = ChatRole.Assistant, - Text = "Hey there! How can I help?", - }; - - ChatResponse response = new([message1, message2]); - - using MemoryStream stream = new(); - var ex = await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(stream, response)); - Assert.Contains("multiple choices", ex.Message); - } - [Fact] public static async Task SerializeStreamingResponse() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs index d1ae1c21ebe..0cddb58e006 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs @@ -34,7 +34,7 @@ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullP { ChatOptions? providedOptions = nullProvidedOptions ? null : new() { ModelId = "test" }; ChatOptions? returnedOptions = null; - ChatResponse expectedResponse = new(Array.Empty()); + ChatResponse expectedResponse = new(); var expectedUpdates = Enumerable.Range(0, 3).Select(i => new ChatResponseUpdate()).ToArray(); using CancellationTokenSource cts = new(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index b87b866c50f..ab67d0e3376 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -39,7 +39,7 @@ public async Task CachesSuccessResultsAsync() // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization - var expectedResponse = new ChatResponse([ + var expectedResponse = new ChatResponse( new(new ChatRole("fakeRole"), "This is some content") { AdditionalProperties = new() { ["a"] = "b" }, @@ -52,8 +52,7 @@ public async Task CachesSuccessResultsAsync() ["arg5"] = false, ["arg6"] = null })] - } - ]) + }) { ResponseId = "someId", Usage = new() @@ -111,7 +110,7 @@ public async Task AllowsConcurrentCallsAsync() { innerCallCount++; await completionTcs.Task; - return new ChatResponse([new(ChatRole.Assistant, "Hello")]); + return new ChatResponse(new(ChatRole.Assistant, "Hello")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -185,7 +184,7 @@ public async Task DoesNotCacheCanceledResultsAsync() await resolutionTcs.Task; } - return new ChatResponse([new(ChatRole.Assistant, "A good result")]); + return new ChatResponse(new(ChatRole.Assistant, "A good result")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -217,13 +216,6 @@ public async Task StreamingCachesSuccessResultsAsync() // even if this involves serialization List actualUpdate = [ - new() - { - Role = new ChatRole("fakeRole1"), - ChoiceIndex = 1, - AdditionalProperties = new() { ["a"] = "b" }, - Contents = [new TextContent("Chunk1")] - }, new() { Role = new ChatRole("fakeRole2"), @@ -243,13 +235,6 @@ public async Task StreamingCachesSuccessResultsAsync() Contents = [new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" })], }, new() - { - Role = new ChatRole("fakeRole1"), - ChoiceIndex = 1, - AdditionalProperties = new() { ["a"] = "b" }, - Contents = [new TextContent("Chunk1")] - }, - new() { Contents = [new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 })], }, @@ -539,7 +524,7 @@ public async Task CacheKeyVariesByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + return new(new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -590,7 +575,7 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + return new(new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new CachingChatClientWithCustomKey(testClient, _storage) @@ -618,13 +603,12 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() public async Task CanCacheCustomContentTypesAsync() { // Arrange - var expectedResponse = new ChatResponse([ + var expectedResponse = new ChatResponse( new(new ChatRole("fakeRole"), [ new CustomAIContent1("Hello", DateTime.Now), new CustomAIContent2("Goodbye", 42), - ]) - ]); + ])); var serializerOptions = new JsonSerializerOptions(TestJsonSerializerContext.Default.Options); serializerOptions.TypeInfoResolver = serializerOptions.TypeInfoResolver!.WithAddedModifier(typeInfo => @@ -678,8 +662,8 @@ public async Task CanResolveIDistributedCacheFromDI() { GetResponseAsyncCallback = delegate { - return Task.FromResult(new ChatResponse([ - new(ChatRole.Assistant, [new TextContent("Hey")])])); + return Task.FromResult(new ChatResponse( + new(ChatRole.Assistant, [new TextContent("Hey")]))); } }; using var outer = testClient @@ -739,33 +723,29 @@ private static void AssertResponsesEqual(ChatResponse expected, ChatResponse act Assert.Equal( JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); - Assert.Equal(expected.Choices.Count, actual.Choices.Count); - for (var i = 0; i < expected.Choices.Count; i++) + Assert.IsType(expected.Message.GetType(), actual.Message); + Assert.Equal(expected.Message.Role, actual.Message.Role); + Assert.Equal(expected.Message.Text, actual.Message.Text); + Assert.Equal(expected.Message.Contents.Count, actual.Message.Contents.Count); + + for (var itemIndex = 0; itemIndex < expected.Message.Contents.Count; itemIndex++) { - Assert.IsType(expected.Choices[i].GetType(), actual.Choices[i]); - Assert.Equal(expected.Choices[i].Role, actual.Choices[i].Role); - Assert.Equal(expected.Choices[i].Text, actual.Choices[i].Text); - Assert.Equal(expected.Choices[i].Contents.Count, actual.Choices[i].Contents.Count); + var expectedItem = expected.Message.Contents[itemIndex]; + var actualItem = actual.Message.Contents[itemIndex]; + Assert.IsType(expectedItem.GetType(), actualItem); - for (var itemIndex = 0; itemIndex < expected.Choices[i].Contents.Count; itemIndex++) + if (expectedItem is FunctionCallContent expectedFcc) { - var expectedItem = expected.Choices[i].Contents[itemIndex]; - var actualItem = actual.Choices[i].Contents[itemIndex]; - Assert.IsType(expectedItem.GetType(), actualItem); - - if (expectedItem is FunctionCallContent expectedFcc) - { - var actualFcc = (FunctionCallContent)actualItem; - Assert.Equal(expectedFcc.Name, actualFcc.Name); - Assert.Equal(expectedFcc.CallId, actualFcc.CallId); - - // The correct JSON-round-tripping of AIContent/AIContent is not - // the responsibility of CachingChatClient, so not testing that here. - Assert.Equal( - JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), - JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); - } + var actualFcc = (FunctionCallContent)actualItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); } } } @@ -780,7 +760,6 @@ private static async Task AssertResponsesEqualAsync(IReadOnlyList _keepMessagesConfigure = - b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = true }); - [Fact] public void InvalidArgs_Throws() { @@ -37,7 +34,6 @@ public void Ctor_HasExpectedDefaults() Assert.False(client.AllowConcurrentInvocation); Assert.False(client.IncludeDetailedErrors); - Assert.True(client.KeepFunctionCallingContent); Assert.Null(client.MaximumIterationsPerRequest); Assert.False(client.RetryOnError); } @@ -67,9 +63,9 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() new ChatMessage(ChatRole.Assistant, "world"), ]; - await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertAsync(options, plan); - await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -115,7 +111,7 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = concurrentInvocation, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = concurrentInvocation }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -156,7 +152,7 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = true, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = true }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -199,68 +195,13 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() new ChatMessage(ChatRole.Assistant, "done"), ]; - await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); - - await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunctionCallingMessages) - { - var options = new ChatOptions - { - Tools = - [ - AIFunctionFactory.Create(() => "Result 1", "Func1"), - AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), - AIFunctionFactory.Create((int i) => { }, "VoidReturn"), - ] - }; - - List plan = - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ]; - - List? expected = keepFunctionCallingMessages ? null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "world") - ]; - - Func configure = b => b.Use( - client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = keepFunctionCallingMessages }); + await InvokeAndAssertAsync(options, plan); - Validate(await InvokeAndAssertAsync(options, plan, expected, configure)); - Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure)); - - void Validate(List finalChat) - { - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else - { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); - } - } + await InvokeAndAssertStreamingAsync(options, plan); } - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) + [Fact] + public async Task KeepsFunctionCallingContent() { var options = new ChatOptions { @@ -285,18 +226,12 @@ public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctio ]; Func configure = b => b.Use( - client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = keepFunctionCallingMessages }); + client => new FunctionInvokingChatClient(client)); #pragma warning disable SA1005, S125 - Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Assistant, "more"), - new ChatMessage(ChatRole.Assistant, "world"), - ], configure)); + Validate(await InvokeAndAssertAsync(options, plan, null, configure)); - Validate(await InvokeAndAssertStreamingAsync(options, plan, keepFunctionCallingMessages ? + Validate(await InvokeAndAssertStreamingAsync(options, plan, [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), @@ -306,23 +241,12 @@ public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctio new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), - ] : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), ], configure)); - void Validate(List finalChat) + static void Validate(List finalChat) { IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else - { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); - } + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); } } @@ -348,51 +272,13 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { IncludeDetailedErrors = detailedErrors, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { IncludeDetailedErrors = detailedErrors }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } - [Fact] - public async Task RejectsMultipleChoicesAsync() - { - var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); - var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); - - var expected = new ChatResponse( - [ - new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Name)]), - new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Name)]), - ]); - - using var innerClient = new TestChatClient - { - GetResponseAsyncCallback = async (chatContents, options, cancellationToken) => - { - await Task.Yield(); - return expected; - }, - GetStreamingResponseAsyncCallback = (chatContents, options, cancellationToken) => - YieldAsync(expected.ToChatResponseUpdates()), - }; - - IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build(); - - List chat = [new ChatMessage(ChatRole.User, "hello")]; - ChatOptions options = new() { Tools = [func1, func2] }; - - Validate(await Assert.ThrowsAsync(() => service.GetResponseAsync(chat, options))); - Validate(await Assert.ThrowsAsync(() => service.GetStreamingResponseAsync(chat, options).ToChatResponseAsync())); - - void Validate(Exception ex) - { - Assert.Contains("only accepts a single choice", ex.Message); - Assert.Single(chat); // It didn't add anything to the chat history - } - } - [Theory] [InlineData(LogLevel.Trace)] [InlineData(LogLevel.Debug)] @@ -413,10 +299,7 @@ public async Task FunctionInvocationsLogged(LogLevel level) }; Func configure = b => - b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>()) - { - KeepFunctionCallingContent = true, - }); + b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())); await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services)); @@ -472,10 +355,7 @@ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) }; Func configure = b => b.Use(c => - new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName)) - { - KeepFunctionCallingContent = true, - }); + new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName))); await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure)); @@ -542,7 +422,7 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() } }; - using var client = new FunctionInvokingChatClient(innerClient) { KeepFunctionCallingContent = true }; + using var client = new FunctionInvokingChatClient(innerClient); var updates = new List(); await foreach (var update in client.GetStreamingResponseAsync(messages, options, CancellationToken.None)) @@ -616,7 +496,7 @@ await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [ // The last message is the one returned by the chat client // This message's content should contain the last function call before the termination new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary { ["i"] = 42 })]), - ], configurePipeline: _keepMessagesConfigure)); + ])); await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ .. planBeforeTermination, @@ -624,7 +504,7 @@ await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ // The last message is the one returned by the chat client // When streaming, function call content is removed from this message new ChatMessage(ChatRole.Assistant, []), - ], configurePipeline: _keepMessagesConfigure)); + ])); // The current context should be null outside the async call stack for the function invocation Assert.Null(FunctionInvokingChatClient.CurrentContext); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs index 721768a5e08..c6bf16b3bf2 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -56,7 +56,7 @@ public async Task GetResponseAsync_LogsResponseInvocationAndCompletion(LogLevel { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - return Task.FromResult(new ChatResponse([new(ChatRole.Assistant, "blue whale")])); + return Task.FromResult(new ChatResponse(new(ChatRole.Assistant, "blue whale"))); }, }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs index 3f94a47b7bd..1d8ff32693d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs @@ -99,7 +99,7 @@ public async Task GetResponseFunc_ContextPropagated() Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Choices[0].Text += " world"; + cc.Message.Text += " world"; return cc; }, null) .Build(); @@ -202,7 +202,7 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Choices[0].Text += " world (non-streaming)"; + cc.Message.Text += " world (non-streaming)"; return cc; }, (chatMessages, options, innerClient, cancellationToken) => From 144eac4a00cb6a93b085567bf1eb3c83603915b1 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 27 Feb 2025 20:46:14 -0500 Subject: [PATCH 2/9] Change IChatClient contract to add response content to message list This has a bunch of benefits: - It reduces what a caller needs to do. They don't need to add the response content to the history, because that's handled for them by the client they're using. This is particularly helpful for streaming responses. - It keeps stateless usage (where the caller provides the full list of messages) and stateful usage (where the caller provides a thread id and then only additional messages) more similar, as in the stateful case the service is already tracking all the content and the caller shouldn't be adding anything to history. - It fixes the ordering of messages in automatic function invocation, where it can now fully manage the history list, even when streaming, because the caller is not responsible for adding the streaming content into the history and thus there's no concern about ordering between the consumer adds (nothing) and what the implementation adds. - It ensures in automatic function invocation that all content is sent back to inner client, because it's the inner client that added the content in the first place. - It enables all content to be yielded from the function invoking client, including content it creates that has a different role (tool) from the other streaming content (assistant), which then enables consumers to use that knowledge for things like keeping UIs up to date. --- .../ChatResponseUpdateExtensions.cs | 31 +++- .../ChatCompletion/IChatClient.cs | 27 ++- .../Embeddings/IEmbeddingGenerator.cs | 3 + .../README.md | 164 ++++-------------- .../AzureAIInferenceChatClient.cs | 6 + .../OllamaChatClient.cs | 10 +- .../OpenAIChatClient.cs | 7 +- ...nAIModelMappers.StreamingChatCompletion.cs | 7 + .../ChatCompletion/CachingChatClient.cs | 18 +- .../FunctionInvokingChatClient.cs | 131 +++++++------- .../ChatClientIntegrationTests.cs | 3 - .../FunctionInvokingChatClientTests.cs | 79 ++++----- 12 files changed, 234 insertions(+), 252 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs index 3fb2eb139b5..f552ef1aec0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs @@ -18,12 +18,39 @@ namespace Microsoft.Extensions.AI; /// public static class ChatResponseUpdateExtensions { + /// Combines instances into a single . + /// The updates to be combined. + /// + /// to attempt to coalesce contiguous items, where applicable, + /// into a single , in order to reduce the number of individual content items that are included in + /// the manufactured instance. When , the original content items are used. + /// The default is . + /// + /// The combined . + public static ChatMessage ToChatMessage( + this IEnumerable updates, bool coalesceContent = true) => + ToChatResponse(updates, coalesceContent).Message; // TO DO: More efficient implementation + + /// Combines instances into a single . + /// The updates to be combined. + /// + /// to attempt to coalesce contiguous items, where applicable, + /// into a single , in order to reduce the number of individual content items that are included in + /// the manufactured instance. When , the original content items are used. + /// The default is . + /// + /// The to monitor for cancellation requests. The default is . + /// The combined . + public static async Task ToChatMessageAsync( + this IAsyncEnumerable updates, bool coalesceContent = true, CancellationToken cancellationToken = default) => + (await ToChatResponseAsync(updates, coalesceContent, cancellationToken).ConfigureAwait(false)).Message; // TO DO: More efficient implementation + /// Combines instances into a single . /// The updates to be combined. /// /// to attempt to coalesce contiguous items, where applicable, /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instances. When , the original content items are used. + /// the manufactured instance. When , the original content items are used. /// The default is . /// /// The combined . @@ -49,7 +76,7 @@ public static ChatResponse ToChatResponse( /// /// to attempt to coalesce contiguous items, where applicable, /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instances. When , the original content items are used. + /// the manufactured instance. When , the original content items are used. /// The default is . /// /// The to monitor for cancellation requests. The default is . diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 26a39f05105..ce99f126b3d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -26,13 +26,17 @@ namespace Microsoft.Extensions.AI; public interface IChatClient : IDisposable { /// Sends chat messages and returns the response. - /// The chat content to send. - /// The chat options to configure the request. + /// The list of chat messages to send and to be augmented with generated messages. + /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . /// - /// The returned messages aren't added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, are included. + /// The response message generated by is both returned from the method as well as automatically + /// added into . Any intermediate messages generated implicitly as part of the interaction are + /// also added to the chat history. For example, if as part of satisfying this request, the method + /// itself issues multiple requests to one or more underlying instances, all of those messages will also + /// be included in . /// Task GetResponseAsync( IList chatMessages, @@ -40,13 +44,17 @@ Task GetResponseAsync( CancellationToken cancellationToken = default); /// Sends chat messages and streams the response. - /// The chat content to send. - /// The chat options to configure the request. + /// The list of chat messages to send and to be augmented with generated messages. + /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. + /// is . /// - /// The returned messages aren't added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, are included. + /// The response updates generated by are both stream from the method as well as automatically + /// added into . Any intermediate messages generated implicitly as part of the interaction are + /// also added to the chat history. For example, if as part of satisfying this request, the method + /// itself issues multiple requests to one or more underlying instances, all of those messages will also + /// be included in . /// IAsyncEnumerable GetStreamingResponseAsync( IList chatMessages, @@ -60,7 +68,8 @@ IAsyncEnumerable GetStreamingResponseAsync( /// is . /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , - /// including itself or any services it might be wrapping. + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. /// object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index c260708079c..531b8ceeeb5 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -32,6 +32,7 @@ public interface IEmbeddingGenerator : IDisposable /// The embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embeddings. + /// is . Task> GenerateAsync( IEnumerable values, EmbeddingGenerationOptions? options = null, @@ -45,6 +46,8 @@ Task> GenerateAsync( /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the /// , including itself or any services it might be wrapping. + /// For example, to access the for the instance, may + /// be used to request it. /// object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md index b8a6cba944f..401002c82fd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md @@ -27,105 +27,24 @@ of the abstractions. ### `IChatClient` -The `IChatClient` interface defines a client abstraction responsible for interacting with AI services that provide "chat" capabilities. It defines methods for sending and receiving messages comprised of multi-modal content (text, images, audio, etc.), with responses being either as a complete result or streamed incrementally. Additionally, it allows for retrieving strongly-typed services that may be provided by the client or its underlying services. - -#### Sample Implementation +The `IChatClient` interface defines a client abstraction responsible for interacting with AI services that provide "chat" capabilities. It defines methods for sending and receiving messages comprised of multi-modal content (text, images, audio, etc.), with responses providing either a complete result or one streamed incrementally. Additionally, it allows for retrieving strongly-typed services that may be provided by the client or its underlying services. .NET libraries that provide clients for language models and services may provide an implementation of the `IChatClient` interface. Any consumers of the interface are then able to interoperate seamlessly with these models and services via the abstractions. -Here is a sample implementation of an `IChatClient` to show the general structure. - -```csharp -using System.Runtime.CompilerServices; -using Microsoft.Extensions.AI; - -public class SampleChatClient : IChatClient -{ - private readonly ChatClientMetadata _metadata; - - public SampleChatClient(Uri endpoint, string modelId) => - _metadata = new("SampleChatClient", endpoint, modelId); - - public async Task GetResponseAsync( - IList chatMessages, - ChatOptions? options = null, - CancellationToken cancellationToken = default) - { - // Simulate some operation. - await Task.Delay(300, cancellationToken); - - // Return a sample chat response randomly. - string[] responses = - [ - "This is the first sample response.", - "Here is another example of a response message.", - "This is yet another response message." - ]; - - return new(new ChatMessage() - { - Role = ChatRole.Assistant, - Text = responses[Random.Shared.Next(responses.Length)], - }); - } - - public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, - ChatOptions? options = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - // Simulate streaming by yielding messages one by one. - string[] words = ["This ", "is ", "the ", "response ", "for ", "the ", "request."]; - foreach (string word in words) - { - // Simulate some operation. - await Task.Delay(100, cancellationToken); - - // Yield the next message in the response. - yield return new ChatResponseUpdate - { - Role = ChatRole.Assistant, - Text = word, - }; - } - } - - object? IChatClient.GetService(Type serviceType, object? serviceKey = null) => - serviceKey is not null ? null : - serviceType == typeof(ChatClientMetadata) ? _metadata : - serviceType?.IsInstanceOfType(this) is true ? this : - null; - - void IDisposable.Dispose() { } -} -``` - -As further examples, you can find other concrete implementations in the following packages (but many more such implementations for a large variety of services are available on NuGet): - -- [Microsoft.Extensions.AI.AzureAIInference](https://aka.ms/meai-azaiinference-nuget) -- [Microsoft.Extensions.AI.OpenAI](https://aka.ms/meai-openai-nuget) -- [Microsoft.Extensions.AI.Ollama](https://aka.ms/meai-ollama-nuget) - #### Requesting a Chat Response: `GetResponseAsync` With an instance of `IChatClient`, the `GetResponseAsync` method may be used to send a request and get a response. The request is composed of one or more messages, each of which is composed of one or more pieces of content. Accelerator methods exist to simplify common cases, such as constructing a request for a single piece of text content. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); +IChatClient client = ...; -var response = await client.GetResponseAsync("What is AI?"); - -Console.WriteLine(response.Message); +Console.WriteLine(await client.GetResponseAsync("What is AI?")); ``` The core `GetResponseAsync` method on the `IChatClient` interface accepts a list of messages. This list represents the history of all messages that are part of the conversation. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); +IChatClient client = ...; Console.WriteLine(await client.GetResponseAsync( [ @@ -134,7 +53,7 @@ Console.WriteLine(await client.GetResponseAsync( ])); ``` -The `ChatResponse` that's returned from `GetResponseAsync` exposes a `ChatMessage` representing the response message. It may be added back into the history in order to provide this response back to the service in a subsequent request, e.g. +The `ChatResponse` that's returned from `GetResponseAsync` exposes a `ChatMessage` representing the response message. It is automatically added into the history by the `IChatClient`, so that it'll be provided back to the service in a subsequent request, e.g. ```csharp List history = []; @@ -143,10 +62,7 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - ChatResponse response = await client.GetResponseAsync(history); - - Console.WriteLine(response); - history.Add(response.Message); + Console.WriteLine(await client.GetResponseAsync(history)); } ``` @@ -155,9 +71,7 @@ while (true) The inputs to `GetStreamingResponseAsync` are identical to those of `GetResponseAsync`. However, rather than returning the complete response as part of a `ChatResponse` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); +IChatClient client = ...; await foreach (var update in client.GetStreamingResponseAsync("What is AI?")) { @@ -165,46 +79,41 @@ await foreach (var update in client.GetStreamingResponseAsync("What is AI?")) } ``` -Such a stream of response updates may be combined into a single response object via the `ToChatResponse` and `ToChatResponseAsync` helper methods, e.g. - +As with `GetResponseAsync`, the `IChatClient.GetStreamingResponseAsync` implementation is responsible for adding +the response message back into the history, so that it'll be provided back to the service in a subsequent request. ```csharp List history = []; -List updates = []; while (true) { Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - updates.Clear(); await foreach (var update in client.GetStreamingResponseAsync(history)) { Console.Write(update); - updates.Add(update); } - history.Add(updates.ToChatResponse().Message)); + Console.WriteLine(); } ``` #### Tool Calling -Some models and services support the notion of tool calling, where requests may include information about tools (in particular .NET methods) that the model may request be invoked in order to gather additional information. Rather than sending back a response message that represents the final response to the input, the model sends back a request to invoke a given function with a given set of arguments; the client may then find and invoke the relevant function and send back the results to the model (along with all the rest of the history). The abstractions in Microsoft.Extensions.AI include representations for various forms of content that may be included in messages, and this includes representations for these function call requests and results. While it's possible for the consumer of the `IChatClient` to interact with this content directly, `Microsoft.Extensions.AI` supports automating these interactions. It provides an `AIFunction` that represents an invocable function along with metadata for describing the function to the AI model, along with an `AIFunctionFactory` for creating `AIFunction`s to represent .NET methods. It also provides a `FunctionInvokingChatClient` that both is an `IChatClient` and also wraps an `IChatClient`, enabling layering automatic function invocation capabilities around an arbitrary `IChatClient` implementation. +Some models and services support the notion of tool calling, where requests may include information about tools (in particular .NET methods) that the model may request be invoked in order to gather additional information. Rather than sending back a response message that represents the final response to the input, the model sends back a request to invoke a given function with a given set of arguments; the client may then find and invoke the relevant function and send back the results to the model (along with all the rest of the history). The abstractions in `Microsoft.Extensions.AI` include representations for various forms of content that may be included in messages, and this includes representations for these function call requests and results. While it's possible for the consumer of the `IChatClient` to interact with this content directly, `Microsoft.Extensions.AI` supports automating these interactions. It provides an `AIFunction` that represents an invocable function along with metadata for describing the function to the AI model, as well as an `AIFunctionFactory` for creating `AIFunction`s to represent .NET methods. It also provides a `FunctionInvokingChatClient` that both is an `IChatClient` and also wraps an `IChatClient`, enabling layering automatic function invocation capabilities around an arbitrary `IChatClient` implementation. ```csharp -using System.ComponentModel; using Microsoft.Extensions.AI; -[Description("Gets the current weather")] string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; -IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") + .AsBuilder() .UseFunctionInvocation() .Build(); -var response = client.GetStreamingResponseAsync( - "Should I wear a rain coat?", - new() { Tools = [AIFunctionFactory.Create(GetCurrentWeather)] }); +ChatOptions options = new() { Tools = [AIFunctionFactory.Create(GetCurrentWeather)] }; +var response = client.GetStreamingResponseAsync("Should I wear a rain coat?", options); await foreach (var update in response) { Console.Write(update); @@ -221,7 +130,7 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; -IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))) .Build(); @@ -252,7 +161,7 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() .AddConsoleExporter() .Build(); -IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseOpenTelemetry(sourceName: sourceName, configure: c => c.EnableSensitiveData = true) .Build(); @@ -269,7 +178,8 @@ Options may also be baked into an `IChatClient` via the `ConfigureOptions` exten ```csharp using Microsoft.Extensions.AI; -IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"))) +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434")) + .AsBuilder() .ConfigureOptions(options => options.ModelId ??= "phi3") .Build(); @@ -372,7 +282,7 @@ using Microsoft.Extensions.AI; using System.Threading.RateLimiting; var client = new RateLimitingChatClient( - new SampleChatClient(new Uri("http://localhost"), "test"), + new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"), new ConcurrencyLimiter(new() { PermitLimit = 1, QueueLimit = int.MaxValue })); await client.GetResponseAsync("What color is the sky?"); @@ -398,7 +308,7 @@ public static class RateLimitingChatClientExtensions The consumer can then easily use this in their pipeline, e.g. ```csharp -var client = new SampleChatClient(new Uri("http://localhost"), "test") +var client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") .AsBuilder() .UseDistributedCache() .UseRateLimiting() @@ -412,7 +322,7 @@ need to do work before and after delegating to the next client in the pipeline. be used that accepts a delegate which is used for both `GetResponseAsync` and `GetStreamingResponseAsync`, reducing the boilderplate required: ```csharp RateLimiter rateLimiter = ...; -var client = new SampleChatClient(new Uri("http://localhost"), "test") +var client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") .AsBuilder() .UseDistributedCache() .Use(async (chatMessages, options, nextAsync, cancellationToken) => @@ -443,7 +353,7 @@ using Microsoft.Extensions.Hosting; // App Setup var builder = Host.CreateApplicationBuilder(); builder.Services.AddDistributedMemoryCache(); -builder.Services.AddChatClient(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +builder.Services.AddChatClient(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseDistributedCache(); var host = builder.Build(); @@ -459,7 +369,9 @@ What instance and configuration is injected may differ based on the current need "Stateless" services require all relevant conversation history to sent back on every request, while "stateful" services keep track of the history and instead require only additional messages be sent with a request. The `IChatClient` interface is designed to handle both stateless and stateful AI services. -If you know you're working with a stateless service (currently the most common form), responses may be added back into a message history for sending back to the server. +When working with a stateless service, the `GetResponseAsync` and `GetStreamingResponseAsync` methods will automatically add the response message back +into the history. The client can then simply pass the same list of messages to a subsequent request, as that list will contain all +of the context necessary to enable the next request. ```csharp List history = []; while (true) @@ -467,23 +379,20 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - ChatResponse response = await client.GetResponseAsync(history); - - Console.WriteLine(response); - history.Add(response.Message); + Console.WriteLine(await client.GetResponseAsync(history)); } ``` -For stateful services, you may know ahead of time an identifier used for the relevant conversation. That identifier can be put into `ChatOptions.ChatThreadId`: +For stateful services, you may know ahead of time an identifier used for the relevant conversation. That identifier can be put into `ChatOptions.ChatThreadId`. +Usage then follows the same pattern: ```csharp ChatOptions options = new() { ChatThreadId = "my-conversation-id" }; while (true) { Console.Write("Q: "); + ChatMessage message = new(ChatRole.User, Console.ReadLine()); - ChatResponse response = await client.GetResponseAsync(Console.ReadLine(), options); - - Console.WriteLine(response); + Console.WriteLine(await client.GetResponseAsync(message, options)); } ``` @@ -494,10 +403,11 @@ ChatOptions options = new(); while (true) { Console.Write("Q: "); + ChatMessage message = new(ChatRole.User, Console.ReadLine()); - ChatResponse response = await client.GetResponseAsync(Console.ReadLine(), options); - + ChatResponse response = await client.GetResponseAsync(message, options); Console.WriteLine(response); + options.ChatThreadId = response.ChatThreadId; } ``` @@ -515,17 +425,13 @@ while (true) history.Add(new(ChatRole.User, Console.ReadLine())); ChatResponse response = await client.GetResponseAsync(history); - Console.WriteLine(response); + options.ChatThreadId = response.ChatThreadId; if (response.ChatThreadId is not null) { history.Clear(); } - else - { - history.Add(response.Message); - } } ``` diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 7a44dbf3e57..760c21f69aa 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -128,6 +128,7 @@ public async Task GetResponseAsync( } // Wrap the content in a ChatResponse to return. + chatMessages.Add(message); return new ChatResponse(message) { CreatedAt = response.Created, @@ -152,6 +153,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( DateTimeOffset? createdAt = null; string? modelId = null; string lastCallId = string.Empty; + List responseUpdates = []; // Process each update as it arrives var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); @@ -220,6 +222,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } // Now yield the item. + responseUpdates.Add(responseUpdate); yield return responseUpdate; } @@ -248,8 +251,11 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } } + responseUpdates.Add(responseUpdate); yield return responseUpdate; } + + chatMessages.Add(responseUpdates.ToChatMessage()); } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index ae18a430c45..2c3fb6bf709 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -102,7 +102,11 @@ public async Task GetResponseAsync(IList chatMessages throw new InvalidOperationException($"Ollama error: {response.Error}"); } - return new(FromOllamaMessage(response.Message!)) + var responseMessage = FromOllamaMessage(response.Message!); + + chatMessages.Add(responseMessage); + + return new(responseMessage) { CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(response), @@ -137,6 +141,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( #endif .ConfigureAwait(false); + List updates = []; using var streamReader = new StreamReader(httpResponseStream); #if NET while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line) @@ -186,8 +191,11 @@ public async IAsyncEnumerable GetStreamingResponseAsync( update.Contents.Add(new UsageContent(usage)); } + updates.Add(update); yield return update; } + + chatMessages.Add(updates.ToChatMessage()); } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index ba584cc1734..a671fb0e769 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -110,7 +110,10 @@ public async Task GetResponseAsync( // Make the call to OpenAI. var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false); - return OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions); + ChatResponse chatResponse = OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions); + chatMessages.Add(chatResponse.Message); + + return chatResponse; } /// @@ -125,7 +128,7 @@ public IAsyncEnumerable GetStreamingResponseAsync( // Make the call to OpenAI. var chatCompletionUpdates = _chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken); - return OpenAIModelMappers.FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, cancellationToken); + return OpenAIModelMappers.FromOpenAIStreamingChatCompletionAsync(chatMessages, chatCompletionUpdates, cancellationToken); } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs index bfafbdf82b2..4e21ddff022 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs @@ -62,6 +62,7 @@ public static async IAsyncEnumerable ToOpenAIStre } public static async IAsyncEnumerable FromOpenAIStreamingChatCompletionAsync( + IList chatMessages, IAsyncEnumerable updates, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -74,6 +75,8 @@ public static async IAsyncEnumerable FromOpenAIStreamingChat string? modelId = null; string? fingerprint = null; + List responseUpdates = []; + // Process each update as it arrives await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) { @@ -158,6 +161,7 @@ public static async IAsyncEnumerable FromOpenAIStreamingChat } // Now yield the item. + responseUpdates.Add(responseUpdate); yield return responseUpdate; } @@ -199,7 +203,10 @@ public static async IAsyncEnumerable FromOpenAIStreamingChat (responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint; } + responseUpdates.Add(responseUpdate); yield return responseUpdate; } + + chatMessages.Add(responseUpdates.ToChatMessage()); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 79f41d1790e..af1f41186ea 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -54,7 +54,11 @@ public override async Task GetResponseAsync(IList cha // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } result) + { + chatMessages.Add(result.Message); + } + else { result = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); @@ -83,6 +87,11 @@ public override async IAsyncEnumerable GetStreamingResponseA { yield return chunk; } + + if (chatResponse.ChatThreadId is null) + { + chatMessages.Add(chatResponse.Message); + } } else { @@ -104,10 +113,17 @@ public override async IAsyncEnumerable GetStreamingResponseA if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) { // Yield all of the cached items. + string? chatThreadId = null; foreach (var chunk in existingChunks) { + chatThreadId ??= chunk.ChatThreadId; yield return chunk; } + + if (chatThreadId is null) + { + chatMessages.Add(existingChunks.ToChatMessage()); + } } else { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index af93f59485c..48166f78df9 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -13,6 +13,7 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable CA2213 // Disposable fields should be disposed +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test namespace Microsoft.Extensions.AI; @@ -186,10 +187,13 @@ public override async Task GetResponseAsync(IList cha ChatResponse? response = null; UsageDetails? totalUsage = null; IList originalChatMessages = chatMessages; + List? functionCallContents = null; try { for (int iteration = 0; ; iteration++) { + functionCallContents?.Clear(); + // Make the call to the handler. response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); @@ -209,14 +213,14 @@ public override async Task GetResponseAsync(IList cha } // Extract any function call contents. If there are none, we're done. - FunctionCallContent[] functionCallContents = response.Message.Contents.OfType().ToArray(); - if (functionCallContents.Length == 0) + CopyFunctionCalls(response.Message.Contents, ref functionCallContents); + if (functionCallContents is not { Count: > 0 }) { break; } - // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending - // what we already sent as well as this response message, so create a new list to store the response message(s). + // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. + // In that case, we also avoid touching the user's history, so that we don't need to clear it. if (response.ChatThreadId is not null) { if (chatMessages == originalChatMessages) @@ -228,11 +232,6 @@ public override async Task GetResponseAsync(IList cha chatMessages.Clear(); } } - else - { - // Add the original response message into the history. - chatMessages.Add(response.Message); - } // Add the responses from the function calls into the history. var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); @@ -264,53 +263,33 @@ public override async IAsyncEnumerable GetStreamingResponseA // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - List functionCallContents = []; + List? functionCallContents = null; IList originalChatMessages = chatMessages; for (int iteration = 0; ; iteration++) { + functionCallContents?.Clear(); string? chatThreadId = null; - functionCallContents.Clear(); + await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) { - // We're going to emit all ChatResponseUpdates upstream, even ones that contain function call - // content, because a given ChatResponseUpdate can contain other content/metadata. But if we - // yield the function calls, and the consumer adds all the content into a message that's then - // added into history, they'll end up with function call contents that aren't directly paired - // with function result contents, which may cause issues for some models when the history is - // later sent again. We thus remove the FunctionCallContent instances from the updates before - // yielding them, tracking those FunctionCallContents separately so they can be processed and - // added to the chat history. - - // Find all the FCCs. We need to track these separately in order to be able to process them later. - int preFccCount = functionCallContents.Count; - functionCallContents.AddRange(update.Contents.OfType()); - - // If there were any, remove them from the update. We do this before yielding the update so - // that we're not modifying an instance already provided back to the caller. - int addedFccs = functionCallContents.Count - preFccCount; - if (addedFccs > 0) - { - update.Contents = addedFccs == update.Contents.Count ? - [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); - } - chatThreadId ??= update.ChatThreadId; + CopyFunctionCalls(update.Contents, ref functionCallContents); yield return update; Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } // If there are no tools to call, or for any other reason we should stop, return the response. - if (options is null + if (functionCallContents is not { Count: > 0 } + || options is null || options.Tools is not { Count: > 0 } - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) - || functionCallContents is not { Count: > 0 }) + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) { break; } - // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending - // what we already sent as well as this response message, so create a new list to store the response message(s). + // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. + // In that case, we also avoid touching the user's history, so that we don't need to clear it. if (chatThreadId is not null) { if (chatMessages == originalChatMessages) @@ -322,11 +301,6 @@ public override async IAsyncEnumerable GetStreamingResponseA chatMessages.Clear(); } } - else - { - // Add a manufactured response message containing the function call contents to the chat history. - chatMessages.Add(new(ChatRole.Assistant, [.. functionCallContents])); - } // Process all of the functions, adding their results into the history. var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); @@ -335,6 +309,38 @@ public override async IAsyncEnumerable GetStreamingResponseA // Terminate yield break; } + + // Stream any generated function results. These are already part of the history, + // but we stream them out for informational purposes. + foreach (var message in modeAndMessages.MessagesAdded) + { + var toolResultUpdate = new ChatResponseUpdate + { + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + ChatThreadId = chatThreadId, + CreatedAt = DateTimeOffset.UtcNow, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + Role = message.Role, + }; + + yield return toolResultUpdate; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + } + } + + /// Copies any from to . + private static void CopyFunctionCalls(IList content, ref List? functionCalls) + { + int count = content.Count; + for (int i = 0; i < count; i++) + { + if (content[i] is FunctionCallContent functionCall) + { + (functionCalls ??= []).Add(functionCall); + } } } @@ -397,18 +403,18 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - IList chatMessages, ChatOptions options, IReadOnlyList functionCallContents, int iteration, CancellationToken cancellationToken) + IList chatMessages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - int functionCount = functionCallContents.Count; - Debug.Assert(functionCount > 0, $"Expecteded {nameof(functionCount)} to be > 0, got {functionCount}."); + Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. - if (functionCount == 1) + if (functionCallContents.Count == 1) { - FunctionInvocationResult result = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[0], iteration, 0, 1, cancellationToken).ConfigureAwait(false); + FunctionInvocationResult result = await ProcessFunctionCallAsync( + chatMessages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); IList added = AddResponseMessages(chatMessages, [result]); return (result.ContinueMode, added); } @@ -420,16 +426,20 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti { // Schedule the invocation of every function. results = await Task.WhenAll( - from i in Enumerable.Range(0, functionCount) - select Task.Run(() => ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken))).ConfigureAwait(false); + from i in Enumerable.Range(0, functionCallContents.Count) + select Task.Run(() => ProcessFunctionCallAsync( + chatMessages, options, functionCallContents, + iteration, i, cancellationToken))).ConfigureAwait(false); } else { // Invoke each function serially. - results = new FunctionInvocationResult[functionCount]; - for (int i = 0; i < functionCount; i++) + results = new FunctionInvocationResult[functionCallContents.Count]; + for (int i = 0; i < results.Length; i++) { - results[i] = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken).ConfigureAwait(false); + results[i] = await ProcessFunctionCallAsync( + chatMessages, options, functionCallContents, + iteration, i, cancellationToken).ConfigureAwait(false); } } @@ -447,19 +457,20 @@ from i in Enumerable.Range(0, functionCount) } } - /// Processes the function call described in . + /// Processes the function call described in []. /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. - /// The function call content representing the function to be invoked. + /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. - /// The 0-based index of the function being called out of total functions. - /// The number of function call requests made, of which this is one. + /// The 0-based index of the function being called out of . /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - IList chatMessages, ChatOptions options, FunctionCallContent callContent, - int iteration, int functionCallIndex, int totalFunctionCount, CancellationToken cancellationToken) + IList chatMessages, ChatOptions options, List callContents, + int iteration, int functionCallIndex, CancellationToken cancellationToken) { + var callContent = callContents[functionCallIndex]; + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); if (function is null) @@ -474,7 +485,7 @@ private async Task ProcessFunctionCallAsync( Function = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, - FunctionCount = totalFunctionCount, + FunctionCount = callContents.Count, }; object? result; diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 581e21daaea..c6c12c4e192 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -88,9 +88,6 @@ public virtual async Task GetStreamingResponseAsync() string responseText = sb.ToString(); Assert.Contains("one small step", responseText, StringComparison.OrdinalIgnoreCase); Assert.Contains("one giant leap", responseText, StringComparison.OrdinalIgnoreCase); - - // The input list is left unaugmented. - Assert.Single(chatHistory); } [ConditionalFact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 81b81668108..f9143285431 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -225,23 +225,10 @@ public async Task KeepsFunctionCallingContent() new ChatMessage(ChatRole.Assistant, "world"), ]; - Func configure = b => b.Use( - client => new FunctionInvokingChatClient(client)); - #pragma warning disable SA1005, S125 - Validate(await InvokeAndAssertAsync(options, plan, null, configure)); + Validate(await InvokeAndAssertAsync(options, plan)); - Validate(await InvokeAndAssertStreamingAsync(options, plan, - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), - ], configure)); + Validate(await InvokeAndAssertStreamingAsync(options, plan)); static void Validate(List finalChat) { @@ -412,13 +399,23 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() GetStreamingResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) => { // If the conversation is just starting, issue two consecutive updates with function calls - // Otherwise just end the conversation - return chatContents.Last().Text == "Hello" - ? YieldAsync( - new ChatResponseUpdate { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary { ["text"] = "Input 1" })] }, - new ChatResponseUpdate { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary { ["text"] = "Input 2" })] }) - : YieldAsync( - new ChatResponseUpdate { Contents = [new TextContent("OK bye")] }); + // Otherwise just end the conversation. + List updates; + if (chatContents.Last().Text == "Hello") + { + updates = + [ + new() { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary { ["text"] = "Input 1" })] }, + new() { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary { ["text"] = "Input 2" })] } + ]; + } + else + { + updates = [new() { Contents = [new TextContent("OK bye")] }]; + } + + chatContents.Add(updates.ToChatMessage()); + return YieldAsync(updates); } }; @@ -438,12 +435,13 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() c => Assert.Equal("Input 2", Assert.IsType(c).Arguments!["text"])), m => Assert.Collection(m.Contents, c => Assert.Equal("Result for Input 1", Assert.IsType(c).Result?.ToString()), - c => Assert.Equal("Result for Input 2", Assert.IsType(c).Result?.ToString()))); + c => Assert.Equal("Result for Input 2", Assert.IsType(c).Result?.ToString())), + m => Assert.Equal("OK bye", Assert.IsType(Assert.Single(m.Contents)).Text)); - // The returned updates should *not* include the FCCs and FRCs + // The returned updates also include the FCCs and FRCs var allUpdateContents = updates.SelectMany(updates => updates.Contents).ToList(); - var singleUpdateContent = Assert.IsType(Assert.Single(allUpdateContents)); - Assert.Equal("OK bye", singleUpdateContent.Text); + Assert.Contains(allUpdateContents, c => c is FunctionCallContent); + Assert.Contains(allUpdateContents, c => c is FunctionResultContent); } [Fact] @@ -490,21 +488,9 @@ public async Task CanAccesssFunctionInvocationContextFromFunctionCall() new ChatMessage(ChatRole.Assistant, "world"), ]; - await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [ - .. planBeforeTermination, + await InvokeAsync(() => InvokeAndAssertAsync(options, plan, planBeforeTermination)); - // The last message is the one returned by the chat client - // This message's content should contain the last function call before the termination - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary { ["i"] = 42 })]), - ])); - - await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ - .. planBeforeTermination, - - // The last message is the one returned by the chat client - // When streaming, function call content is removed from this message - new ChatMessage(ChatRole.Assistant, []), - ])); + await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, planBeforeTermination)); // The current context should be null outside the async call stack for the function invocation Assert.Null(FunctionInvokingChatClient.CurrentContext); @@ -608,14 +594,16 @@ private static async Task> InvokeAndAssertAsync( var usage = CreateRandomUsage(); expectedTotalTokenCounts += usage.InputTokenCount!.Value; - return new ChatResponse(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])) { Usage = usage }; + + var message = new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents]); + contents.Add(message); + return new ChatResponse(message) { Usage = usage }; } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetResponseAsync(chat, options, cts.Token); - chat.Add(result.Message); expected ??= plan; Assert.NotNull(result); @@ -697,14 +685,15 @@ private static async Task> InvokeAndAssertStreamingAsync( { Assert.Equal(cts.Token, actualCancellationToken); - return YieldAsync(new ChatResponse(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToChatResponseUpdates()); + ChatMessage message = new(ChatRole.Assistant, [.. plan[contents.Count].Contents]); + contents.Add(message); + return YieldAsync(new ChatResponse(message).ToChatResponseUpdates()); } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetStreamingResponseAsync(chat, options, cts.Token).ToChatResponseAsync(); - chat.Add(result.Message); expected ??= plan; Assert.NotNull(result); @@ -743,7 +732,7 @@ private static async Task> InvokeAndAssertStreamingAsync( return chat; } - private static async IAsyncEnumerable YieldAsync(params T[] items) + private static async IAsyncEnumerable YieldAsync(params IEnumerable items) { await Task.Yield(); foreach (var item in items) From e52ca1cde6684378b8b8e09d2bb7c8617cd7b59f Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 28 Feb 2025 10:34:47 -0500 Subject: [PATCH 3/9] Change ChatResponse to carry multiple response messages Multiple messages could be generated in a single turn as part of server-side tool use, as part of agents scenarios, as part of automatic function calling, etc. This is now being added into the history, but in some cases the caller doesn't have a history object (e.g. they just provided a string prompt and a temporary history was created), in some stateful cases the history may not be updated because it's all being tracked by the service, in some cases middleware could have manipulated the history in a way that makes it challenging to know what the additional messages are, etc. Further, this makes the non-streaming and streaming cases more synonymous, as in the streaming case all chat response updates from all messages are being "returned". --- .../AdditionalPropertiesDictionary{TValue}.cs | 12 ++ .../ChatCompletion/ChatMessage.cs | 19 ++- .../ChatCompletion/ChatResponse.cs | 121 ++++++++++++--- .../ChatCompletion/ChatResponseUpdate.cs | 12 +- .../ChatResponseUpdateExtensions.cs | 93 ++++++++---- .../FunctionInvokingChatClient.cs | 142 +++++++++++------- src/Shared/Throw/Throw.cs | 18 +++ .../ChatCompletion/ChatResponseTests.cs | 58 ++++++- .../ChatResponseUpdateExtensionsTests.cs | 6 +- .../DistributedCachingChatClientTest.cs | 14 +- .../FunctionInvokingChatClientTests.cs | 42 ++++++ .../ChatCompletion/LoggingChatClientTests.cs | 2 +- 12 files changed, 415 insertions(+), 124 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs index 21d1daf2820..2835545c4d3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs @@ -107,6 +107,18 @@ public bool TryAdd(string key, TValue value) #endif } + /// Copies all of the entries from into the dictionary, overwriting any existing items in the dictionary with the same key. + /// The items to add. + public void SetAll(IEnumerable> items) + { + _ = Throw.IfNull(items); + + foreach (var item in items) + { + _dictionary[item.Key] = item.Value; + } + } + /// void ICollection>.Add(KeyValuePair item) => ((ICollection>)_dictionary).Add(item); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index 32e2159950c..ab8dd19e964 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -33,12 +34,17 @@ public ChatMessage(ChatRole role, string? content) /// Initializes a new instance of the class. /// The role of the author of the message. /// The contents for this message. + /// is . + /// must not be read-only. public ChatMessage( ChatRole role, IList contents) { + _ = Throw.IfNull(contents); + _ = Throw.IfReadOnly(contents); + Role = role; - _contents = Throw.IfNull(contents); + _contents = contents; } /// Clones the to a new instance. @@ -92,11 +98,20 @@ public string? Text } /// Gets or sets the chat message content items. + /// The must not be read-only. [AllowNull] public IList Contents { get => _contents ??= []; - set => _contents = value; + set + { + if (value is not null) + { + _ = Throw.IfReadOnly(value); + } + + _contents = value; + } } /// Gets or sets the raw representation of the chat message from an underlying implementation. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index b6344c6e8fe..94393fc1e13 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -9,30 +9,95 @@ namespace Microsoft.Extensions.AI; /// Represents the response to a chat request. +/// +/// provides one or more response messages and metadata about the response. +/// A typical response will contain a single message, however a response may contain multiple messages +/// in a variety of scenarios. For example, if automatic function calling is employed, such that a single +/// request to a may actually generate multiple roundtrips to an inner +/// it uses, all of the involved messages may be surfaced as part of the final . +/// The messages are ordered, such that returns the last message in the list. +/// public class ChatResponse { - /// The response message. - private ChatMessage _message; + /// The response messages. + private IList? _messages; /// Initializes a new instance of the class. public ChatResponse() { - _message = new(ChatRole.Assistant, []); } /// Initializes a new instance of the class. /// The response message. + /// is . public ChatResponse(ChatMessage message) { _ = Throw.IfNull(message); - _message = message; + Messages.Add(message); } - /// Gets or sets the chat response message. + /// Initializes a new instance of the class. + /// The response messages. + /// is . + /// must not be read-only. + public ChatResponse(IList messages) + { + _ = Throw.IfNull(messages); + _ = Throw.IfReadOnly(messages); + + _messages = messages; + } + + /// Gets or sets the chat response messages. + /// + /// The last message in the list maps to . It should represent + /// the final result message of the operation. + /// + /// The must not be read-only. + public IList Messages + { + get => _messages ??= new List(1); + set + { + if (value is not null) + { + _ = Throw.IfReadOnly(value); + } + + _messages = value; + } + } + + /// Gets or sets the last chat response message. + /// + /// When getting , if there are no messages, will add a new + /// empty message to the list and return that message; if there are messages, the last will be returned. + /// When setting , if there are messages, the last message will be replaced by the + /// newly set instance; if there are no messages, the newly set instance will be added to the list. + /// + [JsonIgnore] public ChatMessage Message { - get => _message; - set => _message = Throw.IfNull(value); + get + { + if (Messages.Count == 0) + { + Messages.Add(new ChatMessage(ChatRole.Assistant, [])); + } + + return Messages[Messages.Count - 1]; + } + set + { + if (Messages.Count > 0) + { + Messages[Messages.Count - 1] = value; + } + else + { + Messages.Add(value); + } + } } /// Gets or sets the ID of the chat response. @@ -73,7 +138,9 @@ public ChatMessage Message public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() => _message.ToString(); + public override string ToString() => + _messages is null || _messages.Count == 0 ? string.Empty : + Message.ToString(); /// Creates an array of instances that represent this . /// An array of instances that may be used to represent this . @@ -93,27 +160,33 @@ public ChatResponseUpdate[] ToChatResponseUpdates() } } - var updates = new ChatResponseUpdate[extra is null ? 1 : 2]; + int messageCount = _messages?.Count ?? 0; + var updates = new ChatResponseUpdate[messageCount + (extra is not null ? 1 : 0)]; - updates[0] = new ChatResponseUpdate + int i; + for (i = 0; i < messageCount; i++) { - ChatThreadId = ChatThreadId, - - AdditionalProperties = _message.AdditionalProperties, - AuthorName = _message.AuthorName, - Contents = _message.Contents, - RawRepresentation = _message.RawRepresentation, - Role = _message.Role, - - ResponseId = ResponseId, - CreatedAt = CreatedAt, - FinishReason = FinishReason, - ModelId = ModelId - }; + ChatMessage message = _messages![i]; + updates[i] = new ChatResponseUpdate + { + ChatThreadId = ChatThreadId, + + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + Role = message.Role, + + ResponseId = ResponseId, + CreatedAt = CreatedAt, + FinishReason = FinishReason, + ModelId = ModelId + }; + } if (extra is not null) { - updates[1] = extra; + updates[i] = extra; } return updates; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 696dc91cbe1..57f3019ae6e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -71,11 +72,20 @@ public string? Text } /// Gets or sets the chat response update content items. + /// The must not be read-only. [AllowNull] public IList Contents { get => _contents ??= []; - set => _contents = value; + set + { + if (value is not null) + { + _ = Throw.IfReadOnly(value); + } + + _contents = value; + } } /// Gets or sets the raw representation of the response update from an underlying implementation. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs index f552ef1aec0..54b903d68c8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs @@ -8,7 +8,6 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S109 // Magic numbers should not be used -#pragma warning disable S127 // "for" loop stop conditions should be invariant #pragma warning disable S1121 // Assignments should not be made from within sub-expressions namespace Microsoft.Extensions.AI; @@ -59,7 +58,7 @@ public static ChatResponse ToChatResponse( { _ = Throw.IfNull(updates); - ChatResponse response = new(new(default, [])); + ChatResponse response = new(); foreach (var update in updates) { @@ -91,7 +90,7 @@ public static Task ToChatResponseAsync( static async Task ToChatResponseAsync( IAsyncEnumerable updates, bool coalesceContent, CancellationToken cancellationToken) { - ChatResponse response = new(new(default, [])); + ChatResponse response = new(); await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) { @@ -104,18 +103,45 @@ static async Task ToChatResponseAsync( } } + /// Finalizes the object. + private static void FinalizeResponse(ChatResponse response, bool coalesceContent) + { + if (coalesceContent) + { + foreach (ChatMessage message in response.Messages) + { + CoalesceTextContent((List)message.Contents); + } + } + } + /// Processes the , incorporating its contents into . /// The update to process. /// The object that should be updated based on . private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse response) { - response.ChatThreadId ??= update.ChatThreadId; - response.CreatedAt ??= update.CreatedAt; - response.FinishReason ??= update.FinishReason; - response.ModelId ??= update.ModelId; - response.ResponseId ??= update.ResponseId; + // If there is no message created yet, or if the last update we saw had a different + // response ID than the newest update, create a new message. + if (response.Messages.Count == 0 || + (update.ResponseId is string updateId && response.ResponseId is string responseId && updateId != responseId)) + { + response.Messages.Add(new ChatMessage(ChatRole.Assistant, [])); + } + + // Some members on ChatResponseUpdate map to members of ChatMessage. + // Incorporate those into the latest message; in cases where the message + // stores a single value, prefer the latest update's value over anything + // stored in the message. + if (update.AuthorName is not null) + { + response.Message.AuthorName = update.AuthorName; + } + + if (update.Role is ChatRole role) + { + response.Message.Role = role; + } - // Incorporate all content from the update into the response. foreach (var content in update.Contents) { switch (content) @@ -131,10 +157,33 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon } } - response.Message.AuthorName ??= update.AuthorName; - if (update.Role is ChatRole role && response.Message.Role == default) + // Other members on a ChatResponseUpdate map to members of the ChatResponse. + // Update the response object with those, preferring the values from later updates. + if (update.ChatThreadId is not null) { - response.Message.Role = role; + response.ChatThreadId = update.ChatThreadId; + } + + if (update.CreatedAt is not null) + { + response.CreatedAt = update.CreatedAt; + } + + if (update.FinishReason is not null) + { + response.FinishReason = update.FinishReason; + } + + if (update.ModelId is not null) + { + response.ModelId = update.ModelId; + } + + if (update.ResponseId is not null) + { + // Note that this must come after the message checks earlier, as they depend + // on this value for change detection. + response.ResponseId = update.ResponseId; } if (update.AdditionalProperties is not null) @@ -145,29 +194,11 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon } else { - foreach (var entry in update.AdditionalProperties) - { - // Use first-wins behavior to match the behavior of the other properties. - _ = response.AdditionalProperties.TryAdd(entry.Key, entry.Value); - } + response.AdditionalProperties.SetAll(update.AdditionalProperties); } } } - /// Finalizes the object. - private static void FinalizeResponse(ChatResponse response, bool coalesceContent) - { - if (response.Message.Role == default) - { - response.Message.Role = ChatRole.Assistant; - } - - if (coalesceContent) - { - CoalesceTextContent((List)response.Message.Contents); - } - } - /// Coalesces sequential content elements. private static void CoalesceTextContent(List contents) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 48166f78df9..ab63bfc5b48 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -11,6 +12,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; +using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI; #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test @@ -24,8 +26,10 @@ namespace Microsoft.Extensions.AI; /// /// /// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a . +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . /// /// /// The provided implementation of is thread-safe for concurrent use so long as the @@ -179,78 +183,100 @@ public int? MaximumIterationsPerRequest public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatMessages); + _ = Throw.IfReadOnly(chatMessages); // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); + IList originalChatMessages = chatMessages; ChatResponse? response = null; + List? responseMessages = null; UsageDetails? totalUsage = null; - IList originalChatMessages = chatMessages; List? functionCallContents = null; - try + + for (int iteration = 0; ; iteration++) { - for (int iteration = 0; ; iteration++) + functionCallContents?.Clear(); + + // Make the call to the handler. + response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + if (response is null) { - functionCallContents?.Clear(); + throw new InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); + } - // Make the call to the handler. - response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. + bool requiresFunctionInvocation = + options?.Tools is { Count: > 0 } && + (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && + CopyFunctionCalls(response.Message.Contents, ref functionCallContents); - // Aggregate usage data over all calls - if (response.Usage is not null) + // In the common case where we make a request and there's no function calling work required, + // fast path out by just returning the original response. + if (iteration == 0 && !requiresFunctionInvocation) + { + return response; + } + + // Track aggregatable details from the response. + (responseMessages ??= []).AddRange(response.Messages); + if (response.Usage is not null) + { + if (totalUsage is not null) { - totalUsage ??= new(); totalUsage.Add(response.Usage); } - - // If there are no tools to call, or for any other reason we should stop, return the response. - if (options is null - || options.Tools is not { Count: > 0 } - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + else { - break; + totalUsage = response.Usage; } + } - // Extract any function call contents. If there are none, we're done. - CopyFunctionCalls(response.Message.Contents, ref functionCallContents); - if (functionCallContents is not { Count: > 0 }) + // If there are no tools to call, or for any other reason we should stop, we're done. + if (!requiresFunctionInvocation) + { + // If this is the first request, we can just return the response, as we don't need to + // incorporate any information from previous requests. + if (iteration == 0) { - break; + return response; } - // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. - // In that case, we also avoid touching the user's history, so that we don't need to clear it. - if (response.ChatThreadId is not null) + // Otherwise, break out of the loop and allow the handling at the end to configure + // the response with aggregated data from previous requests. + break; + } + + // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. + // In that case, we also avoid touching the user's history, so that we don't need to clear it. + if (response.ChatThreadId is not null) + { + if (chatMessages == originalChatMessages) { - if (chatMessages == originalChatMessages) - { - chatMessages = []; - } - else - { - chatMessages.Clear(); - } + chatMessages = []; } - - // Add the responses from the function calls into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) + else { - // Terminate - return response; + chatMessages.Clear(); } } - return response; - } - finally - { - if (response is not null) + // Add the responses from the function calls into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) { - response.Usage = totalUsage; + // Terminate + break; } } + + Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); + response.Messages = responseMessages!; + response.Usage = totalUsage; + + return response; } /// @@ -258,6 +284,7 @@ public override async IAsyncEnumerable GetStreamingResponseA IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatMessages); + _ = Throw.IfReadOnly(chatMessages); // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. @@ -272,18 +299,22 @@ public override async IAsyncEnumerable GetStreamingResponseA await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) { + if (update is null) + { + throw new InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); + } + chatThreadId ??= update.ChatThreadId; - CopyFunctionCalls(update.Contents, ref functionCallContents); + _ = CopyFunctionCalls(update.Contents, ref functionCallContents); yield return update; Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } // If there are no tools to call, or for any other reason we should stop, return the response. - if (functionCallContents is not { Count: > 0 } - || options is null - || options.Tools is not { Count: > 0 } - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + if (functionCallContents is not { Count: > 0 } || + options?.Tools is not { Count: > 0 } || + (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) { break; } @@ -312,6 +343,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // Stream any generated function results. These are already part of the history, // but we stream them out for informational purposes. + string toolResponseId = Guid.NewGuid().ToString("N"); foreach (var message in modeAndMessages.MessagesAdded) { var toolResultUpdate = new ChatResponseUpdate @@ -322,6 +354,7 @@ public override async IAsyncEnumerable GetStreamingResponseA CreatedAt = DateTimeOffset.UtcNow, Contents = message.Contents, RawRepresentation = message.RawRepresentation, + ResponseId = toolResponseId, Role = message.Role, }; @@ -332,16 +365,21 @@ public override async IAsyncEnumerable GetStreamingResponseA } /// Copies any from to . - private static void CopyFunctionCalls(IList content, ref List? functionCalls) + private static bool CopyFunctionCalls( + IList content, [NotNullWhen(true)] ref List? functionCalls) { + bool any = false; int count = content.Count; for (int i = 0; i < count; i++) { if (content[i] is FunctionCallContent functionCall) { (functionCalls ??= []).Add(functionCall); + any = true; } } + + return any; } /// Updates for the response. @@ -533,10 +571,10 @@ protected virtual IList AddResponseMessages(IList chat { _ = Throw.IfNull(chatMessages); - var contents = new AIContent[results.Length]; + var contents = new List(results.Length); for (int i = 0; i < results.Length; i++) { - contents[i] = CreateFunctionResultContent(results[i]); + contents.Add(CreateFunctionResultContent(results[i])); } ChatMessage message = new(ChatRole.Tool, contents); diff --git a/src/Shared/Throw/Throw.cs b/src/Shared/Throw/Throw.cs index 257a880e6ac..393caaa7c0f 100644 --- a/src/Shared/Throw/Throw.cs +++ b/src/Shared/Throw/Throw.cs @@ -294,6 +294,24 @@ public static IEnumerable IfNullOrEmpty([NotNull] IEnumerable? argument return argument; } + /// + /// Throws an if the collection's + /// is . + /// + /// The collection to evaluate. + /// The name of the parameter being checked. + /// The type of objects in the collection. + /// The original value of . + public static ICollection IfReadOnly(ICollection argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument.IsReadOnly) + { + ArgumentException(paramName, "Collection is read-only"); + } + + return argument; + } + #endregion #region Exceptions diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index f3536bd116f..0d8e4f8bb3b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; using System.Text.Json; using Xunit; @@ -12,7 +13,8 @@ public class ChatResponseTests [Fact] public void Constructor_InvalidArgs_Throws() { - Assert.Throws("message", () => new ChatResponse(null!)); + Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); + Assert.Throws("messages", () => new ChatResponse((List)null!)); } [Fact] @@ -31,6 +33,56 @@ public void Constructor_Message_Roundtrips() Assert.Same(message, response.Message); } + [Fact] + public void Constructor_Messages_Roundtrips() + { + ChatResponse response = new(); + Assert.NotNull(response.Messages); + Assert.Same(response.Messages, response.Messages); + + List messages = new(); + response = new(messages); + Assert.Same(messages, response.Messages); + + messages = new(); + response.Messages = messages; + Assert.Same(messages, response.Messages); + } + + [Fact] + public void Message_LastMessageOfMessages() + { + ChatResponse response = new(); + + Assert.Empty(response.Messages); + Assert.NotNull(response.Message); + Assert.NotEmpty(response.Messages); + + for (int i = 1; i < 3; i++) + { + Assert.Same(response.Messages[response.Messages.Count - 1], response.Message); + response.Messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); + } + } + + [Fact] + public void Message_SetterSetsLast() + { + ChatResponse response = new(); + + Assert.Empty(response.Messages); + ChatMessage message = new(); + response.Message = message; + Assert.NotEmpty(response.Messages); + Assert.Same(message, response.Messages[0]); + + message = new(); + response.Message = message; + Assert.Single(response.Messages); + Assert.Same(message, response.Messages[0]); + Assert.Same(message, response.Message); + } + [Fact] public void Properties_Roundtrip() { @@ -71,7 +123,7 @@ public void Properties_Roundtrip() [Fact] public void JsonSerialization_Roundtrips() { - ChatResponse original = new(new(ChatRole.Assistant, "the message")) + ChatResponse original = new(new ChatMessage(ChatRole.Assistant, "the message")) { ResponseId = "id", ModelId = "modelId", @@ -106,7 +158,7 @@ public void JsonSerialization_Roundtrips() [Fact] public void ToString_OutputsChatMessageToString() { - ChatResponse response = new(new(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); + ChatResponse response = new(new ChatMessage(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); Assert.Equal(response.Message.ToString(), response.ToString()); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index d818420359a..32a0ddf3007 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -42,8 +42,8 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool new() { Text = ", ", AuthorName = "Someone", Role = new ChatRole("human"), AdditionalProperties = new() { ["a"] = "b" } }, new() { Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), ChatThreadId = "123", AdditionalProperties = new() { ["c"] = "d" } }, - new() { Contents = new[] { new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 }) } }, - new() { Contents = new[] { new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 }) } }, + new() { Contents = [new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 })] }, + new() { Contents = [new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 })] }, ]; ChatResponse response = (coalesceContent is bool, useAsync) switch @@ -61,7 +61,7 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool Assert.Equal(7, response.Usage.OutputTokenCount); Assert.Equal("12345", response.ResponseId); - Assert.Equal(new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), response.CreatedAt); + Assert.Equal(new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), response.CreatedAt); Assert.Equal("model123", response.ModelId); Assert.Equal("123", response.ChatThreadId); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index ab67d0e3376..4a3398e6a6d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -40,7 +40,7 @@ public async Task CachesSuccessResultsAsync() // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization var expectedResponse = new ChatResponse( - new(new ChatRole("fakeRole"), "This is some content") + new ChatMessage(new ChatRole("fakeRole"), "This is some content") { AdditionalProperties = new() { ["a"] = "b" }, Contents = [new FunctionCallContent("someCallId", "functionName", new Dictionary @@ -110,7 +110,7 @@ public async Task AllowsConcurrentCallsAsync() { innerCallCount++; await completionTcs.Task; - return new ChatResponse(new(ChatRole.Assistant, "Hello")); + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "Hello")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -184,7 +184,7 @@ public async Task DoesNotCacheCanceledResultsAsync() await resolutionTcs.Task; } - return new ChatResponse(new(ChatRole.Assistant, "A good result")); + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "A good result")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -524,7 +524,7 @@ public async Task CacheKeyVariesByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new(new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); + return new(new ChatMessage(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -575,7 +575,7 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new(new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); + return new(new ChatMessage(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new CachingChatClientWithCustomKey(testClient, _storage) @@ -604,7 +604,7 @@ public async Task CanCacheCustomContentTypesAsync() { // Arrange var expectedResponse = new ChatResponse( - new(new ChatRole("fakeRole"), + new ChatMessage(new ChatRole("fakeRole"), [ new CustomAIContent1("Hello", DateTime.Now), new CustomAIContent2("Goodbye", 42), @@ -663,7 +663,7 @@ public async Task CanResolveIDistributedCacheFromDI() GetResponseAsyncCallback = delegate { return Task.FromResult(new ChatResponse( - new(ChatRole.Assistant, [new TextContent("Hey")]))); + new ChatMessage(ChatRole.Assistant, [new TextContent("Hey")]))); } }; using var outer = testClient diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index f9143285431..371a1666002 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -444,6 +444,48 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() Assert.Contains(allUpdateContents, c => c is FunctionResultContent); } + [Fact] + public async Task AllResponseMessagesReturned() + { + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "doesn't matter", "Func1")] + }; + + var messages = new List + { + new(ChatRole.User, "Hello"), + }; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (chatContents, chatOptions, cancellationToken) => + { + await Task.Yield(); + + ChatMessage message = chatContents.Count is 1 or 3 ? + new(ChatRole.Assistant, [new FunctionCallContent($"callId{chatContents.Count}", "Func1")]) : + new(ChatRole.Assistant, "The answer is 42."); + + chatContents.Add(message); + + return new(message); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient); + + ChatResponse response = await client.GetResponseAsync(messages, options); + + Assert.Equal(5, response.Messages.Count); + Assert.Equal("The answer is 42.", response.Message.Text); + Assert.IsType(Assert.Single(response.Messages[0].Contents)); + Assert.IsType(Assert.Single(response.Messages[1].Contents)); + Assert.IsType(Assert.Single(response.Messages[2].Contents)); + Assert.IsType(Assert.Single(response.Messages[3].Contents)); + Assert.IsType(Assert.Single(response.Messages[4].Contents)); + } + [Fact] public async Task CanAccesssFunctionInvocationContextFromFunctionCall() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs index c6bf16b3bf2..9bd777ef543 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -56,7 +56,7 @@ public async Task GetResponseAsync_LogsResponseInvocationAndCompletion(LogLevel { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - return Task.FromResult(new ChatResponse(new(ChatRole.Assistant, "blue whale"))); + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "blue whale"))); }, }; From bd95abc8305e127a9e0612bba3017e54918c870f Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 4 Mar 2025 00:30:29 -0500 Subject: [PATCH 4/9] Address feedback and some more cleanup --- .../AdditionalPropertiesDictionary{TValue}.cs | 24 +- .../ChatCompletion/ChatMessage.cs | 45 ++-- .../ChatCompletion/ChatResponse.cs | 42 ++-- .../ChatCompletion/ChatResponseUpdate.cs | 64 ++++-- .../ChatResponseUpdateExtensions.cs | 175 +++++++++----- .../ChatCompletion/RequiredChatToolMode.cs | 1 + .../Contents/AIContentExtensions.cs | 56 ++--- .../Contents/FunctionCallContent.cs | 6 +- .../UsageDetails.cs | 3 + .../Utilities/AIJsonUtilities.Schema.cs | 2 + .../AzureAIInferenceChatClient.cs | 184 +++++++-------- .../AzureAIInferenceEmbeddingGenerator.cs | 3 + .../RelevanceTruthAndCompletenessEvaluator.cs | 8 +- .../SingleNumericMetricEvaluator.cs | 4 +- .../Utilities/JsonOutputFixer.cs | 4 +- .../OllamaChatClient.cs | 79 ++++--- .../OllamaEmbeddingGenerator.cs | 2 + .../OpenAIAssistantClient.cs | 6 +- .../OpenAIChatClient.cs | 8 +- .../OpenAIEmbeddingGenerator.cs | 5 + .../OpenAIModelMapper.ChatCompletion.cs | 29 ++- ...nAIModelMappers.StreamingChatCompletion.cs | 216 +++++++++--------- .../OpenAIRealtimeExtensions.cs | 4 + .../OpenAISerializationHelpers.cs | 5 + .../AnonymousDelegatingChatClient.cs | 2 +- .../ChatCompletion/CachingChatClient.cs | 21 +- .../ChatCompletion/ChatClientBuilder.cs | 3 + .../ChatClientBuilderChatClientExtensions.cs | 2 + ...lientBuilderServiceCollectionExtensions.cs | 23 +- .../ChatClientStructuredOutputExtensions.cs | 3 + .../ChatCompletion/ChatResponse{T}.cs | 14 +- ...igureOptionsChatClientBuilderExtensions.cs | 2 + ...butedCachingChatClientBuilderExtensions.cs | 1 + .../FunctionInvokingChatClient.cs | 18 +- ...tionInvokingChatClientBuilderExtensions.cs | 1 + .../LoggingChatClientBuilderExtensions.cs | 1 + .../ChatCompletion/OpenTelemetryChatClient.cs | 16 +- ...ionsEmbeddingGeneratorBuilderExtensions.cs | 2 + .../DistributedCachingEmbeddingGenerator.cs | 2 + ...hingEmbeddingGeneratorBuilderExtensions.cs | 1 + .../Embeddings/EmbeddingGeneratorBuilder.cs | 3 + ...atorBuilderEmbeddingGeneratorExtensions.cs | 2 + ...ratorBuilderServiceCollectionExtensions.cs | 23 +- ...gingEmbeddingGeneratorBuilderExtensions.cs | 1 + .../Functions/AIFunctionFactory.cs | 4 + .../ChatClientExtensionsTests.cs | 2 +- .../ChatCompletion/ChatMessageTests.cs | 84 +++---- .../ChatCompletion/ChatResponseTests.cs | 64 +----- .../ChatResponseUpdateExtensionsTests.cs | 56 ++--- .../ChatCompletion/ChatResponseUpdateTests.cs | 62 +---- .../DelegatingChatClientTests.cs | 4 +- .../AzureAIInferenceChatClientTests.cs | 26 +-- .../AdditionalContextTests.cs | 5 +- .../EndToEndTests.cs | 4 +- ...vanceTruthAndCompletenessEvaluatorTests.cs | 4 +- .../ChatClientIntegrationTests.cs | 39 ++-- .../PromptBasedFunctionCallingChatClient.cs | 12 +- .../OllamaChatClientIntegrationTests.cs | 4 +- .../OllamaChatClientTests.cs | 24 +- .../OpenAIChatClientTests.cs | 38 +-- .../DistributedCachingChatClientTest.cs | 56 ++--- .../FunctionInvokingChatClientTests.cs | 5 +- .../ChatCompletion/LoggingChatClientTests.cs | 4 +- .../OpenTelemetryChatClientTests.cs | 4 +- .../UseDelegateChatClientTests.cs | 31 +-- test/Shared/Throw/ThrowTest.cs | 12 + 66 files changed, 851 insertions(+), 809 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs index 2835545c4d3..14125e95b76 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs @@ -107,18 +107,6 @@ public bool TryAdd(string key, TValue value) #endif } - /// Copies all of the entries from into the dictionary, overwriting any existing items in the dictionary with the same key. - /// The items to add. - public void SetAll(IEnumerable> items) - { - _ = Throw.IfNull(items); - - foreach (var item in items) - { - _dictionary[item.Key] = item.Value; - } - } - /// void ICollection>.Add(KeyValuePair item) => ((ICollection>)_dictionary).Add(item); @@ -213,6 +201,18 @@ public bool TryGetValue(string key, [NotNullWhen(true)] out T? value) /// bool IReadOnlyDictionary.TryGetValue(string key, out TValue value) => _dictionary.TryGetValue(key, out value!); + /// Copies all of the entries from into the dictionary, overwriting any existing items in the dictionary with the same key. + /// The items to add. + internal void SetAll(IEnumerable> items) + { + _ = Throw.IfNull(items); + + foreach (var item in items) + { + _dictionary[item.Key] = item.Value; + } + } + /// Enumerates the elements of an . public struct Enumerator : IEnumerator> { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index ab8dd19e964..194c6d68b1a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -18,6 +18,7 @@ public class ChatMessage private string? _authorName; /// Initializes a new instance of the class. + /// The instance defaults to having a role of . [JsonConstructor] public ChatMessage() { @@ -25,26 +26,25 @@ public ChatMessage() /// Initializes a new instance of the class. /// The role of the author of the message. - /// The contents of the message. - public ChatMessage(ChatRole role, string? content) - : this(role, content is null ? [] : [new TextContent(content)]) + /// The text contents of the message. + public ChatMessage(ChatRole role, string? contents) + : this(role, contents is null ? [] : [new TextContent(contents)]) { } /// Initializes a new instance of the class. /// The role of the author of the message. /// The contents for this message. - /// is . /// must not be read-only. - public ChatMessage( - ChatRole role, - IList contents) + public ChatMessage(ChatRole role, IList? contents) { - _ = Throw.IfNull(contents); - _ = Throw.IfReadOnly(contents); + if (contents is not null) + { + _ = Throw.IfReadOnly(contents); + _contents = contents; + } Role = role; - _contents = contents; } /// Clones the to a new instance. @@ -73,29 +73,12 @@ public string? AuthorName /// Gets or sets the role of the author of the message. public ChatRole Role { get; set; } = ChatRole.User; - /// - /// Gets or sets the text of the first instance in . - /// + /// Gets the text of this message. /// - /// If there is no instance in , then the getter returns , - /// and the setter adds a new instance with the provided value. + /// This property concatenates the text of all objects in . /// [JsonIgnore] - public string? Text - { - get => Contents.FindFirst()?.Text; - set - { - if (Contents.FindFirst() is { } textContent) - { - textContent.Text = value; - } - else if (value is not null) - { - Contents.Add(new TextContent(value)); - } - } - } + public string Text => Contents.ConcatText(); /// Gets or sets the chat message content items. /// The must not be read-only. @@ -127,7 +110,7 @@ public IList Contents public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() => Contents.ConcatText(); + public override string ToString() => Text; /// Gets a object to display in the debugger display. [DebuggerBrowsable(DebuggerBrowsableState.Never)] diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index 94393fc1e13..c41c3823126 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; @@ -15,7 +16,6 @@ namespace Microsoft.Extensions.AI; /// in a variety of scenarios. For example, if automatic function calling is employed, such that a single /// request to a may actually generate multiple roundtrips to an inner /// it uses, all of the involved messages may be surfaced as part of the final . -/// The messages are ordered, such that returns the last message in the list. /// public class ChatResponse { @@ -49,10 +49,6 @@ public ChatResponse(IList messages) } /// Gets or sets the chat response messages. - /// - /// The last message in the list maps to . It should represent - /// the final result message of the operation. - /// /// The must not be read-only. public IList Messages { @@ -68,35 +64,29 @@ public IList Messages } } - /// Gets or sets the last chat response message. + /// Gets the text of the response. /// - /// When getting , if there are no messages, will add a new - /// empty message to the list and return that message; if there are messages, the last will be returned. - /// When setting , if there are messages, the last message will be replaced by the - /// newly set instance; if there are no messages, the newly set instance will be added to the list. + /// This property concatenates the of all + /// instances in . /// [JsonIgnore] - public ChatMessage Message + public string Text { get { - if (Messages.Count == 0) + IList? messages = _messages; + if (messages is null) { - Messages.Add(new ChatMessage(ChatRole.Assistant, [])); + return string.Empty; } - return Messages[Messages.Count - 1]; - } - set - { - if (Messages.Count > 0) + int count = messages.Count; + return count switch { - Messages[Messages.Count - 1] = value; - } - else - { - Messages.Add(value); - } + 0 => string.Empty, + 1 => messages[0].Text, + _ => messages.SelectMany(m => m.Contents).ConcatText(), + }; } } @@ -138,9 +128,7 @@ public ChatMessage Message public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() => - _messages is null || _messages.Count == 0 ? string.Empty : - Message.ToString(); + public override string ToString() => Text; /// Creates an array of instances that represent this . /// An array of instances that may be used to represent this . diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 57f3019ae6e..214acbed465 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; @@ -29,6 +30,7 @@ namespace Microsoft.Extensions.AI; /// only one of the values will be used to populate . /// /// +[DebuggerDisplay("[{Role}] {ContentForDebuggerDisplay}{EllipsesForDebuggerDisplay,nq}")] public class ChatResponseUpdate { /// The response update content items. @@ -37,6 +39,35 @@ public class ChatResponseUpdate /// The name of the author of the update. private string? _authorName; + /// Initializes a new instance of the class. + [JsonConstructor] + public ChatResponseUpdate() + { + } + + /// Initializes a new instance of the class. + /// The role of the author of the update. + /// The text contents of the update. + public ChatResponseUpdate(ChatRole? role, string? contents) + : this(role, contents is null ? null : [new TextContent(contents)]) + { + } + + /// Initializes a new instance of the class. + /// The role of the author of the update. + /// The contents of the update. + /// must not be read-only. + public ChatResponseUpdate(ChatRole? role, IList? contents) + { + if (contents is not null) + { + _ = Throw.IfReadOnly(contents); + _contents = contents; + } + + Role = role; + } + /// Gets or sets the name of the author of the response update. public string? AuthorName { @@ -47,29 +78,12 @@ public string? AuthorName /// Gets or sets the role of the author of the response update. public ChatRole? Role { get; set; } - /// - /// Gets or sets the text of the first instance in . - /// + /// Gets the text of this update. /// - /// If there is no instance in , then the getter returns , - /// and the setter will add new instance with the provided value. + /// This property concatenates the text of all objects in . /// [JsonIgnore] - public string? Text - { - get => Contents.FindFirst()?.Text; - set - { - if (Contents.FindFirst() is { } textContent) - { - textContent.Text = value; - } - else if (value is not null) - { - Contents.Add(new TextContent(value)); - } - } - } + public string Text => _contents is not null ? _contents.ConcatText() : string.Empty; /// Gets or sets the chat response update content items. /// The must not be read-only. @@ -123,5 +137,13 @@ public IList Contents public string? ModelId { get; set; } /// - public override string ToString() => Contents.ConcatText(); + public override string ToString() => Text; + + /// Gets a object to display in the debugger display. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private AIContent? ContentForDebuggerDisplay => _contents is { Count: > 0 } ? _contents[0] : null; + + /// Gets an indication for the debugger display of whether there's more content. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string EllipsesForDebuggerDisplay => _contents is { Count: > 1 } ? ", ..." : string.Empty; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs index 54b903d68c8..53d15e24b70 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Text; using System.Threading; @@ -17,44 +18,94 @@ namespace Microsoft.Extensions.AI; /// public static class ChatResponseUpdateExtensions { - /// Combines instances into a single . - /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instance. When , the original content items are used. - /// The default is . - /// - /// The combined . - public static ChatMessage ToChatMessage( - this IEnumerable updates, bool coalesceContent = true) => - ToChatResponse(updates, coalesceContent).Message; // TO DO: More efficient implementation + /// Converts the into instances and adds them to . + /// The list to which the newly constructed messages should be added. + /// The instances to convert to messages and add to the list. + /// is . + /// is . + /// + /// As part of combining into a series of instances, tne + /// method may use to determine message boundaries, as well as coalesce + /// contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static void AddRangeFromUpdates(this IList chatMessages, IEnumerable updates) + { + _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(updates); - /// Combines instances into a single . - /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instance. When , the original content items are used. - /// The default is . - /// + if (updates is ICollection { Count: 0 }) + { + return; + } + + ChatResponse response = updates.ToChatResponse(); + if (chatMessages is List list) + { + list.AddRange(response.Messages); + } + else + { + int count = response.Messages.Count; + for (int i = 0; i < count; i++) + { + chatMessages.Add(response.Messages[i]); + } + } + } + + /// Converts the into instances and adds them to . + /// The list to which the newly constructed messages should be added. + /// The instances to convert to messages and add to the list. /// The to monitor for cancellation requests. The default is . - /// The combined . - public static async Task ToChatMessageAsync( - this IAsyncEnumerable updates, bool coalesceContent = true, CancellationToken cancellationToken = default) => - (await ToChatResponseAsync(updates, coalesceContent, cancellationToken).ConfigureAwait(false)).Message; // TO DO: More efficient implementation + /// A representing the completion of the operation. + /// is . + /// is . + /// + /// As part of combining into a series of instances, tne + /// method may use to determine message boundaries, as well as coalesce + /// contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static Task AddRangeFromUpdatesAsync( + this IList chatMessages, IAsyncEnumerable updates, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(updates); + + return AddRangeFromUpdatesAsync(chatMessages, updates, cancellationToken); + + static async Task AddRangeFromUpdatesAsync( + IList chatMessages, IAsyncEnumerable updates, CancellationToken cancellationToken) + { + ChatResponse response = await updates.ToChatResponseAsync(cancellationToken).ConfigureAwait(false); + if (chatMessages is List list) + { + list.AddRange(response.Messages); + } + else + { + int count = response.Messages.Count; + for (int i = 0; i < count; i++) + { + chatMessages.Add(response.Messages[i]); + } + } + } + } /// Combines instances into a single . /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instance. When , the original content items are used. - /// The default is . - /// /// The combined . + /// is . + /// + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine + /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// public static ChatResponse ToChatResponse( - this IEnumerable updates, bool coalesceContent = true) + this IEnumerable updates) { _ = Throw.IfNull(updates); @@ -65,30 +116,31 @@ public static ChatResponse ToChatResponse( ProcessUpdate(update, response); } - FinalizeResponse(response, coalesceContent); + FinalizeResponse(response); return response; } /// Combines instances into a single . /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instance. When , the original content items are used. - /// The default is . - /// /// The to monitor for cancellation requests. The default is . /// The combined . + /// is . + /// + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine + /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// public static Task ToChatResponseAsync( - this IAsyncEnumerable updates, bool coalesceContent = true, CancellationToken cancellationToken = default) + this IAsyncEnumerable updates, CancellationToken cancellationToken = default) { _ = Throw.IfNull(updates); - return ToChatResponseAsync(updates, coalesceContent, cancellationToken); + return ToChatResponseAsync(updates, cancellationToken); static async Task ToChatResponseAsync( - IAsyncEnumerable updates, bool coalesceContent, CancellationToken cancellationToken) + IAsyncEnumerable updates, CancellationToken cancellationToken) { ChatResponse response = new(); @@ -97,21 +149,19 @@ static async Task ToChatResponseAsync( ProcessUpdate(update, response); } - FinalizeResponse(response, coalesceContent); + FinalizeResponse(response); return response; } } /// Finalizes the object. - private static void FinalizeResponse(ChatResponse response, bool coalesceContent) + private static void FinalizeResponse(ChatResponse response) { - if (coalesceContent) + int count = response.Messages.Count; + for (int i = 0; i < count; i++) { - foreach (ChatMessage message in response.Messages) - { - CoalesceTextContent((List)message.Contents); - } + CoalesceTextContent((List)response.Messages[i].Contents); } } @@ -122,10 +172,16 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon { // If there is no message created yet, or if the last update we saw had a different // response ID than the newest update, create a new message. + ChatMessage message; if (response.Messages.Count == 0 || (update.ResponseId is string updateId && response.ResponseId is string responseId && updateId != responseId)) { - response.Messages.Add(new ChatMessage(ChatRole.Assistant, [])); + message = new ChatMessage(ChatRole.Assistant, []); + response.Messages.Add(message); + } + else + { + message = response.Messages[response.Messages.Count - 1]; } // Some members on ChatResponseUpdate map to members of ChatMessage. @@ -134,12 +190,12 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon // stored in the message. if (update.AuthorName is not null) { - response.Message.AuthorName = update.AuthorName; + message.AuthorName = update.AuthorName; } if (update.Role is ChatRole role) { - response.Message.Role = role; + message.Role = role; } foreach (var content in update.Contents) @@ -152,13 +208,21 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon break; default: - response.Message.Contents.Add(content); + message.Contents.Add(content); break; } } // Other members on a ChatResponseUpdate map to members of the ChatResponse. // Update the response object with those, preferring the values from later updates. + + if (update.ResponseId is not null) + { + // Note that this must come after the message checks earlier, as they depend + // on this value for change detection. + response.ResponseId = update.ResponseId; + } + if (update.ChatThreadId is not null) { response.ChatThreadId = update.ChatThreadId; @@ -179,13 +243,6 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon response.ModelId = update.ModelId; } - if (update.ResponseId is not null) - { - // Note that this must come after the message checks earlier, as they depend - // on this value for change detection. - response.ResponseId = update.ResponseId; - } - if (update.AdditionalProperties is not null) { if (response.AdditionalProperties is null) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs index 74858dfe89b..91397e67602 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs @@ -26,6 +26,7 @@ public sealed class RequiredChatToolMode : ChatToolMode /// Initializes a new instance of the class that requires a specific function to be called. /// /// The name of the function that must be called. + /// is empty or composed entirely of whitespace. /// /// can be . However, it's preferable to use /// when any function can be selected. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs index eb516e2a7c1..d2b1e73a5b3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs @@ -3,9 +3,8 @@ using System; using System.Collections.Generic; -#if !NET using System.Linq; -#else +#if NET using System.Runtime.CompilerServices; #endif @@ -14,51 +13,36 @@ namespace Microsoft.Extensions.AI; /// Internal extensions for working with . internal static class AIContentExtensions { - /// Finds the first occurrence of a in the list. - public static T? FindFirst(this IList contents) - where T : AIContent - { - int count = contents.Count; - for (int i = 0; i < count; i++) - { - if (contents[i] is T t) - { - return t; - } - } - - return null; - } - /// Concatenates the text of all instances in the list. - public static string ConcatText(this IList contents) + public static string ConcatText(this IEnumerable contents) { - int count = contents.Count; - switch (count) + if (contents is IList list) { - case 0: - break; + int count = list.Count; + switch (count) + { + case 0: + return string.Empty; - case 1: - return contents[0] is TextContent tc ? tc.Text : string.Empty; + case 1: + return (list[0] as TextContent)?.Text ?? string.Empty; - default: #if NET - DefaultInterpolatedStringHandler builder = new(0, 0, null, stackalloc char[512]); - for (int i = 0; i < count; i++) - { - if (contents[i] is TextContent text) + default: + DefaultInterpolatedStringHandler builder = new(0, 0, null, stackalloc char[512]); + for (int i = 0; i < count; i++) { - builder.AppendLiteral(text.Text); + if (list[i] is TextContent text) + { + builder.AppendLiteral(text.Text); + } } - } - return builder.ToStringAndClear(); -#else - return string.Concat(contents.OfType()); + return builder.ToStringAndClear(); #endif + } } - return string.Empty; + return string.Concat(contents.OfType()); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs index 88e0a207127..d19988b2b76 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs @@ -65,15 +65,19 @@ public FunctionCallContent(string callId, string name, IDictionaryThe function name. /// The parsing implementation converting the encoding to a dictionary of arguments. /// A new instance of containing the parse result. + /// is . + /// is . + /// is . + /// is . public static FunctionCallContent CreateFromParsedArguments( TEncoding encodedArguments, string callId, string name, Func?> argumentParser) { + _ = Throw.IfNull(encodedArguments); _ = Throw.IfNull(callId); _ = Throw.IfNull(name); - _ = Throw.IfNull(encodedArguments); _ = Throw.IfNull(argumentParser); IDictionary? arguments = null; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs index 7d4e7ddbea2..b3c62cb67e0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs @@ -29,9 +29,12 @@ public class UsageDetails public AdditionalPropertiesDictionary? AdditionalCounts { get; set; } /// Adds usage data from another into this instance. + /// The source with which to augment this instance. + /// is . public void Add(UsageDetails usage) { _ = Throw.IfNull(usage); + InputTokenCount = NullableSum(InputTokenCount, usage.InputTokenCount); OutputTokenCount = NullableSum(OutputTokenCount, usage.OutputTokenCount); TotalTokenCount = NullableSum(TotalTokenCount, usage.TotalTokenCount); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index e8a962a0be8..1d0224f7c9f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -55,6 +55,7 @@ public static partial class AIJsonUtilities /// The options used to extract the schema from the specified type. /// The options controlling schema inference. /// A JSON schema document encoded as a . + /// is . public static JsonElement CreateFunctionJsonSchema( MethodBase method, string? title = null, @@ -63,6 +64,7 @@ public static JsonElement CreateFunctionJsonSchema( AIJsonSchemaCreateOptions? inferenceOptions = null) { _ = Throw.IfNull(method); + serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; title ??= method.Name; diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 760c21f69aa..23edd8e9266 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -39,9 +39,12 @@ public sealed class AzureAIInferenceChatClient : IChatClient /// Initializes a new instance of the class for the specified . /// The underlying client. /// The ID of the model to use. If null, it can be provided per request via . + /// is . + /// is empty or composed entirely of whitespace. public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, string? modelId = null) { _ = Throw.IfNull(chatCompletionsClient); + if (modelId is not null) { _ = Throw.IfNullOrWhitespace(modelId); @@ -91,17 +94,11 @@ public async Task GetResponseAsync( cancellationToken: cancellationToken).ConfigureAwait(false)).Value; // Create the return message. - ChatMessage message = new() + ChatMessage message = new(ToChatRole(response.Role), response.Content) { RawRepresentation = response, - Role = ToChatRole(response.Role), }; - if (response.Content is string content) - { - message.Text = content; - } - if (response.ToolCalls is { Count: > 0 } toolCalls) { foreach (var toolCall in toolCalls) @@ -153,109 +150,114 @@ public async IAsyncEnumerable GetStreamingResponseAsync( DateTimeOffset? createdAt = null; string? modelId = null; string lastCallId = string.Empty; - List responseUpdates = []; - // Process each update as it arrives - var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); - await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) + List responseUpdates = []; + try { - // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null; - finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null; - responseId ??= chatCompletionUpdate.Id; - createdAt ??= chatCompletionUpdate.Created; - modelId ??= chatCompletionUpdate.Model; - - // Create the response content object. - ChatResponseUpdate responseUpdate = new() - { - CreatedAt = chatCompletionUpdate.Created, - FinishReason = finishReason, - ModelId = modelId, - RawRepresentation = chatCompletionUpdate, - ResponseId = chatCompletionUpdate.Id, - Role = streamedRole, - }; - - // Transfer over content update items. - if (chatCompletionUpdate.ContentUpdate is string update) - { - responseUpdate.Contents.Add(new TextContent(update)); - } - - // Transfer over tool call updates. - if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate) + // Process each update as it arrives + var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); + await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) { - // TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference - // has removed the Index property from ToolCallUpdate. It's now impossible via the - // exposed APIs to correctly handle multiple parallel tool calls, as the CallId is - // often null for anything other than the first update for a given call, and Index - // isn't available to correlate which updates are for which call. This is a temporary - // workaround to at least make a single tool call work and also make work multiple - // tool calls when their updates aren't interleaved. - if (toolCallUpdate.Id is not null) + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null; + responseId ??= chatCompletionUpdate.Id; + createdAt ??= chatCompletionUpdate.Created; + modelId ??= chatCompletionUpdate.Model; + + // Create the response content object. + ChatResponseUpdate responseUpdate = new() + { + CreatedAt = chatCompletionUpdate.Created, + FinishReason = finishReason, + ModelId = modelId, + RawRepresentation = chatCompletionUpdate, + ResponseId = chatCompletionUpdate.Id, + Role = streamedRole, + }; + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is string update) { - lastCallId = toolCallUpdate.Id; + responseUpdate.Contents.Add(new TextContent(update)); } - functionCallInfos ??= []; - if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing)) + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate) { - functionCallInfos[lastCallId] = existing = new(); + // TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference + // has removed the Index property from ToolCallUpdate. It's now impossible via the + // exposed APIs to correctly handle multiple parallel tool calls, as the CallId is + // often null for anything other than the first update for a given call, and Index + // isn't available to correlate which updates are for which call. This is a temporary + // workaround to at least make a single tool call work and also make work multiple + // tool calls when their updates aren't interleaved. + if (toolCallUpdate.Id is not null) + { + lastCallId = toolCallUpdate.Id; + } + + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing)) + { + functionCallInfos[lastCallId] = existing = new(); + } + + existing.Name ??= toolCallUpdate.Function.Name; + if (toolCallUpdate.Function.Arguments is { } arguments) + { + _ = (existing.Arguments ??= new()).Append(arguments); + } } - existing.Name ??= toolCallUpdate.Function.Name; - if (toolCallUpdate.Function.Arguments is { } arguments) + if (chatCompletionUpdate.Usage is { } usage) { - _ = (existing.Arguments ??= new()).Append(arguments); + responseUpdate.Contents.Add(new UsageContent(new() + { + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens, + })); } - } - if (chatCompletionUpdate.Usage is { } usage) - { - responseUpdate.Contents.Add(new UsageContent(new() - { - InputTokenCount = usage.PromptTokens, - OutputTokenCount = usage.CompletionTokens, - TotalTokenCount = usage.TotalTokens, - })); + // Now yield the item. + responseUpdates.Add(responseUpdate); + yield return responseUpdate; } - // Now yield the item. - responseUpdates.Add(responseUpdate); - yield return responseUpdate; - } - - // Now that we've received all updates, combine any for function calls into a single item to yield. - if (functionCallInfos is not null) - { - var responseUpdate = new ChatResponseUpdate + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) { - CreatedAt = createdAt, - FinishReason = finishReason, - ModelId = modelId, - ResponseId = responseId, - Role = streamedRole, - }; - - foreach (var entry in functionCallInfos) - { - FunctionCallInfo fci = entry.Value; - if (!string.IsNullOrWhiteSpace(fci.Name)) + var responseUpdate = new ChatResponseUpdate + { + CreatedAt = createdAt, + FinishReason = finishReason, + ModelId = modelId, + ResponseId = responseId, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) { - FunctionCallContent callContent = ParseCallContentFromJsonString( - fci.Arguments?.ToString() ?? string.Empty, - entry.Key, - fci.Name!); - responseUpdate.Contents.Add(callContent); + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + FunctionCallContent callContent = ParseCallContentFromJsonString( + fci.Arguments?.ToString() ?? string.Empty, + entry.Key, + fci.Name!); + responseUpdate.Contents.Add(callContent); + } } - } - responseUpdates.Add(responseUpdate); - yield return responseUpdate; + responseUpdates.Add(responseUpdate); + yield return responseUpdate; + } + } + finally + { + chatMessages.AddRangeFromUpdates(responseUpdates); } - - chatMessages.Add(responseUpdates.ToChatMessage()); } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index 17bd4fa4662..c0f4b2f4636 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -41,6 +41,9 @@ public sealed class AzureAIInferenceEmbeddingGenerator : /// Either this parameter or must provide a valid model ID. /// /// The number of dimensions to generate in each embedding. + /// is . + /// is empty or composed entirely of whitespace. + /// is not positive. public AzureAIInferenceEmbeddingGenerator( EmbeddingsClient embeddingsClient, string? modelId = null, int? dimensions = null) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs index 4fdccf03be9..24d96802542 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs @@ -125,10 +125,10 @@ await chatConfiguration.ChatClient.GetResponseAsync( _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - string? evaluationResponseText = evaluationResponse.Message.Text?.Trim(); + string evaluationResponseText = evaluationResponse.Text.Trim(); Rating rating; - if (string.IsNullOrWhiteSpace(evaluationResponseText)) + if (string.IsNullOrEmpty(evaluationResponseText)) { rating = Rating.Inconclusive; result.AddDiagnosticToAllMetrics( @@ -145,13 +145,13 @@ await chatConfiguration.ChatClient.GetResponseAsync( { try { - string? repairedJson = + string repairedJson = await JsonOutputFixer.RepairJsonAsync( chatConfiguration, evaluationResponseText!, cancellationToken).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(repairedJson)) + if (string.IsNullOrEmpty(repairedJson)) { rating = Rating.Inconclusive; result.AddDiagnosticToAllMetrics( diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs index 8b9367dbf32..f56e1e427fb 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs @@ -71,11 +71,11 @@ await chatConfiguration.ChatClient.GetResponseAsync( _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - string? evaluationResponseText = evaluationResponse.Message.Text?.Trim(); + string evaluationResponseText = evaluationResponse.Text.Trim(); NumericMetric metric = result.Get(MetricName); - if (string.IsNullOrWhiteSpace(evaluationResponseText)) + if (string.IsNullOrEmpty(evaluationResponseText)) { metric.AddDiagnostic( EvaluationDiagnostic.Error( diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs index e6b10dedb84..b50d69bcebd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs @@ -32,7 +32,7 @@ internal static ReadOnlySpan TrimMarkdownDelimiters(string json) return trimmed; } - internal static async ValueTask RepairJsonAsync( + internal static async ValueTask RepairJsonAsync( ChatConfiguration chatConfig, string json, CancellationToken cancellationToken) @@ -74,6 +74,6 @@ await chatConfig.ChatClient.GetResponseAsync( chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - return response.Message.Text?.Trim(); + return response.Text.Trim(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index 2c3fb6bf709..abddf295daf 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -56,6 +56,8 @@ public OllamaChatClient(string endpoint, string? modelId = null, HttpClient? htt /// Either this parameter or must provide a valid model ID. /// /// An instance to use for HTTP operations. + /// is . + /// is empty or composed entirely of whitespace. public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) { _ = Throw.IfNull(endpoint); @@ -142,60 +144,65 @@ public async IAsyncEnumerable GetStreamingResponseAsync( .ConfigureAwait(false); List updates = []; - using var streamReader = new StreamReader(httpResponseStream); + try + { + using var streamReader = new StreamReader(httpResponseStream); #if NET - while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line) + while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line) #else - while ((await streamReader.ReadLineAsync().ConfigureAwait(false)) is { } line) + while ((await streamReader.ReadLineAsync().ConfigureAwait(false)) is { } line) #endif - { - var chunk = JsonSerializer.Deserialize(line, JsonContext.Default.OllamaChatResponse); - if (chunk is null) { - continue; - } - - string? modelId = chunk.Model ?? _metadata.ModelId; + var chunk = JsonSerializer.Deserialize(line, JsonContext.Default.OllamaChatResponse); + if (chunk is null) + { + continue; + } - ChatResponseUpdate update = new() - { - CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, - FinishReason = ToFinishReason(chunk), - ModelId = modelId, - ResponseId = chunk.CreatedAt, - Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, - }; + string? modelId = chunk.Model ?? _metadata.ModelId; - if (chunk.Message is { } message) - { - if (message.ToolCalls is { Length: > 0 }) + ChatResponseUpdate update = new() + { + CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, + FinishReason = ToFinishReason(chunk), + ModelId = modelId, + ResponseId = chunk.CreatedAt, + Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, + }; + + if (chunk.Message is { } message) { - foreach (var toolCall in message.ToolCalls) + if (message.ToolCalls is { Length: > 0 }) { - if (toolCall.Function is { } function) + foreach (var toolCall in message.ToolCalls) { - update.Contents.Add(ToFunctionCallContent(function)); + if (toolCall.Function is { } function) + { + update.Contents.Add(ToFunctionCallContent(function)); + } } } + + // Equivalent rule to the nonstreaming case + if (message.Content?.Length > 0 || update.Contents.Count == 0) + { + update.Contents.Insert(0, new TextContent(message.Content)); + } } - // Equivalent rule to the nonstreaming case - if (message.Content?.Length > 0 || update.Contents.Count == 0) + if (ParseOllamaChatResponseUsage(chunk) is { } usage) { - update.Contents.Insert(0, new TextContent(message.Content)); + update.Contents.Add(new UsageContent(usage)); } - } - if (ParseOllamaChatResponseUsage(chunk) is { } usage) - { - update.Contents.Add(new UsageContent(usage)); + updates.Add(update); + yield return update; } - - updates.Add(update); - yield return update; } - - chatMessages.Add(updates.ToChatMessage()); + finally + { + chatMessages.AddRangeFromUpdates(updates); + } } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 3d869f3f278..6056753dd26 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -45,6 +45,8 @@ public OllamaEmbeddingGenerator(string endpoint, string? modelId = null, HttpCli /// Either this parameter or must provide a valid model ID. /// /// An instance to use for HTTP operations. + /// is . + /// is empty or composed entirely of whitespace. public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) { _ = Throw.IfNull(endpoint); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index 7b68ce5e15e..e144938cc70 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -71,7 +71,7 @@ public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId /// public Task GetResponseAsync( IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(coalesceContent: true, cancellationToken); + GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(cancellationToken); /// public async IAsyncEnumerable GetStreamingResponseAsync( @@ -117,12 +117,10 @@ public async IAsyncEnumerable GetStreamingResponseAsync( switch (update) { case MessageContentUpdate mcu: - yield return new() + yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) { ChatThreadId = threadId, RawRepresentation = mcu, - Role = mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, - Text = mcu.Text, }; break; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index a671fb0e769..fbb2c3fa4e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -39,6 +39,8 @@ public sealed class OpenAIChatClient : IChatClient /// Initializes a new instance of the class for the specified . /// The underlying client. /// The model to use. + /// is . + /// is empty or composed entirely of whitespace. public OpenAIChatClient(OpenAIClient openAIClient, string modelId) { _ = Throw.IfNull(openAIClient); @@ -59,6 +61,7 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId) /// Initializes a new instance of the class for the specified . /// The underlying client. + /// is . public OpenAIChatClient(ChatClient chatClient) { _ = Throw.IfNull(chatClient); @@ -111,7 +114,10 @@ public async Task GetResponseAsync( var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false); ChatResponse chatResponse = OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions); - chatMessages.Add(chatResponse.Message); + foreach (ChatMessage message in chatResponse.Messages) + { + chatMessages.Add(message); + } return chatResponse; } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 55c887ba108..7cf0be18fb0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -38,6 +38,9 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGeneratorThe underlying client. /// The model to use. /// The number of dimensions to generate in each embedding. + /// is . + /// is empty or composed entirely of whitespace. + /// is not positive. public OpenAIEmbeddingGenerator( OpenAIClient openAIClient, string modelId, int? dimensions = null) { @@ -66,6 +69,8 @@ public OpenAIEmbeddingGenerator( /// Initializes a new instance of the class. /// The underlying client. /// The number of dimensions to generate in each embedding. + /// is . + /// is not positive. public OpenAIEmbeddingGenerator(EmbeddingClient embeddingClient, int? dimensions = null) { _ = Throw.IfNull(embeddingClient); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index e67fa627f3f..59727d38f00 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -31,17 +31,24 @@ public static ChatCompletion ToOpenAIChatCompletion(ChatResponse response, JsonS _ = Throw.IfNull(response); List? toolCalls = null; - foreach (AIContent content in response.Message.Contents) + ChatRole? role = null; + List allContents = []; + foreach (ChatMessage message in response.Messages) { - if (content is FunctionCallContent callRequest) + role = message.Role; + foreach (AIContent content in message.Contents) { - toolCalls ??= []; - toolCalls.Add(ChatToolCall.CreateFunctionToolCall( - callRequest.CallId, - callRequest.Name, - new(JsonSerializer.SerializeToUtf8Bytes( - callRequest.Arguments, - options.GetTypeInfo(typeof(IDictionary)))))); + allContents.Add(content); + if (content is FunctionCallContent callRequest) + { + toolCalls ??= []; + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + } } } @@ -55,9 +62,9 @@ public static ChatCompletion ToOpenAIChatCompletion(ChatResponse response, JsonS id: response.ResponseId ?? CreateCompletionId(), model: response.ModelId, createdAt: response.CreatedAt ?? DateTimeOffset.UtcNow, - role: ToOpenAIChatRole(response.Message.Role).Value, + role: ToOpenAIChatRole(role) ?? ChatMessageRole.Assistant, finishReason: ToOpenAIFinishReason(response.FinishReason), - content: new(ToOpenAIChatContent(response.Message.Contents)), + content: new(ToOpenAIChatContent(allContents)), toolCalls: toolCalls, refusal: response.AdditionalProperties.GetValueOrDefault(nameof(ChatCompletion.Refusal)), contentTokenLogProbabilities: response.AdditionalProperties.GetValueOrDefault>(nameof(ChatCompletion.ContentTokenLogProbabilities)), diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs index 4e21ddff022..d4858b3b70c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs @@ -76,137 +76,141 @@ public static async IAsyncEnumerable FromOpenAIStreamingChat string? fingerprint = null; List responseUpdates = []; - - // Process each update as it arrives - await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) + try { - // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= update.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; - finishReason ??= update.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; - responseId ??= update.CompletionId; - createdAt ??= update.CreatedAt; - modelId ??= update.Model; - fingerprint ??= update.SystemFingerprint; - - // Create the response content object. - ChatResponseUpdate responseUpdate = new() - { - ResponseId = update.CompletionId, - CreatedAt = update.CreatedAt, - FinishReason = finishReason, - ModelId = modelId, - RawRepresentation = update, - Role = streamedRole, - }; - - // Populate it with any additional metadata from the OpenAI object. - if (update.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + // Process each update as it arrives + await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.ContentTokenLogProbabilities)] = contentTokenLogProbs; - } + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= update.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; + finishReason ??= update.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; + responseId ??= update.CompletionId; + createdAt ??= update.CreatedAt; + modelId ??= update.Model; + fingerprint ??= update.SystemFingerprint; + + // Create the response content object. + ChatResponseUpdate responseUpdate = new() + { + ResponseId = update.CompletionId, + CreatedAt = update.CreatedAt, + FinishReason = finishReason, + ModelId = modelId, + RawRepresentation = update, + Role = streamedRole, + }; + + // Populate it with any additional metadata from the OpenAI object. + if (update.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } - if (update.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; - } + if (update.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } - if (fingerprint is not null) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.SystemFingerprint)] = fingerprint; - } + if (fingerprint is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.SystemFingerprint)] = fingerprint; + } - // Transfer over content update items. - if (update.ContentUpdate is { Count: > 0 }) - { - foreach (ChatMessageContentPart contentPart in update.ContentUpdate) + // Transfer over content update items. + if (update.ContentUpdate is { Count: > 0 }) { - if (ToAIContent(contentPart) is AIContent aiContent) + foreach (ChatMessageContentPart contentPart in update.ContentUpdate) { - responseUpdate.Contents.Add(aiContent); + if (ToAIContent(contentPart) is AIContent aiContent) + { + responseUpdate.Contents.Add(aiContent); + } } } - } - // Transfer over refusal updates. - if (update.RefusalUpdate is not null) - { - _ = (refusal ??= new()).Append(update.RefusalUpdate); - } + // Transfer over refusal updates. + if (update.RefusalUpdate is not null) + { + _ = (refusal ??= new()).Append(update.RefusalUpdate); + } - // Transfer over tool call updates. - if (update.ToolCallUpdates is { Count: > 0 } toolCallUpdates) - { - foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) + // Transfer over tool call updates. + if (update.ToolCallUpdates is { Count: > 0 } toolCallUpdates) { - functionCallInfos ??= []; - if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) + foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) { - functionCallInfos[toolCallUpdate.Index] = existing = new(); + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) + { + functionCallInfos[toolCallUpdate.Index] = existing = new(); + } + + existing.CallId ??= toolCallUpdate.ToolCallId; + existing.Name ??= toolCallUpdate.FunctionName; + if (toolCallUpdate.FunctionArgumentsUpdate is { } argUpdate && !argUpdate.ToMemory().IsEmpty) + { + _ = (existing.Arguments ??= new()).Append(argUpdate.ToString()); + } } + } - existing.CallId ??= toolCallUpdate.ToolCallId; - existing.Name ??= toolCallUpdate.FunctionName; - if (toolCallUpdate.FunctionArgumentsUpdate is { } argUpdate && !argUpdate.ToMemory().IsEmpty) - { - _ = (existing.Arguments ??= new()).Append(argUpdate.ToString()); - } + // Transfer over usage updates. + if (update.Usage is ChatTokenUsage tokenUsage) + { + var usageDetails = FromOpenAIUsage(tokenUsage); + responseUpdate.Contents.Add(new UsageContent(usageDetails)); } - } - // Transfer over usage updates. - if (update.Usage is ChatTokenUsage tokenUsage) - { - var usageDetails = FromOpenAIUsage(tokenUsage); - responseUpdate.Contents.Add(new UsageContent(usageDetails)); + // Now yield the item. + responseUpdates.Add(responseUpdate); + yield return responseUpdate; } - // Now yield the item. - responseUpdates.Add(responseUpdate); - yield return responseUpdate; - } - - // Now that we've received all updates, combine any for function calls into a single item to yield. - if (functionCallInfos is not null) - { - ChatResponseUpdate responseUpdate = new() + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) { - ResponseId = responseId, - CreatedAt = createdAt, - FinishReason = finishReason, - ModelId = modelId, - Role = streamedRole, - }; - - foreach (var entry in functionCallInfos) - { - FunctionCallInfo fci = entry.Value; - if (!string.IsNullOrWhiteSpace(fci.Name)) + ChatResponseUpdate responseUpdate = new() + { + ResponseId = responseId, + CreatedAt = createdAt, + FinishReason = finishReason, + ModelId = modelId, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) { - var callContent = ParseCallContentFromJsonString( - fci.Arguments?.ToString() ?? string.Empty, - fci.CallId!, - fci.Name!); - responseUpdate.Contents.Add(callContent); + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + var callContent = ParseCallContentFromJsonString( + fci.Arguments?.ToString() ?? string.Empty, + fci.CallId!, + fci.Name!); + responseUpdate.Contents.Add(callContent); + } } - } - // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, - // add it to this function calling item. - if (refusal is not null) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); - } + // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, + // add it to this function calling item. + if (refusal is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); + } - // Propagate additional relevant metadata. - if (fingerprint is not null) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint; - } + // Propagate additional relevant metadata. + if (fingerprint is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint; + } - responseUpdates.Add(responseUpdate); - yield return responseUpdate; + responseUpdates.Add(responseUpdate); + yield return responseUpdate; + } + } + finally + { + chatMessages.AddRangeFromUpdates(responseUpdates); } - - chatMessages.Add(responseUpdates.ToChatMessage()); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs index 8a652a71766..d74505e64f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs @@ -23,6 +23,7 @@ public static class OpenAIRealtimeExtensions /// it can be used with . /// /// A that can be used with . + /// is . public static ConversationFunctionTool ToConversationFunctionTool(this AIFunction aiFunction) { _ = Throw.IfNull(aiFunction); @@ -53,6 +54,9 @@ public static ConversationFunctionTool ToConversationFunctionTool(this AIFunctio /// An optional that controls JSON handling. /// An optional . /// A that represents the completion of processing, including invoking any asynchronous tools. + /// is . + /// is . + /// is . public static async Task HandleToolCallsAsync( this RealtimeConversationSession session, ConversationUpdate update, diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs index 659db4ed3bd..e736d110650 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs @@ -25,6 +25,7 @@ public static class OpenAISerializationHelpers /// The stream containing a message using the OpenAI wire format. /// A token used to cancel the operation. /// The deserialized list of chat messages and chat options. + /// is . public static async Task DeserializeChatCompletionRequestAsync( Stream stream, CancellationToken cancellationToken = default) { @@ -43,6 +44,8 @@ public static async Task DeserializeChatCompletionR /// The governing function call content serialization. /// A token used to cancel the serialization operation. /// A task tracking the serialization operation. + /// is . + /// is . public static async Task SerializeAsync( Stream stream, ChatResponse response, @@ -66,6 +69,8 @@ public static async Task SerializeAsync( /// The governing function call content serialization. /// A token used to cancel the serialization operation. /// A task tracking the serialization operation. + /// is . + /// is . public static Task SerializeStreamingAsync( Stream stream, IAsyncEnumerable updates, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index 8193e841536..8063f914764 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -117,7 +117,7 @@ await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellat { Debug.Assert(_getStreamingResponseFunc is not null, "Expected non-null streaming delegate."); return _getStreamingResponseFunc!(chatMessages, options, InnerClient, cancellationToken) - .ToChatResponseAsync(coalesceContent: true, cancellationToken); + .ToChatResponseAsync(cancellationToken); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index af1f41186ea..97232d6762d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -56,7 +56,13 @@ public override async Task GetResponseAsync(IList cha if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } result) { - chatMessages.Add(result.Message); + if (options?.ChatThreadId is null) + { + foreach (ChatMessage message in result.Messages) + { + chatMessages.Add(message); + } + } } else { @@ -90,7 +96,10 @@ public override async IAsyncEnumerable GetStreamingResponseA if (chatResponse.ChatThreadId is null) { - chatMessages.Add(chatResponse.Message); + foreach (ChatMessage message in chatResponse.Messages) + { + chatMessages.Add(message); + } } } else @@ -122,7 +131,7 @@ public override async IAsyncEnumerable GetStreamingResponseA if (chatThreadId is null) { - chatMessages.Add(existingChunks.ToChatMessage()); + chatMessages.AddRangeFromUpdates(existingChunks); } } else @@ -153,6 +162,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// The cache key. /// The to monitor for cancellation requests. /// The previously cached data, if available, otherwise . + /// is . protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); /// @@ -162,6 +172,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// The cache key. /// The to monitor for cancellation requests. /// The previously cached data, if available, otherwise . + /// is . protected abstract Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken); /// @@ -172,6 +183,8 @@ public override async IAsyncEnumerable GetStreamingResponseA /// The to be stored. /// The to monitor for cancellation requests. /// A representing the completion of the operation. + /// is . + /// is . protected abstract Task WriteCacheAsync(string key, ChatResponse value, CancellationToken cancellationToken); /// @@ -182,5 +195,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// The to be stored. /// The to monitor for cancellation requests. /// A representing the completion of the operation. + /// is . + /// is . protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs index ecd6d04914b..5ecf6403d78 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -19,6 +19,7 @@ public sealed class ChatClientBuilder /// Initializes a new instance of the class. /// The inner that represents the underlying backend. + /// is . public ChatClientBuilder(IChatClient innerClient) { _ = Throw.IfNull(innerClient); @@ -61,6 +62,7 @@ public IChatClient Build(IServiceProvider? services = null) /// Adds a factory for an intermediate chat client to the chat client pipeline. /// The client factory function. /// The updated instance. + /// is . public ChatClientBuilder Use(Func clientFactory) { _ = Throw.IfNull(clientFactory); @@ -71,6 +73,7 @@ public ChatClientBuilder Use(Func clientFactory) /// Adds a factory for an intermediate chat client to the chat client pipeline. /// The client factory function. /// The updated instance. + /// is . public ChatClientBuilder Use(Func clientFactory) { _ = Throw.IfNull(clientFactory); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs index 87983bf2367..b4e1e7f280f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; @@ -16,6 +17,7 @@ public static class ChatClientBuilderChatClientExtensions /// This method is equivalent to using the constructor directly, /// specifying as the inner client. /// + /// is . public static ChatClientBuilder AsBuilder(this IChatClient innerClient) { _ = Throw.IfNull(innerClient); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs index c1be6406d1a..d1e6761f317 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs @@ -16,11 +16,18 @@ public static class ChatClientBuilderServiceCollectionExtensions /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a singleton service. + /// is . + /// is . public static ChatClientBuilder AddChatClient( this IServiceCollection serviceCollection, IChatClient innerClient, ServiceLifetime lifetime = ServiceLifetime.Singleton) - => AddChatClient(serviceCollection, _ => innerClient, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddChatClient(serviceCollection, _ => innerClient, lifetime); + } /// Registers a singleton in the . /// The to which the client should be added. @@ -28,6 +35,8 @@ public static ChatClientBuilder AddChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a singleton service. + /// is . + /// is . public static ChatClientBuilder AddChatClient( this IServiceCollection serviceCollection, Func innerClientFactory, @@ -48,12 +57,19 @@ public static ChatClientBuilder AddChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a scoped service. + /// is . + /// is . public static ChatClientBuilder AddKeyedChatClient( this IServiceCollection serviceCollection, object? serviceKey, IChatClient innerClient, ServiceLifetime lifetime = ServiceLifetime.Singleton) - => AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient, lifetime); + } /// Registers a keyed singleton in the . /// The to which the client should be added. @@ -62,6 +78,8 @@ public static ChatClientBuilder AddKeyedChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a scoped service. + /// is . + /// is . public static ChatClientBuilder AddKeyedChatClient( this IServiceCollection serviceCollection, object? serviceKey, @@ -69,7 +87,6 @@ public static ChatClientBuilder AddKeyedChatClient( ServiceLifetime lifetime = ServiceLifetime.Singleton) { _ = Throw.IfNull(serviceCollection); - _ = Throw.IfNull(serviceKey); _ = Throw.IfNull(innerClientFactory); var builder = new ChatClientBuilder(innerClientFactory); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 778150f1ac1..ceb8087289a 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -150,6 +150,9 @@ public static Task> GetResponseAsync( /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. /// /// The type of structured output to request. + /// is . + /// is . + /// is . public static async Task> GetResponseAsync( this IChatClient chatClient, IList chatMessages, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs index a02792fbcf3..2a9fca23fae 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs @@ -3,7 +3,6 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Text; using System.Text.Json; @@ -17,7 +16,7 @@ namespace Microsoft.Extensions.AI; /// /// Language models are not guaranteed to honor the requested schema. If the model's output is not /// parseable as the expected type, then will return . -/// You can access the underlying JSON response on the property. +/// You can access the underlying JSON response on the property. /// public class ChatResponse : ChatResponse { @@ -31,10 +30,11 @@ public class ChatResponse : ChatResponse /// The unstructured that is being wrapped. /// The to use when deserializing the result. public ChatResponse(ChatResponse response, JsonSerializerOptions serializerOptions) - : base(Throw.IfNull(response).Message) + : base(Throw.IfNull(response).Messages) { _serializerOptions = Throw.IfNull(serializerOptions); AdditionalProperties = response.AdditionalProperties; + ChatThreadId = response.ChatThreadId; CreatedAt = response.CreatedAt; FinishReason = response.FinishReason; ModelId = response.ModelId; @@ -114,12 +114,6 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) /// internal bool IsWrappedInObject { get; set; } - private string? GetResultAsJson() - { - var content = Message.Contents.Count == 1 ? Message.Contents[0] : null; - return (content as TextContent)?.Text; - } - private T? GetResultCore(out FailureReason? failureReason) { if (_hasDeserializedResult) @@ -128,7 +122,7 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) return _deserializedResult; } - var json = GetResultAsJson(); + var json = Text; if (string.IsNullOrEmpty(json)) { failureReason = FailureReason.ResultDidNotContainJson; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs index ea990d09a85..d76b2ba1a2e 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs @@ -25,6 +25,8 @@ public static class ConfigureOptionsChatClientBuilderExtensions /// of the caller-supplied instance if one was supplied. /// /// The . + /// is . + /// is . public static ChatClientBuilder ConfigureOptions( this ChatClientBuilder builder, Action configure) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs index 6396459c09c..6a9474b751d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs @@ -22,6 +22,7 @@ public static class DistributedCachingChatClientBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The provided as . + /// is . public static ChatClientBuilder UseDistributedCache(this ChatClientBuilder builder, IDistributedCache? storage = null, Action? configure = null) { _ = Throw.IfNull(builder); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index ab63bfc5b48..7c55a157718 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -210,7 +210,7 @@ public override async Task GetResponseAsync(IList cha bool requiresFunctionInvocation = options?.Tools is { Count: > 0 } && (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && - CopyFunctionCalls(response.Message.Contents, ref functionCallContents); + CopyFunctionCalls(response.Messages, ref functionCallContents); // In the common case where we make a request and there's no function calling work required, // fast path out by just returning the original response. @@ -364,6 +364,20 @@ public override async IAsyncEnumerable GetStreamingResponseA } } + /// Copies any from to . + private static bool CopyFunctionCalls( + IList messages, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = messages.Count; + for (int i = 0; i < count; i++) + { + any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); + } + + return any; + } + /// Copies any from to . private static bool CopyFunctionCalls( IList content, [NotNullWhen(true)] ref List? functionCalls) @@ -567,6 +581,7 @@ internal enum ContinueMode /// The chat to which to add the one or more response messages. /// Information about the function call invocations and results. /// A list of all chat messages added to . + /// is . protected virtual IList AddResponseMessages(IList chatMessages, ReadOnlySpan results) { _ = Throw.IfNull(chatMessages); @@ -617,6 +632,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul /// /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . + /// is . protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) { _ = Throw.IfNull(context); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs index 0d2d6f8bc9b..f2a60718ea9 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -21,6 +21,7 @@ public static class FunctionInvokingChatClientBuilderExtensions /// An optional to use to create a logger for logging function invocations. /// An optional callback that can be used to configure the instance. /// The supplied . + /// is . public static ChatClientBuilder UseFunctionInvocation( this ChatClientBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs index 6ae8d176e5e..d34716ed886 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs @@ -20,6 +20,7 @@ public static class LoggingChatClientBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The . + /// is . public static ChatClientBuilder UseLogging( this ChatClientBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index 9809872f1d0..3a98afadacf 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -446,7 +446,7 @@ private void LogChatMessages(IEnumerable messages) if (message.Role == ChatRole.Assistant) { Log(new(1, OpenTelemetryConsts.GenAI.Assistant.Message), - JsonSerializer.Serialize(CreateAssistantEvent(message), OtelContext.Default.AssistantEvent)); + JsonSerializer.Serialize(CreateAssistantEvent(message.Contents), OtelContext.Default.AssistantEvent)); } else if (message.Role == ChatRole.Tool) { @@ -468,7 +468,7 @@ private void LogChatMessages(IEnumerable messages) JsonSerializer.Serialize(new() { Role = message.Role != ChatRole.System && message.Role != ChatRole.User && !string.IsNullOrWhiteSpace(message.Role.Value) ? message.Role.Value : null, - Content = GetMessageContent(message), + Content = GetMessageContent(message.Contents), }, OtelContext.Default.SystemOrUserEvent)); } } @@ -486,7 +486,7 @@ private void LogChatResponse(ChatResponse response) { FinishReason = response.FinishReason?.Value ?? "error", Index = 0, - Message = CreateAssistantEvent(response.Message), + Message = CreateAssistantEvent(response.Messages is { Count: 1 } ? response.Messages[0].Contents : response.Messages.SelectMany(m => m.Contents)), }, OtelContext.Default.ChoiceEvent)); } @@ -505,9 +505,9 @@ private void Log(EventId id, [StringSyntax(StringSyntaxAttribute.Json)] string e _logger.Log(EventLogLevel, id, tags, null, (_, __) => eventBodyJson); } - private AssistantEvent CreateAssistantEvent(ChatMessage message) + private AssistantEvent CreateAssistantEvent(IEnumerable contents) { - var toolCalls = message.Contents.OfType().Select(fc => new ToolCall + var toolCalls = contents.OfType().Select(fc => new ToolCall { Id = fc.CallId, Function = new() @@ -521,16 +521,16 @@ private AssistantEvent CreateAssistantEvent(ChatMessage message) return new() { - Content = GetMessageContent(message), + Content = GetMessageContent(contents), ToolCalls = toolCalls.Length > 0 ? toolCalls : null, }; } - private string? GetMessageContent(ChatMessage message) + private string? GetMessageContent(IEnumerable contents) { if (EnableSensitiveData) { - string content = string.Concat(message.Contents.OfType()); + string content = string.Concat(contents.OfType()); if (content.Length > 0) { return content; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs index 51f1804c2df..73867e4b2f7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs @@ -28,6 +28,8 @@ public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions /// of the caller-supplied instance if one was supplied. /// /// The . + /// is . + /// is . public static EmbeddingGeneratorBuilder ConfigureOptions( this EmbeddingGeneratorBuilder builder, Action configure) diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index bd801911257..d6c20ffb2f5 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -30,6 +30,7 @@ public class DistributedCachingEmbeddingGenerator : CachingE /// Initializes a new instance of the class. /// The underlying . /// A instance that will be used as the backing store for the cache. + /// is . public DistributedCachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, IDistributedCache storage) : base(innerGenerator) { @@ -39,6 +40,7 @@ public DistributedCachingEmbeddingGenerator(IEmbeddingGeneratorGets or sets JSON serialization options to use when serializing cache data. + /// is . public JsonSerializerOptions JsonSerializerOptions { get => _jsonSerializerOptions; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs index 7d42407d930..c2bbdbd1ded 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs @@ -25,6 +25,7 @@ public static class DistributedCachingEmbeddingGeneratorBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The provided as . + /// is . public static EmbeddingGeneratorBuilder UseDistributedCache( this EmbeddingGeneratorBuilder builder, IDistributedCache? storage = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs index 1baa64d2a20..e5cd800800d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs @@ -22,6 +22,7 @@ public sealed class EmbeddingGeneratorBuilder /// Initializes a new instance of the class. /// The inner that represents the underlying backend. + /// is . public EmbeddingGeneratorBuilder(IEmbeddingGenerator innerGenerator) { _ = Throw.IfNull(innerGenerator); @@ -66,6 +67,7 @@ public IEmbeddingGenerator Build(IServiceProvider? services /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. /// The generator factory function. /// The updated instance. + /// is . public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) { _ = Throw.IfNull(generatorFactory); @@ -76,6 +78,7 @@ public EmbeddingGeneratorBuilder Use(FuncAdds a factory for an intermediate embedding generator to the embedding generator pipeline. /// The generator factory function. /// The updated instance. + /// is . public EmbeddingGeneratorBuilder Use( Func, IServiceProvider, IEmbeddingGenerator> generatorFactory) { diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs index 73784f56916..84d4815cb23 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; @@ -22,6 +23,7 @@ public static class EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions /// This method is equivalent to using the /// constructor directly, specifying as the inner generator. /// + /// is . public static EmbeddingGeneratorBuilder AsBuilder( this IEmbeddingGenerator innerGenerator) where TEmbedding : Embedding diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs index 2000e71cf03..b84e8ac6e60 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -18,12 +18,19 @@ public static class EmbeddingGeneratorBuilderServiceCollectionExtensions /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddEmbeddingGenerator( this IServiceCollection serviceCollection, IEmbeddingGenerator innerGenerator, ServiceLifetime lifetime = ServiceLifetime.Singleton) where TEmbedding : Embedding - => AddEmbeddingGenerator(serviceCollection, _ => innerGenerator, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerGenerator); + + return AddEmbeddingGenerator(serviceCollection, _ => innerGenerator, lifetime); + } /// Registers a singleton embedding generator in the . /// The type from which embeddings will be generated. @@ -33,6 +40,8 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddEmbeddingGenerator( this IServiceCollection serviceCollection, Func> innerGeneratorFactory, @@ -56,13 +65,20 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGenerator( this IServiceCollection serviceCollection, object? serviceKey, IEmbeddingGenerator innerGenerator, ServiceLifetime lifetime = ServiceLifetime.Singleton) where TEmbedding : Embedding - => AddKeyedEmbeddingGenerator(serviceCollection, serviceKey, _ => innerGenerator, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerGenerator); + + return AddKeyedEmbeddingGenerator(serviceCollection, serviceKey, _ => innerGenerator, lifetime); + } /// Registers a keyed singleton embedding generator in the . /// The type from which embeddings will be generated. @@ -73,6 +89,8 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGenerator( this IServiceCollection serviceCollection, object? serviceKey, @@ -81,7 +99,6 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen where TEmbedding : Embedding { _ = Throw.IfNull(serviceCollection); - _ = Throw.IfNull(serviceKey); _ = Throw.IfNull(innerGeneratorFactory); var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs index 52fb7dd1ca3..eb472fb1e0e 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs @@ -22,6 +22,7 @@ public static class LoggingEmbeddingGeneratorBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The . + /// is . public static EmbeddingGeneratorBuilder UseLogging( this EmbeddingGeneratorBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index d8be8e9f128..4d16ac6ae6b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -39,6 +39,7 @@ public static partial class AIFunctionFactory /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); @@ -61,6 +62,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) { _ = Throw.IfNull(method); @@ -98,6 +100,7 @@ public static AIFunction Create(Delegate method, string? name = null, string? de /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); @@ -126,6 +129,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) { _ = Throw.IfNull(method); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 9671d2bc602..4bc886ae580 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -143,7 +143,7 @@ public async Task GetStreamingResponseAsync_CreatesTextMessageAsync() Assert.Equal(cts.Token, cancellationToken); - return YieldAsync([new ChatResponseUpdate { Text = "world" }]); + return YieldAsync([new ChatResponseUpdate(ChatRole.Assistant, "world")]); }, }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs index b67fb1de4a5..08e46a52a86 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text.Json; using Xunit; +using static System.Net.Mime.MediaTypeNames; namespace Microsoft.Extensions.AI; @@ -18,7 +19,7 @@ public void Constructor_Parameterless_PropsDefaulted() Assert.Null(message.AuthorName); Assert.Empty(message.Contents); Assert.Equal(ChatRole.User, message.Role); - Assert.Null(message.Text); + Assert.Empty(message.Text); Assert.NotNull(message.Contents); Assert.Same(message.Contents, message.Contents); Assert.Empty(message.Contents); @@ -55,9 +56,23 @@ public void Constructor_RoleString_PropsRoundtrip(string? text) } [Fact] - public void Constructor_RoleList_InvalidArgs_Throws() + public void Constructor_NullArgs_Valid() { - Assert.Throws("contents", () => new ChatMessage(ChatRole.User, (IList)null!)); + ChatMessage message; + + message = new(); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + message = new(ChatRole.User, (string?)null); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + message = new(ChatRole.User, (IList?)null); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + Assert.Throws(() => new ChatMessage(ChatRole.User, Array.Empty())); } [Theory] @@ -80,7 +95,7 @@ public void Constructor_RoleList_PropsRoundtrip(int messageCount) if (messageCount == 0) { Assert.Empty(message.Contents); - Assert.Null(message.Text); + Assert.Empty(message.Text); } else { @@ -91,7 +106,7 @@ public void Constructor_RoleList_PropsRoundtrip(int messageCount) Assert.Equal($"text-{i}", tc.Text); } - Assert.Equal("text-0", message.Text); + Assert.Equal(string.Concat(Enumerable.Range(0, messageCount).Select(i => $"text-{i}")), message.Text); Assert.Equal(string.Concat(Enumerable.Range(0, messageCount).Select(i => $"text-{i}")), message.ToString()); } @@ -120,7 +135,7 @@ public void AuthorName_InvalidArg_UsesNull(string? authorName) } [Fact] - public void Text_GetSet_UsesFirstTextContent() + public void Text_ConcatsAllTextContent() { ChatMessage message = new(ChatRole.User, [ @@ -134,57 +149,15 @@ public void Text_GetSet_UsesFirstTextContent() TextContent textContent = Assert.IsType(message.Contents[3]); Assert.Equal("text-1", textContent.Text); - Assert.Equal("text-1", message.Text); + Assert.Equal("text-1text-2", message.Text); Assert.Equal("text-1text-2", message.ToString()); - message.Text = "text-3"; - Assert.Equal("text-3", message.Text); - Assert.Equal("text-3", message.Text); - Assert.Same(textContent, message.Contents[3]); + ((TextContent)message.Contents[3]).Text = "text-3"; + Assert.Equal("text-3", textContent.Text); + Assert.Equal("text-3text-2", message.Text); Assert.Equal("text-3text-2", message.ToString()); } - [Fact] - public void Text_Set_AddsTextMessageToEmptyList() - { - ChatMessage message = new(ChatRole.User, []); - Assert.Empty(message.Contents); - - message.Text = "text-1"; - Assert.Equal("text-1", message.Text); - - Assert.Single(message.Contents); - TextContent textContent = Assert.IsType(message.Contents[0]); - Assert.Equal("text-1", textContent.Text); - } - - [Fact] - public void Text_Set_AddsTextMessageToListWithNoText() - { - ChatMessage message = new(ChatRole.User, - [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), - new FunctionCallContent("callId1", "fc1"), - ]); - Assert.Equal(3, message.Contents.Count); - - message.Text = "text-1"; - Assert.Equal("text-1", message.Text); - Assert.Equal(4, message.Contents.Count); - - message.Text = "text-2"; - Assert.Equal("text-2", message.Text); - Assert.Equal(4, message.Contents.Count); - - message.Contents.RemoveAt(3); - Assert.Equal(3, message.Contents.Count); - - message.Text = "text-3"; - Assert.Equal("text-3", message.Text); - Assert.Equal(4, message.Contents.Count); - } - [Fact] public void Contents_InitializesToList() { @@ -282,12 +255,13 @@ public void ItCanBeSerializeAndDeserialized() ]; // Act - var chatMessageJson = JsonSerializer.Serialize(new ChatMessage(ChatRole.User, contents: items) + var chatMessage = new ChatMessage(ChatRole.User, contents: items) { - Text = "content-1-override", // Override the content of the first text content item that has the "content-1" content AuthorName = "Fred", AdditionalProperties = new() { ["message-metadata-key-1"] = "message-metadata-value-1" }, - }, TestJsonSerializerContext.Default.Options); + }; + ((TextContent)chatMessage.Contents[0]).Text = "content-1-override"; // Override the content of the first text content item that has the "content-1" content + var chatMessageJson = JsonSerializer.Serialize(chatMessage, TestJsonSerializerContext.Default.Options); var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson, TestJsonSerializerContext.Default.Options)!; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index 0d8e4f8bb3b..1507d591a77 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using Xunit; @@ -17,22 +18,6 @@ public void Constructor_InvalidArgs_Throws() Assert.Throws("messages", () => new ChatResponse((List)null!)); } - [Fact] - public void Constructor_Message_Roundtrips() - { - ChatResponse response = new(); - Assert.NotNull(response.Message); - Assert.Same(response.Message, response.Message); - - ChatMessage message = new(); - response = new(message); - Assert.Same(message, response.Message); - - message = new(); - response.Message = message; - Assert.Same(message, response.Message); - } - [Fact] public void Constructor_Messages_Roundtrips() { @@ -40,49 +25,16 @@ public void Constructor_Messages_Roundtrips() Assert.NotNull(response.Messages); Assert.Same(response.Messages, response.Messages); - List messages = new(); + List messages = []; response = new(messages); Assert.Same(messages, response.Messages); - messages = new(); + messages = []; + Assert.NotSame(messages, response.Messages); response.Messages = messages; Assert.Same(messages, response.Messages); } - [Fact] - public void Message_LastMessageOfMessages() - { - ChatResponse response = new(); - - Assert.Empty(response.Messages); - Assert.NotNull(response.Message); - Assert.NotEmpty(response.Messages); - - for (int i = 1; i < 3; i++) - { - Assert.Same(response.Messages[response.Messages.Count - 1], response.Message); - response.Messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); - } - } - - [Fact] - public void Message_SetterSetsLast() - { - ChatResponse response = new(); - - Assert.Empty(response.Messages); - ChatMessage message = new(); - response.Message = message; - Assert.NotEmpty(response.Messages); - Assert.Same(message, response.Messages[0]); - - message = new(); - response.Message = message; - Assert.Single(response.Messages); - Assert.Same(message, response.Messages[0]); - Assert.Same(message, response.Message); - } - [Fact] public void Properties_Roundtrip() { @@ -139,8 +91,8 @@ public void JsonSerialization_Roundtrips() ChatResponse? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponse); Assert.NotNull(result); - Assert.Equal(ChatRole.Assistant, result.Message.Role); - Assert.Equal("the message", result.Message.Text); + Assert.Equal(ChatRole.Assistant, result.Messages.Single().Role); + Assert.Equal("the message", result.Messages.Single().Text); Assert.Equal("id", result.ResponseId); Assert.Equal("modelId", result.ModelId); @@ -156,11 +108,11 @@ public void JsonSerialization_Roundtrips() } [Fact] - public void ToString_OutputsChatMessageToString() + public void ToString_OutputsText() { ChatResponse response = new(new ChatMessage(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); - Assert.Equal(response.Message.ToString(), response.ToString()); + Assert.Equal(response.Text, response.ToString()); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index 32a0ddf3007..454c3c3cad3 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -21,39 +21,24 @@ public void InvalidArgs_Throws() Assert.Throws("updates", () => ((List)null!).ToChatResponse()); } - public static IEnumerable ToChatResponse_SuccessfullyCreatesResponse_MemberData() - { - foreach (bool useAsync in new[] { false, true }) - { - foreach (bool? coalesceContent in new bool?[] { null, false, true }) - { - yield return new object?[] { useAsync, coalesceContent }; - } - } - } - [Theory] - [MemberData(nameof(ToChatResponse_SuccessfullyCreatesResponse_MemberData))] - public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool? coalesceContent) + [InlineData(false)] + [InlineData(true)] + public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync) { ChatResponseUpdate[] updates = [ - new() { Text = "Hello", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, - new() { Text = ", ", AuthorName = "Someone", Role = new ChatRole("human"), AdditionalProperties = new() { ["a"] = "b" } }, - new() { Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), ChatThreadId = "123", AdditionalProperties = new() { ["c"] = "d" } }, + new(ChatRole.Assistant, "Hello") { ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, + new(new("human"), ", ") { AuthorName = "Someone", AdditionalProperties = new() { ["a"] = "b" } }, + new(null, "world!") { CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), ChatThreadId = "123", AdditionalProperties = new() { ["c"] = "d" } }, new() { Contents = [new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 })] }, new() { Contents = [new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 })] }, ]; - ChatResponse response = (coalesceContent is bool, useAsync) switch - { - (false, false) => updates.ToChatResponse(), - (false, true) => await YieldAsync(updates).ToChatResponseAsync(), - - (true, false) => updates.ToChatResponse(coalesceContent.GetValueOrDefault()), - (true, true) => await YieldAsync(updates).ToChatResponseAsync(coalesceContent.GetValueOrDefault()), - }; + ChatResponse response = useAsync ? + updates.ToChatResponse() : + await YieldAsync(updates).ToChatResponseAsync(); Assert.NotNull(response); Assert.NotNull(response.Usage); @@ -66,7 +51,7 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool Assert.Equal("123", response.ChatThreadId); - ChatMessage message = response.Message; + ChatMessage message = response.Messages.Last(); Assert.Equal(new ChatRole("human"), message.Role); Assert.Equal("Someone", message.AuthorName); Assert.Null(message.AdditionalProperties); @@ -76,16 +61,7 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool Assert.Equal("b", response.AdditionalProperties["a"]); Assert.Equal("d", response.AdditionalProperties["c"]); - if (coalesceContent is null or true) - { - Assert.Equal("Hello, world!", response.Message.Text); - } - else - { - Assert.Equal("Hello", response.Message.Contents[0].ToString()); - Assert.Equal(", ", response.Message.Contents[1].ToString()); - Assert.Equal("world!", response.Message.Contents[2].ToString()); - } + Assert.Equal("Hello, world!", response.Text); } public static IEnumerable ToChatResponse_Coalescing_VariousSequenceAndGapLengths_MemberData() @@ -127,7 +103,7 @@ public async Task ToChatResponse_Coalescing_VariousSequenceAndGapLengths(bool us for (int i = 0; i < sequenceLength; i++) { string text = $"{(char)('A' + sequenceNum)}{i}"; - updates.Add(new() { Text = text }); + updates.Add(new(null, text)); sb.Append(text); } @@ -155,7 +131,7 @@ void AddGap() ChatResponse response = useAsync ? await YieldAsync(updates).ToChatResponseAsync() : updates.ToChatResponse(); Assert.NotNull(response); - ChatMessage message = response.Message; + ChatMessage message = response.Messages.Single(); Assert.NotNull(message); Assert.Equal(expected.Count + (gapLength * ((numSequences - 1) + (gapBeginningEnd ? 2 : 0))), message.Contents.Count); @@ -173,8 +149,8 @@ public async Task ToChatResponse_UsageContentExtractedFromContents() { ChatResponseUpdate[] updates = { - new() { Text = "Hello, " }, - new() { Text = "world!" }, + new(null, "Hello, "), + new(null, "world!"), new() { Contents = [new UsageContent(new() { TotalTokenCount = 42 })] }, }; @@ -185,7 +161,7 @@ public async Task ToChatResponse_UsageContentExtractedFromContents() Assert.NotNull(response.Usage); Assert.Equal(42, response.Usage.TotalTokenCount); - Assert.Equal("Hello, world!", Assert.IsType(Assert.Single(response.Message.Contents)).Text); + Assert.Equal("Hello, world!", Assert.IsType(Assert.Single(Assert.Single(response.Messages).Contents)).Text); } private static async IAsyncEnumerable YieldAsync(IEnumerable updates) diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs index 4bb9e5ae0b3..7e5ff6b1e84 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs @@ -16,7 +16,7 @@ public void Constructor_PropsDefaulted() ChatResponseUpdate update = new(); Assert.Null(update.AuthorName); Assert.Null(update.Role); - Assert.Null(update.Text); + Assert.Empty(update.Text); Assert.Empty(update.Contents); Assert.Null(update.RawRepresentation); Assert.Null(update.AdditionalProperties); @@ -51,9 +51,7 @@ public void Properties_Roundtrip() Assert.NotNull(update.Contents); Assert.Empty(update.Contents); - Assert.Null(update.Text); - update.Text = "text"; - Assert.Equal("text", update.Text); + Assert.Empty(update.Text); Assert.Null(update.RawRepresentation); object raw = new(); @@ -79,7 +77,7 @@ public void Properties_Roundtrip() } [Fact] - public void Text_GetSet_UsesFirstTextContent() + public void Text_Get_UsesAllTextContent() { ChatResponseUpdate update = new() { @@ -97,63 +95,15 @@ public void Text_GetSet_UsesFirstTextContent() TextContent textContent = Assert.IsType(update.Contents[3]); Assert.Equal("text-1", textContent.Text); - Assert.Equal("text-1", update.Text); + Assert.Equal("text-1text-2", update.Text); Assert.Equal("text-1text-2", update.ToString()); - update.Text = "text-3"; - Assert.Equal("text-3", update.Text); - Assert.Equal("text-3", update.Text); + ((TextContent)update.Contents[3]).Text = "text-3"; + Assert.Equal("text-3text-2", update.Text); Assert.Same(textContent, update.Contents[3]); Assert.Equal("text-3text-2", update.ToString()); } - [Fact] - public void Text_Set_AddsTextMessageToEmptyList() - { - ChatResponseUpdate update = new() - { - Role = ChatRole.User, - }; - Assert.Empty(update.Contents); - - update.Text = "text-1"; - Assert.Equal("text-1", update.Text); - - Assert.Single(update.Contents); - TextContent textContent = Assert.IsType(update.Contents[0]); - Assert.Equal("text-1", textContent.Text); - } - - [Fact] - public void Text_Set_AddsTextMessageToListWithNoText() - { - ChatResponseUpdate update = new() - { - Contents = - [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), - new FunctionCallContent("callId1", "fc1"), - ] - }; - Assert.Equal(3, update.Contents.Count); - - update.Text = "text-1"; - Assert.Equal("text-1", update.Text); - Assert.Equal(4, update.Contents.Count); - - update.Text = "text-2"; - Assert.Equal("text-2", update.Text); - Assert.Equal(4, update.Contents.Count); - - update.Contents.RemoveAt(3); - Assert.Equal(3, update.Contents.Count); - - update.Text = "text-3"; - Assert.Equal("text-3", update.Text); - Assert.Equal(4, update.Contents.Count); - } - [Fact] public void JsonSerialization_Roundtrips() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs index 5d9170c77e3..bab36d7f91a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -58,8 +58,8 @@ public async Task ChatStreamingAsyncDefaultsToInnerClientAsync() var expectedCancellationToken = CancellationToken.None; ChatResponseUpdate[] expectedResults = [ - new() { Role = ChatRole.User, Text = "Message 1" }, - new() { Role = ChatRole.User, Text = "Message 2" } + new(ChatRole.User, "Message 1"), + new(ChatRole.User, "Message 2") ]; using var inner = new TestChatClient diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index 7d2d0a6c9ab..a89beabc97e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -164,9 +164,9 @@ [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c Assert.NotNull(response); Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.ResponseId); - Assert.Equal("Hello! How can I assist you today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How can I assist you today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -551,9 +551,9 @@ public async Task MultipleMessages_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -685,9 +685,9 @@ public async Task NullAssistantText_ContentEmpty_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("Hello.", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello.", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -801,9 +801,9 @@ public async Task FunctionCallContent_NonStreaming(ChatToolMode mode) }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); Assert.NotNull(response.Usage); @@ -811,8 +811,8 @@ public async Task FunctionCallContent_NonStreaming(ChatToolMode mode) Assert.Equal(16, response.Usage.OutputTokenCount); Assert.Equal(77, response.Usage.TotalTokenCount); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs index f5ffd922816..cbc78ef6642 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Threading.Tasks; using FluentAssertions; using FluentAssertions.Execution; @@ -59,7 +60,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = @@ -94,7 +95,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); var baselineResponseForEquivalenceEvaluator = diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs index 4062a0c4fda..dbfdebc529c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs @@ -71,7 +71,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); @@ -122,7 +122,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs index eac1f5ea228..8b479ea57cf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs @@ -68,7 +68,7 @@ await _reportingConfigurationWithoutReasoning.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); @@ -101,7 +101,7 @@ await _reportingConfigurationWithReasoning.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index c6c12c4e192..55b840eea5f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -21,6 +21,7 @@ #pragma warning disable CA2000 // Dispose objects before losing scope #pragma warning disable CA2214 // Do not call overridable methods in constructors +#pragma warning disable CA2249 // Consider using 'string.Contains' instead of 'string.IndexOf' namespace Microsoft.Extensions.AI; @@ -48,7 +49,7 @@ public virtual async Task GetResponseAsync_SingleRequestMessage() var response = await _chatClient.GetResponseAsync("What's the biggest animal?"); - Assert.Contains("whale", response.Message.Text, StringComparison.OrdinalIgnoreCase); + Assert.Contains("whale", response.Text, StringComparison.OrdinalIgnoreCase); } [ConditionalFact] @@ -65,8 +66,8 @@ public virtual async Task GetResponseAsync_MultipleRequestMessages() new(ChatRole.User, "What continent are they each in?"), ]); - Assert.Contains("America", response.Message.Text); - Assert.Contains("Asia", response.Message.Text); + Assert.Contains("America", response.Text); + Assert.Contains("Asia", response.Text); } [ConditionalFact] @@ -146,7 +147,7 @@ public virtual async Task MultiModal_DescribeImage() ], new() { ModelId = GetModel_MultiModal_DescribeImage() }); - Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text); + Assert.True(response.Text.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Text); } [ConditionalFact] @@ -176,7 +177,7 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Paramet Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] }); - Assert.Contains(secretNumber.ToString(), response.Message.Text); + Assert.Contains(secretNumber.ToString(), response.Text); // If the underlying IChatClient provides usage data, function invocation should aggregate the // usage data across all calls to produce a single Usage value on the final response @@ -201,7 +202,7 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithPar Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] }); - Assert.Contains("3528", response.Message.Text); + Assert.Contains("3528", response.Text); } [ConditionalFact] @@ -253,8 +254,8 @@ public virtual async Task FunctionInvocation_SupportsMultipleParallelRequests() }); Assert.True( - Regex.IsMatch(response.Message.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), - $"Doesn't contain three: {response.Message.Text}"); + Regex.IsMatch(response.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), + $"Doesn't contain three: {response.Text}"); } [ConditionalFact] @@ -310,7 +311,7 @@ public virtual async Task Caching_OutputVariesWithoutCaching() var firstResponse = await _chatClient.GetResponseAsync([message]); var secondResponse = await _chatClient.GetResponseAsync([message]); - Assert.NotEqual(firstResponse.Message.Text, secondResponse.Message.Text); + Assert.NotEqual(firstResponse.Text, secondResponse.Text); } [ConditionalFact] @@ -329,13 +330,13 @@ public virtual async Task Caching_SamePromptResultsInCacheHit_NonStreaming() for (int i = 0; i < 3; i++) { var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(firstResponse.Message.Text, secondResponse.Message.Text); + Assert.Equal(firstResponse.Messages.Select(m => m.Text), secondResponse.Messages.Select(m => m.Text)); } // ... but if the conversation differs, we should see different output - message.Text += "!"; + ((TextContent)message.Contents[0]).Text += "!"; var thirdResponse = await chatClient.GetResponseAsync([message]); - Assert.NotEqual(firstResponse.Message.Text, thirdResponse.Message.Text); + Assert.NotEqual(firstResponse.Messages, thirdResponse.Messages); } [ConditionalFact] @@ -367,7 +368,7 @@ public virtual async Task Caching_SamePromptResultsInCacheHit_Streaming() } // ... but if the conversation differs, we should see different output - message.Text += "!"; + ((TextContent)message.Contents[0]).Text += "!"; StringBuilder third = new(); await foreach (var update in chatClient.GetStreamingResponseAsync([message])) { @@ -401,14 +402,14 @@ public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls() var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("101", response.Message.Text); + Assert.Contains("101", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(2, llmCallCount!.CallCount); // Second call doesn't execute the function or call the LLM, but rather just returns the cached result var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(response.Text, secondResponse.Text); Assert.Equal(1, functionCallCount); Assert.Equal(2, llmCallCount!.CallCount); } @@ -440,7 +441,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("58", response.Message.Text); + Assert.Contains("58", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(1, functionCallCount); @@ -448,7 +449,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange // Second time, the calls to the LLM don't happen, but the function is called again var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(response.Text, secondResponse.Text); Assert.Equal(2, functionCallCount); Assert.Equal(2, llmCallCount!.CallCount); } @@ -480,7 +481,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("81", response.Message.Text); + Assert.Contains("81", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(1, functionCallCount); @@ -489,7 +490,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA // Second time, the first call to the LLM don't happen, but the function is called again, // and since its output now differs, we no longer hit the cache so the second LLM call does happen var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Contains("82", secondResponse.Message.Text); + Assert.Contains("82", secondResponse.Text); Assert.Equal(2, functionCallCount); Assert.Equal(3, llmCallCount!.CallCount); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs index 96fab4d244c..7c2c8343f95 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -18,8 +18,8 @@ namespace Microsoft.Extensions.AI; -// This isn't a feature we're planning to ship, but demonstrates how custom clients can -// layer in non-trivial functionality. In this case we're able to upgrade non-function-calling models +// Demonstrates how custom clients can layer in non-trivial functionality. +// In this case we're able to upgrade non-function-calling models // to behaving as if they do support function calling. // // In practice: @@ -82,10 +82,10 @@ public override async Task GetResponseAsync(IList cha var result = await base.GetResponseAsync(chatMessages, options, cancellationToken); - if (result.Message.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos + if (result.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos && startPos >= 0) { - var message = result.Message; + var message = result.Messages.Last(); var contentItem = message.Contents.SingleOrDefault(); content = content.Substring(startPos); @@ -153,7 +153,7 @@ private static void ParseArguments(IDictionary arguments) private static void AddOrUpdateToolPrompt(IList chatMessages, IList tools) { - var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text?.StartsWith(MessageIntro, StringComparison.Ordinal) is true); + var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text.StartsWith(MessageIntro, StringComparison.Ordinal) is true); if (existingToolPrompt is null) { existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null); @@ -161,7 +161,7 @@ private static void AddOrUpdateToolPrompt(IList chatMessages, IList } var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType().Select(ToToolDescriptor), _jsonOptions); - existingToolPrompt.Text = $$""" + existingToolPrompt.Contents.OfType().First().Text = $$""" {{MessageIntro}} For each function call, return a JSON object with the function name and arguments within XML tags diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs index 43fad43438c..a31661166a8 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -47,7 +47,7 @@ public async Task PromptBasedFunctionCalling_NoArgs() Seed = 0, }); - Assert.Contains(secretNumber.ToString(), response.Message.Text); + Assert.Contains(secretNumber.ToString(), response.Text); } [ConditionalFact] @@ -81,7 +81,7 @@ public async Task PromptBasedFunctionCalling_WithArgs() Seed = 0, }); - Assert.Contains("999", response.Message.Text); + Assert.Contains("999", response.Text); Assert.False(didCallIrrelevantTool); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index dc173307921..8f7499aa272 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -118,9 +118,9 @@ public async Task BasicRequestResponse_NonStreaming() }); Assert.NotNull(response); - Assert.Equal("Hello! How are you today? Is there something", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How are you today? Is there something", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("llama3.1", response.ModelId); Assert.Equal(DateTimeOffset.Parse("2024-10-01T15:46:10.5248793Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Length, response.FinishReason); @@ -281,9 +281,9 @@ public async Task MultipleMessages_NonStreaming() but I'm functioning properly and ready to help with any questions or tasks you may have! How about we chat about something in particular or just shoot the breeze ? Your choice! """), - VerbatimHttpHandler.RemoveWhiteSpace(response.Message.Text)); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + VerbatimHttpHandler.RemoveWhiteSpace(response.Text)); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("llama3.1", response.ModelId); Assert.Equal(DateTimeOffset.Parse("2024-10-01T17:18:46.308987Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -370,9 +370,9 @@ public async Task FunctionCallContent_NonStreaming() }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("llama3.1", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.Parse("2024-10-01T18:48:30.2669578Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); Assert.NotNull(response.Usage); @@ -380,8 +380,8 @@ public async Task FunctionCallContent_NonStreaming() Assert.Equal(19, response.Usage.OutputTokenCount); Assert.Equal(189, response.Usage.TotalTokenCount); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); } @@ -467,9 +467,9 @@ public async Task FunctionResultContent_NonStreaming() }); Assert.NotNull(response); - Assert.Equal("Alice is 42 years old.", response.Message.Text); + Assert.Equal("Alice is 42 years old.", response.Text); Assert.Equal("llama3.1", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.Parse("2024-10-01T20:57:20.157266Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); Assert.NotNull(response.Usage); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs index 6495d5f957e..8cd53c55766 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -195,9 +195,9 @@ public async Task BasicRequestResponse_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.ResponseId); - Assert.Equal("Hello! How can I assist you today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How can I assist you today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -455,9 +455,9 @@ public async Task MultipleMessages_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -555,9 +555,9 @@ public async Task MultiPartSystemMessage_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("Hi! It's so good to hear from you!", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hi! It's so good to hear from you!", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -656,9 +656,9 @@ public async Task EmptyAssistantMessage_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -771,9 +771,9 @@ public async Task FunctionCallContent_NonStreaming() }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); Assert.NotNull(response.Usage); @@ -791,8 +791,8 @@ public async Task FunctionCallContent_NonStreaming() { "OutputTokenDetails.RejectedPredictionTokenCount", 0 }, }, response.Usage.AdditionalCounts); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); @@ -1032,9 +1032,9 @@ public async Task AssistantMessageWithBothToolsAndContent_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 4a3398e6a6d..611d3b9f45b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -127,13 +127,13 @@ public async Task AllowsConcurrentCallsAsync() Assert.False(result1.IsCompleted); Assert.False(result2.IsCompleted); completionTcs.SetResult(true); - Assert.Equal("Hello", (await result1).Message.Text); - Assert.Equal("Hello", (await result2).Message.Text); + Assert.Equal("Hello", (await result1).Text); + Assert.Equal("Hello", (await result2).Text); // Act 2: Subsequent calls after completion are resolved from the cache var result3 = outer.GetResponseAsync("some input"); Assert.Equal(2, innerCallCount); - Assert.Equal("Hello", (await result3).Message.Text); + Assert.Equal("Hello", (await result3).Text); } [Fact] @@ -204,7 +204,7 @@ public async Task DoesNotCacheCanceledResultsAsync() // Act/Assert: Second call can succeed var result2 = await outer.GetResponseAsync([input]); Assert.Equal(2, innerCallCount); - Assert.Equal("A good result", result2.Message.Text); + Assert.Equal("A good result", result2.Text); } [Fact] @@ -280,12 +280,12 @@ public async Task StreamingCoalescesConsecutiveTextChunksAsync(bool? coalesce) // Arrange List expectedResponse = [ - new() { Role = ChatRole.Assistant, Text = "This" }, - new() { Role = ChatRole.Assistant, Text = " becomes one chunk" }, + new(ChatRole.Assistant, "This"), + new(ChatRole.Assistant, " becomes one chunk"), new() { Role = ChatRole.Assistant, Contents = [new FunctionCallContent("callId1", "separator")] }, - new() { Role = ChatRole.Assistant, Text = "... and this" }, - new() { Role = ChatRole.Assistant, Text = " becomes another" }, - new() { Role = ChatRole.Assistant, Text = " one." }, + new(ChatRole.Assistant, "... and this"), + new(ChatRole.Assistant, " becomes another"), + new(ChatRole.Assistant, " one."), ]; using var testClient = new TestChatClient @@ -401,7 +401,7 @@ public async Task StreamingAllowsConcurrentCallsAsync() var completionTcs = new TaskCompletionSource(); List expectedResponse = [ - new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + new(ChatRole.Assistant, "Chunk 1"), ]; using var testClient = new TestChatClient { @@ -449,7 +449,7 @@ public async Task StreamingDoesNotCacheExceptionResultsAsync() innerCallCount++; return ToAsyncEnumerableAsync(Task.CompletedTask, [ - () => new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + () => new(ChatRole.Assistant, "Chunk 1"), () => throw new InvalidTimeZoneException("some failure"), ]); } @@ -488,7 +488,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync() innerCallCount++; return ToAsyncEnumerableAsync( innerCallCount == 1 ? completionTcs.Task : Task.CompletedTask, - [() => new() { Role = ChatRole.Assistant, Text = "A good result" }]); + [() => new(ChatRole.Assistant, "A good result")]); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -544,8 +544,8 @@ public async Task CacheKeyVariesByChatOptionsAsync() // Assert: Same result Assert.Equal(1, innerCallCount); - Assert.Equal("value 1", result1.Message.Text); - Assert.Equal("value 1", result2.Message.Text); + Assert.Equal("value 1", result1.Text); + Assert.Equal("value 1", result2.Text); // Act: Call with two different ChatOptions that have different values var result3 = await outer.GetResponseAsync([], new ChatOptions @@ -559,8 +559,8 @@ public async Task CacheKeyVariesByChatOptionsAsync() // Assert: Different results Assert.Equal(2, innerCallCount); - Assert.Equal("value 1", result3.Message.Text); - Assert.Equal("value 2", result4.Message.Text); + Assert.Equal("value 1", result3.Text); + Assert.Equal("value 2", result4.Text); } [Fact] @@ -595,8 +595,8 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() // Assert: Different results Assert.Equal(2, innerCallCount); - Assert.Equal("value 1", result1.Message.Text); - Assert.Equal("value 2", result2.Message.Text); + Assert.Equal("value 1", result1.Text); + Assert.Equal("value 2", result2.Text); } [Fact] @@ -647,8 +647,8 @@ public async Task CanCacheCustomContentTypesAsync() // Assert Assert.Equal(1, innerCallCount); AssertResponsesEqual(expectedResponse, result2); - Assert.NotSame(result2.Message.Contents[0], expectedResponse.Message.Contents[0]); - Assert.NotSame(result2.Message.Contents[1], expectedResponse.Message.Contents[1]); + Assert.NotSame(result2.Messages.Last().Contents[0], expectedResponse.Messages.Last().Contents[0]); + Assert.NotSame(result2.Messages.Last().Contents[1], expectedResponse.Messages.Last().Contents[1]); } [Fact] @@ -724,15 +724,17 @@ private static void AssertResponsesEqual(ChatResponse expected, ChatResponse act JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); - Assert.IsType(expected.Message.GetType(), actual.Message); - Assert.Equal(expected.Message.Role, actual.Message.Role); - Assert.Equal(expected.Message.Text, actual.Message.Text); - Assert.Equal(expected.Message.Contents.Count, actual.Message.Contents.Count); + ChatMessage expectedMessage = expected.Messages.Last(); + ChatMessage actualMessage = actual.Messages.Last(); + Assert.IsType(expectedMessage.GetType(), actualMessage); + Assert.Equal(expectedMessage.Role, actualMessage.Role); + Assert.Equal(expectedMessage.Text, actualMessage.Text); + Assert.Equal(expectedMessage.Contents.Count, actualMessage.Contents.Count); - for (var itemIndex = 0; itemIndex < expected.Message.Contents.Count; itemIndex++) + for (var itemIndex = 0; itemIndex < expectedMessage.Contents.Count; itemIndex++) { - var expectedItem = expected.Message.Contents[itemIndex]; - var actualItem = actual.Message.Contents[itemIndex]; + var expectedItem = expectedMessage.Contents[itemIndex]; + var actualItem = actualMessage.Contents[itemIndex]; Assert.IsType(expectedItem.GetType(), actualItem); if (expectedItem is FunctionCallContent expectedFcc) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 371a1666002..caaf8dae575 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -414,7 +414,8 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() updates = [new() { Contents = [new TextContent("OK bye")] }]; } - chatContents.Add(updates.ToChatMessage()); + chatContents.AddRangeFromUpdates(updates); + return YieldAsync(updates); } }; @@ -478,7 +479,7 @@ public async Task AllResponseMessagesReturned() ChatResponse response = await client.GetResponseAsync(messages, options); Assert.Equal(5, response.Messages.Count); - Assert.Equal("The answer is 42.", response.Message.Text); + Assert.Equal("The answer is 42.", response.Text); Assert.IsType(Assert.Single(response.Messages[0].Contents)); Assert.IsType(Assert.Single(response.Messages[1].Contents)); Assert.IsType(Assert.Single(response.Messages[2].Contents)); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs index 9bd777ef543..51638d1a252 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -105,8 +105,8 @@ public async Task GetResponseStreamingStreamAsync_LogsUpdateReceived(LogLevel le static async IAsyncEnumerable GetUpdatesAsync() { await Task.Yield(); - yield return new ChatResponseUpdate { Role = ChatRole.Assistant, Text = "blue " }; - yield return new ChatResponseUpdate { Role = ChatRole.Assistant, Text = "whale" }; + yield return new(ChatRole.Assistant, "blue "); + yield return new(ChatRole.Assistant, "whale"); } using IChatClient client = innerClient diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs index bccba4cc65d..c91d3c1ccdb 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs @@ -69,10 +69,8 @@ async static IAsyncEnumerable CallbackAsync( foreach (string text in new[] { "The ", "blue ", "whale,", " ", "", "I", " think." }) { await Task.Yield(); - yield return new ChatResponseUpdate + yield return new ChatResponseUpdate(ChatRole.Assistant, text) { - Role = ChatRole.Assistant, - Text = text, ResponseId = "id123", }; } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs index 1d8ff32693d..63b68569d07 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -47,7 +48,7 @@ public async Task Shared_ContextPropagated() Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "world" }); + return YieldUpdates(new ChatResponseUpdate(null, "world")); }, }; @@ -64,11 +65,11 @@ public async Task Shared_ContextPropagated() Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello", response.Message.Text); + Assert.Equal("hello", response.Text); Assert.Equal(0, asyncLocal.Value); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("world", response.Message.Text); + Assert.Equal("world", response.Text); } [Fact] @@ -99,7 +100,7 @@ public async Task GetResponseFunc_ContextPropagated() Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Message.Text += " world"; + cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world"; return cc; }, null) .Build(); @@ -107,10 +108,10 @@ public async Task GetResponseFunc_ContextPropagated() Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); } [Fact] @@ -129,7 +130,7 @@ public async Task GetStreamingResponseFunc_ContextPropagated() Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "hello" }); + return YieldUpdates(new ChatResponseUpdate(null, "hello")); }, }; @@ -150,7 +151,7 @@ static async IAsyncEnumerable Impl( yield return update; } - yield return new() { Text = " world" }; + yield return new(null, " world"); } }) .Build(); @@ -158,10 +159,10 @@ static async IAsyncEnumerable Impl( Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); } [Fact] @@ -189,7 +190,7 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "streaming hello" }); + return YieldUpdates(new ChatResponseUpdate(null, "streaming hello")); }, }; @@ -202,7 +203,7 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Message.Text += " world (non-streaming)"; + cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world (non-streaming)"; return cc; }, (chatMessages, options, innerClient, cancellationToken) => @@ -221,7 +222,7 @@ static async IAsyncEnumerable Impl( yield return update; } - yield return new() { Text = " world (streaming)" }; + yield return new(null, " world (streaming)"); } }) .Build(); @@ -229,10 +230,10 @@ static async IAsyncEnumerable Impl( Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("non-streaming hello world (non-streaming)", response.Message.Text); + Assert.Equal("non-streaming hello world (non-streaming)", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("streaming hello world (streaming)", response.Message.Text); + Assert.Equal("streaming hello world (streaming)", response.Text); } private static async IAsyncEnumerable YieldUpdates(params ChatResponseUpdate[] updates) diff --git a/test/Shared/Throw/ThrowTest.cs b/test/Shared/Throw/ThrowTest.cs index 691217d86ce..1986e854230 100644 --- a/test/Shared/Throw/ThrowTest.cs +++ b/test/Shared/Throw/ThrowTest.cs @@ -382,6 +382,18 @@ public void Shorter_Version_Of_NullOrEmpty_Get_Correct_Argument_Name() Assert.Contains(nameof(listButActuallyNull), exceptionImplicitArgumentName.Message); } + [Fact] + public void Collection_IfReadOnly() + { + _ = Throw.IfReadOnly(new List()); + + IList list = new int[4]; + Assert.Throws("list", () => Throw.IfReadOnly(list); + + list = new ReadOnlyCollection(); + Assert.Throws("list", () => Throw.IfReadOnly(list); + } + #endregion #region For Enums From 5e7320c9bc5486dceb94d0367144cd618b14457d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 4 Mar 2025 13:46:59 -0500 Subject: [PATCH 5/9] Address feedback --- .../ChatCompletion/ChatResponse.cs | 2 +- .../ChatCompletion/IChatClient.cs | 12 +- .../README.md | 2 +- .../OpenAIAssistantClient.cs | 104 ++++++++++-------- .../ChatCompletion/CachingChatClient.cs | 14 +-- .../ChatClientStructuredOutputExtensions.cs | 8 -- .../FunctionInvocationContext.cs | 3 + .../FunctionInvokingChatClient.cs | 101 ++++++++++++++--- test/Shared/Throw/ThrowTest.cs | 7 +- 9 files changed, 164 insertions(+), 89 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index c41c3823126..6a71944b8f1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -85,7 +85,7 @@ public string Text { 0 => string.Empty, 1 => messages[0].Text, - _ => messages.SelectMany(m => m.Contents).ConcatText(), + _ => string.Join(Environment.NewLine, messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s))), }; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index ce99f126b3d..334245ec73e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -32,9 +32,9 @@ public interface IChatClient : IDisposable /// The response messages generated by the client. /// is . /// - /// The response message generated by is both returned from the method as well as automatically - /// added into . Any intermediate messages generated implicitly as part of the interaction are - /// also added to the chat history. For example, if as part of satisfying this request, the method + /// The response messages generated by are returned from the method as well as automatically + /// added into . This includes any messages generated implicitly as part of the interaction. + /// For example, if as part of satisfying this request, the method /// itself issues multiple requests to one or more underlying instances, all of those messages will also /// be included in . /// @@ -50,9 +50,9 @@ Task GetResponseAsync( /// The response messages generated by the client. /// is . /// - /// The response updates generated by are both stream from the method as well as automatically - /// added into . Any intermediate messages generated implicitly as part of the interaction are - /// also added to the chat history. For example, if as part of satisfying this request, the method + /// The response updates generated by are streamed from the method as well as automatically + /// added into . This includes any messages generated implicitly as part of the interaction. + /// For example, if as part of satisfying this request, the method /// itself issues multiple requests to one or more underlying instances, all of those messages will also /// be included in . /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md index 401002c82fd..1bac95467f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md @@ -10,7 +10,7 @@ From the command-line: dotnet add package Microsoft.Extensions.AI.Abstractions ``` -Or directly in the C# project file: +or directly in the C# project file: ```xml diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index e144938cc70..6c251a46d26 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -112,59 +112,73 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } // Process each update. - await foreach (var update in updates.ConfigureAwait(false)) + List responseUpdates = []; + try { - switch (update) + string? responseId = null; + await foreach (var update in updates.ConfigureAwait(false)) { - case MessageContentUpdate mcu: - yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) - { - ChatThreadId = threadId, - RawRepresentation = mcu, - }; - break; + switch (update) + { + case MessageContentUpdate mcu: + ChatResponseUpdate responseUpdate = new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) + { + ChatThreadId = threadId, + RawRepresentation = mcu, + ResponseId = responseId, + }; + responseUpdates.Add(responseUpdate); + yield return responseUpdate; + break; - case ThreadUpdate tu when options is not null: - threadId ??= tu.Value.Id; - break; + case ThreadUpdate tu when options is not null: + threadId ??= tu.Value.Id; + break; - case RunUpdate ru: - threadId ??= ru.Value.ThreadId; + case RunUpdate ru: + threadId ??= ru.Value.ThreadId; + responseId ??= ru.Value.Id; - ChatResponseUpdate ruUpdate = new() - { - AuthorName = ru.Value.AssistantId, - ChatThreadId = threadId, - CreatedAt = ru.Value.CreatedAt, - ModelId = ru.Value.Model, - RawRepresentation = ru, - ResponseId = ru.Value.Id, - Role = ChatRole.Assistant, - }; - - if (ru.Value.Usage is { } usage) - { - ruUpdate.Contents.Add(new UsageContent(new() + ChatResponseUpdate ruUpdate = new() { - InputTokenCount = usage.InputTokenCount, - OutputTokenCount = usage.OutputTokenCount, - TotalTokenCount = usage.TotalTokenCount, - })); - } - - if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) - { - ruUpdate.Contents.Add( - new FunctionCallContent( - JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!), - functionName, - JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!)); - } - - yield return ruUpdate; - break; + AuthorName = ru.Value.AssistantId, + ChatThreadId = threadId, + CreatedAt = ru.Value.CreatedAt, + ModelId = ru.Value.Model, + RawRepresentation = ru, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + + if (ru.Value.Usage is { } usage) + { + ruUpdate.Contents.Add(new UsageContent(new() + { + InputTokenCount = usage.InputTokenCount, + OutputTokenCount = usage.OutputTokenCount, + TotalTokenCount = usage.TotalTokenCount, + })); + } + + if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + { + ruUpdate.Contents.Add( + new FunctionCallContent( + JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!), + functionName, + JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!)); + } + + responseUpdates.Add(ruUpdate); + yield return ruUpdate; + break; + } } } + finally + { + chatMessages.AddRangeFromUpdates(responseUpdates); + } } /// diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 97232d6762d..7563ecd891c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -56,12 +56,9 @@ public override async Task GetResponseAsync(IList cha if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } result) { - if (options?.ChatThreadId is null) + foreach (ChatMessage message in result.Messages) { - foreach (ChatMessage message in result.Messages) - { - chatMessages.Add(message); - } + chatMessages.Add(message); } } else @@ -94,12 +91,9 @@ public override async IAsyncEnumerable GetStreamingResponseA yield return chunk; } - if (chatResponse.ChatThreadId is null) + foreach (ChatMessage message in chatResponse.Messages) { - foreach (ChatMessage message in chatResponse.Messages) - { - chatMessages.Add(message); - } + chatMessages.Add(message); } } else diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index ceb8087289a..5932f92de24 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -36,10 +36,6 @@ public static class ChatClientStructuredOutputExtensions /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages will not have been added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. - /// /// The type of structured output to request. public static Task> GetResponseAsync( this IChatClient chatClient, @@ -145,10 +141,6 @@ public static Task> GetResponseAsync( /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages will not have been added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. - /// /// The type of structured output to request. /// is . /// is . diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs index 690af275761..97e5f75caa5 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs @@ -45,6 +45,9 @@ public IList ChatMessages set => _chatMessages = Throw.IfNull(value); } + /// Gets or sets the chat options associated with the operation that initiated this function call request. + public ChatOptions? Options { get; set; } + /// Gets or sets the AI function to be invoked. public AIFunction Function { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 7c55a157718..509a9c24c61 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -216,9 +216,19 @@ public override async Task GetResponseAsync(IList cha // fast path out by just returning the original response. if (iteration == 0 && !requiresFunctionInvocation) { + Debug.Assert(originalChatMessages == chatMessages, + "Expected the history to be the original, such that there's no additional work to do to keep it up to date."); return response; } + // If chatMessages is different from originalChatMessages, we previously created a different history + // in order to avoid sending state back to an inner client that was already tracking it. But we still + // need that original history to contain all the state. So copy it over if necessary. + if (chatMessages != originalChatMessages) + { + AddRange(originalChatMessages, response.Messages); + } + // Track aggregatable details from the response. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) @@ -249,7 +259,6 @@ public override async Task GetResponseAsync(IList cha } // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. - // In that case, we also avoid touching the user's history, so that we don't need to clear it. if (response.ChatThreadId is not null) { if (chatMessages == originalChatMessages) @@ -261,10 +270,24 @@ public override async Task GetResponseAsync(IList cha chatMessages.Clear(); } } + else if (chatMessages != originalChatMessages) + { + // This should be a very rare case. In a previous iteration, we got back a non-null + // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId, + // and chatMessages is no longer the full history. Thankfully, we've been keeping + // originalChatMessages up to date; we can just switch back to use it. + chatMessages = originalChatMessages; + } // Add the responses from the function calls into the history. var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + + if (chatMessages != originalChatMessages) + { + AddRange(originalChatMessages, modeAndMessages.MessagesAdded); + } + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) { // Terminate @@ -311,6 +334,19 @@ public override async IAsyncEnumerable GetStreamingResponseA Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } + // Make sure that any of the response messages that were added to the chat history also get + // added to the original history if it's different. + if (chatMessages != originalChatMessages) + { + // If chatThreadId was null previously, then we would have added any function result content into + // the original chat messages, passed those chat messages to GetStreamingResponseAsync, and it would + // have added all the new response messages into the original chat messages. But chatThreadId was + // non-null, hence we forked chatMessages. chatMessages then included only the function result content + // and should now include that function result content plus the response messages. None of that is + // in the original, so we can just add everything from chatMessages into the original. + AddRange(originalChatMessages, chatMessages); + } + // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || options?.Tools is not { Count: > 0 } || @@ -332,14 +368,17 @@ public override async IAsyncEnumerable GetStreamingResponseA chatMessages.Clear(); } } + else if (chatMessages != originalChatMessages) + { + // This should be a very rare case. In a previous iteration, we got back a non-null + // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId, + // and chatMessages is no longer the full history. Thankfully, we've been keeping + // originalChatMessages up to date; we can just switch back to use it. + chatMessages = originalChatMessages; + } // Process all of the functions, adding their results into the history. var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId)) - { - // Terminate - yield break; - } // Stream any generated function results. These are already part of the history, // but we stream them out for informational purposes. @@ -361,6 +400,12 @@ public override async IAsyncEnumerable GetStreamingResponseA yield return toolResultUpdate; Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } + + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId)) + { + // Terminate + yield break; + } } } @@ -407,10 +452,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti // as otherwise we'll be in an infinite loop. options = options.Clone(); options.ToolMode = null; - if (chatThreadId is not null) - { - options.ChatThreadId = chatThreadId; - } + options.ChatThreadId = chatThreadId; break; @@ -419,10 +461,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti options = options.Clone(); options.Tools = null; options.ToolMode = null; - if (chatThreadId is not null) - { - options.ChatThreadId = chatThreadId; - } + options.ChatThreadId = chatThreadId; break; @@ -433,7 +472,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti default: // As with the other modes, ensure we've propagated the chat thread ID to the options. // We only need to clone the options if we're actually mutating it. - if (chatThreadId is not null && options.ChatThreadId != chatThreadId) + if (options.ChatThreadId != chatThreadId) { options = options.Clone(); options.ChatThreadId = chatThreadId; @@ -468,6 +507,8 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti FunctionInvocationResult result = await ProcessFunctionCallAsync( chatMessages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); IList added = AddResponseMessages(chatMessages, [result]); + + ThrowIfNoFunctionResultsAdded(added); return (result.ContinueMode, added); } else @@ -505,10 +546,23 @@ select Task.Run(() => ProcessFunctionCallAsync( } } + ThrowIfNoFunctionResultsAdded(added); return (continueMode, added); } } + /// + /// Throws an exception if is empty due to an override of + /// not having added any messages. + /// + private void ThrowIfNoFunctionResultsAdded(IList chatMessages) + { + if (chatMessages.Count == 0) + { + Throw.InvalidOperationException($"{GetType().Name}.{nameof(AddResponseMessages)} did not add any function result messages."); + } + } + /// Processes the function call described in []. /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. @@ -533,6 +587,7 @@ private async Task ProcessFunctionCallAsync( FunctionInvocationContext context = new() { ChatMessages = chatMessages, + Options = options, CallContent = callContent, Function = function, Iteration = iteration, @@ -698,6 +753,22 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul return result; } + /// Adds all messages from into . + private static void AddRange(IList destination, IEnumerable source) + { + if (destination is List list) + { + list.AddRange(source); + } + else + { + foreach (var message in source) + { + destination.Add(message); + } + } + } + private static TimeSpan GetElapsedTime(long startingTimestamp) => #if NET Stopwatch.GetElapsedTime(startingTimestamp); diff --git a/test/Shared/Throw/ThrowTest.cs b/test/Shared/Throw/ThrowTest.cs index 1986e854230..057f9098f5c 100644 --- a/test/Shared/Throw/ThrowTest.cs +++ b/test/Shared/Throw/ThrowTest.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace Microsoft.Shared.Diagnostics.Test; @@ -388,10 +389,10 @@ public void Collection_IfReadOnly() _ = Throw.IfReadOnly(new List()); IList list = new int[4]; - Assert.Throws("list", () => Throw.IfReadOnly(list); + Assert.Throws("list", () => Throw.IfReadOnly(list)); - list = new ReadOnlyCollection(); - Assert.Throws("list", () => Throw.IfReadOnly(list); + list = new ReadOnlyCollection(new List()); + Assert.Throws("list", () => Throw.IfReadOnly(list)); } #endregion From a17aecc905ccb660a0a6a0b3712c86bf2e5052f3 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 5 Mar 2025 14:47:20 -0500 Subject: [PATCH 6/9] Avoid mutating chat history, and other cleanup --- .../ChatCompletion/ChatMessage.cs | 20 +- .../ChatCompletion/ChatResponse.cs | 30 +- ...xtensions.cs => ChatResponseExtensions.cs} | 83 +++--- .../ChatCompletion/ChatResponseUpdate.cs | 24 +- .../ChatCompletion/DelegatingChatClient.cs | 14 +- .../ChatCompletion/IChatClient.cs | 35 +-- .../README.md | 57 ++-- .../AzureAIInferenceChatClient.cs | 183 ++++++------ .../ChatConversationEvaluator.cs | 29 +- .../CoherenceEvaluator.cs | 9 +- .../EquivalenceEvaluator.cs | 9 +- .../FluencyEvaluator.cs | 9 +- .../GroundednessEvaluator.cs | 11 +- .../RelevanceTruthAndCompletenessEvaluator.cs | 11 +- .../SingleNumericMetricEvaluator.cs | 4 +- .../Storage/AzureStorageResultStore.cs | 2 +- .../CSharp/ScenarioRun.cs | 2 +- .../CSharp/ScenarioRunExtensions.cs | 67 ++++- .../CSharp/ScenarioRunResult.cs | 10 +- .../CSharp/ScenarioRunResultExtensions.cs | 2 +- .../CSharp/Storage/DiskBasedResultStore.cs | 2 +- .../CompositeEvaluator.cs | 6 +- .../EvaluationMetricExtensions.cs | 2 +- .../EvaluationResult.cs | 2 +- .../EvaluationResultExtensions.cs | 8 +- .../EvaluatorExtensions.cs | 98 +++++- .../IEvaluator.cs | 2 +- .../TokenizerExtensions.cs | 2 +- .../OllamaChatClient.cs | 98 +++--- .../OpenAIAssistantClient.cs | 126 ++++---- .../OpenAIChatClient.cs | 22 +- ...nAIModelMappers.StreamingChatCompletion.cs | 213 +++++++------ .../AnonymousDelegatingChatClient.cs | 41 +-- .../ChatCompletion/CachingChatClient.cs | 48 +-- .../ChatCompletion/ChatClientBuilder.cs | 6 +- .../ChatClientStructuredOutputExtensions.cs | 36 +-- .../ConfigureOptionsChatClient.cs | 11 +- .../FunctionInvocationContext.cs | 8 +- .../FunctionInvokingChatClient.cs | 280 +++++++++--------- .../ChatCompletion/LoggingChatClient.cs | 16 +- .../ChatCompletion/OpenTelemetryChatClient.cs | 17 +- src/Shared/Throw/Throw.cs | 18 -- .../ChatClientExtensionsTests.cs | 8 +- .../ChatCompletion/ChatMessageTests.cs | 6 +- .../ChatCompletion/ChatResponseTests.cs | 17 +- .../TestChatClient.cs | 12 +- .../AzureAIInferenceChatClientTests.cs | 8 +- .../TestEvaluator.cs | 2 +- .../ResultStoreTester.cs | 2 +- .../ScenarioRunResultTests.cs | 23 +- .../CallCountingChatClient.cs | 8 +- .../PromptBasedFunctionCallingChatClient.cs | 28 +- .../ReducingChatClientTests.cs | 84 ++---- .../OllamaChatClientIntegrationTests.cs | 4 +- ...atClientStructuredOutputExtensionsTests.cs | 3 +- .../FunctionInvocationContextTests.cs | 14 +- .../FunctionInvokingChatClientTests.cs | 53 ++-- .../OpenTelemetryChatClientTests.cs | 8 +- .../UseDelegateChatClientTests.cs | 62 ++-- test/Shared/Throw/ThrowTest.cs | 13 - 60 files changed, 1019 insertions(+), 1009 deletions(-) rename src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/{ChatResponseUpdateExtensions.cs => ChatResponseExtensions.cs} (85%) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index 194c6d68b1a..eae74f68e62 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -6,7 +6,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -35,16 +34,10 @@ public ChatMessage(ChatRole role, string? contents) /// Initializes a new instance of the class. /// The role of the author of the message. /// The contents for this message. - /// must not be read-only. public ChatMessage(ChatRole role, IList? contents) { - if (contents is not null) - { - _ = Throw.IfReadOnly(contents); - _contents = contents; - } - Role = role; + _contents = contents; } /// Clones the to a new instance. @@ -81,20 +74,11 @@ public string? AuthorName public string Text => Contents.ConcatText(); /// Gets or sets the chat message content items. - /// The must not be read-only. [AllowNull] public IList Contents { get => _contents ??= []; - set - { - if (value is not null) - { - _ = Throw.IfReadOnly(value); - } - - _contents = value; - } + set => _contents = value; } /// Gets or sets the raw representation of the chat message from an underlying implementation. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index 6a71944b8f1..4b2bf5b95aa 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json.Serialization; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -29,39 +28,26 @@ public ChatResponse() /// Initializes a new instance of the class. /// The response message. - /// is . - public ChatResponse(ChatMessage message) + public ChatResponse(ChatMessage? message) { - _ = Throw.IfNull(message); - Messages.Add(message); + if (message is not null) + { + Messages.Add(message); + } } /// Initializes a new instance of the class. /// The response messages. - /// is . - /// must not be read-only. - public ChatResponse(IList messages) + public ChatResponse(IList? messages) { - _ = Throw.IfNull(messages); - _ = Throw.IfReadOnly(messages); - _messages = messages; } /// Gets or sets the chat response messages. - /// The must not be read-only. public IList Messages { get => _messages ??= new List(1); - set - { - if (value is not null) - { - _ = Throw.IfReadOnly(value); - } - - _messages = value; - } + set => _messages = value; } /// Gets the text of the response. @@ -99,7 +85,7 @@ public string Text /// the input messages supplied to need only be the additional messages beyond /// what's already stored. If this property is non-, it represents an identifier for that state, /// and it should be used in a subsequent instead of supplying the same messages - /// (and this 's message) as part of the chatMessages parameter. + /// (and this 's message) as part of the messages parameter. /// public string? ChatThreadId { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs similarity index 85% rename from src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs index 53d15e24b70..604918f21fc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs @@ -14,14 +14,37 @@ namespace Microsoft.Extensions.AI; /// -/// Provides extension methods for working with instances. +/// Provides extension methods for working with and instances. /// -public static class ChatResponseUpdateExtensions +public static class ChatResponseExtensions { - /// Converts the into instances and adds them to . - /// The list to which the newly constructed messages should be added. + /// Adds all of the messages from into . + /// The destination list into which the messages should be added. + /// The response containing the messages to add. + /// is . + /// is . + public static void AddMessages(this IList list, ChatResponse response) + { + _ = Throw.IfNull(list); + _ = Throw.IfNull(response); + + if (list is List listConcrete) + { + listConcrete.AddRange(response.Messages); + } + else + { + foreach (var message in response.Messages) + { + list.Add(message); + } + } + } + + /// Converts the into instances and adds them to . + /// The list to which the newly constructed messages should be added. /// The instances to convert to messages and add to the list. - /// is . + /// is . /// is . /// /// As part of combining into a series of instances, tne @@ -29,9 +52,9 @@ public static class ChatResponseUpdateExtensions /// contiguous items where applicable, e.g. multiple /// instances in a row may be combined into a single . /// - public static void AddRangeFromUpdates(this IList chatMessages, IEnumerable updates) + public static void AddMessages(this IList list, IEnumerable updates) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(list); _ = Throw.IfNull(updates); if (updates is ICollection { Count: 0 }) @@ -39,27 +62,15 @@ public static void AddRangeFromUpdates(this IList chatMessages, IEn return; } - ChatResponse response = updates.ToChatResponse(); - if (chatMessages is List list) - { - list.AddRange(response.Messages); - } - else - { - int count = response.Messages.Count; - for (int i = 0; i < count; i++) - { - chatMessages.Add(response.Messages[i]); - } - } + list.AddMessages(updates.ToChatResponse()); } - /// Converts the into instances and adds them to . - /// The list to which the newly constructed messages should be added. + /// Converts the into instances and adds them to . + /// The list to which the newly constructed messages should be added. /// The instances to convert to messages and add to the list. /// The to monitor for cancellation requests. The default is . /// A representing the completion of the operation. - /// is . + /// is . /// is . /// /// As part of combining into a series of instances, tne @@ -67,31 +78,17 @@ public static void AddRangeFromUpdates(this IList chatMessages, IEn /// contiguous items where applicable, e.g. multiple /// instances in a row may be combined into a single . /// - public static Task AddRangeFromUpdatesAsync( - this IList chatMessages, IAsyncEnumerable updates, CancellationToken cancellationToken = default) + public static Task AddMessagesAsync( + this IList list, IAsyncEnumerable updates, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(list); _ = Throw.IfNull(updates); - return AddRangeFromUpdatesAsync(chatMessages, updates, cancellationToken); + return AddRangeFromUpdatesAsync(list, updates, cancellationToken); static async Task AddRangeFromUpdatesAsync( - IList chatMessages, IAsyncEnumerable updates, CancellationToken cancellationToken) - { - ChatResponse response = await updates.ToChatResponseAsync(cancellationToken).ConfigureAwait(false); - if (chatMessages is List list) - { - list.AddRange(response.Messages); - } - else - { - int count = response.Messages.Count; - for (int i = 0; i < count; i++) - { - chatMessages.Add(response.Messages[i]); - } - } - } + IList list, IAsyncEnumerable updates, CancellationToken cancellationToken) => + list.AddMessages(await updates.ToChatResponseAsync(cancellationToken).ConfigureAwait(false)); } /// Combines instances into a single . diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 214acbed465..0202c4b65c4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -6,7 +6,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -21,7 +20,7 @@ namespace Microsoft.Extensions.AI; /// /// /// The relationship between and is -/// codified in the and +/// codified in the and /// , which enable bidirectional conversions /// between the two. Note, however, that the provided conversions may be lossy, for example if multiple /// updates all have different objects whereas there's only one slot for @@ -56,16 +55,10 @@ public ChatResponseUpdate(ChatRole? role, string? contents) /// Initializes a new instance of the class. /// The role of the author of the update. /// The contents of the update. - /// must not be read-only. public ChatResponseUpdate(ChatRole? role, IList? contents) { - if (contents is not null) - { - _ = Throw.IfReadOnly(contents); - _contents = contents; - } - Role = role; + _contents = contents; } /// Gets or sets the name of the author of the response update. @@ -86,20 +79,11 @@ public string? AuthorName public string Text => _contents is not null ? _contents.ConcatText() : string.Empty; /// Gets or sets the chat response update content items. - /// The must not be read-only. [AllowNull] public IList Contents { get => _contents ??= []; - set - { - if (value is not null) - { - _ = Throw.IfReadOnly(value); - } - - _contents = value; - } + set => _contents = value; } /// Gets or sets the raw representation of the response update from an underlying implementation. @@ -123,7 +107,7 @@ public IList Contents /// the input messages supplied to need only be the additional messages beyond /// what's already stored. If this property is non-, it represents an identifier for that state, /// and it should be used in a subsequent instead of supplying the same messages - /// (and this streaming message) as part of the chatMessages parameter. + /// (and this streaming message) as part of the messages parameter. /// public string? ChatThreadId { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs index 7882529ac85..c0ccbb84f28 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs @@ -38,16 +38,14 @@ public void Dispose() protected IChatClient InnerClient { get; } /// - public virtual Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return InnerClient.GetResponseAsync(chatMessages, options, cancellationToken); - } + public virtual Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + InnerClient.GetResponseAsync(messages, options, cancellationToken); /// - public virtual IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return InnerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken); - } + public virtual IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken); /// public virtual object? GetService(Type serviceType, object? serviceKey = null) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 334245ec73e..4b487b4d9b3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -16,48 +16,33 @@ namespace Microsoft.Extensions.AI; /// /// /// However, implementations of might mutate the arguments supplied to and -/// , such as by adding additional messages to the messages list or configuring the options -/// instance. Thus, consumers of the interface either should avoid using shared instances of these arguments for concurrent -/// invocations or should otherwise ensure by construction that no instances are used which might employ -/// such mutation. For example, the WithChatOptions method be provided with a callback that could mutate the supplied options -/// argument, and that should be avoided if using a singleton options instance. +/// , such as by configuring the options instance. Thus, consumers of the interface either +/// should avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction +/// that no instances are used which might employ such mutation. For example, the WithChatOptions method be +/// provided with a callback that could mutate the supplied options argument, and that should be avoided if using a singleton options instance. /// /// public interface IChatClient : IDisposable { /// Sends chat messages and returns the response. - /// The list of chat messages to send and to be augmented with generated messages. + /// The list of chat messages to send and to be augmented with generated messages. /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// is . - /// - /// The response messages generated by are returned from the method as well as automatically - /// added into . This includes any messages generated implicitly as part of the interaction. - /// For example, if as part of satisfying this request, the method - /// itself issues multiple requests to one or more underlying instances, all of those messages will also - /// be included in . - /// + /// is . Task GetResponseAsync( - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default); /// Sends chat messages and streams the response. - /// The list of chat messages to send and to be augmented with generated messages. + /// The list of chat messages to send and to be augmented with generated messages. /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// is . - /// - /// The response updates generated by are streamed from the method as well as automatically - /// added into . This includes any messages generated implicitly as part of the interaction. - /// For example, if as part of satisfying this request, the method - /// itself issues multiple requests to one or more underlying instances, all of those messages will also - /// be included in . - /// + /// is . IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md index 1bac95467f8..0d94cacc925 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md @@ -41,7 +41,7 @@ IChatClient client = ...; Console.WriteLine(await client.GetResponseAsync("What is AI?")); ``` -The core `GetResponseAsync` method on the `IChatClient` interface accepts a list of messages. This list represents the history of all messages that are part of the conversation. +The core `GetResponseAsync` method on the `IChatClient` interface accepts a list of messages. This list often represents the history of all messages that are part of the conversation. ```csharp IChatClient client = ...; @@ -53,7 +53,10 @@ Console.WriteLine(await client.GetResponseAsync( ])); ``` -The `ChatResponse` that's returned from `GetResponseAsync` exposes a `ChatMessage` representing the response message. It is automatically added into the history by the `IChatClient`, so that it'll be provided back to the service in a subsequent request, e.g. +The `ChatResponse` that's returned from `GetResponseAsync` exposes a list of `ChatMessage` instances representing one or more messages generated as part of the operation. +In common cases, there is only one response message, but a variety of situations can result in their being multiple; the list is ordered, such that the last message in +the list represents the final message to the request. In order to provide all of those response messages back to the service in a subsequent request, the messages from +the response may be added back into the messages list. ```csharp List history = []; @@ -62,13 +65,17 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - Console.WriteLine(await client.GetResponseAsync(history)); + var response = await client.GetResponseAsync(history); + Console.WriteLine(response); + + history.AddMessages(response); } ``` #### Requesting a Streaming Chat Response: `GetStreamingResponseAsync` -The inputs to `GetStreamingResponseAsync` are identical to those of `GetResponseAsync`. However, rather than returning the complete response as part of a `ChatResponse` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. +The inputs to `GetStreamingResponseAsync` are identical to those of `GetResponseAsync`. However, rather than returning the complete response as part of a +`ChatResponse` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. ```csharp IChatClient client = ...; @@ -79,8 +86,10 @@ await foreach (var update in client.GetStreamingResponseAsync("What is AI?")) } ``` -As with `GetResponseAsync`, the `IChatClient.GetStreamingResponseAsync` implementation is responsible for adding -the response message back into the history, so that it'll be provided back to the service in a subsequent request. +As with `GetResponseAsync`, the updates from `IChatClient.GetStreamingResponseAsync` can be added back into the messages list. As the updates provided +are individual pieces of a response, helpers like `ToChatResponse` can be used to compose one or more updates back into a single `ChatResponse` instance. +Further helpers like `AddMessages` perform that same operation and then extract the composed messages from the response and add them into a list. + ```csharp List history = []; while (true) @@ -88,12 +97,14 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); + List updates = []; await foreach (var update in client.GetStreamingResponseAsync(history)) { Console.Write(update); } - Console.WriteLine(); + + history.AddMessages(updates); } ``` @@ -165,7 +176,7 @@ IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http:// .UseOpenTelemetry(sourceName: sourceName, configure: c => c.EnableSensitiveData = true) .Build(); -Console.WriteLine((await client.GetResponseAsync("What is AI?")).Message); +Console.WriteLine(await client.GetResponseAsync("What is AI?")); ``` Alternatively, the `LoggingChatClient` and corresponding `UseLogging` method provide a simple way to write log entries to an `ILogger` for every request and response. @@ -245,23 +256,23 @@ using System.Threading.RateLimiting; public sealed class RateLimitingChatClient(IChatClient innerClient, RateLimiter rateLimiter) : DelegatingChatClient(innerClient) { public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - return await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); } public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) yield return update; } @@ -285,7 +296,7 @@ var client = new RateLimitingChatClient( new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"), new ConcurrencyLimiter(new() { PermitLimit = 1, QueueLimit = int.MaxValue })); -await client.GetResponseAsync("What color is the sky?"); +Console.WriteLine(await client.GetResponseAsync("What color is the sky?")); ``` To make it easier to compose such components with others, the author of the component is recommended to create a "Use" extension method for registering this component into a pipeline, e.g. @@ -325,13 +336,13 @@ RateLimiter rateLimiter = ...; var client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") .AsBuilder() .UseDistributedCache() - .Use(async (chatMessages, options, nextAsync, cancellationToken) => + .Use(async (messages, options, nextAsync, cancellationToken) => { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - await nextAsync(chatMessages, options, cancellationToken); + await nextAsync(messages, options, cancellationToken); }) .UseOpenTelemetry() .Build(); @@ -369,9 +380,8 @@ What instance and configuration is injected may differ based on the current need "Stateless" services require all relevant conversation history to sent back on every request, while "stateful" services keep track of the history and instead require only additional messages be sent with a request. The `IChatClient` interface is designed to handle both stateless and stateful AI services. -When working with a stateless service, the `GetResponseAsync` and `GetStreamingResponseAsync` methods will automatically add the response message back -into the history. The client can then simply pass the same list of messages to a subsequent request, as that list will contain all -of the context necessary to enable the next request. +When working with a stateless service, callers maintain a list of all messages, adding in all received response messages, and providing the list +back on subsequent interactions. ```csharp List history = []; while (true) @@ -379,12 +389,15 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - Console.WriteLine(await client.GetResponseAsync(history)); + var response = await client.GetResponseAsync(history); + Console.WriteLine(response); + + history.AddMessages(response); } ``` For stateful services, you may know ahead of time an identifier used for the relevant conversation. That identifier can be put into `ChatOptions.ChatThreadId`. -Usage then follows the same pattern: +Usage then follows the same pattern, except there's no need to maintain a history manually. ```csharp ChatOptions options = new() { ChatThreadId = "my-conversation-id" }; while (true) @@ -432,6 +445,10 @@ while (true) { history.Clear(); } + else + { + history.AddMessages(response); + } } ``` diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 23edd8e9266..db03a62f2a9 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -84,13 +84,13 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions /// public async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // Make the call. ChatCompletions response = (await _chatCompletionsClient.CompleteAsync( - ToAzureAIOptions(chatMessages, options), + ToAzureAIOptions(messages, options), cancellationToken: cancellationToken).ConfigureAwait(false)).Value; // Create the return message. @@ -125,7 +125,6 @@ public async Task GetResponseAsync( } // Wrap the content in a ChatResponse to return. - chatMessages.Add(message); return new ChatResponse(message) { CreatedAt = response.Created, @@ -139,9 +138,9 @@ public async Task GetResponseAsync( /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); Dictionary? functionCallInfos = null; ChatRole? streamedRole = default; @@ -151,112 +150,102 @@ public async IAsyncEnumerable GetStreamingResponseAsync( string? modelId = null; string lastCallId = string.Empty; - List responseUpdates = []; - try + // Process each update as it arrives + var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(messages, options), cancellationToken).ConfigureAwait(false); + await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) { - // Process each update as it arrives - var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); - await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null; + responseId ??= chatCompletionUpdate.Id; + createdAt ??= chatCompletionUpdate.Created; + modelId ??= chatCompletionUpdate.Model; + + // Create the response content object. + ChatResponseUpdate responseUpdate = new() { - // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null; - finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null; - responseId ??= chatCompletionUpdate.Id; - createdAt ??= chatCompletionUpdate.Created; - modelId ??= chatCompletionUpdate.Model; - - // Create the response content object. - ChatResponseUpdate responseUpdate = new() - { - CreatedAt = chatCompletionUpdate.Created, - FinishReason = finishReason, - ModelId = modelId, - RawRepresentation = chatCompletionUpdate, - ResponseId = chatCompletionUpdate.Id, - Role = streamedRole, - }; - - // Transfer over content update items. - if (chatCompletionUpdate.ContentUpdate is string update) + CreatedAt = chatCompletionUpdate.Created, + FinishReason = finishReason, + ModelId = modelId, + RawRepresentation = chatCompletionUpdate, + ResponseId = chatCompletionUpdate.Id, + Role = streamedRole, + }; + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is string update) + { + responseUpdate.Contents.Add(new TextContent(update)); + } + + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate) + { + // TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference + // has removed the Index property from ToolCallUpdate. It's now impossible via the + // exposed APIs to correctly handle multiple parallel tool calls, as the CallId is + // often null for anything other than the first update for a given call, and Index + // isn't available to correlate which updates are for which call. This is a temporary + // workaround to at least make a single tool call work and also make work multiple + // tool calls when their updates aren't interleaved. + if (toolCallUpdate.Id is not null) { - responseUpdate.Contents.Add(new TextContent(update)); + lastCallId = toolCallUpdate.Id; } - // Transfer over tool call updates. - if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate) + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing)) { - // TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference - // has removed the Index property from ToolCallUpdate. It's now impossible via the - // exposed APIs to correctly handle multiple parallel tool calls, as the CallId is - // often null for anything other than the first update for a given call, and Index - // isn't available to correlate which updates are for which call. This is a temporary - // workaround to at least make a single tool call work and also make work multiple - // tool calls when their updates aren't interleaved. - if (toolCallUpdate.Id is not null) - { - lastCallId = toolCallUpdate.Id; - } - - functionCallInfos ??= []; - if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing)) - { - functionCallInfos[lastCallId] = existing = new(); - } - - existing.Name ??= toolCallUpdate.Function.Name; - if (toolCallUpdate.Function.Arguments is { } arguments) - { - _ = (existing.Arguments ??= new()).Append(arguments); - } + functionCallInfos[lastCallId] = existing = new(); } - if (chatCompletionUpdate.Usage is { } usage) + existing.Name ??= toolCallUpdate.Function.Name; + if (toolCallUpdate.Function.Arguments is { } arguments) { - responseUpdate.Contents.Add(new UsageContent(new() - { - InputTokenCount = usage.PromptTokens, - OutputTokenCount = usage.CompletionTokens, - TotalTokenCount = usage.TotalTokens, - })); + _ = (existing.Arguments ??= new()).Append(arguments); } - - // Now yield the item. - responseUpdates.Add(responseUpdate); - yield return responseUpdate; } - // Now that we've received all updates, combine any for function calls into a single item to yield. - if (functionCallInfos is not null) + if (chatCompletionUpdate.Usage is { } usage) { - var responseUpdate = new ChatResponseUpdate - { - CreatedAt = createdAt, - FinishReason = finishReason, - ModelId = modelId, - ResponseId = responseId, - Role = streamedRole, - }; - - foreach (var entry in functionCallInfos) + responseUpdate.Contents.Add(new UsageContent(new() { - FunctionCallInfo fci = entry.Value; - if (!string.IsNullOrWhiteSpace(fci.Name)) - { - FunctionCallContent callContent = ParseCallContentFromJsonString( - fci.Arguments?.ToString() ?? string.Empty, - entry.Key, - fci.Name!); - responseUpdate.Contents.Add(callContent); - } - } - - responseUpdates.Add(responseUpdate); - yield return responseUpdate; + InputTokenCount = usage.PromptTokens, + OutputTokenCount = usage.CompletionTokens, + TotalTokenCount = usage.TotalTokens, + })); } + + // Now yield the item. + yield return responseUpdate; } - finally + + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) { - chatMessages.AddRangeFromUpdates(responseUpdates); + var responseUpdate = new ChatResponseUpdate + { + CreatedAt = createdAt, + FinishReason = finishReason, + ModelId = modelId, + ResponseId = responseId, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + FunctionCallContent callContent = ParseCallContentFromJsonString( + fci.Arguments?.ToString() ?? string.Empty, + entry.Key, + fci.Name!); + responseUpdate.Contents.Add(callContent); + } + } + + yield return responseUpdate; } } @@ -292,7 +281,7 @@ private static ChatRole ToChatRole(global::Azure.AI.Inference.ChatRole role) => new(s); /// Converts an extensions options instance to an AzureAI options instance. - private ChatCompletionsOptions ToAzureAIOptions(IList chatContents, ChatOptions? options) + private ChatCompletionsOptions ToAzureAIOptions(IEnumerable chatContents, ChatOptions? options) { ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents)) { @@ -420,7 +409,7 @@ private static ChatCompletionsToolDefinition ToAzureAIChatTool(AIFunction aiFunc } /// Converts an Extensions chat message enumerable to an AzureAI chat message enumerable. - private IEnumerable ToAzureAIInferenceChatMessages(IList inputs) + private IEnumerable ToAzureAIInferenceChatMessages(IEnumerable inputs) { // Maps all of the M.E.AI types to the corresponding AzureAI types. // Unrecognized or non-processable content is ignored. diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs index 4113de75568..1d64fe89f4e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs @@ -35,13 +35,13 @@ public abstract class ChatConversationEvaluator : IEvaluator /// public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(modelResponse, nameof(modelResponse)); - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); + _ = Throw.IfNull(modelResponse); + _ = Throw.IfNull(chatConfiguration); EvaluationResult result = InitializeResult(); @@ -211,8 +211,8 @@ await PerformEvaluationAsync( ChatConfiguration chatConfiguration, CancellationToken cancellationToken) { - _ = Throw.IfNull(message, nameof(message)); - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); + _ = Throw.IfNull(message); + _ = Throw.IfNull(chatConfiguration); IEvaluationTokenCounter? tokenCounter = chatConfiguration.TokenCounter; if (tokenCounter is null) @@ -250,21 +250,26 @@ await PerformEvaluationAsync( } /// - /// Renders the supplied to a string that can be included as part of the evaluation + /// Renders the supplied to a string that can be included as part of the evaluation /// prompt that this uses. /// - /// - /// A message that is part of the conversation history for the response being evaluated and that is to be rendered + /// + /// Messages that are part of the conversation history for the response being evaluated and that is to be rendered /// as part of the evaluation prompt. /// /// A that can cancel the operation. /// - /// A string representation of the supplied that can be included as part of the + /// A string representation of the supplied that can be included as part of the /// evaluation prompt. /// - protected virtual ValueTask RenderAsync(ChatMessage message, CancellationToken cancellationToken) + /// + /// The default implementation considers only the last message of . + /// + protected virtual ValueTask RenderAsync(IEnumerable messages, CancellationToken cancellationToken) { - _ = Throw.IfNull(message, nameof(message)); + _ = Throw.IfNullOrEmpty(messages); + + ChatMessage message = messages.Last(); string? author = message.AuthorName; string role = message.Role.Value; @@ -296,7 +301,7 @@ protected virtual ValueTask RenderAsync(ChatMessage message, Cancellatio /// The evaluation prompt. protected abstract ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs index 8c31feb2dde..1a482a09eaf 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -31,16 +32,18 @@ public sealed class CoherenceEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { - string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); + _ = Throw.IfNull(modelResponse); + + string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) + ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) : string.Empty; string prompt = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs index ed482688e0c..d3fbb9ed56a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -35,16 +36,18 @@ public sealed class EquivalenceEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { - string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); + _ = Throw.IfNull(modelResponse); + + string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) + ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) : string.Empty; string groundTruth; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs index 8c11cf0f0c0..4612baa77ed 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -31,16 +32,18 @@ public sealed class FluencyEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { - string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); + _ = Throw.IfNull(modelResponse); + + string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) + ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) : string.Empty; string prompt = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs index ddb3d522a44..b39d17103ab 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs @@ -6,6 +6,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -35,16 +36,18 @@ public sealed class GroundednessEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { - string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); + _ = Throw.IfNull(modelResponse); + + string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) + ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) : string.Empty; var builder = new StringBuilder(); @@ -61,7 +64,7 @@ userRequest is not null { foreach (ChatMessage message in includedHistory) { - _ = builder.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); + _ = builder.Append(await RenderAsync([message], cancellationToken).ConfigureAwait(false)); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs index 24d96802542..6c2e843efad 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI.Evaluation.Quality.Utilities; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -75,16 +76,18 @@ protected override EvaluationResult InitializeResult() /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { - string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); + _ = Throw.IfNull(modelResponse); + + string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) + ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) : string.Empty; var builder = new StringBuilder(); @@ -92,7 +95,7 @@ userRequest is not null { foreach (ChatMessage message in includedHistory) { - _ = builder.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); + _ = builder.Append(await RenderAsync([message], cancellationToken).ConfigureAwait(false)); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs index f56e1e427fb..437dde3eb1e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs @@ -62,8 +62,8 @@ protected sealed override async ValueTask PerformEvaluationAsync( EvaluationResult result, CancellationToken cancellationToken) { - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(result); ChatResponse evaluationResponse = await chatConfiguration.ChatClient.GetResponseAsync( diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs index 1ff06c07467..fe1a3d91a9c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs @@ -156,7 +156,7 @@ public async ValueTask WriteResultsAsync( IEnumerable results, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(results, nameof(results)); + _ = Throw.IfNull(results); foreach (ScenarioRunResult result in results) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs index 8dc189767f2..3793afff05b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs @@ -126,7 +126,7 @@ internal ScenarioRun( /// An containing one or more s. public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs index 3c9a8fd5d44..3b723a2d258 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs @@ -87,7 +87,36 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(scenarioRun, nameof(scenarioRun)); + _ = Throw.IfNull(scenarioRun); + + return scenarioRun.EvaluateAsync( + messages: [], + new ChatResponse(modelResponse), + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// The of which this evaluation is a part. + /// The response that is to be evaluated. + /// + /// Additional contextual information that the s included in this + /// may need to accurately evaluate the supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this ScenarioRun scenarioRun, + ChatResponse modelResponse, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(scenarioRun); return scenarioRun.EvaluateAsync( messages: [], @@ -121,7 +150,41 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(scenarioRun, nameof(scenarioRun)); + _ = Throw.IfNull(scenarioRun); + + return scenarioRun.EvaluateAsync( + messages: [userRequest], + new ChatResponse(modelResponse), + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// The of which this evaluation is a part. + /// + /// The request that produced the that is to be evaluated. + /// + /// The response that is to be evaluated. + /// + /// Additional contextual information (beyond that which is available in ) that the + /// s included in this may need to accurately evaluate the + /// supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this ScenarioRun scenarioRun, + ChatMessage userRequest, + ChatResponse modelResponse, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(scenarioRun); return scenarioRun.EvaluateAsync( messages: [userRequest], diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs index 22d9ff0167e..e1a4102e42c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI.Evaluation.Reporting; /// Represents the results of a single execution of a particular iteration of a particular scenario under evaluation. /// In other words, represents the results of evaluating a /// and includes the that is produced when -/// +/// /// is invoked. /// /// @@ -44,7 +44,7 @@ public sealed class ScenarioRunResult( string executionName, DateTime creationTime, IList messages, - ChatMessage modelResponse, + ChatResponse modelResponse, EvaluationResult evaluationResult) { /// @@ -68,7 +68,7 @@ public ScenarioRunResult( string executionName, DateTime creationTime, IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, EvaluationResult evaluationResult) : this( scenarioName, @@ -115,7 +115,7 @@ public ScenarioRunResult( /// /// Gets or sets the response being evaluated in this . /// - public ChatMessage ModelResponse { get; set; } = modelResponse; + public ChatResponse ModelResponse { get; set; } = modelResponse; /// /// Gets or sets the for the corresponding to @@ -123,7 +123,7 @@ public ScenarioRunResult( /// /// /// This is the same that is returned when - /// + /// /// is invoked. /// public EvaluationResult EvaluationResult { get; set; } = evaluationResult; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs index 8b82a7336cf..ecc3dcb80e8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs @@ -30,7 +30,7 @@ public static bool ContainsDiagnostics( this ScenarioRunResult result, Func? predicate = null) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); return result.EvaluationResult.ContainsDiagnostics(predicate); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs index 3ab62df05f8..422bcab2fb2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs @@ -83,7 +83,7 @@ public async ValueTask WriteResultsAsync( IEnumerable results, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(results, nameof(results)); + _ = Throw.IfNull(results); foreach (ScenarioRunResult result in results) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs index af14851ff92..7dc544c66c8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs @@ -42,7 +42,7 @@ public CompositeEvaluator(params IEvaluator[] evaluators) /// An enumeration of s that are to be composed. public CompositeEvaluator(IEnumerable evaluators) { - _ = Throw.IfNull(evaluators, nameof(evaluators)); + _ = Throw.IfNull(evaluators); var metricNames = new HashSet(); @@ -102,7 +102,7 @@ public CompositeEvaluator(IEnumerable evaluators) /// An containing one or more s. public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) @@ -127,7 +127,7 @@ public async ValueTask EvaluateAsync( private IAsyncEnumerable EvaluateAndStreamResultsAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs index 9af7a5a2427..9b6f5e05104 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs @@ -31,7 +31,7 @@ public static bool ContainsDiagnostics( this EvaluationMetric metric, Func? predicate = null) { - _ = Throw.IfNull(metric, nameof(metric)); + _ = Throw.IfNull(metric); return predicate is null ? metric.Diagnostics.Any() : metric.Diagnostics.Any(predicate); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs index 668422d349e..a49cfeb8463 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs @@ -50,7 +50,7 @@ public EvaluationResult(IDictionary metrics) /// public EvaluationResult(IEnumerable metrics) { - _ = Throw.IfNull(metrics, nameof(metrics)); + _ = Throw.IfNull(metrics); var metricsDictionary = new Dictionary(); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs index 18b7181c7aa..30305327c8d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs @@ -22,7 +22,7 @@ public static class EvaluationResultExtensions /// The that is to be added. public static void AddDiagnosticToAllMetrics(this EvaluationResult result, EvaluationDiagnostic diagnostic) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); foreach (EvaluationMetric metric in result.Metrics.Values) { @@ -49,7 +49,7 @@ public static bool ContainsDiagnostics( this EvaluationResult result, Func? predicate = null) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); return result.Metrics.Values.Any(m => m.ContainsDiagnostics(predicate)); } @@ -69,8 +69,8 @@ public static void Interpret( this EvaluationResult result, Func interpretationProvider) { - _ = Throw.IfNull(result, nameof(result)); - _ = Throw.IfNull(interpretationProvider, nameof(interpretationProvider)); + _ = Throw.IfNull(result); + _ = Throw.IfNull(interpretationProvider); foreach (EvaluationMetric metric in result.Metrics.Values) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs index efda72a5c39..cfef4121af4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs @@ -133,7 +133,52 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(evaluator, nameof(evaluator)); + _ = Throw.IfNull(evaluator); + + return evaluator.EvaluateAsync( + messages: [], + new ChatResponse(modelResponse), + chatConfiguration, + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// + /// + /// The s of the s contained in the returned + /// should match . + /// + /// + /// Also note that must not be omitted if the evaluation is performed using an + /// AI model. + /// + /// + /// The that should perform the evaluation. + /// The response that is to be evaluated. + /// + /// A that specifies the and the + /// that should be used if the evaluation is performed using an AI model. + /// + /// + /// Additional contextual information that the may need to accurately evaluate the + /// supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this IEvaluator evaluator, + ChatResponse modelResponse, + ChatConfiguration? chatConfiguration = null, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(evaluator); return evaluator.EvaluateAsync( messages: [], @@ -182,7 +227,56 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(evaluator, nameof(evaluator)); + _ = Throw.IfNull(evaluator); + + return evaluator.EvaluateAsync( + messages: [userRequest], + new ChatResponse(modelResponse), + chatConfiguration, + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// + /// + /// The s of the s contained in the returned + /// should match . + /// + /// + /// Also note that must not be omitted if the evaluation is performed using an + /// AI model. + /// + /// + /// The that should perform the evaluation. + /// + /// The request that produced the that is to be evaluated. + /// + /// The response that is to be evaluated. + /// + /// A that specifies the and the + /// that should be used if the evaluation is performed using an AI model. + /// + /// + /// Additional contextual information (beyond that which is available in ) that the + /// may need to accurately evaluate the supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this IEvaluator evaluator, + ChatMessage userRequest, + ChatResponse modelResponse, + ChatConfiguration? chatConfiguration = null, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(evaluator); return evaluator.EvaluateAsync( messages: [userRequest], diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs index d30e4b92df7..9528d4132d3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs @@ -50,7 +50,7 @@ public interface IEvaluator /// An containing one or more s. ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs index a9ef5e0c508..681d69ed1e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs @@ -37,7 +37,7 @@ public int CountTokens(string content) /// public static IEvaluationTokenCounter ToTokenCounter(this Tokenizer tokenizer, int inputTokenLimit) { - _ = Throw.IfNull(tokenizer, nameof(tokenizer)); + _ = Throw.IfNull(tokenizer); return new TokenCounter(tokenizer, inputTokenLimit); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index abddf295daf..ed1448c8b69 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -80,13 +80,14 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions } /// - public async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); using var httpResponse = await _httpClient.PostAsJsonAsync( _apiChatEndpoint, - ToOllamaChatRequest(chatMessages, options, stream: false), + ToOllamaChatRequest(messages, options, stream: false), JsonContext.Default.OllamaChatRequest, cancellationToken).ConfigureAwait(false); @@ -104,11 +105,7 @@ public async Task GetResponseAsync(IList chatMessages throw new InvalidOperationException($"Ollama error: {response.Error}"); } - var responseMessage = FromOllamaMessage(response.Message!); - - chatMessages.Add(responseMessage); - - return new(responseMessage) + return new(FromOllamaMessage(response.Message!)) { CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(response), @@ -120,13 +117,13 @@ public async Task GetResponseAsync(IList chatMessages /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); using HttpRequestMessage request = new(HttpMethod.Post, _apiChatEndpoint) { - Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest) + Content = JsonContent.Create(ToOllamaChatRequest(messages, options, stream: true), JsonContext.Default.OllamaChatRequest) }; using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); @@ -143,65 +140,56 @@ public async IAsyncEnumerable GetStreamingResponseAsync( #endif .ConfigureAwait(false); - List updates = []; - try - { - using var streamReader = new StreamReader(httpResponseStream); + using var streamReader = new StreamReader(httpResponseStream); #if NET - while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line) + while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line) #else - while ((await streamReader.ReadLineAsync().ConfigureAwait(false)) is { } line) + while ((await streamReader.ReadLineAsync().ConfigureAwait(false)) is { } line) #endif + { + var chunk = JsonSerializer.Deserialize(line, JsonContext.Default.OllamaChatResponse); + if (chunk is null) { - var chunk = JsonSerializer.Deserialize(line, JsonContext.Default.OllamaChatResponse); - if (chunk is null) - { - continue; - } + continue; + } - string? modelId = chunk.Model ?? _metadata.ModelId; + string? modelId = chunk.Model ?? _metadata.ModelId; - ChatResponseUpdate update = new() - { - CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, - FinishReason = ToFinishReason(chunk), - ModelId = modelId, - ResponseId = chunk.CreatedAt, - Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, - }; - - if (chunk.Message is { } message) + ChatResponseUpdate update = new() + { + CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, + FinishReason = ToFinishReason(chunk), + ModelId = modelId, + ResponseId = chunk.CreatedAt, + Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, + }; + + if (chunk.Message is { } message) + { + if (message.ToolCalls is { Length: > 0 }) { - if (message.ToolCalls is { Length: > 0 }) + foreach (var toolCall in message.ToolCalls) { - foreach (var toolCall in message.ToolCalls) + if (toolCall.Function is { } function) { - if (toolCall.Function is { } function) - { - update.Contents.Add(ToFunctionCallContent(function)); - } + update.Contents.Add(ToFunctionCallContent(function)); } } - - // Equivalent rule to the nonstreaming case - if (message.Content?.Length > 0 || update.Contents.Count == 0) - { - update.Contents.Insert(0, new TextContent(message.Content)); - } } - if (ParseOllamaChatResponseUsage(chunk) is { } usage) + // Equivalent rule to the nonstreaming case + if (message.Content?.Length > 0 || update.Contents.Count == 0) { - update.Contents.Add(new UsageContent(usage)); + update.Contents.Insert(0, new TextContent(message.Content)); } + } - updates.Add(update); - yield return update; + if (ParseOllamaChatResponseUsage(chunk) is { } usage) + { + update.Contents.Add(new UsageContent(usage)); } - } - finally - { - chatMessages.AddRangeFromUpdates(updates); + + yield return update; } } @@ -305,12 +293,12 @@ private static FunctionCallContent ToFunctionCallContent(OllamaFunctionToolCall } } - private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, ChatOptions? options, bool stream) + private OllamaChatRequest ToOllamaChatRequest(IEnumerable messages, ChatOptions? options, bool stream) { OllamaChatRequest request = new() { Format = ToOllamaChatResponseFormat(options?.ResponseFormat), - Messages = chatMessages.SelectMany(ToOllamaChatRequestMessages).ToArray(), + Messages = messages.SelectMany(ToOllamaChatRequestMessages).ToArray(), Model = options?.ModelId ?? _metadata.ModelId ?? string.Empty, Stream = stream, Tools = options?.ToolMode is not NoneChatToolMode && options?.Tools is { Count: > 0 } tools ? tools.OfType().Select(ToOllamaTool) : null, diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index 6c251a46d26..1e5afb6d529 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -70,21 +70,21 @@ public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId /// public Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(cancellationToken); + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + GetStreamingResponseAsync(messages, options, cancellationToken).ToChatResponseAsync(cancellationToken); /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // Extract necessary state from chatMessages and options. - (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(chatMessages, options); + // Extract necessary state from messages and options. + (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(messages, options); // Get the thread ID. string? threadId = options?.ChatThreadId ?? _threadId; if (threadId is null && toolResults is not null) { - Throw.ArgumentException(nameof(chatMessages), "No thread ID was provided, but chat messages includes tool results."); + Throw.ArgumentException(nameof(messages), "No thread ID was provided, but chat messages includes tool results."); } // Get the updates to process from the assistant. If we have any tool results, this means submitting those and ignoring @@ -112,73 +112,62 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } // Process each update. - List responseUpdates = []; - try + string? responseId = null; + await foreach (var update in updates.ConfigureAwait(false)) { - string? responseId = null; - await foreach (var update in updates.ConfigureAwait(false)) + switch (update) { - switch (update) - { - case MessageContentUpdate mcu: - ChatResponseUpdate responseUpdate = new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) - { - ChatThreadId = threadId, - RawRepresentation = mcu, - ResponseId = responseId, - }; - responseUpdates.Add(responseUpdate); - yield return responseUpdate; - break; + case MessageContentUpdate mcu: + yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) + { + ChatThreadId = threadId, + RawRepresentation = mcu, + ResponseId = responseId, + }; + break; - case ThreadUpdate tu when options is not null: - threadId ??= tu.Value.Id; - break; + case ThreadUpdate tu when options is not null: + threadId ??= tu.Value.Id; + break; - case RunUpdate ru: - threadId ??= ru.Value.ThreadId; - responseId ??= ru.Value.Id; + case RunUpdate ru: + threadId ??= ru.Value.ThreadId; + responseId ??= ru.Value.Id; - ChatResponseUpdate ruUpdate = new() - { - AuthorName = ru.Value.AssistantId, - ChatThreadId = threadId, - CreatedAt = ru.Value.CreatedAt, - ModelId = ru.Value.Model, - RawRepresentation = ru, - ResponseId = responseId, - Role = ChatRole.Assistant, - }; - - if (ru.Value.Usage is { } usage) - { - ruUpdate.Contents.Add(new UsageContent(new() - { - InputTokenCount = usage.InputTokenCount, - OutputTokenCount = usage.OutputTokenCount, - TotalTokenCount = usage.TotalTokenCount, - })); - } - - if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + ChatResponseUpdate ruUpdate = new() + { + AuthorName = ru.Value.AssistantId, + ChatThreadId = threadId, + CreatedAt = ru.Value.CreatedAt, + ModelId = ru.Value.Model, + RawRepresentation = ru, + ResponseId = responseId, + Role = ChatRole.Assistant, + }; + + if (ru.Value.Usage is { } usage) + { + ruUpdate.Contents.Add(new UsageContent(new() { - ruUpdate.Contents.Add( - new FunctionCallContent( - JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!), - functionName, - JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!)); - } - - responseUpdates.Add(ruUpdate); - yield return ruUpdate; - break; - } + InputTokenCount = usage.InputTokenCount, + OutputTokenCount = usage.OutputTokenCount, + TotalTokenCount = usage.TotalTokenCount, + })); + } + + if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName) + { + ruUpdate.Contents.Add( + new FunctionCallContent( + JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!), + functionName, + JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!)); + } + + yield return ruUpdate; + break; } } - finally - { - chatMessages.AddRangeFromUpdates(responseUpdates); - } } /// @@ -188,9 +177,10 @@ void IDisposable.Dispose() } /// Adds the provided messages to the thread and returns the options to use for the request. - private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions(IList chatMessages, ChatOptions? options) + private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions( + IEnumerable messages, ChatOptions? options) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); RunCreationOptions runOptions = new(); @@ -285,7 +275,7 @@ strictObj is bool strictValue ? // Handle ChatMessages. System messages are turned into additional instructions. StringBuilder? instructions = null; List? functionResults = null; - foreach (var chatMessage in chatMessages) + foreach (var chatMessage in messages) { List messageContents = []; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index fbb2c3fa4e1..7852a87c2e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -103,38 +103,32 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions /// public async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); - var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(messages, ToolCallJsonSerializerOptions); var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); // Make the call to OpenAI. var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false); - ChatResponse chatResponse = OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions); - foreach (ChatMessage message in chatResponse.Messages) - { - chatMessages.Add(message); - } - - return chatResponse; + return OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options, openAIOptions); } /// public IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); - var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(messages, ToolCallJsonSerializerOptions); var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); // Make the call to OpenAI. var chatCompletionUpdates = _chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken); - return OpenAIModelMappers.FromOpenAIStreamingChatCompletionAsync(chatMessages, chatCompletionUpdates, cancellationToken); + return OpenAIModelMappers.FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, cancellationToken); } /// diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs index d4858b3b70c..bfafbdf82b2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs @@ -62,7 +62,6 @@ public static async IAsyncEnumerable ToOpenAIStre } public static async IAsyncEnumerable FromOpenAIStreamingChatCompletionAsync( - IList chatMessages, IAsyncEnumerable updates, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -75,142 +74,132 @@ public static async IAsyncEnumerable FromOpenAIStreamingChat string? modelId = null; string? fingerprint = null; - List responseUpdates = []; - try + // Process each update as it arrives + await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) { - // Process each update as it arrives - await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= update.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; + finishReason ??= update.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; + responseId ??= update.CompletionId; + createdAt ??= update.CreatedAt; + modelId ??= update.Model; + fingerprint ??= update.SystemFingerprint; + + // Create the response content object. + ChatResponseUpdate responseUpdate = new() { - // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= update.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; - finishReason ??= update.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; - responseId ??= update.CompletionId; - createdAt ??= update.CreatedAt; - modelId ??= update.Model; - fingerprint ??= update.SystemFingerprint; - - // Create the response content object. - ChatResponseUpdate responseUpdate = new() - { - ResponseId = update.CompletionId, - CreatedAt = update.CreatedAt, - FinishReason = finishReason, - ModelId = modelId, - RawRepresentation = update, - Role = streamedRole, - }; - - // Populate it with any additional metadata from the OpenAI object. - if (update.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.ContentTokenLogProbabilities)] = contentTokenLogProbs; - } + ResponseId = update.CompletionId, + CreatedAt = update.CreatedAt, + FinishReason = finishReason, + ModelId = modelId, + RawRepresentation = update, + Role = streamedRole, + }; + + // Populate it with any additional metadata from the OpenAI object. + if (update.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } - if (update.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; - } + if (update.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } - if (fingerprint is not null) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(update.SystemFingerprint)] = fingerprint; - } + if (fingerprint is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(update.SystemFingerprint)] = fingerprint; + } - // Transfer over content update items. - if (update.ContentUpdate is { Count: > 0 }) + // Transfer over content update items. + if (update.ContentUpdate is { Count: > 0 }) + { + foreach (ChatMessageContentPart contentPart in update.ContentUpdate) { - foreach (ChatMessageContentPart contentPart in update.ContentUpdate) + if (ToAIContent(contentPart) is AIContent aiContent) { - if (ToAIContent(contentPart) is AIContent aiContent) - { - responseUpdate.Contents.Add(aiContent); - } + responseUpdate.Contents.Add(aiContent); } } + } - // Transfer over refusal updates. - if (update.RefusalUpdate is not null) - { - _ = (refusal ??= new()).Append(update.RefusalUpdate); - } + // Transfer over refusal updates. + if (update.RefusalUpdate is not null) + { + _ = (refusal ??= new()).Append(update.RefusalUpdate); + } - // Transfer over tool call updates. - if (update.ToolCallUpdates is { Count: > 0 } toolCallUpdates) + // Transfer over tool call updates. + if (update.ToolCallUpdates is { Count: > 0 } toolCallUpdates) + { + foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) { - foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) { - functionCallInfos ??= []; - if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) - { - functionCallInfos[toolCallUpdate.Index] = existing = new(); - } - - existing.CallId ??= toolCallUpdate.ToolCallId; - existing.Name ??= toolCallUpdate.FunctionName; - if (toolCallUpdate.FunctionArgumentsUpdate is { } argUpdate && !argUpdate.ToMemory().IsEmpty) - { - _ = (existing.Arguments ??= new()).Append(argUpdate.ToString()); - } + functionCallInfos[toolCallUpdate.Index] = existing = new(); } - } - // Transfer over usage updates. - if (update.Usage is ChatTokenUsage tokenUsage) - { - var usageDetails = FromOpenAIUsage(tokenUsage); - responseUpdate.Contents.Add(new UsageContent(usageDetails)); + existing.CallId ??= toolCallUpdate.ToolCallId; + existing.Name ??= toolCallUpdate.FunctionName; + if (toolCallUpdate.FunctionArgumentsUpdate is { } argUpdate && !argUpdate.ToMemory().IsEmpty) + { + _ = (existing.Arguments ??= new()).Append(argUpdate.ToString()); + } } - - // Now yield the item. - responseUpdates.Add(responseUpdate); - yield return responseUpdate; } - // Now that we've received all updates, combine any for function calls into a single item to yield. - if (functionCallInfos is not null) + // Transfer over usage updates. + if (update.Usage is ChatTokenUsage tokenUsage) { - ChatResponseUpdate responseUpdate = new() - { - ResponseId = responseId, - CreatedAt = createdAt, - FinishReason = finishReason, - ModelId = modelId, - Role = streamedRole, - }; - - foreach (var entry in functionCallInfos) - { - FunctionCallInfo fci = entry.Value; - if (!string.IsNullOrWhiteSpace(fci.Name)) - { - var callContent = ParseCallContentFromJsonString( - fci.Arguments?.ToString() ?? string.Empty, - fci.CallId!, - fci.Name!); - responseUpdate.Contents.Add(callContent); - } - } + var usageDetails = FromOpenAIUsage(tokenUsage); + responseUpdate.Contents.Add(new UsageContent(usageDetails)); + } - // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, - // add it to this function calling item. - if (refusal is not null) - { - (responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); - } + // Now yield the item. + yield return responseUpdate; + } - // Propagate additional relevant metadata. - if (fingerprint is not null) + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) + { + ChatResponseUpdate responseUpdate = new() + { + ResponseId = responseId, + CreatedAt = createdAt, + FinishReason = finishReason, + ModelId = modelId, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) { - (responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint; + var callContent = ParseCallContentFromJsonString( + fci.Arguments?.ToString() ?? string.Empty, + fci.CallId!, + fci.Name!); + responseUpdate.Contents.Add(callContent); } + } - responseUpdates.Add(responseUpdate); - yield return responseUpdate; + // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, + // add it to this function calling item. + if (refusal is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); } - } - finally - { - chatMessages.AddRangeFromUpdates(responseUpdates); + + // Propagate additional relevant metadata. + if (fingerprint is not null) + { + (responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint; + } + + yield return responseUpdate; } } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index 8063f914764..5e3032fee37 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI; internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient { /// The delegate to use as the implementation of . - private readonly Func, ChatOptions?, IChatClient, CancellationToken, Task>? _getResponseFunc; + private readonly Func, ChatOptions?, IChatClient, CancellationToken, Task>? _getResponseFunc; /// The delegate to use as the implementation of . /// @@ -25,10 +25,10 @@ internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient /// will be invoked with the same arguments as the method itself, along with a reference to the inner client. /// When , will delegate directly to the inner client. /// - private readonly Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? _getStreamingResponseFunc; + private readonly Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? _getStreamingResponseFunc; /// The delegate to use as the implementation of both and . - private readonly Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task>? _sharedFunc; + private readonly Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task>? _sharedFunc; /// /// Initializes a new instance of the class. @@ -47,7 +47,7 @@ internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient /// is . public AnonymousDelegatingChatClient( IChatClient innerClient, - Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) + Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) : base(innerClient) { _ = Throw.IfNull(sharedFunc); @@ -73,8 +73,8 @@ public AnonymousDelegatingChatClient( /// Both and are . public AnonymousDelegatingChatClient( IChatClient innerClient, - Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, - Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) + Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, + Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) : base(innerClient) { ThrowIfBothDelegatesNull(getResponseFunc, getStreamingResponseFunc); @@ -85,20 +85,21 @@ public AnonymousDelegatingChatClient( /// public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (_sharedFunc is not null) { - return GetResponseViaSharedAsync(chatMessages, options, cancellationToken); + return GetResponseViaSharedAsync(messages, options, cancellationToken); - async Task GetResponseViaSharedAsync(IList chatMessages, ChatOptions? options, CancellationToken cancellationToken) + async Task GetResponseViaSharedAsync( + IEnumerable messages, ChatOptions? options, CancellationToken cancellationToken) { ChatResponse? response = null; - await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellationToken) => + await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - response = await InnerClient.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); }, cancellationToken).ConfigureAwait(false); if (response is null) @@ -111,21 +112,21 @@ await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellat } else if (_getResponseFunc is not null) { - return _getResponseFunc(chatMessages, options, InnerClient, cancellationToken); + return _getResponseFunc(messages, options, InnerClient, cancellationToken); } else { Debug.Assert(_getStreamingResponseFunc is not null, "Expected non-null streaming delegate."); - return _getStreamingResponseFunc!(chatMessages, options, InnerClient, cancellationToken) + return _getStreamingResponseFunc!(messages, options, InnerClient, cancellationToken) .ToChatResponseAsync(cancellationToken); } } /// public override IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (_sharedFunc is not null) { @@ -138,9 +139,9 @@ public override IAsyncEnumerable GetStreamingResponseAsync( Exception? error = null; try { - await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellationToken) => + await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - await foreach (var update in InnerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false); } @@ -161,12 +162,12 @@ await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellat } else if (_getStreamingResponseFunc is not null) { - return _getStreamingResponseFunc(chatMessages, options, InnerClient, cancellationToken); + return _getStreamingResponseFunc(messages, options, InnerClient, cancellationToken); } else { Debug.Assert(_getResponseFunc is not null, "Expected non-null non-streaming delegate."); - return GetStreamingResponseAsyncViaGetResponseAsync(_getResponseFunc!(chatMessages, options, InnerClient, cancellationToken)); + return GetStreamingResponseAsyncViaGetResponseAsync(_getResponseFunc!(messages, options, InnerClient, cancellationToken)); static async IAsyncEnumerable GetStreamingResponseAsyncViaGetResponseAsync(Task task) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 7563ecd891c..7d7b2b58403 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -45,25 +45,19 @@ protected CachingChatClient(IChatClient innerClient) public bool CoalesceStreamingUpdates { get; set; } = true; /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // We're only storing the final result, not the in-flight task, so that we can avoid caching failures // or having problems when one of the callers cancels but others don't. This has the drawback that // concurrent callers might trigger duplicate requests, but that's acceptable. - var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options); + var cacheKey = GetCacheKey(_boxedFalse, messages, options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } result) + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) { - foreach (ChatMessage message in result.Messages) - { - chatMessages.Add(message); - } - } - else - { - result = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); } @@ -72,9 +66,9 @@ public override async Task GetResponseAsync(IList cha /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (CoalesceStreamingUpdates) { @@ -82,7 +76,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // we make a streaming request, yielding those results, but then convert those into a non-streaming // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. - var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, messages, options); if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse) { // Yield all of the cached items. @@ -90,17 +84,12 @@ public override async IAsyncEnumerable GetStreamingResponseA { yield return chunk; } - - foreach (ChatMessage message in chatResponse.Messages) - { - chatMessages.Add(message); - } } else { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { capturedItems.Add(chunk); yield return chunk; @@ -112,7 +101,7 @@ public override async IAsyncEnumerable GetStreamingResponseA } else { - var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, messages, options); if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) { // Yield all of the cached items. @@ -122,17 +111,12 @@ public override async IAsyncEnumerable GetStreamingResponseA chatThreadId ??= chunk.ChatThreadId; yield return chunk; } - - if (chatThreadId is null) - { - chatMessages.AddRangeFromUpdates(existingChunks); - } } else { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { capturedItems.Add(chunk); yield return chunk; @@ -151,7 +135,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// /// Returns a previously cached , if available. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to monitor for cancellation requests. @@ -161,7 +145,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// /// Returns a previously cached list of values, if available. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to monitor for cancellation requests. @@ -171,7 +155,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// /// Stores a in the underlying cache. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to be stored. @@ -183,7 +167,7 @@ public override async IAsyncEnumerable GetStreamingResponseA /// /// Stores a list of values in the underlying cache. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to be stored. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs index 5ecf6403d78..23fba1e0abd 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -99,7 +99,7 @@ public ChatClientBuilder Use(Func cl /// need to interact with the results of the operation, which will come from the inner client. /// /// is . - public ChatClientBuilder Use(Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) + public ChatClientBuilder Use(Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) { _ = Throw.IfNull(sharedFunc); @@ -133,8 +133,8 @@ public ChatClientBuilder Use(Func, ChatOptions?, Func /// Both and are . public ChatClientBuilder Use( - Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, - Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) + Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, + Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) { AnonymousDelegatingChatClient.ThrowIfBothDelegatesNull(getResponseFunc, getStreamingResponseFunc); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 5932f92de24..59bd70eefc6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -11,6 +11,8 @@ using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; +#pragma warning disable SA1118 // Parameter should not span multiple lines + namespace Microsoft.Extensions.AI; /// @@ -27,7 +29,7 @@ public static class ChatClientStructuredOutputExtensions /// Sends chat messages, requesting a response matching the type . /// The . - /// The chat content to send. + /// The chat content to send. /// The chat options to configure the request. /// /// Optionally specifies whether to set a JSON schema on the . @@ -39,11 +41,11 @@ public static class ChatClientStructuredOutputExtensions /// The type of structured output to request. public static Task> GetResponseAsync( this IChatClient chatClient, - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, bool? useNativeJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, chatMessages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, messages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); /// Sends a user chat text message, requesting a response matching the type . /// The . @@ -131,7 +133,7 @@ public static Task> GetResponseAsync( /// Sends chat messages, requesting a response matching the type . /// The . - /// The chat content to send. + /// The chat content to send. /// The JSON serialization options to use. /// The chat options to configure the request. /// @@ -143,18 +145,18 @@ public static Task> GetResponseAsync( /// The response messages generated by the client. /// The type of structured output to request. /// is . - /// is . + /// is . /// is . public static async Task> GetResponseAsync( this IChatClient chatClient, - IList chatMessages, + IEnumerable messages, JsonSerializerOptions serializerOptions, ChatOptions? options = null, bool? useNativeJsonSchema = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatClient); - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _ = Throw.IfNull(serializerOptions); serializerOptions.MakeReadOnly(); @@ -188,7 +190,7 @@ public static async Task> GetResponseAsync( } ChatMessage? promptAugmentation = null; - options = (options ?? new()).Clone(); + options = options is not null ? options.Clone() : new(); // Currently there's no way for the inner IChatClient to specify whether structured output // is supported, so we always default to false. In the future, some mechanism of declaring @@ -207,30 +209,18 @@ public static async Task> GetResponseAsync( options.ResponseFormat = ChatResponseFormat.Json; // When not using native structured output, augment the chat messages with a schema prompt -#pragma warning disable SA1118 // Parameter should not span multiple lines promptAugmentation = new ChatMessage(ChatRole.User, $$""" Respond with a JSON value conforming to the following schema: ``` {{schema}} ``` """); -#pragma warning restore SA1118 // Parameter should not span multiple lines - chatMessages.Add(promptAugmentation); + messages = [.. messages, promptAugmentation]; } - try - { - var result = await chatClient.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); - return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; - } - finally - { - if (promptAugmentation is not null) - { - _ = chatMessages.Remove(promptAugmentation); - } - } + var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; } private static bool SchemaRepresentsObject(JsonElement schemaElement) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs index 551441139a7..5a5dfea06c3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -34,16 +34,15 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action c } /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return await base.GetResponseAsync(chatMessages, Configure(options), cancellationToken).ConfigureAwait(false); - } + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false); /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs index 97e5f75caa5..8dca904ccd0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs @@ -18,7 +18,7 @@ public sealed class FunctionInvocationContext private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext)); /// The chat contents associated with the operation that initiated this function call request. - private IList _chatMessages = Array.Empty(); + private IList _messages = Array.Empty(); /// The AI function to be invoked. private AIFunction _function = _nopFunction; @@ -39,10 +39,10 @@ public FunctionCallContent CallContent } /// Gets or sets the chat contents associated with the operation that initiated this function call request. - public IList ChatMessages + public IList Messages { - get => _chatMessages; - set => _chatMessages = Throw.IfNull(value); + get => _messages; + set => _messages = Throw.IfNull(value); } /// Gets or sets the chat options associated with the operation that initiated this function call request. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 509a9c24c61..43cb2019b1c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -180,27 +180,29 @@ public int? MaximumIterationsPerRequest } /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); - _ = Throw.IfReadOnly(chatMessages); + _ = Throw.IfNull(messages); // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - IList originalChatMessages = chatMessages; - ChatResponse? response = null; - List? responseMessages = null; - UsageDetails? totalUsage = null; - List? functionCallContents = null; + IEnumerable originalMessages = messages; // the original messages, tracked for the rare case where we need to know what was originally provided + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response + List? functionCallContents = null; // function call contents that need responding to in the current turn + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set for (int iteration = 0; ; iteration++) { functionCallContents?.Clear(); - // Make the call to the handler. - response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + // Make the call to the inner client. + response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); if (response is null) { throw new InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); @@ -212,24 +214,14 @@ public override async Task GetResponseAsync(IList cha (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && CopyFunctionCalls(response.Messages, ref functionCallContents); - // In the common case where we make a request and there's no function calling work required, + // In a common case where we make a request and there's no function calling work required, // fast path out by just returning the original response. if (iteration == 0 && !requiresFunctionInvocation) { - Debug.Assert(originalChatMessages == chatMessages, - "Expected the history to be the original, such that there's no additional work to do to keep it up to date."); return response; } - // If chatMessages is different from originalChatMessages, we previously created a different history - // in order to avoid sending state back to an inner client that was already tracking it. But we still - // need that original history to contain all the state. So copy it over if necessary. - if (chatMessages != originalChatMessages) - { - AddRange(originalChatMessages, response.Messages); - } - - // Track aggregatable details from the response. + // Track aggregatable details from the response, including all of the response messages and usage details. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) { @@ -244,50 +236,21 @@ public override async Task GetResponseAsync(IList cha } // If there are no tools to call, or for any other reason we should stop, we're done. + // Break out of the loop and allow the handling at the end to configure the response + // with aggregated data from previous requests. if (!requiresFunctionInvocation) { - // If this is the first request, we can just return the response, as we don't need to - // incorporate any information from previous requests. - if (iteration == 0) - { - return response; - } - - // Otherwise, break out of the loop and allow the handling at the end to configure - // the response with aggregated data from previous requests. break; } - // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. - if (response.ChatThreadId is not null) - { - if (chatMessages == originalChatMessages) - { - chatMessages = []; - } - else - { - chatMessages.Clear(); - } - } - else if (chatMessages != originalChatMessages) - { - // This should be a very rare case. In a previous iteration, we got back a non-null - // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId, - // and chatMessages is no longer the full history. Thankfully, we've been keeping - // originalChatMessages up to date; we can just switch back to use it. - chatMessages = originalChatMessages; - } + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - // Add the responses from the function calls into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + // Add the responses from the function calls into the augmented history and also into the tracked + // list of response messages. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); - if (chatMessages != originalChatMessages) - { - AddRange(originalChatMessages, modeAndMessages.MessagesAdded); - } - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) { // Terminate @@ -304,49 +267,41 @@ public override async Task GetResponseAsync(IList cha /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); - _ = Throw.IfReadOnly(chatMessages); + _ = Throw.IfNull(messages); // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - List? functionCallContents = null; - IList originalChatMessages = chatMessages; + IEnumerable originalMessages = messages; // the original messages, tracked for the rare case where we need to know what was originally provided + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + List? functionCallContents = null; // function call contents that need responding to in the current turn + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + List updates = []; // updates from the current response + for (int iteration = 0; ; iteration++) { + updates.Clear(); functionCallContents?.Clear(); - string? chatThreadId = null; - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { if (update is null) { throw new InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); } - chatThreadId ??= update.ChatThreadId; + updates.Add(update); + _ = CopyFunctionCalls(update.Contents, ref functionCallContents); yield return update; Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } - // Make sure that any of the response messages that were added to the chat history also get - // added to the original history if it's different. - if (chatMessages != originalChatMessages) - { - // If chatThreadId was null previously, then we would have added any function result content into - // the original chat messages, passed those chat messages to GetStreamingResponseAsync, and it would - // have added all the new response messages into the original chat messages. But chatThreadId was - // non-null, hence we forked chatMessages. chatMessages then included only the function result content - // and should now include that function result content plus the response messages. None of that is - // in the original, so we can just add everything from chatMessages into the original. - AddRange(originalChatMessages, chatMessages); - } - // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || options?.Tools is not { Count: > 0 } || @@ -355,33 +310,19 @@ public override async IAsyncEnumerable GetStreamingResponseA break; } - // If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state. - // In that case, we also avoid touching the user's history, so that we don't need to clear it. - if (chatThreadId is not null) - { - if (chatMessages == originalChatMessages) - { - chatMessages = []; - } - else - { - chatMessages.Clear(); - } - } - else if (chatMessages != originalChatMessages) - { - // This should be a very rare case. In a previous iteration, we got back a non-null - // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId, - // and chatMessages is no longer the full history. Thankfully, we've been keeping - // originalChatMessages up to date; we can just switch back to use it. - chatMessages = originalChatMessages; - } + // Reconsistitue a response from the response updates. + ChatResponse response = updates.ToChatResponse(); + (responseMessages ??= []).AddRange(response.Messages); + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + responseMessages.AddRange(modeAndMessages.MessagesAdded); - // Stream any generated function results. These are already part of the history, - // but we stream them out for informational purposes. + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // includes all activitys, including generated function results. string toolResponseId = Guid.NewGuid().ToString("N"); foreach (var message in modeAndMessages.MessagesAdded) { @@ -389,7 +330,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { AdditionalProperties = message.AdditionalProperties, AuthorName = message.AuthorName, - ChatThreadId = chatThreadId, + ChatThreadId = response.ChatThreadId, CreatedAt = DateTimeOffset.UtcNow, Contents = message.Contents, RawRepresentation = message.RawRepresentation, @@ -401,7 +342,7 @@ public override async IAsyncEnumerable GetStreamingResponseA Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId)) + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) { // Terminate yield break; @@ -409,6 +350,68 @@ public override async IAsyncEnumerable GetStreamingResponseA } } + /// Prepares the various chat message lists after a response from the inner client and before invoking functions. + /// The original messages provided by the caller. + /// The messages reference passed to the inner client. + /// The augmented history containing all the messages to be sent. + /// The most recent response being handled. + /// A list of all response messages received up until this point. + /// Whether the previous iteration's response had a thread id. + private static void FixupHistories( + IEnumerable originalMessages, + ref IEnumerable messages, + [NotNull] ref List? augmentedHistory, + ChatResponse response, + List allTurnsResponseMessages, + ref bool lastIterationHadThreadId) + { + // We're now going to need to augment the history with function result contents. + // That means we need a separate list to store the augmented history. + if (response.ChatThreadId is not null) + { + // The response indicates the inner client is tracking the history, so we don't want to send + // anything we've already sent or received. + if (augmentedHistory is not null) + { + augmentedHistory.Clear(); + } + else + { + augmentedHistory = []; + } + + lastIterationHadThreadId = true; + } + else if (lastIterationHadThreadId) + { + // In the very rare case where the inner client returned a response with a thread ID but then + // returned a subsequent response without one, we want to reconstitue the full history. To do that, + // we can populate the history with the original chat messages and then all of the response + // messages up until this point, which includes the most recent ones. + augmentedHistory ??= []; + augmentedHistory.Clear(); + augmentedHistory.AddRange(originalMessages); + augmentedHistory.AddRange(allTurnsResponseMessages); + + lastIterationHadThreadId = false; + } + else + { + // If augmentedHistory is already non-null, then we've already populated it with everything up + // until this point (except for the most recent response). If it's null, we need to seed it with + // the chat history provided by the caller. + augmentedHistory ??= originalMessages.ToList(); + + // Now add the most recent response messages. + augmentedHistory.AddMessages(response); + + lastIterationHadThreadId = false; + } + + // Use the augmented history as the new set of messages to send. + messages = augmentedHistory; + } + /// Copies any from to . private static bool CopyFunctionCalls( IList messages, [NotNullWhen(true)] ref List? functionCalls) @@ -487,14 +490,14 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// /// Processes the function calls in the list. /// - /// The current chat contents, inclusive of the function call contents being processed. + /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - IList chatMessages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) + List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. @@ -505,10 +508,12 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - chatMessages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); - IList added = AddResponseMessages(chatMessages, [result]); + messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); + + messages.AddRange(added); return (result.ContinueMode, added); } else @@ -521,7 +526,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti results = await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) select Task.Run(() => ProcessFunctionCallAsync( - chatMessages, options, functionCallContents, + messages, options, functionCallContents, iteration, i, cancellationToken))).ConfigureAwait(false); } else @@ -531,13 +536,17 @@ select Task.Run(() => ProcessFunctionCallAsync( for (int i = 0; i < results.Length; i++) { results[i] = await ProcessFunctionCallAsync( - chatMessages, options, functionCallContents, + messages, options, functionCallContents, iteration, i, cancellationToken).ConfigureAwait(false); } } ContinueMode continueMode = ContinueMode.Continue; - IList added = AddResponseMessages(chatMessages, results); + + IList added = CreateResponseMessages(results); + ThrowIfNoFunctionResultsAdded(added); + + messages.AddRange(added); foreach (FunctionInvocationResult fir in results) { if (fir.ContinueMode > continueMode) @@ -546,25 +555,23 @@ select Task.Run(() => ProcessFunctionCallAsync( } } - ThrowIfNoFunctionResultsAdded(added); return (continueMode, added); } } /// - /// Throws an exception if is empty due to an override of - /// not having added any messages. + /// Throws an exception if doesn't create any messages. /// - private void ThrowIfNoFunctionResultsAdded(IList chatMessages) + private void ThrowIfNoFunctionResultsAdded(IList? messages) { - if (chatMessages.Count == 0) + if (messages is null || messages.Count == 0) { - Throw.InvalidOperationException($"{GetType().Name}.{nameof(AddResponseMessages)} did not add any function result messages."); + Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); } } /// Processes the function call described in []. - /// The current chat contents, inclusive of the function call contents being processed. + /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. @@ -572,7 +579,7 @@ private void ThrowIfNoFunctionResultsAdded(IList chatMessages) /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - IList chatMessages, ChatOptions options, List callContents, + List messages, ChatOptions options, List callContents, int iteration, int functionCallIndex, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; @@ -586,7 +593,7 @@ private async Task ProcessFunctionCallAsync( FunctionInvocationContext context = new() { - ChatMessages = chatMessages, + Messages = messages, Options = options, CallContent = callContent, Function = function, @@ -632,24 +639,19 @@ internal enum ContinueMode Terminate = 2, } - /// Adds one or more response messages for function invocation results. - /// The chat to which to add the one or more response messages. + /// Creates one or more response messages for function invocation results. /// Information about the function call invocations and results. - /// A list of all chat messages added to . - /// is . - protected virtual IList AddResponseMessages(IList chatMessages, ReadOnlySpan results) + /// A list of all chat messages created from . + protected virtual IList CreateResponseMessages( + ReadOnlySpan results) { - _ = Throw.IfNull(chatMessages); - var contents = new List(results.Length); for (int i = 0; i < results.Length; i++) { contents.Add(CreateFunctionResultContent(results[i])); } - ChatMessage message = new(ChatRole.Tool, contents); - chatMessages.Add(message); - return [message]; + return [new(ChatRole.Tool, contents)]; FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) { @@ -753,22 +755,6 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul return result; } - /// Adds all messages from into . - private static void AddRange(IList destination, IEnumerable source) - { - if (destination is List list) - { - list.AddRange(source); - } - else - { - foreach (var message in source) - { - destination.Add(message); - } - } - } - private static TimeSpan GetElapsedTime(long startingTimestamp) => #if NET Stopwatch.GetElapsedTime(startingTimestamp); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs index b8e15718b78..51ca5a8f6d1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -44,13 +44,13 @@ public JsonSerializerOptions JsonSerializerOptions /// public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { if (_logger.IsEnabled(LogLevel.Debug)) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(nameof(GetResponseAsync), AsJson(chatMessages), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(nameof(GetResponseAsync), AsJson(messages), AsJson(options), AsJson(this.GetService())); } else { @@ -60,7 +60,7 @@ public override async Task GetResponseAsync( try { - var response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -90,13 +90,13 @@ public override async Task GetResponseAsync( /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (_logger.IsEnabled(LogLevel.Debug)) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(nameof(GetStreamingResponseAsync), AsJson(chatMessages), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(nameof(GetStreamingResponseAsync), AsJson(messages), AsJson(options), AsJson(this.GetService())); } else { @@ -107,7 +107,7 @@ public override async IAsyncEnumerable GetStreamingResponseA IAsyncEnumerator e; try { - e = base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + e = base.GetStreamingResponseAsync(messages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); } catch (OperationCanceledException) { @@ -173,8 +173,8 @@ public override async IAsyncEnumerable GetStreamingResponseA [LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")] private partial void LogInvoked(string methodName); - [LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")] - private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata); + [LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {Messages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")] + private partial void LogInvokedSensitive(string methodName, string messages, string chatOptions, string chatClientMetadata); [LoggerMessage(LogLevel.Debug, "{MethodName} completed.")] private partial void LogCompleted(string methodName); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index 3a98afadacf..b3e4f8b86bf 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -121,22 +121,23 @@ protected override void Dispose(bool disposing) base.GetService(serviceType, serviceKey); /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _jsonSerializerOptions.MakeReadOnly(); using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; string? requestModelId = options?.ModelId ?? _modelId; - LogChatMessages(chatMessages); + LogChatMessages(messages); ChatResponse? response = null; Exception? error = null; try { - response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); return response; } catch (Exception ex) @@ -152,21 +153,21 @@ public override async Task GetResponseAsync(IList cha /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _jsonSerializerOptions.MakeReadOnly(); using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; string? requestModelId = options?.ModelId ?? _modelId; - LogChatMessages(chatMessages); + LogChatMessages(messages); IAsyncEnumerable updates; try { - updates = base.GetStreamingResponseAsync(chatMessages, options, cancellationToken); + updates = base.GetStreamingResponseAsync(messages, options, cancellationToken); } catch (Exception ex) { diff --git a/src/Shared/Throw/Throw.cs b/src/Shared/Throw/Throw.cs index 393caaa7c0f..257a880e6ac 100644 --- a/src/Shared/Throw/Throw.cs +++ b/src/Shared/Throw/Throw.cs @@ -294,24 +294,6 @@ public static IEnumerable IfNullOrEmpty([NotNull] IEnumerable? argument return argument; } - /// - /// Throws an if the collection's - /// is . - /// - /// The collection to evaluate. - /// The name of the parameter being checked. - /// The type of objects in the collection. - /// The original value of . - public static ICollection IfReadOnly(ICollection argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") - { - if (argument.IsReadOnly) - { - ArgumentException(paramName, "Collection is read-only"); - } - - return argument; - } - #endregion #region Exceptions diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 4bc886ae580..04d686acebd 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -106,9 +106,9 @@ public async Task GetResponseAsync_CreatesTextMessageAsync() using TestChatClient client = new() { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - ChatMessage m = Assert.Single(chatMessages); + ChatMessage m = Assert.Single(messages); Assert.Equal(ChatRole.User, m.Role); Assert.Equal("hello", m.Text); @@ -133,9 +133,9 @@ public async Task GetStreamingResponseAsync_CreatesTextMessageAsync() using TestChatClient client = new() { - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - ChatMessage m = Assert.Single(chatMessages); + ChatMessage m = Assert.Single(messages); Assert.Equal(ChatRole.User, m.Role); Assert.Equal("hello", m.Text); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs index 08e46a52a86..7174d2a70c8 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -56,7 +56,7 @@ public void Constructor_RoleString_PropsRoundtrip(string? text) } [Fact] - public void Constructor_NullArgs_Valid() + public void Constructor_NullEmptyArgs_Valid() { ChatMessage message; @@ -72,7 +72,9 @@ public void Constructor_NullArgs_Valid() Assert.Empty(message.Text); Assert.Empty(message.Contents); - Assert.Throws(() => new ChatMessage(ChatRole.User, Array.Empty())); + message = new ChatMessage(ChatRole.User, Array.Empty()); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); } [Theory] diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index 1507d591a77..abe7660e0cf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -12,10 +12,21 @@ namespace Microsoft.Extensions.AI; public class ChatResponseTests { [Fact] - public void Constructor_InvalidArgs_Throws() + public void Constructor_NullEmptyArgs_Valid() { - Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); - Assert.Throws("messages", () => new ChatResponse((List)null!)); + ChatResponse response; + + response = new(); + Assert.Empty(response.Messages); + Assert.Empty(response.Text); + + response = new((ChatMessage?)null); + Assert.Empty(response.Messages); + Assert.Empty(response.Text); + + response = new((IList?)null); + Assert.Empty(response.Messages); + Assert.Empty(response.Text); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs index 226612bcff4..95f89a79141 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -17,20 +17,20 @@ public TestChatClient() public IServiceProvider? Services { get; set; } - public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } - public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } public Func GetServiceCallback { get; set; } private object? DefaultGetServiceCallback(Type serviceType, object? serviceKey) => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; - public Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - => GetResponseAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => GetResponseAsyncCallback!.Invoke(messages, options, cancellationToken); - public IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - => GetStreamingResponseAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => GetStreamingResponseAsyncCallback!.Invoke(messages, options, cancellationToken); public object? GetService(Type serviceType, object? serviceKey = null) => GetServiceCallback(serviceType, serviceKey); diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index a89beabc97e..b8a68c913ed 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -152,11 +152,11 @@ public async Task BasicRequestResponse_NonStreaming(bool multiContent) using HttpClient httpClient = new(handler); using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); - List chatMessages = multiContent ? + List messages = multiContent ? [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c.ToString())).ToList())] : [new ChatMessage(ChatRole.User, "hello")]; - var response = await client.GetResponseAsync(chatMessages, new() + var response = await client.GetResponseAsync(messages, new() { MaxOutputTokens = 10, Temperature = 0.5f, @@ -224,12 +224,12 @@ public async Task BasicRequestResponse_Streaming(bool multiContent) using HttpClient httpClient = new(handler); using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); - List chatMessages = multiContent ? + List messages = multiContent ? [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c.ToString())).ToList())] : [new ChatMessage(ChatRole.User, "hello")]; List updates = []; - await foreach (var update in client.GetStreamingResponseAsync(chatMessages, new() + await foreach (var update in client.GetStreamingResponseAsync(messages, new() { MaxOutputTokens = 20, Temperature = 0.5f, diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs index 8584ada0853..f853b3ac030 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs @@ -25,7 +25,7 @@ private ValueTask GetResultAsync() => async ValueTask IEvaluator.EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration, IEnumerable? additionalContext, CancellationToken cancellationToken) diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs index f68eb15380e..7ed0f31f3c5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs @@ -32,7 +32,7 @@ private static ScenarioRunResult CreateTestResult(string scenarioName, string it executionName: executionName, creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "User prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "LLM response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "LLM response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric)); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs index b522f797675..9418a5db359 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs @@ -36,7 +36,7 @@ public void SerializeScenarioRunResult() executionName: "Test Execution", creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric, metricWithNoValue)); string json = JsonSerializer.Serialize(entry, SerializerContext.Default.ScenarioRunResult); @@ -48,7 +48,7 @@ public void SerializeScenarioRunResult() Assert.Equal(entry.ExecutionName, deserialized.ExecutionName); Assert.Equal(entry.CreationTime, deserialized.CreationTime); Assert.True(entry.Messages.SequenceEqual(deserialized.Messages, ChatMessageComparer.Instance)); - Assert.Equal(entry.ModelResponse, deserialized.ModelResponse, ChatMessageComparer.Instance); + Assert.Equal(entry.ModelResponse, deserialized.ModelResponse, ChatResponseComparer.Instance); ValidateEquivalence(entry.EvaluationResult, deserialized.EvaluationResult); } @@ -75,7 +75,7 @@ public void SerializeDatasetCompact() executionName: "Test Execution", creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric, metricWithNoValue)); var dataset = new Dataset([entry], createdAt: DateTime.UtcNow, generatorVersion: "1.2.3.4"); @@ -89,7 +89,7 @@ public void SerializeDatasetCompact() Assert.Equal(entry.ExecutionName, deserialized.ScenarioRunResults[0].ExecutionName); Assert.Equal(entry.CreationTime, deserialized.ScenarioRunResults[0].CreationTime); Assert.True(entry.Messages.SequenceEqual(deserialized.ScenarioRunResults[0].Messages, ChatMessageComparer.Instance)); - Assert.Equal(entry.ModelResponse, deserialized.ScenarioRunResults[0].ModelResponse, ChatMessageComparer.Instance); + Assert.Equal(entry.ModelResponse, deserialized.ScenarioRunResults[0].ModelResponse, ChatResponseComparer.Instance); Assert.Single(deserialized.ScenarioRunResults); Assert.Equal(dataset.CreatedAt, deserialized.CreatedAt); @@ -155,7 +155,20 @@ public bool Equals(ChatMessage? x, ChatMessage? y) => x?.AuthorName == y?.AuthorName && x?.Role == y?.Role && x?.Text == y?.Text; public int GetHashCode(ChatMessage obj) - => obj.GetHashCode(); + => obj.Text.GetHashCode(); + } + + private class ChatResponseComparer : IEqualityComparer + { + public static ChatResponseComparer Instance { get; } = new ChatResponseComparer(); + + public bool Equals(ChatResponse? x, ChatResponse? y) + => + x is null ? y is null : + y is not null && x.Messages.SequenceEqual(y.Messages, ChatMessageComparer.Instance); + + public int GetHashCode(ChatResponse obj) + => obj.Text.GetHashCode(); } private class DiagnosticComparer : IEqualityComparer diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs index 853815ff033..c0045dc0f82 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs @@ -17,17 +17,17 @@ internal sealed class CallCountingChatClient(IChatClient innerClient) : Delegati public int CallCount => _callCount; public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Interlocked.Increment(ref _callCount); - return base.GetResponseAsync(chatMessages, options, cancellationToken); + return base.GetResponseAsync(messages, options, cancellationToken); } public override IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Interlocked.Increment(ref _callCount); - return base.GetStreamingResponseAsync(chatMessages, options, cancellationToken); + return base.GetStreamingResponseAsync(messages, options, cancellationToken); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs index 7c2c8343f95..1b8b90f4a3a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -7,6 +7,7 @@ using System.ComponentModel; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; @@ -39,13 +40,16 @@ internal sealed class PromptBasedFunctionCallingChatClient(IChatClient innerClie DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, }; - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { + List chatMessageList = [.. messages]; + // Our goal is to convert tools into a prompt describing them, then to detect tool calls in the // response and convert those into FunctionCallContent. if (options?.Tools is { Count: > 0 }) { - AddOrUpdateToolPrompt(chatMessages, options.Tools); + AddOrUpdateToolPrompt(chatMessageList, options.Tools); options = options.Clone(); options.Tools = null; @@ -58,7 +62,7 @@ public override async Task GetResponseAsync(IList cha // Since the point of this client is to avoid relying on the underlying model having // native tool call support, we have to replace any "tool" or "toolcall" messages with // "user" or "assistant" ones. - foreach (var message in chatMessages) + foreach (var message in chatMessageList) { for (var itemIndex = 0; itemIndex < message.Contents.Count; itemIndex++) { @@ -80,7 +84,7 @@ public override async Task GetResponseAsync(IList cha } } - var result = await base.GetResponseAsync(chatMessages, options, cancellationToken); + var result = await base.GetResponseAsync(chatMessageList, options, cancellationToken); if (result.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos && startPos >= 0) @@ -131,6 +135,16 @@ public override async Task GetResponseAsync(IList cha return result; } + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var response = await GetResponseAsync(messages, options, cancellationToken); + foreach (var update in response.ToChatResponseUpdates()) + { + yield return update; + } + } + private static void ParseArguments(IDictionary arguments) { // This is a simple implementation. A more robust answer is to use other schema information given by @@ -151,13 +165,13 @@ private static void ParseArguments(IDictionary arguments) } } - private static void AddOrUpdateToolPrompt(IList chatMessages, IList tools) + private static void AddOrUpdateToolPrompt(List messages, IList tools) { - var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text.StartsWith(MessageIntro, StringComparison.Ordinal) is true); + var existingToolPrompt = messages.FirstOrDefault(c => c.Text.StartsWith(MessageIntro, StringComparison.Ordinal) is true); if (existingToolPrompt is null) { existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null); - chatMessages.Insert(0, existingToolPrompt); + messages.Insert(0, existingToolPrompt); } var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType().Select(ToToolDescriptor), _jsonOptions); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs index ec7ca3c2cf0..99533e56f53 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs @@ -29,7 +29,7 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Equal(2, messages.Count); + Assert.Equal(2, messages.Count()); Assert.Collection(messages, m => Assert.StartsWith("Golden retrievers are quite active", m.Text, StringComparison.Ordinal), m => Assert.StartsWith("Are they good with kids?", m.Text, StringComparison.Ordinal)); @@ -61,69 +61,57 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit() public sealed class ReducingChatClient : DelegatingChatClient { private readonly IChatReducer _reducer; - private readonly bool _inPlace; /// Initializes a new instance of the class. /// The inner client. /// The reducer to be used by this instance. - /// - /// true if the should perform any modifications directly on the supplied list of messages; - /// false if it should instead create a new list when reduction is necessary. - /// - public ReducingChatClient(IChatClient innerClient, IChatReducer reducer, bool inPlace = false) + public ReducingChatClient(IChatClient innerClient, IChatReducer reducer) : base(innerClient) { _reducer = Throw.IfNull(reducer); - _inPlace = inPlace; } /// public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + messages = await _reducer.ReduceAsync(messages, cancellationToken).ConfigureAwait(false); - return await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); } /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + messages = await _reducer.ReduceAsync(messages, cancellationToken).ConfigureAwait(false); - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { yield return update; } } - - /// Runs the reducer and gets the chat message list to forward to the inner client. - private async Task> GetChatMessagesToPropagate(IList chatMessages, CancellationToken cancellationToken) => - await _reducer.ReduceAsync(chatMessages, _inPlace, cancellationToken).ConfigureAwait(false) ?? - chatMessages; } /// Represents a reducer capable of shrinking the size of a list of chat messages. public interface IChatReducer { /// Reduces the size of a list of chat messages. - /// The messages. - /// true if the reducer should modify the provided list; false if a new list should be returned. + /// The messages. /// The to monitor for cancellation requests. The default is . /// The new list of messages, or null if no reduction need be performed or was true. - Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken); + Task> ReduceAsync(IEnumerable messages, CancellationToken cancellationToken); } /// Provides extensions for configuring instances. public static class ReducingChatClientExtensions { - public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer, bool inPlace = false) + public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer) { _ = Throw.IfNull(builder); _ = Throw.IfNull(reducer); - return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer, inPlace)); + return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer)); } } @@ -139,51 +127,29 @@ public TokenCountingChatReducer(Tokenizer tokenizer, int tokenLimit) _tokenLimit = Throw.IfLessThan(tokenLimit, 1); } - public async Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken) + public async Task> ReduceAsync( + IEnumerable messages, CancellationToken cancellationToken) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); + + List list = messages.ToList(); - if (chatMessages.Count > 1) + if (list.Count > 1) { - int totalCount = CountTokens(chatMessages[chatMessages.Count - 1]); + int totalCount = CountTokens(list[list.Count - 1]); - if (inPlace) - { - for (int i = chatMessages.Count - 2; i >= 0; i--) - { - totalCount += CountTokens(chatMessages[i]); - if (totalCount > _tokenLimit) - { - if (chatMessages is List list) - { - list.RemoveRange(0, i + 1); - } - else - { - for (int j = i; j >= 0; j--) - { - chatMessages.RemoveAt(j); - } - } - - break; - } - } - } - else + for (int i = list.Count - 2; i >= 0; i--) { - for (int i = chatMessages.Count - 2; i >= 0; i--) + totalCount += CountTokens(list[i]); + if (totalCount > _tokenLimit) { - totalCount += CountTokens(chatMessages[i]); - if (totalCount > _tokenLimit) - { - return chatMessages.Skip(i + 1).ToList(); - } + list.RemoveRange(0, i + 1); + break; } } } - return null; + return list; } private int CountTokens(ChatMessage message) diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs index a31661166a8..83e84e49f5b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -106,10 +106,10 @@ public async Task InvalidModelParameter_ThrowsInvalidOperationException() private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) { public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Assert.Null(options?.Tools); - return base.GetResponseAsync(chatMessages, options, cancellationToken); + return base.GetResponseAsync(messages, options, cancellationToken); } } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index ef1ac9718d0..0b8aca0785e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; using System.Text.Json; using System.Text.RegularExpressions; using System.Threading.Tasks; @@ -84,7 +85,7 @@ public async Task WrapsNonObjectValuesInDataProperty() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - var suppliedSchemaMatch = Regex.Match(messages[1].Text!, "```(.*?)```", RegexOptions.Singleline); + var suppliedSchemaMatch = Regex.Match(messages.Last().Text!, "```(.*?)```", RegexOptions.Singleline); Assert.True(suppliedSchemaMatch.Success); Assert.Equal(""" { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvocationContextTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvocationContextTests.cs index ae8f062cb30..6fd3f0bd06a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvocationContextTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvocationContextTests.cs @@ -15,15 +15,15 @@ public void Constructor_PropertiesDefaultToExpectedValues() FunctionInvocationContext ctx = new(); Assert.NotNull(ctx.CallContent); - Assert.NotNull(ctx.ChatMessages); + Assert.NotNull(ctx.Messages); Assert.NotNull(ctx.Function); Assert.Equal(0, ctx.FunctionCallIndex); Assert.Equal(0, ctx.FunctionCount); Assert.Equal(0, ctx.Iteration); Assert.False(ctx.Terminate); - Assert.Empty(ctx.ChatMessages); - Assert.True(ctx.ChatMessages.IsReadOnly); + Assert.Empty(ctx.Messages); + Assert.True(ctx.Messages.IsReadOnly); Assert.Equal(nameof(FunctionInvocationContext), ctx.Function.Name); Assert.Empty(ctx.Function.Description); @@ -35,7 +35,7 @@ public void InvalidArgs_Throws() { FunctionInvocationContext ctx = new(); Assert.Throws("value", () => ctx.CallContent = null!); - Assert.Throws("value", () => ctx.ChatMessages = null!); + Assert.Throws("value", () => ctx.Messages = null!); Assert.Throws("value", () => ctx.Function = null!); } @@ -44,9 +44,9 @@ public void Properties_Roundtrip() { FunctionInvocationContext ctx = new(); - List chatMessages = []; - ctx.ChatMessages = chatMessages; - Assert.Same(chatMessages, ctx.ChatMessages); + List messages = []; + ctx.Messages = messages; + Assert.Same(messages, ctx.Messages); AIFunction function = AIFunctionFactory.Create(() => { }, nameof(Properties_Roundtrip)); ctx.Function = function; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index caaf8dae575..c2b6b067c8a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -401,6 +401,7 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() // If the conversation is just starting, issue two consecutive updates with function calls // Otherwise just end the conversation. List updates; + string responseId = Guid.NewGuid().ToString("N"); if (chatContents.Last().Text == "Hello") { updates = @@ -414,7 +415,10 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() updates = [new() { Contents = [new TextContent("OK bye")] }]; } - chatContents.AddRangeFromUpdates(updates); + foreach (var update in updates) + { + update.ResponseId = responseId; + } return YieldAsync(updates); } @@ -422,15 +426,10 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() using var client = new FunctionInvokingChatClient(innerClient); - var updates = new List(); - await foreach (var update in client.GetStreamingResponseAsync(messages, options, CancellationToken.None)) - { - updates.Add(update); - } + var response = await client.GetStreamingResponseAsync(messages, options, CancellationToken.None).ToChatResponseAsync(); - // Message history should now include the FCCs and FRCs - Assert.Collection(messages, - m => Assert.Equal("Hello", Assert.IsType(Assert.Single(m.Contents)).Text), + // The returned message should include the FCCs and FRCs. + Assert.Collection(response.Messages, m => Assert.Collection(m.Contents, c => Assert.Equal("Input 1", Assert.IsType(c).Arguments!["text"]), c => Assert.Equal("Input 2", Assert.IsType(c).Arguments!["text"])), @@ -438,11 +437,6 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() c => Assert.Equal("Result for Input 1", Assert.IsType(c).Result?.ToString()), c => Assert.Equal("Result for Input 2", Assert.IsType(c).Result?.ToString())), m => Assert.Equal("OK bye", Assert.IsType(Assert.Single(m.Contents)).Text)); - - // The returned updates also include the FCCs and FRCs - var allUpdateContents = updates.SelectMany(updates => updates.Contents).ToList(); - Assert.Contains(allUpdateContents, c => c is FunctionCallContent); - Assert.Contains(allUpdateContents, c => c is FunctionResultContent); } [Fact] @@ -464,12 +458,10 @@ public async Task AllResponseMessagesReturned() { await Task.Yield(); - ChatMessage message = chatContents.Count is 1 or 3 ? - new(ChatRole.Assistant, [new FunctionCallContent($"callId{chatContents.Count}", "Func1")]) : + ChatMessage message = chatContents.Count() is 1 or 3 ? + new(ChatRole.Assistant, [new FunctionCallContent($"callId{chatContents.Count()}", "Func1")]) : new(ChatRole.Assistant, "The answer is 42."); - chatContents.Add(message); - return new(message); } }; @@ -542,7 +534,7 @@ async Task InvokeAsync(Func>> work) { invocationContexts.Clear(); - var chatMessages = await work(); + var messages = await work(); Assert.Collection(invocationContexts, c => AssertInvocationContext(c, iteration: 0, terminate: false), @@ -551,7 +543,8 @@ async Task InvokeAsync(Func>> work) void AssertInvocationContext(FunctionInvocationContext context, int iteration, bool terminate) { Assert.NotNull(context); - Assert.Same(chatMessages, context.ChatMessages); + Assert.Equal(messages.Count, context.Messages.Count); + Assert.Equal(string.Concat(messages), string.Concat(context.Messages)); Assert.Same(function, context.Function); Assert.Equal("Func1", context.CallContent.Name); Assert.Equal(0, context.FunctionCallIndex); @@ -572,7 +565,7 @@ public async Task PropagatesResponseChatThreadIdToOptions() int iteration = 0; - Func, ChatOptions?, CancellationToken, ChatResponse> callback = + Func, ChatOptions?, CancellationToken, ChatResponse> callback = (chatContents, chatOptions, cancellationToken) => { iteration++; @@ -638,18 +631,19 @@ private static async Task> InvokeAndAssertAsync( var usage = CreateRandomUsage(); expectedTotalTokenCounts += usage.InputTokenCount!.Value; - var message = new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents]); - contents.Add(message); - return new ChatResponse(message) { Usage = usage }; + var message = new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count()].Contents]); + return new ChatResponse(message) { Usage = usage, ResponseId = Guid.NewGuid().ToString("N") }; } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetResponseAsync(chat, options, cts.Token); + Assert.NotNull(result); + + chat.AddRange(result.Messages); expected ??= plan; - Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); for (int i = 0; i < expected.Count; i++) { @@ -728,18 +722,19 @@ private static async Task> InvokeAndAssertStreamingAsync( { Assert.Equal(cts.Token, actualCancellationToken); - ChatMessage message = new(ChatRole.Assistant, [.. plan[contents.Count].Contents]); - contents.Add(message); - return YieldAsync(new ChatResponse(message).ToChatResponseUpdates()); + ChatMessage message = new(ChatRole.Assistant, [.. plan[contents.Count()].Contents]); + return YieldAsync(new ChatResponse(message) { ResponseId = Guid.NewGuid().ToString("N") }.ToChatResponseUpdates()); } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetStreamingResponseAsync(chat, options, cts.Token).ToChatResponseAsync(); + Assert.NotNull(result); + + chat.AddRange(result.Messages); expected ??= plan; - Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); for (int i = 0; i < expected.Count; i++) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs index c91d3c1ccdb..37ae545c04c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs @@ -62,7 +62,7 @@ public async Task ExpectedInformationLogged_Async(bool enableSensitiveData, bool }; async static IAsyncEnumerable CallbackAsync( - IList messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); @@ -105,7 +105,7 @@ async static IAsyncEnumerable CallbackAsync( }) .Build(); - List chatMessages = + List messages = [ new(ChatRole.System, "You are a close friend."), new(ChatRole.User, "Hey!"), @@ -136,14 +136,14 @@ async static IAsyncEnumerable CallbackAsync( if (streaming) { - await foreach (var update in chatClient.GetStreamingResponseAsync(chatMessages, options)) + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options)) { await Task.Yield(); } } else { - await chatClient.GetResponseAsync(chatMessages, options); + await chatClient.GetResponseAsync(messages, options); } var activity = Assert.Single(activities); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs index 63b68569d07..18ad0c08bbd 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs @@ -33,18 +33,18 @@ public async Task Shared_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "hello"))); }, - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); @@ -53,13 +53,13 @@ public async Task Shared_ContextPropagated() }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(async (chatMessages, options, next, cancellationToken) => + .Use(async (messages, options, next, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - await next(chatMessages, options, cancellationToken); + await next(messages, options, cancellationToken); }) .Build(); @@ -82,9 +82,9 @@ public async Task GetResponseFunc_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); @@ -93,13 +93,13 @@ public async Task GetResponseFunc_ContextPropagated() }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(async (chatMessages, options, innerClient, cancellationToken) => + .Use(async (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); + var cc = await innerClient.GetResponseAsync(messages, options, cancellationToken); cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world"; return cc; }, null) @@ -124,9 +124,9 @@ public async Task GetStreamingResponseFunc_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); @@ -135,18 +135,18 @@ public async Task GetStreamingResponseFunc_ContextPropagated() }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(null, (chatMessages, options, innerClient, cancellationToken) => + .Use(null, (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - return Impl(chatMessages, options, innerClient, cancellationToken); + return Impl(messages, options, innerClient, cancellationToken); static async IAsyncEnumerable Impl( - IList chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) { - await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken)) + await foreach (var update in innerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { yield return update; } @@ -175,18 +175,18 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "non-streaming hello"))); }, - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); @@ -196,28 +196,28 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( using IChatClient client = new ChatClientBuilder(innerClient) .Use( - async (chatMessages, options, innerClient, cancellationToken) => + async (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); + var cc = await innerClient.GetResponseAsync(messages, options, cancellationToken); cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world (non-streaming)"; return cc; }, - (chatMessages, options, innerClient, cancellationToken) => + (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - return Impl(chatMessages, options, innerClient, cancellationToken); + return Impl(messages, options, innerClient, cancellationToken); static async IAsyncEnumerable Impl( - IList chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) { - await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken)) + await foreach (var update in innerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { yield return update; } diff --git a/test/Shared/Throw/ThrowTest.cs b/test/Shared/Throw/ThrowTest.cs index 057f9098f5c..691217d86ce 100644 --- a/test/Shared/Throw/ThrowTest.cs +++ b/test/Shared/Throw/ThrowTest.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Collections.ObjectModel; using Xunit; namespace Microsoft.Shared.Diagnostics.Test; @@ -383,18 +382,6 @@ public void Shorter_Version_Of_NullOrEmpty_Get_Correct_Argument_Name() Assert.Contains(nameof(listButActuallyNull), exceptionImplicitArgumentName.Message); } - [Fact] - public void Collection_IfReadOnly() - { - _ = Throw.IfReadOnly(new List()); - - IList list = new int[4]; - Assert.Throws("list", () => Throw.IfReadOnly(list)); - - list = new ReadOnlyCollection(new List()); - Assert.Throws("list", () => Throw.IfReadOnly(list)); - } - #endregion #region For Enums From 1d275f481baa87ab4bf590650d8a5b4682d3e5a3 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 5 Mar 2025 16:45:29 -0500 Subject: [PATCH 7/9] Address PR feedback --- .../ChatCompletion/IChatClient.cs | 4 ++-- .../ChatConversationEvaluator.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 4b487b4d9b3..bfa9eb43d14 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -25,7 +25,7 @@ namespace Microsoft.Extensions.AI; public interface IChatClient : IDisposable { /// Sends chat messages and returns the response. - /// The list of chat messages to send and to be augmented with generated messages. + /// The sequence of chat messages to send. /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -36,7 +36,7 @@ Task GetResponseAsync( CancellationToken cancellationToken = default); /// Sends chat messages and streams the response. - /// The list of chat messages to send and to be augmented with generated messages. + /// The sequence of chat messages to send. /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs index 1d64fe89f4e..23642abd23d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs @@ -254,7 +254,7 @@ await PerformEvaluationAsync( /// prompt that this uses. /// /// - /// Messages that are part of the conversation history for the response being evaluated and that is to be rendered + /// Messages that are part of the conversation history for the response being evaluated and that are to be rendered /// as part of the evaluation prompt. /// /// A that can cancel the operation. From 514ff81dee84cc7fb05e5929abbd25418a99de3c Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 5 Mar 2025 18:07:25 -0500 Subject: [PATCH 8/9] Add RenderAsync overload --- .../ChatConversationEvaluator.cs | 44 +++++++++++++++---- .../CoherenceEvaluator.cs | 4 +- .../EquivalenceEvaluator.cs | 4 +- .../FluencyEvaluator.cs | 4 +- .../GroundednessEvaluator.cs | 6 +-- .../RelevanceTruthAndCompletenessEvaluator.cs | 6 +-- 6 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs index 23642abd23d..cbc904277ab 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs @@ -1,8 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -250,26 +252,50 @@ await PerformEvaluationAsync( } /// - /// Renders the supplied to a string that can be included as part of the evaluation + /// Renders the supplied to a string that can be included as part of the evaluation /// prompt that this uses. /// - /// - /// Messages that are part of the conversation history for the response being evaluated and that are to be rendered - /// as part of the evaluation prompt. + /// + /// Chat response being evaluated and that is to be rendered as part of the evaluation prompt. /// /// A that can cancel the operation. /// - /// A string representation of the supplied that can be included as part of the + /// A string representation of the supplied that can be included as part of the /// evaluation prompt. /// /// - /// The default implementation considers only the last message of . + /// The default implementation uses to render + /// each message in the response. /// - protected virtual ValueTask RenderAsync(IEnumerable messages, CancellationToken cancellationToken) + protected virtual async ValueTask RenderAsync(ChatResponse response, CancellationToken cancellationToken) { - _ = Throw.IfNullOrEmpty(messages); + _ = Throw.IfNull(response); - ChatMessage message = messages.Last(); + StringBuilder sb = new(); + foreach (ChatMessage message in response.Messages) + { + _ = sb.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); + } + + return sb.ToString(); + } + + /// + /// Renders the supplied to a string that can be included as part of the evaluation + /// prompt that this uses. + /// + /// + /// Message that is part of the conversation history for the response being evaluated and that is to be rendered + /// as part of the evaluation prompt. + /// + /// A that can cancel the operation. + /// + /// A string representation of the supplied that can be included as part of the + /// evaluation prompt. + /// + protected virtual ValueTask RenderAsync(ChatMessage message, CancellationToken cancellationToken) + { + _ = Throw.IfNull(message); string? author = message.AuthorName; string role = message.Role.Value; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs index 1a482a09eaf..4122a063cf4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs @@ -39,11 +39,11 @@ protected override async ValueTask RenderEvaluationPromptAsync( { _ = Throw.IfNull(modelResponse); - string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) + ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) : string.Empty; string prompt = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs index d3fbb9ed56a..5926d260374 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs @@ -43,11 +43,11 @@ protected override async ValueTask RenderEvaluationPromptAsync( { _ = Throw.IfNull(modelResponse); - string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) + ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) : string.Empty; string groundTruth; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs index 4612baa77ed..d08a30a31b2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs @@ -39,11 +39,11 @@ protected override async ValueTask RenderEvaluationPromptAsync( { _ = Throw.IfNull(modelResponse); - string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) + ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) : string.Empty; string prompt = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs index b39d17103ab..cbb66657a68 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs @@ -43,11 +43,11 @@ protected override async ValueTask RenderEvaluationPromptAsync( { _ = Throw.IfNull(modelResponse); - string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) + ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) : string.Empty; var builder = new StringBuilder(); @@ -64,7 +64,7 @@ userRequest is not null { foreach (ChatMessage message in includedHistory) { - _ = builder.Append(await RenderAsync([message], cancellationToken).ConfigureAwait(false)); + _ = builder.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs index 6c2e843efad..419feb45743 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs @@ -83,11 +83,11 @@ protected override async ValueTask RenderEvaluationPromptAsync( { _ = Throw.IfNull(modelResponse); - string renderedModelResponse = await RenderAsync(modelResponse.Messages, cancellationToken).ConfigureAwait(false); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = userRequest is not null - ? await RenderAsync([userRequest], cancellationToken).ConfigureAwait(false) + ? await RenderAsync(userRequest, cancellationToken).ConfigureAwait(false) : string.Empty; var builder = new StringBuilder(); @@ -95,7 +95,7 @@ userRequest is not null { foreach (ChatMessage message in includedHistory) { - _ = builder.Append(await RenderAsync([message], cancellationToken).ConfigureAwait(false)); + _ = builder.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); } } From 254701cc7aa0e1fee1e4e092b9efa8bad4820e18 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 6 Mar 2025 10:12:08 -0500 Subject: [PATCH 9/9] Final feedback and cleanup --- .../ChatCompletion/ChatMessage.cs | 6 +- .../ChatCompletion/ChatResponse.cs | 34 +++------ .../ChatCompletion/ChatResponseExtensions.cs | 8 +-- .../ChatCompletion/ChatResponseUpdate.cs | 6 +- .../ChatCompletion/DelegatingChatClient.cs | 2 +- .../ChatCompletion/IChatClient.cs | 5 +- .../Contents/AIContentExtensions.cs | 72 ++++++++++++++++++- .../DelegatingEmbeddingGenerator.cs | 2 +- .../EmbeddingGeneratorExtensions.cs | 14 ++-- .../Embeddings/IEmbeddingGenerator.cs | 12 ++-- .../AnonymousDelegatingChatClient.cs | 2 +- .../ChatCompletion/ChatClientBuilder.cs | 7 +- .../FunctionInvokingChatClient.cs | 6 +- .../Embeddings/CachingEmbeddingGenerator.cs | 2 +- .../Embeddings/EmbeddingGeneratorBuilder.cs | 7 +- .../Ingestion/IngestionCacheDbContext.cs | 2 +- .../Services/JsonVectorStore.cs | 2 +- .../ChatClientExtensionsTests.cs | 2 +- .../ChatCompletion/ChatResponseTests.cs | 6 +- .../EmbeddingGeneratorExtensionsTests.cs | 2 +- .../Utilities/AIJsonUtilitiesTests.cs | 16 ++--- .../UseDelegateEmbeddingGeneratorTests.cs | 2 +- .../Functions/AIFunctionFactoryTest.cs | 2 +- 23 files changed, 141 insertions(+), 78 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index eae74f68e62..049536cecd8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -25,9 +25,9 @@ public ChatMessage() /// Initializes a new instance of the class. /// The role of the author of the message. - /// The text contents of the message. - public ChatMessage(ChatRole role, string? contents) - : this(role, contents is null ? [] : [new TextContent(contents)]) + /// The text content of the message. + public ChatMessage(ChatRole role, string? content) + : this(role, content is null ? [] : [new TextContent(content)]) { } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index 4b2bf5b95aa..6babae1258f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -3,8 +3,9 @@ using System; using System.Collections.Generic; -using System.Linq; +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -28,12 +29,12 @@ public ChatResponse() /// Initializes a new instance of the class. /// The response message. - public ChatResponse(ChatMessage? message) + /// is . + public ChatResponse(ChatMessage message) { - if (message is not null) - { - Messages.Add(message); - } + _ = Throw.IfNull(message); + + Messages.Add(message); } /// Initializes a new instance of the class. @@ -44,6 +45,7 @@ public ChatResponse(IList? messages) } /// Gets or sets the chat response messages. + [AllowNull] public IList Messages { get => _messages ??= new List(1); @@ -56,25 +58,7 @@ public IList Messages /// instances in . /// [JsonIgnore] - public string Text - { - get - { - IList? messages = _messages; - if (messages is null) - { - return string.Empty; - } - - int count = messages.Count; - return count switch - { - 0 => string.Empty, - 1 => messages[0].Text, - _ => string.Join(Environment.NewLine, messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s))), - }; - } - } + public string Text => _messages?.ConcatText() ?? string.Empty; /// Gets or sets the ID of the chat response. public string? ResponseId { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs index 604918f21fc..16eed49db93 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs @@ -19,7 +19,7 @@ namespace Microsoft.Extensions.AI; public static class ChatResponseExtensions { /// Adds all of the messages from into . - /// The destination list into which the messages should be added. + /// The destination list to which the messages from should be added. /// The response containing the messages to add. /// is . /// is . @@ -42,7 +42,7 @@ public static void AddMessages(this IList list, ChatResponse respon } /// Converts the into instances and adds them to . - /// The list to which the newly constructed messages should be added. + /// The destination list to which the newly constructed messages should be added. /// The instances to convert to messages and add to the list. /// is . /// is . @@ -84,9 +84,9 @@ public static Task AddMessagesAsync( _ = Throw.IfNull(list); _ = Throw.IfNull(updates); - return AddRangeFromUpdatesAsync(list, updates, cancellationToken); + return AddMessagesAsync(list, updates, cancellationToken); - static async Task AddRangeFromUpdatesAsync( + static async Task AddMessagesAsync( IList list, IAsyncEnumerable updates, CancellationToken cancellationToken) => list.AddMessages(await updates.ToChatResponseAsync(cancellationToken).ConfigureAwait(false)); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 0202c4b65c4..24610ac76fc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -46,9 +46,9 @@ public ChatResponseUpdate() /// Initializes a new instance of the class. /// The role of the author of the update. - /// The text contents of the update. - public ChatResponseUpdate(ChatRole? role, string? contents) - : this(role, contents is null ? null : [new TextContent(contents)]) + /// The text content of the update. + public ChatResponseUpdate(ChatRole? role, string? content) + : this(role, content is null ? null : [new TextContent(content)]) { } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs index c0ccbb84f28..23768dd8da7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs @@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI; /// Provides an optional base class for an that passes through calls to another instance. /// /// -/// This is recommended as a base type when building clients that can be chained in any order around an underlying . +/// This is recommended as a base type when building clients that can be chained around an underlying . /// The default implementation simply passes each call to the inner client instance. /// public class DelegatingChatClient : IChatClient diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index bfa9eb43d14..0de18809bbc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -13,12 +13,13 @@ namespace Microsoft.Extensions.AI; /// /// Unless otherwise specified, all members of are thread-safe for concurrent use. /// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. /// /// /// However, implementations of might mutate the arguments supplied to and /// , such as by configuring the options instance. Thus, consumers of the interface either /// should avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction -/// that no instances are used which might employ such mutation. For example, the WithChatOptions method be +/// that no instances are used which might employ such mutation. For example, the ConfigureOptions method is /// provided with a callback that could mutate the supplied options argument, and that should be avoided if using a singleton options instance. /// /// @@ -52,7 +53,7 @@ IAsyncEnumerable GetStreamingResponseAsync( /// The found object, otherwise . /// is . /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , /// including itself or any services it might be wrapping. For example, to access the for the instance, /// may be used to request it. /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs index d2b1e73a5b3..550a48ab6de 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs @@ -6,6 +6,8 @@ using System.Linq; #if NET using System.Runtime.CompilerServices; +#else +using System.Text; #endif namespace Microsoft.Extensions.AI; @@ -27,9 +29,9 @@ public static string ConcatText(this IEnumerable contents) case 1: return (list[0] as TextContent)?.Text ?? string.Empty; -#if NET default: - DefaultInterpolatedStringHandler builder = new(0, 0, null, stackalloc char[512]); +#if NET + DefaultInterpolatedStringHandler builder = new(count, 0, null, stackalloc char[512]); for (int i = 0; i < count; i++) { if (list[i] is TextContent text) @@ -39,10 +41,76 @@ public static string ConcatText(this IEnumerable contents) } return builder.ToStringAndClear(); +#else + StringBuilder builder = new(); + for (int i = 0; i < count; i++) + { + if (list[i] is TextContent text) + { + builder.Append(text.Text); + } + } + + return builder.ToString(); #endif } } return string.Concat(contents.OfType()); } + + /// Concatenates the of all instances in the list. + /// A newline separator is added between each non-empty piece of text. + public static string ConcatText(this IList messages) + { + int count = messages.Count; + switch (count) + { + case 0: + return string.Empty; + + case 1: + return messages[0].Text; + + default: +#if NET + DefaultInterpolatedStringHandler builder = new(count, 0, null, stackalloc char[512]); + bool needsSeparator = false; + for (int i = 0; i < count; i++) + { + string text = messages[i].Text; + if (text.Length > 0) + { + if (needsSeparator) + { + builder.AppendLiteral(Environment.NewLine); + } + + builder.AppendLiteral(text); + + needsSeparator = true; + } + } + + return builder.ToStringAndClear(); +#else + StringBuilder builder = new(); + for (int i = 0; i < count; i++) + { + string text = messages[i].Text; + if (text.Length > 0) + { + if (builder.Length > 0) + { + builder.AppendLine(); + } + + builder.Append(text); + } + } + + return builder.ToString(); +#endif + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs index f1a4c3aa7a2..e15c2981613 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs @@ -15,7 +15,7 @@ namespace Microsoft.Extensions.AI; /// The type of the input passed to the generator. /// The type of the embedding instance produced by the generator. /// -/// This type is recommended as a base type when building generators that can be chained in any order around an underlying . +/// This type is recommended as a base type when building generators that can be chained around an underlying . /// The default implementation simply passes each call to the inner generator instance. /// public class DelegatingEmbeddingGenerator : IEmbeddingGenerator diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index d69952598dd..35d8260e406 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -189,15 +189,21 @@ public static async Task GenerateEmbeddingAsync( if (embeddings is null) { - throw new InvalidOperationException("Embedding generator returned a null collection of embeddings."); + Throw.InvalidOperationException("Embedding generator returned a null collection of embeddings."); } if (embeddings.Count != 1) { - throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs (1)."); + Throw.InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs (1)."); } - return embeddings[0] ?? throw new InvalidOperationException("Embedding generator generated a null embedding."); + TEmbedding embedding = embeddings[0]; + if (embedding is null) + { + Throw.InvalidOperationException("Embedding generator generated a null embedding."); + } + + return embedding; } /// @@ -235,7 +241,7 @@ public static async Task GenerateEmbeddingAsync( var embeddings = await generator.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); if (embeddings.Count != inputsCount) { - throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputsCount})."); + Throw.InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputsCount})."); } var results = new (TInput, TEmbedding)[embeddings.Count]; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 531b8ceeeb5..59fcc9e2393 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -15,21 +15,21 @@ namespace Microsoft.Extensions.AI; /// /// Unless otherwise specified, all members of are thread-safe for concurrent use. /// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. /// /// /// However, implementations of may mutate the arguments supplied to -/// , such as by adding additional values to the values list or configuring the options -/// instance. Thus, consumers of the interface either should avoid using shared instances of these arguments for concurrent -/// invocations or should otherwise ensure by construction that no instances -/// are used which might employ such mutation. +/// , such as by configuring the options instance. Thus, consumers of the interface either should +/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that +/// no instances are used which might employ such mutation. /// /// public interface IEmbeddingGenerator : IDisposable where TEmbedding : Embedding { /// Generates embeddings for each of the supplied . - /// The collection of values for which to generate embeddings. - /// The embedding generation options to configure the request. + /// The sequence of values for which to generate embeddings. + /// The embedding generation options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embeddings. /// is . diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index 5e3032fee37..dbc3114ec25 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -104,7 +104,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken if (response is null) { - throw new InvalidOperationException("The wrapper completed successfully without producing a ChatResponse."); + Throw.InvalidOperationException("The wrapper completed successfully without producing a ChatResponse."); } return response; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs index 23fba1e0abd..8789810b601 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -49,10 +49,13 @@ public IChatClient Build(IServiceProvider? services = null) { for (var i = _clientFactories.Count - 1; i >= 0; i--) { - chatClient = _clientFactories[i](chatClient, services) ?? - throw new InvalidOperationException( + chatClient = _clientFactories[i](chatClient, services); + if (chatClient is null) + { + Throw.InvalidOperationException( $"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances."); + } } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 43cb2019b1c..67e6249298c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -205,7 +205,7 @@ public override async Task GetResponseAsync( response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); if (response is null) { - throw new InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); } // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. @@ -291,7 +291,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { if (update is null) { - throw new InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); } updates.Add(update); @@ -311,7 +311,7 @@ public override async IAsyncEnumerable GetStreamingResponseA } // Reconsistitue a response from the response updates. - ChatResponse response = updates.ToChatResponse(); + var response = updates.ToChatResponse(); (responseMessages ??= []).AddRange(response.Messages); // Prepare the history for the next iteration. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index 688e4b2353d..43a983d7fd4 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -51,7 +51,7 @@ public override async Task> GenerateAsync( var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); if (generated.Count != 1) { - throw new InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); + Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); } await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs index e5cd800800d..dcb33d37c3c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs @@ -54,10 +54,13 @@ public IEmbeddingGenerator Build(IServiceProvider? services { for (var i = _generatorFactories.Count - 1; i >= 0; i--) { - embeddingGenerator = _generatorFactories[i](embeddingGenerator, services) ?? - throw new InvalidOperationException( + embeddingGenerator = _generatorFactories[i](embeddingGenerator, services); + if (embeddingGenerator is null) + { + Throw.InvalidOperationException( $"The {nameof(IEmbeddingGenerator)} entry at index {i} returned null. " + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IEmbeddingGenerator)} instances."); + } } } diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs index aeaf4ccd52d..78842253abe 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs @@ -34,7 +34,7 @@ public class IngestedDocument public required string Id { get; set; } public required string SourceId { get; set; } public required string Version { get; set; } - public List Records { get; set; } = new(); + public List Records { get; set; } = []; } public class IngestedRecord diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs index 9dba51c7692..cb787c3bbef 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs @@ -51,7 +51,7 @@ public Task CollectionExistsAsync(CancellationToken cancellationToken = de public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - _records = new(); + _records = []; await WriteToDiskAsync(cancellationToken); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 04d686acebd..c74c50813f4 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -32,7 +32,7 @@ public void GetService_ValidService_Returned() { using var client = new TestChatClient { - GetServiceCallback = (Type serviceType, object? serviceKey) => + GetServiceCallback = (serviceType, serviceKey) => { if (serviceType == typeof(string)) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index abe7660e0cf..ee719ee5647 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -20,13 +20,11 @@ public void Constructor_NullEmptyArgs_Valid() Assert.Empty(response.Messages); Assert.Empty(response.Text); - response = new((ChatMessage?)null); - Assert.Empty(response.Messages); - Assert.Empty(response.Text); - response = new((IList?)null); Assert.Empty(response.Messages); Assert.Empty(response.Text); + + Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index 8a61fbb0786..fe4af33cf23 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -33,7 +33,7 @@ public void GetService_ValidService_Returned() { using IEmbeddingGenerator> generator = new TestEmbeddingGenerator { - GetServiceCallback = (Type serviceType, object? serviceKey) => + GetServiceCallback = (serviceType, serviceKey) => { if (serviceType == typeof(string)) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index 77209e0146c..d828365d8b5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -387,9 +387,9 @@ public static void AddAIContentType_ReadOnlyJsonSerializerOptions_ThrowsInvalidO public static void AddAIContentType_NonAIContent_ThrowsArgumentException() { JsonSerializerOptions options = new(); - Assert.Throws(() => options.AddAIContentType(typeof(int), "discriminator")); - Assert.Throws(() => options.AddAIContentType(typeof(object), "discriminator")); - Assert.Throws(() => options.AddAIContentType(typeof(ChatMessage), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(int), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(object), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(ChatMessage), "discriminator")); } [Fact] @@ -415,11 +415,11 @@ public static void AddAIContentType_ConflictingIdentifier_ThrowsInvalidOperation public static void AddAIContentType_NullArguments_ThrowsArgumentNullException() { JsonSerializerOptions options = new(); - Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType("discriminator")); - Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator")); - Assert.Throws(() => options.AddAIContentType(null!)); - Assert.Throws(() => options.AddAIContentType(typeof(DerivedAIContent), null!)); - Assert.Throws(() => options.AddAIContentType(null!, "discriminator")); + Assert.Throws("options", () => ((JsonSerializerOptions)null!).AddAIContentType("discriminator")); + Assert.Throws("options", () => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator")); + Assert.Throws("typeDiscriminatorId", () => options.AddAIContentType(null!)); + Assert.Throws("typeDiscriminatorId", () => options.AddAIContentType(typeof(DerivedAIContent), null!)); + Assert.Throws("contentType", () => options.AddAIContentType(null!, "discriminator")); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs index 1109cbc581a..e71e6d9461c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs @@ -24,7 +24,7 @@ public void InvalidArgs_Throws() [Fact] public async Task GenerateFunc_ContextPropagated() { - GeneratedEmbeddings> expectedEmbeddings = new(); + GeneratedEmbeddings> expectedEmbeddings = []; IList expectedValues = ["hello"]; EmbeddingGenerationOptions expectedOptions = new(); using CancellationTokenSource expectedCts = new(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index a196823c5c5..3d94420063c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -120,7 +120,7 @@ public void Metadata_DerivedFromLambda() Assert.Empty(func.Description); Assert.Same(dotnetFunc.Method, func.UnderlyingMethod); - Func dotnetFunc2 = (string a) => a + " " + a; + Func dotnetFunc2 = a => a + " " + a; func = AIFunctionFactory.Create(dotnetFunc2); Assert.Contains("Metadata_DerivedFromLambda", func.Name); Assert.Empty(func.Description);