diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 7d75f077ac..0fa6473de0 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -162,7 +162,10 @@ static Task GetResponseAsync(IChatClient chatClient, List RunCoreStreami { var inputMessages = Throw.IfNull(messages) as IReadOnlyCollection ?? messages.ToList(); - (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages, IList? chatMessageStoreMessages) = + (ChatClientAgentThread safeThread, + ChatOptions? chatOptions, + List inputMessagesForChatClient, + IList? aiContextProviderMessages, + IList? chatMessageStoreMessages, + ChatClientAgentContinuationToken? continuationToken) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); - ValidateStreamResumptionAllowed(chatOptions?.ContinuationToken, safeThread); - var chatClient = this.ChatClient; chatClient = ApplyRunOptionsTransformations(options, chatClient); @@ -214,7 +220,7 @@ protected override async IAsyncEnumerable RunCoreStreami this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType); - List responseUpdates = []; + List responseUpdates = GetResponseUpdates(continuationToken); IAsyncEnumerator responseUpdatesEnumerator; @@ -225,8 +231,8 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { - await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -240,8 +246,8 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { - await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -253,7 +259,12 @@ protected override async IAsyncEnumerable RunCoreStreami update.AuthorName ??= this.Name; responseUpdates.Add(update); - yield return new(update) { AgentId = this.Id }; + + yield return new(update) + { + AgentId = this.Id, + ContinuationToken = WrapContinuationToken(update.ContinuationToken, GetInputMessages(inputMessages, continuationToken), responseUpdates) + }; } try @@ -262,8 +273,8 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { - await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } } @@ -275,10 +286,10 @@ protected override async IAsyncEnumerable RunCoreStreami this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken), chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await NotifyAIContextProviderOfSuccessAsync(safeThread, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -382,7 +393,12 @@ private async Task RunCoreAsync ?? messages.ToList(); - (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages, IList? chatMessageStoreMessages) = + (ChatClientAgentThread safeThread, + ChatOptions? chatOptions, + List inputMessagesForChatClient, + IList? aiContextProviderMessages, + IList? chatMessageStoreMessages, + ChatClientAgentContinuationToken? _) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); var chatClient = this.ChatClient; @@ -474,20 +490,20 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider /// Optional run options that may include specific chat configuration settings. /// A object representing the merged chat configuration, or if /// neither the run options nor the agent's chat options are available. - private ChatOptions? CreateConfiguredChatOptions(AgentRunOptions? runOptions) + private (ChatOptions?, ChatClientAgentContinuationToken?) CreateConfiguredChatOptions(AgentRunOptions? runOptions) { ChatOptions? requestChatOptions = (runOptions as ChatClientAgentRunOptions)?.ChatOptions?.Clone(); // If no agent chat options were provided, return the request chat options as is. if (this._agentOptions?.ChatOptions is null) { - return ApplyBackgroundResponsesProperties(requestChatOptions, runOptions); + return GetContinuationTokenAndApplyBackgroundResponsesProperties(requestChatOptions, runOptions); } // If no request chat options were provided, use the agent's chat options clone. if (requestChatOptions is null) { - return ApplyBackgroundResponsesProperties(this._agentOptions?.ChatOptions.Clone(), runOptions); + return GetContinuationTokenAndApplyBackgroundResponsesProperties(this._agentOptions?.ChatOptions.Clone(), runOptions); } // If both are present, we need to merge them. @@ -583,19 +599,26 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider } } - return ApplyBackgroundResponsesProperties(requestChatOptions, runOptions); + return GetContinuationTokenAndApplyBackgroundResponsesProperties(requestChatOptions, runOptions); - static ChatOptions? ApplyBackgroundResponsesProperties(ChatOptions? chatOptions, AgentRunOptions? agentRunOptions) + static (ChatOptions?, ChatClientAgentContinuationToken?) GetContinuationTokenAndApplyBackgroundResponsesProperties(ChatOptions? chatOptions, AgentRunOptions? agentRunOptions) { - // If any of the background response properties are set in the run options, we should apply both to the chat options. - if (agentRunOptions?.AllowBackgroundResponses is not null || agentRunOptions?.ContinuationToken is not null) + if (agentRunOptions?.AllowBackgroundResponses is not null) { chatOptions ??= new ChatOptions(); chatOptions.AllowBackgroundResponses = agentRunOptions.AllowBackgroundResponses; - chatOptions.ContinuationToken = agentRunOptions.ContinuationToken; } - return chatOptions; + ChatClientAgentContinuationToken? agentContinuationToken = null; + + if ((agentRunOptions?.ContinuationToken ?? chatOptions?.ContinuationToken) is { } continuationToken) + { + agentContinuationToken = ChatClientAgentContinuationToken.FromToken(continuationToken); + chatOptions ??= new ChatOptions(); + chatOptions.ContinuationToken = agentContinuationToken!.InnerToken; + } + + return (chatOptions, agentContinuationToken); } } @@ -606,21 +629,22 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider /// The input messages to use. /// Optional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . - /// A tuple containing the thread, chat options, and thread messages. + /// A tuple containing the thread, chat options, messages and continuation token. private async Task <( ChatClientAgentThread AgentThread, ChatOptions? ChatOptions, List InputMessagesForChatClient, IList? AIContextProviderMessages, - IList? ChatMessageStoreMessages + IList? ChatMessageStoreMessages, + ChatClientAgentContinuationToken? ContinuationToken )> PrepareThreadAndMessagesAsync( AgentThread? thread, IEnumerable inputMessages, AgentRunOptions? runOptions, CancellationToken cancellationToken) { - ChatOptions? chatOptions = this.CreateConfiguredChatOptions(runOptions); + (ChatOptions? chatOptions, ChatClientAgentContinuationToken? continuationToken) = this.CreateConfiguredChatOptions(runOptions); // Supplying a thread for background responses is required to prevent inconsistent experience // for callers if they forget to provide the thread for initial or follow-up runs. @@ -641,11 +665,6 @@ private async Task throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token."); } - if (chatOptions?.ContinuationToken is not null && typedThread.ConversationId is null && typedThread.MessageStore is null) - { - throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs."); - } - List inputMessagesForChatClient = []; IList? aiContextProviderMessages = null; IList? chatMessageStoreMessages = null; @@ -713,7 +732,7 @@ private async Task chatOptions.ConversationId = typedThread.ConversationId; } - return (typedThread, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatMessageStoreMessages); + return (typedThread, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatMessageStoreMessages, continuationToken); } private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, string? responseConversationId) @@ -791,26 +810,43 @@ private static Task NotifyMessageStoreOfNewMessagesAsync( return Task.CompletedTask; } - private static void ValidateStreamResumptionAllowed(ResponseContinuationToken? continuationToken, ChatClientAgentThread safeThread) + private static ChatClientAgentContinuationToken? WrapContinuationToken(ResponseContinuationToken? continuationToken, IEnumerable? inputMessages = null, List? responseUpdates = null) { if (continuationToken is null) { - return; + return null; } - // Streaming resumption is only supported with chat history managed by the agent service because, currently, there's no good solution - // to collect updates received in failed runs and pass them to the last successful run so it can store them to the message store. - if (safeThread.ConversationId is null) + return new(continuationToken) { - throw new NotSupportedException("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service."); - } + // Save input messages to the continuation token so they can be added to the thread and + // provided to the context provider in the last successful streaming resumption run. + // That's necessary for scenarios where initial streaming run is interrupted and streaming is resumed later. + InputMessages = inputMessages?.Any() is true ? inputMessages : null, + + // Save all updates received so far to the continuation token so they can be provided to the + // message store and context provider in the last successful streaming resumption run. + // That's necessary for scenarios where a streaming run is interrupted after some updates were received. + ResponseUpdates = responseUpdates?.Count > 0 ? responseUpdates : null + }; + } - // Similarly, streaming resumption is not supported when a context provider is used because, currently, there's no good solution - // to collect updates received in failed runs and pass them to the last successful run so it can notify the context provider of the updates. - if (safeThread.AIContextProvider is not null) + private static IEnumerable GetInputMessages(IReadOnlyCollection inputMessages, ChatClientAgentContinuationToken? token) + { + // First, use input messages if provided. + if (inputMessages.Count > 0) { - throw new NotSupportedException("Using context provider with streaming resumption is not supported."); + return inputMessages; } + + // Fallback to messages saved in the continuation token if available. + return token?.InputMessages ?? []; + } + + private static List GetResponseUpdates(ChatClientAgentContinuationToken? token) + { + // Restore any previously received updates from the continuation token. + return token?.ResponseUpdates?.ToList() ?? []; } private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent"; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs new file mode 100644 index 0000000000..aa5659b1d1 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentContinuationToken.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Represents a continuation token for ChatClientAgent operations. +/// +internal class ChatClientAgentContinuationToken : ResponseContinuationToken +{ + private const string TokenTypeName = "chatClientAgentContinuationToken"; + private const string TypeDiscriminator = "type"; + + /// + /// Initializes a new instance of the class. + /// + /// A continuation token provided by the underlying . + [JsonConstructor] + internal ChatClientAgentContinuationToken(ResponseContinuationToken innerToken) + { + this.InnerToken = innerToken; + } + + public override ReadOnlyMemory ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + + writer.WriteStartObject(); + + // This property should be the first one written to identify the type during deserialization. + writer.WriteString(TypeDiscriminator, TokenTypeName); + + writer.WriteString("innerToken", JsonSerializer.Serialize(this.InnerToken, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken)))); + + if (this.InputMessages?.Any() is true) + { + writer.WriteString("inputMessages", JsonSerializer.Serialize(this.InputMessages, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IEnumerable)))); + } + + if (this.ResponseUpdates?.Count > 0) + { + writer.WriteString("responseUpdates", JsonSerializer.Serialize(this.ResponseUpdates, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyList)))); + } + + writer.WriteEndObject(); + + writer.Flush(); + + return stream.ToArray(); + } + + /// + /// Create a new instance of from the provided . + /// + /// The token to create the from. + /// A equivalent of the provided . + internal static ChatClientAgentContinuationToken FromToken(ResponseContinuationToken token) + { + if (token is ChatClientAgentContinuationToken chatClientContinuationToken) + { + return chatClientContinuationToken; + } + + ReadOnlyMemory data = token.ToBytes(); + + if (data.Length == 0) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it does not contain any data."); + } + + Utf8JsonReader reader = new(data.Span); + + // Move to the start object token. + _ = reader.Read(); + + // Validate that the token is of this type. + ValidateTokenType(reader, token); + + ResponseContinuationToken? innerToken = null; + IEnumerable? inputMessages = null; + IReadOnlyList? responseUpdates = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + continue; + } + switch (reader.GetString()) + { + case "innerToken": + _ = reader.Read(); + var innerTokenJson = reader.GetString() ?? throw new ArgumentException("No content for innerToken property.", nameof(token)); + innerToken = (ResponseContinuationToken?)JsonSerializer.Deserialize(innerTokenJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + break; + case "inputMessages": + _ = reader.Read(); + var innerMessagesJson = reader.GetString() ?? throw new ArgumentException("No content for inputMessages property.", nameof(token)); + inputMessages = (IEnumerable?)JsonSerializer.Deserialize(innerMessagesJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IEnumerable))); + break; + case "responseUpdates": + _ = reader.Read(); + var responseUpdatesJson = reader.GetString() ?? throw new ArgumentException("No content for responseUpdates property.", nameof(token)); + responseUpdates = (IReadOnlyList?)JsonSerializer.Deserialize(responseUpdatesJson, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IReadOnlyList))); + break; + default: + break; + } + } + + if (innerToken is null) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it does not contain an inner token."); + } + + return new ChatClientAgentContinuationToken(innerToken) + { + InputMessages = inputMessages, + ResponseUpdates = responseUpdates + }; + } + + private static void ValidateTokenType(Utf8JsonReader reader, ResponseContinuationToken token) + { + try + { + // Move to the first property. + _ = reader.Read(); + + // If the first property name is not "type", or its value does not match this token type name, then we know its not this token type. + if (reader.GetString() != TypeDiscriminator || !reader.Read() || reader.GetString() != TokenTypeName) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it is not of the correct type."); + } + } + catch (JsonException ex) + { + Throw.ArgumentException(nameof(token), "Failed to create ChatClientAgentContinuationToken from provided token because it could not be parsed.", ex); + } + } + + /// + /// Gets a continuation token provided by the underlying . + /// + internal ResponseContinuationToken InnerToken { get; } + + /// + /// Gets or sets the input messages used for streaming run. + /// + internal IEnumerable? InputMessages { get; set; } + + /// + /// Gets or sets the response updates received so far. + /// + internal IReadOnlyList? ResponseUpdates { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs index 913be969c6..9a535cd645 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentStructuredOutput.cs @@ -154,7 +154,10 @@ async Task> GetResponseAsync(IChatClient chatClient, List CreateResponse(ChatResponse chatResponse) { - return new ChatClientAgentRunResponse(chatResponse); + return new ChatClientAgentRunResponse(chatResponse) + { + ContinuationToken = WrapContinuationToken(chatResponse.ContinuationToken) + }; } return this.RunCoreAsync(GetResponseAsync, CreateResponse, messages, thread, options, cancellationToken); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs index a81446d062..8e39b4c4fa 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunResponseTests.cs @@ -57,7 +57,7 @@ public void ConstructorWithChatResponseRoundtrips() RawRepresentation = new object(), ResponseId = "responseId", Usage = new UsageDetails(), - ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }), + ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; AgentRunResponse response = new(chatResponse); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs new file mode 100644 index 0000000000..a2add9634b --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentContinuationTokenTests.cs @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +public class ChatClientAgentContinuationTokenTests +{ + [Fact] + public void ToBytes_Roundtrip() + { + // Arrange + ResponseContinuationToken originalToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 }); + + ChatClientAgentContinuationToken chatClientToken = new(originalToken) + { + InputMessages = + [ + new ChatMessage(ChatRole.User, "Hello!"), + new ChatMessage(ChatRole.User, "How are you?") + ], + ResponseUpdates = + [ + new ChatResponseUpdate(ChatRole.Assistant, "I'm fine, thank you."), + new ChatResponseUpdate(ChatRole.Assistant, "How can I assist you today?") + ] + }; + + // Act + ReadOnlyMemory bytes = chatClientToken.ToBytes(); + + ChatClientAgentContinuationToken tokenFromBytes = ChatClientAgentContinuationToken.FromToken(ResponseContinuationToken.FromBytes(bytes)); + + // Assert + Assert.NotNull(tokenFromBytes); + Assert.Equal(chatClientToken.ToBytes().ToArray(), tokenFromBytes.ToBytes().ToArray()); + + // Verify InnerToken + Assert.Equal(chatClientToken.InnerToken.ToBytes().ToArray(), tokenFromBytes.InnerToken.ToBytes().ToArray()); + + // Verify InputMessages + Assert.NotNull(tokenFromBytes.InputMessages); + Assert.Equal(chatClientToken.InputMessages.Count(), tokenFromBytes.InputMessages.Count()); + for (int i = 0; i < chatClientToken.InputMessages.Count(); i++) + { + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Role, tokenFromBytes.InputMessages.ElementAt(i).Role); + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Text, tokenFromBytes.InputMessages.ElementAt(i).Text); + } + + // Verify ResponseUpdates + Assert.NotNull(tokenFromBytes.ResponseUpdates); + Assert.Equal(chatClientToken.ResponseUpdates.Count, tokenFromBytes.ResponseUpdates.Count); + for (int i = 0; i < chatClientToken.ResponseUpdates.Count; i++) + { + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Role, tokenFromBytes.ResponseUpdates.ElementAt(i).Role); + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Text, tokenFromBytes.ResponseUpdates.ElementAt(i).Text); + } + } + + [Fact] + public void Serialization_Roundtrip() + { + // Arrange + ResponseContinuationToken originalToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 }); + + ChatClientAgentContinuationToken chatClientToken = new(originalToken) + { + InputMessages = + [ + new ChatMessage(ChatRole.User, "Hello!"), + new ChatMessage(ChatRole.User, "How are you?") + ], + ResponseUpdates = + [ + new ChatResponseUpdate(ChatRole.Assistant, "I'm fine, thank you."), + new ChatResponseUpdate(ChatRole.Assistant, "How can I assist you today?") + ] + }; + + // Act + string json = JsonSerializer.Serialize(chatClientToken, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + + ResponseContinuationToken? deserializedToken = (ResponseContinuationToken?)JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ResponseContinuationToken))); + + ChatClientAgentContinuationToken deserializedChatClientToken = ChatClientAgentContinuationToken.FromToken(deserializedToken!); + + // Assert + Assert.NotNull(deserializedChatClientToken); + Assert.Equal(chatClientToken.ToBytes().ToArray(), deserializedChatClientToken.ToBytes().ToArray()); + + // Verify InnerToken + Assert.Equal(chatClientToken.InnerToken.ToBytes().ToArray(), deserializedChatClientToken.InnerToken.ToBytes().ToArray()); + + // Verify InputMessages + Assert.NotNull(deserializedChatClientToken.InputMessages); + Assert.Equal(chatClientToken.InputMessages.Count(), deserializedChatClientToken.InputMessages.Count()); + for (int i = 0; i < chatClientToken.InputMessages.Count(); i++) + { + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Role, deserializedChatClientToken.InputMessages.ElementAt(i).Role); + Assert.Equal(chatClientToken.InputMessages.ElementAt(i).Text, deserializedChatClientToken.InputMessages.ElementAt(i).Text); + } + + // Verify ResponseUpdates + Assert.NotNull(deserializedChatClientToken.ResponseUpdates); + Assert.Equal(chatClientToken.ResponseUpdates.Count, deserializedChatClientToken.ResponseUpdates.Count); + for (int i = 0; i < chatClientToken.ResponseUpdates.Count; i++) + { + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Role, deserializedChatClientToken.ResponseUpdates.ElementAt(i).Role); + Assert.Equal(chatClientToken.ResponseUpdates.ElementAt(i).Text, deserializedChatClientToken.ResponseUpdates.ElementAt(i).Text); + } + } + + [Fact] + public void FromToken_WithChatClientAgentContinuationToken_ReturnsSameInstance() + { + // Arrange + ChatClientAgentContinuationToken originalToken = new(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3, 4, 5 })); + + // Act + ChatClientAgentContinuationToken fromToken = ChatClientAgentContinuationToken.FromToken(originalToken); + + // Assert + Assert.Same(originalToken, fromToken); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 3bc28ee12f..cfccb7267a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -18,10 +18,10 @@ public class ChatClientAgent_BackgroundResponsesTests [Theory] [InlineData(true)] [InlineData(false)] - public async Task RunAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) + public async Task RunAsync_PropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) { // Arrange - var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -63,15 +63,15 @@ public async Task RunAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsy // Assert Assert.NotNull(capturedChatOptions); Assert.True(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOptionsOverOnesFromChatOptionsAsync() + public async Task RunAsync_WhenPropertiesSetInBothLocations_PrioritizesAgentRunOptionsOverChatOptionsAsync() { // Arrange - var continuationToken1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); - var continuationToken2 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken1 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); + var continuationToken2 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })); ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -104,13 +104,13 @@ public async Task RunAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOp // Assert Assert.NotNull(capturedChatOptions); Assert.False(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken2, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken2.InnerToken, capturedChatOptions.ContinuationToken); } [Theory] [InlineData(true)] [InlineData(false)] - public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) + public async Task RunStreamingAsync_PropagatesBackgroundResponsesPropertiesToChatClientAsync(bool providePropsViaChatOptions) { // Arrange ChatResponseUpdate[] returnUpdates = @@ -119,7 +119,7 @@ public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChat new ChatResponseUpdate(role: ChatRole.Assistant, content: "at?") { ConversationId = "conversation-id" }, ]; - var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -164,11 +164,11 @@ public async Task RunStreamingAsyncPropagatesBackgroundResponsesPropertiesToChat Assert.NotNull(capturedChatOptions); Assert.True(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromAgentRunOptionsOverOnesFromChatOptionsAsync() + public async Task RunStreamingAsync_WhenPropertiesSetInBothLocations_PrioritizesAgentRunOptionsOverChatOptionsAsync() { // Arrange ChatResponseUpdate[] returnUpdates = @@ -176,8 +176,8 @@ public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromA new ChatResponseUpdate(role: ChatRole.Assistant, content: "wh") { ConversationId = "conversation-id" }, ]; - var continuationToken1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); - var continuationToken2 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); + var continuationToken1 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; + var continuationToken2 = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] }; ChatOptions? capturedChatOptions = null; Mock mockChatClient = new(); mockChatClient @@ -212,11 +212,11 @@ public async Task RunStreamingAsyncPrioritizesBackgroundResponsesPropertiesFromA // Assert Assert.NotNull(capturedChatOptions); Assert.False(capturedChatOptions.AllowBackgroundResponses); - Assert.Same(continuationToken2, capturedChatOptions.ContinuationToken); + Assert.Same(continuationToken2.InnerToken, capturedChatOptions.ContinuationToken); } [Fact] - public async Task RunAsyncPropagatesContinuationTokenFromChatResponseToAgentRunResponseAsync() + public async Task RunAsync_WhenContinuationTokenReceivedFromChatResponse_WrapsContinuationTokenAsync() { // Arrange var continuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); @@ -237,11 +237,11 @@ public async Task RunAsyncPropagatesContinuationTokenFromChatResponseToAgentRunR var response = await agent.RunAsync([new(ChatRole.User, "hi")], thread, options: runOptions); // Assert - Assert.Same(continuationToken, response.ContinuationToken); + Assert.Same(continuationToken, (response.ContinuationToken as ChatClientAgentContinuationToken)?.InnerToken); } [Fact] - public async Task RunStreamingAsyncPropagatesContinuationTokensFromUpdatesAsync() + public async Task RunStreamingAsync_WhenContinuationTokenReceived_WrapsContinuationTokenAsync() { // Arrange var token1 = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }); @@ -272,19 +272,19 @@ public async Task RunStreamingAsyncPropagatesContinuationTokensFromUpdatesAsync( // Assert Assert.Equal(2, actualUpdates.Count); - Assert.Same(token1, actualUpdates[0].ContinuationToken); + Assert.Same(token1, (actualUpdates[0].ContinuationToken as ChatClientAgentContinuationToken)?.InnerToken); Assert.Null(actualUpdates[1].ContinuationToken); // last update has null token } [Fact] - public async Task RunAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() + public async Task RunAsync_WhenMessagesProvidedWithContinuationToken_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); ChatClientAgent agent = new(mockChatClient.Object); - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() { ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) }; IEnumerable inputMessages = [new ChatMessage(ChatRole.User, "test message")]; @@ -301,14 +301,14 @@ public async Task RunAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() } [Fact] - public async Task RunStreamingAsyncThrowsWhenMessagesProvidedWithContinuationTokenAsync() + public async Task RunStreamingAsync_WhenMessagesProvidedWithContinuationToken_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); ChatClientAgent agent = new(mockChatClient.Object); - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() { ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) }; IEnumerable inputMessages = [new ChatMessage(ChatRole.User, "test message")]; @@ -331,7 +331,7 @@ await Assert.ThrowsAsync(async () => } [Fact] - public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync() + public async Task RunAsync_WhenContinuationTokenProvided_SkipsThreadMessagePopulationAsync() { // Arrange List capturedMessages = []; @@ -371,7 +371,10 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + }; // Act await agent.RunAsync([], thread, options: runOptions); @@ -393,7 +396,7 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync } [Fact] - public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync() + public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsThreadMessagePopulationAsync() { // Arrange List capturedMessages = []; @@ -433,14 +436,15 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) { InputMessages = [new ChatMessage()] } + }; // Act - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync()); + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); // Assert - Assert.Equal("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service.", exception.Message); - // With continuation token, thread message population should be skipped Assert.Empty(capturedMessages); @@ -456,7 +460,7 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT } [Fact] - public async Task RunAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() + public async Task RunAsync_WhenNoThreadProvidedForBackgroundResponses_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); @@ -480,7 +484,7 @@ public async Task RunAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() } [Fact] - public async Task RunStreamingAsyncThrowsWhenNoThreadProvideForBackgroundResponsesAsync() + public async Task RunStreamingAsync_WhenNoThreadProvidedForBackgroundResponses_ThrowsInvalidOperationExceptionAsync() { // Arrange Mock mockChatClient = new(); @@ -510,126 +514,287 @@ await Assert.ThrowsAsync(async () => } [Fact] - public async Task RunAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync() + public async Task RunStreamingAsync_WhenInputMessagesPresentInContinuationToken_ResumesStreamingAsync() { // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "continuation") { ConversationId = "conversation-id" }, + ]; + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); ChatClientAgent agent = new(mockChatClient.Object); - // Create a new thread with no ConversationId and no MessageStore (initial run state) - ChatClientAgentThread thread = new(); + ChatClientAgentThread thread = new() { ConversationId = "conversation-id" }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + InputMessages = [new ChatMessage(ChatRole.User, "previous message")] + } + }; - // Act & Assert - var exception = await Assert.ThrowsAsync(() => agent.RunAsync(thread: thread, options: runOptions)); - Assert.Equal("Continuation tokens are not allowed to be used for initial runs.", exception.Message); + // Act + var updates = new List(); + await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) + { + updates.Add(update); + } - // Verify that the IChatClient was never called due to early validation + // Assert + Assert.Single(updates); + + // Verify that the IChatClient was called mockChatClient.Verify( - c => c.GetResponseAsync( + c => c.GetStreamingResponseAsync( It.IsAny>(), It.IsAny(), It.IsAny()), - Times.Never); + Times.Once); } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenProvidedForInitialRunAsync() + public async Task RunStreamingAsync_WhenResponseUpdatesPresentInContinuationToken_ResumesStreamingAsync() { // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "continuation") { ConversationId = "conversation-id" }, + ]; + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); ChatClientAgent agent = new(mockChatClient.Object); - // Create a new thread with no ConversationId and no MessageStore (initial run state) - ChatClientAgentThread thread = new(); + ChatClientAgentThread thread = new() { ConversationId = "conversation-id" }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + ResponseUpdates = [new ChatResponseUpdate(ChatRole.Assistant, "previous update")] + } + }; - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); - Assert.Equal("Continuation tokens are not allowed to be used for initial runs.", exception.Message); + // Act + var updates = new List(); + await foreach (var update in agent.RunStreamingAsync(thread, options: runOptions)) + { + updates.Add(update); + } - // Verify that the IChatClient was never called due to early validation + // Assert + Assert.Single(updates); + + // Verify that the IChatClient was called mockChatClient.Verify( c => c.GetStreamingResponseAsync( It.IsAny>(), It.IsAny(), It.IsAny()), - Times.Never); + Times.Once); } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenUsedWithClientSideManagedChatHistoryAsync() + public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitialRunForContextProviderAndMessageStoreAsync() { // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "upon"), + new ChatResponseUpdate(role: ChatRole.Assistant, content: " a"), + new ChatResponseUpdate(role: ChatRole.Assistant, content: " time"), + ]; + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); ChatClientAgent agent = new(mockChatClient.Object); - // Create a thread with a MessageStore + List capturedMessagesAddedToStore = []; + var mockMessageStore = new Mock(); + mockMessageStore + .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((ctx, ct) => capturedMessagesAddedToStore.AddRange(ctx.ResponseMessages ?? [])) + .Returns(new ValueTask()); + + AIContextProvider.InvokedContext? capturedInvokedContext = null; + var mockContextProvider = new Mock(); + mockContextProvider + .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((context, ct) => capturedInvokedContext = context) + .Returns(new ValueTask()); + ChatClientAgentThread thread = new() { - MessageStore = new InMemoryChatMessageStore(), // Setting a message store to skip checking the continuation token in the initial run - ConversationId = null, // No conversation ID to simulate client-side managed chat history + MessageStore = mockMessageStore.Object, + AIContextProvider = mockContextProvider.Object }; - // Create run options with a continuation token - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + ResponseUpdates = [new ChatResponseUpdate(ChatRole.Assistant, "once ")] + } + }; - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); - Assert.Equal("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service.", exception.Message); + // Act + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); - // Verify that the IChatClient was never called due to early validation - mockChatClient.Verify( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny()), - Times.Never); + // Assert + mockMessageStore.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.Single(capturedMessagesAddedToStore); + Assert.Contains("once upon a time", capturedMessagesAddedToStore[0].Text); + + mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.NotNull(capturedInvokedContext?.ResponseMessages); + Assert.Single(capturedInvokedContext.ResponseMessages); + Assert.Contains("once upon a time", capturedInvokedContext.ResponseMessages.ElementAt(0).Text); } [Fact] - public async Task RunStreamingAsyncThrowsWhenContinuationTokenUsedWithAIContextProviderAsync() + public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromInitialRunForContextProviderAndMessageStoreAsync() { // Arrange Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(Array.Empty())); ChatClientAgent agent = new(mockChatClient.Object); - // Create a mock AIContextProvider + List capturedMessagesAddedToStore = []; + var mockMessageStore = new Mock(); + mockMessageStore + .Setup(ms => ms.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((ctx, ct) => capturedMessagesAddedToStore.AddRange(ctx.RequestMessages)) + .Returns(new ValueTask()); + + AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); mockContextProvider - .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(new AIContext()); - mockContextProvider - .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Setup(cp => cp.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - // Create a thread with an AIContextProvider and conversation ID to simulate non-initial run ChatClientAgentThread thread = new() { - ConversationId = "existing-conversation-id", + MessageStore = mockMessageStore.Object, AIContextProvider = mockContextProvider.Object }; - AgentRunOptions runOptions = new() { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }; + AgentRunOptions runOptions = new() + { + ContinuationToken = new ChatClientAgentContinuationToken(ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 })) + { + InputMessages = [new ChatMessage(ChatRole.User, "Tell me a story")], + } + }; - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync(thread: thread, options: runOptions).ToListAsync()); + // Act + await agent.RunStreamingAsync(thread, options: runOptions).ToListAsync(); - Assert.Equal("Using context provider with streaming resumption is not supported.", exception.Message); + // Assert + mockMessageStore.Verify(ms => ms.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.Single(capturedMessagesAddedToStore); + Assert.Contains("Tell me a story", capturedMessagesAddedToStore[0].Text); + + mockContextProvider.Verify(cp => cp.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + Assert.NotNull(capturedInvokedContext?.RequestMessages); + Assert.Single(capturedInvokedContext.RequestMessages); + Assert.Contains("Tell me a story", capturedInvokedContext.RequestMessages.ElementAt(0).Text); + } - // Verify that the IChatClient was never called due to early validation - mockChatClient.Verify( - c => c.GetStreamingResponseAsync( + [Fact] + public async Task RunStreamingAsync_WhenResumingStreaming_SavesInputMessagesAndUpdatesInContinuationTokenAsync() + { + // Arrange + List returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "Once") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " upon") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " a") { ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + new ChatResponseUpdate(role: ChatRole.Assistant, content: " time"){ ContinuationToken = ResponseContinuationToken.FromBytes(new byte[] { 1, 2, 3 }) }, + ]; + + Mock mockChatClient = new(); + mockChatClient + .Setup(c => c.GetStreamingResponseAsync( It.IsAny>(), It.IsAny(), - It.IsAny()), - Times.Never); + It.IsAny())) + .Returns(ToAsyncEnumerableAsync(returnUpdates)); + + ChatClientAgent agent = new(mockChatClient.Object); + + ChatClientAgentThread thread = new() { }; + + List capturedContinuationTokens = []; + + ChatMessage userMessage = new(ChatRole.User, "Tell me a story"); + + // Act + + // Do the initial run + await foreach (var update in agent.RunStreamingAsync(userMessage, thread)) + { + capturedContinuationTokens.Add(Assert.IsType(update.ContinuationToken)); + break; + } + + // Now resume the run using the captured continuation token + returnUpdates.RemoveAt(0); // remove the first mock update as it was already processed + var options = new AgentRunOptions { ContinuationToken = capturedContinuationTokens[0] }; + await foreach (var update in agent.RunStreamingAsync(thread, options: options)) + { + capturedContinuationTokens.Add(Assert.IsType(update.ContinuationToken)); + } + + // Assert + Assert.Equal(4, capturedContinuationTokens.Count); + + // Verify that the first continuation token has the initial input and first update + Assert.NotNull(capturedContinuationTokens[0].InputMessages); + Assert.Single(capturedContinuationTokens[0].InputMessages!); + Assert.Equal("Tell me a story", capturedContinuationTokens[0].InputMessages!.Last().Text); + Assert.NotNull(capturedContinuationTokens[0].ResponseUpdates); + Assert.Single(capturedContinuationTokens[0].ResponseUpdates!); + Assert.Equal("Once", capturedContinuationTokens[0].ResponseUpdates![0].Text); + + // Verify the last continuation token has the input and all updates + var lastToken = capturedContinuationTokens[^1]; + Assert.NotNull(lastToken.InputMessages); + Assert.Single(lastToken.InputMessages!); + Assert.Equal("Tell me a story", lastToken.InputMessages!.Last().Text); + Assert.NotNull(lastToken.ResponseUpdates); + Assert.Equal(4, lastToken.ResponseUpdates!.Count); + Assert.Equal("Once", lastToken.ResponseUpdates!.ElementAt(0).Text); + Assert.Equal(" upon", lastToken.ResponseUpdates!.ElementAt(1).Text); + Assert.Equal(" a", lastToken.ResponseUpdates!.ElementAt(2).Text); + Assert.Equal(" time", lastToken.ResponseUpdates!.ElementAt(3).Text); } private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values)