diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index a4e588f347..6beef64405 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -44,11 +44,19 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the thread of the input and output messages. - await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); + var invokedContext = new ChatMessageStore.InvokedContext(messages, storeMessages) + { + ResponseMessages = responseMessages + }; + await typedThread.MessageStore.InvokedAsync(invokedContext, cancellationToken); return new AgentRunResponse { @@ -68,11 +76,19 @@ protected override async IAsyncEnumerable RunCoreStreami throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread)); } + // Get existing messages from the store + var invokingContext = new ChatMessageStore.InvokingContext(messages); + var storeMessages = await typedThread.MessageStore.InvokingAsync(invokingContext, cancellationToken); + // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the thread of the input and output messages. - await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); + var invokedContext = new ChatMessageStore.InvokedContext(messages, storeMessages) + { + ResponseMessages = responseMessages + }; + await typedThread.MessageStore.InvokedAsync(invokedContext, cancellationToken); foreach (var message in responseMessages) { diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index 42015d87cd..9207a08182 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -62,7 +62,12 @@ .CreateAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions) + AIContextProviderFactory = ctx => new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions), + // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy + // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that + // we don't bloat chat history with all the search result messages. + ChatMessageStoreFactory = ctx => new InMemoryChatMessageStore(ctx.SerializedState, ctx.JsonSerializerOptions) + .WithAIContextProviderMessageRemoval(), }); AgentThread thread = agent.GetNewThread(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index e9794e871a..280c84dc0d 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -89,24 +89,7 @@ public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedSto public string? ThreadDbKey { get; private set; } - public override async Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) - { - this.ThreadDbKey ??= Guid.NewGuid().ToString("N"); - - var collection = this._vectorStore.GetCollection("ChatHistory"); - await collection.EnsureCollectionExistsAsync(cancellationToken); - - await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem() - { - Key = this.ThreadDbKey + x.MessageId, - Timestamp = DateTimeOffset.UtcNow, - ThreadId = this.ThreadDbKey, - SerializedMessage = JsonSerializer.Serialize(x), - MessageText = x.Text - }), cancellationToken); - } - - public override async Task> GetMessagesAsync(CancellationToken cancellationToken = default) + public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); @@ -124,6 +107,33 @@ public override async Task> GetMessagesAsync(Cancellati return messages; } + public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + // Don't store messages if the request failed. + if (context.InvokeException is not null) + { + return; + } + + this.ThreadDbKey ??= Guid.NewGuid().ToString("N"); + + var collection = this._vectorStore.GetCollection("ChatHistory"); + await collection.EnsureCollectionExistsAsync(cancellationToken); + + // Add both request and response messages to the store + // Optionally messages produced by the AIContextProvider can also be persisted (not shown). + var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + + await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() + { + Key = this.ThreadDbKey + x.MessageId, + Timestamp = DateTimeOffset.UtcNow, + ThreadId = this.ThreadDbKey, + SerializedMessage = JsonSerializer.Serialize(x), + MessageText = x.Text + }), cancellationToken); + } + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => // We have to serialize the thread id, so that on deserialization we can retrieve the messages using the same thread id. JsonSerializer.SerializeToElement(this.ThreadDbKey); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs index 9f89031464..d28cd191b7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStore.cs @@ -32,8 +32,9 @@ namespace Microsoft.Agents.AI; public abstract class ChatMessageStore { /// - /// Asynchronously retrieves all messages from the store that should be provided as context for the next agent invocation. + /// Called at the start of agent invocation to retrieve all messages from the store that should be provided as context for the next agent invocation. /// + /// Contains the request context including the caller provided messages that will be used by the agent for this invocation. /// The to monitor for cancellation requests. The default is . /// /// A task that represents the asynchronous operation. The task result contains a collection of @@ -59,20 +60,19 @@ public abstract class ChatMessageStore /// and context management. /// /// - public abstract Task> GetMessagesAsync(CancellationToken cancellationToken = default); + public abstract ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); /// - /// Asynchronously adds new messages to the store. + /// Called at the end of the agent invocation to add new messages to the store. /// - /// The collection of chat messages to add to the store. + /// Contains the invocation context including request messages, response messages, and any exception that occurred. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous add operation. - /// is . /// /// /// Messages should be added in the order they were generated to maintain proper chronological sequence. /// The store is responsible for preserving message ordering and ensuring that subsequent calls to - /// return messages in the correct chronological order. + /// return messages in the correct chronological order. /// /// /// Implementations may perform additional processing during message addition, such as: @@ -83,8 +83,12 @@ public abstract class ChatMessageStore /// Updating indices or search capabilities /// /// + /// + /// This method is called regardless of whether the invocation succeeded or failed. + /// To check if the invocation was successful, inspect the property. + /// /// - public abstract Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default); + public abstract ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default); /// /// Serializes the current object's state to a using the specified serialization options. @@ -121,4 +125,100 @@ public abstract class ChatMessageStore /// public TService? GetService(object? serviceKey = null) => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; + + /// + /// Contains the context information provided to . + /// + /// + /// This class provides context about the invocation before the messages are retrieved from the store, + /// including the new messages that will be used. Stores can use this information to determine what + /// messages should be retrieved for the invocation. + /// + public sealed class InvokingContext + { + /// + /// Initializes a new instance of the class with the specified request messages. + /// + /// The new messages to be used by the agent for this invocation. + /// is . + public InvokingContext(IEnumerable requestMessages) + { + this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + } + + /// + /// Gets the caller provided messages that will be used by the agent for this invocation. + /// + /// + /// A collection of instances representing new messages that were provided by the caller. + /// + public IEnumerable RequestMessages { get; } + } + + /// + /// Contains the context information provided to . + /// + /// + /// This class provides context about a completed agent invocation, including both the + /// request messages that were used and the response messages that were generated. It also indicates + /// whether the invocation succeeded or failed. + /// + public sealed class InvokedContext + { + /// + /// Initializes a new instance of the class with the specified request messages. + /// + /// The caller provided messages that were used by the agent for this invocation. + /// The messages retrieved from the for this invocation. + /// is . + public InvokedContext(IEnumerable requestMessages, IEnumerable chatMessageStoreMessages) + { + this.RequestMessages = Throw.IfNull(requestMessages); + this.ChatMessageStoreMessages = chatMessageStoreMessages; + } + + /// + /// Gets the caller provided messages that were used by the agent for this invocation. + /// + /// + /// A collection of instances representing new messages that were provided by the caller. + /// This does not include any supplied messages. + /// + public IEnumerable RequestMessages { get; } + + /// + /// Gets the messages retrieved from the for this invocation, if any. + /// + /// + /// A collection of instances that were retrieved from the , + /// and were used by the agent as part of the invocation. + /// + public IEnumerable ChatMessageStoreMessages { get; } + + /// + /// Gets or sets the messages provided by the for this invocation, if any. + /// + /// + /// A collection of instances that were provided by the , + /// and were used by the agent as part of the invocation. + /// + public IEnumerable? AIContextProviderMessages { get; set; } + + /// + /// Gets the collection of response messages generated during this invocation if the invocation succeeded. + /// + /// + /// A collection of instances representing the response, + /// or if the invocation failed or did not produce response messages. + /// + public IEnumerable? ResponseMessages { get; set; } + + /// + /// Gets the that was thrown during the invocation, if the invocation failed. + /// + /// + /// The exception that caused the invocation to fail, or if the invocation succeeded. + /// + public Exception? InvokeException { get; set; } + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreExtensions.cs new file mode 100644 index 0000000000..a205fc1d9e --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreExtensions.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Contains extension methods for the class. +/// +public static class ChatMessageStoreExtensions +{ + /// + /// Adds message filtering to an existing store, so that messages passed to the store and messages produced by the store + /// can be filtered, updated or replaced. + /// + /// The store to add the message filter to. + /// An optional filter function to apply to messages produced by the store. If null, no filter is applied at this + /// stage. + /// An optional filter function to apply to the invoked context messages before they are passed to the store. If null, no + /// filter is applied at this stage. + /// The with filtering applied. + public static ChatMessageStore WithMessageFilters( + this ChatMessageStore store, + Func, IEnumerable>? invokingMessagesFilter = null, + Func? invokedMessagesFilter = null) + { + return new ChatMessageStoreMessageFilter( + innerChatMessageStore: store, + invokingMessagesFilter: invokingMessagesFilter, + invokedMessagesFilter: invokedMessagesFilter); + } + + /// + /// Decorates the provided chat message store so that it does not store messages produced by any . + /// + /// The store to add the message filter to. + /// A new instance that filters out messages so they do not get stored. + public static ChatMessageStore WithAIContextProviderMessageRemoval(this ChatMessageStore store) + { + return new ChatMessageStoreMessageFilter( + innerChatMessageStore: store, + invokedMessagesFilter: (ctx) => + { + ctx.AIContextProviderMessages = null; + return ctx; + }); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreMessageFilter.cs new file mode 100644 index 0000000000..e58f233067 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageStoreMessageFilter.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// A decorator that allows filtering the messages +/// passed into and out of an inner . +/// +public sealed class ChatMessageStoreMessageFilter : ChatMessageStore +{ + private readonly ChatMessageStore _innerChatMessageStore; + private readonly Func, IEnumerable>? _invokingMessagesFilter; + private readonly Func? _invokedMessagesFilter; + + /// + /// Initializes a new instance of the class. + /// + /// Use this constructor to customize how messages are filtered before and after invocation by + /// providing appropriate filter functions. If no filters are provided, the message store operates without + /// additional filtering. + /// The underlying chat message store to be wrapped. Cannot be null. + /// An optional filter function to apply to messages before they are invoked. If null, no filter is applied at this + /// stage. + /// An optional filter function to apply to the invocation context after messages have been invoked. If null, no + /// filter is applied at this stage. + /// Thrown if innerChatMessageStore is null. + public ChatMessageStoreMessageFilter( + ChatMessageStore innerChatMessageStore, + Func, IEnumerable>? invokingMessagesFilter = null, + Func? invokedMessagesFilter = null) + { + this._innerChatMessageStore = Throw.IfNull(innerChatMessageStore); + + if (invokingMessagesFilter == null && invokedMessagesFilter == null) + { + throw new ArgumentException("At least one filter function, invokingMessagesFilter or invokedMessagesFilter, must be provided."); + } + + this._invokingMessagesFilter = invokingMessagesFilter; + this._invokedMessagesFilter = invokedMessagesFilter; + } + + /// + public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + var messages = await this._innerChatMessageStore.InvokingAsync(context, cancellationToken).ConfigureAwait(false); + return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages; + } + + /// + public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + if (this._invokedMessagesFilter != null) + { + context = this._invokedMessagesFilter(context); + } + + return this._innerChatMessageStore.InvokedAsync(context, cancellationToken); + } + + /// + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + return this._innerChatMessageStore.Serialize(jsonSerializerOptions); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs index 79d303207c..f7f4522f8f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatMessageStore.cs @@ -134,27 +134,36 @@ public ChatMessage this[int index] } /// - public override async Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) + public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(messages); + _ = Throw.IfNull(context); - this._messages.AddRange(messages); - - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) + if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) { this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); } + + return this._messages; } /// - public override async Task> GetMessagesAsync(CancellationToken cancellationToken = default) + public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) + _ = Throw.IfNull(context); + + if (context.InvokeException is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + return; } - return this._messages; + // Add request, AI context provider, and response messages to the store + var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + this._messages.AddRange(allNewMessages); + + if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) + { + this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + } } /// @@ -221,7 +230,7 @@ public enum ChatReducerTriggerEvent { /// /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. + /// will only complete when reducer processing is done. /// AfterMessageAdded, diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatMessageStore.cs index 03334d90f9..5c2c23ff9e 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatMessageStore.cs @@ -287,7 +287,7 @@ public static CosmosChatMessageStore CreateFromSerializedState(CosmosClient cosm } /// - public override async Task> GetMessagesAsync(CancellationToken cancellationToken = default) + public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -347,11 +347,14 @@ public override async Task> GetMessagesAsync(Cancellati } /// - public override async Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) + public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { - if (messages is null) + Throw.IfNull(context); + + if (context.InvokeException is not null) { - throw new ArgumentNullException(nameof(messages)); + // Do not store messages if there was an exception during invocation + return; } #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks @@ -361,7 +364,7 @@ public override async Task AddMessagesAsync(IEnumerable messages, C } #pragma warning restore CA1513 - var messageList = messages as IReadOnlyCollection ?? messages.ToList(); + var messageList = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; @@ -381,7 +384,7 @@ public override async Task AddMessagesAsync(IEnumerable messages, C /// /// Adds multiple messages using transactional batch operations for atomicity. /// - private async Task AddMessagesInBatchAsync(IReadOnlyCollection messages, CancellationToken cancellationToken) + private async Task AddMessagesInBatchAsync(List messages, CancellationToken cancellationToken) { var currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 7c0479b85e..8e8012f5bb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -66,7 +66,7 @@ private async ValueTask ValidateWorkflowAsync() public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => new WorkflowThread(this._workflow, serializedThread, this._executionEnvironment, this._checkpointManager, jsonSerializerOptions); - private async ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) + private ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) { thread ??= this.GetNewThread(); @@ -75,8 +75,10 @@ private async ValueTask UpdateThreadAsync(IEnumerable(workflowThread); } protected override async diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowMessageStore.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowMessageStore.cs index 39c83bcadf..87cef04e76 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowMessageStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowMessageStore.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -45,14 +46,21 @@ internal sealed class StoreState internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); - public override Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) + public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(this._chatMessages.AsReadOnly()); + + public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { - this._chatMessages.AddRange(messages); + if (context.InvokeException is not null) + { + return default; + } - return Task.CompletedTask; - } + var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); + this._chatMessages.AddRange(allNewMessages); - public override Task> GetMessagesAsync(CancellationToken cancellationToken = default) => Task.FromResult>(this._chatMessages.AsReadOnly()); + return default; + } public IEnumerable GetFromBookmark() { diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index f4a7fcd9c2..9c5858b8e2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -201,7 +201,7 @@ protected override async IAsyncEnumerable RunCoreStreami { var inputMessages = Throw.IfNull(messages) as IReadOnlyCollection ?? messages.ToList(); - (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages) = + (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages, IList? chatMessageStoreMessages) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); ValidateStreamResumptionAllowed(chatOptions?.ContinuationToken, safeThread); @@ -225,6 +225,7 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { + await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -239,6 +240,7 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { + await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -260,6 +262,7 @@ protected override async IAsyncEnumerable RunCoreStreami } catch (Exception ex) { + await NotifyMessageStoreOfFailureAsync(safeThread, ex, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -272,7 +275,7 @@ protected override async IAsyncEnumerable RunCoreStreami this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages, chatMessageStoreMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); @@ -379,7 +382,7 @@ private async Task RunCoreAsync ?? messages.ToList(); - (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages) = + (ChatClientAgentThread safeThread, ChatOptions? chatOptions, List inputMessagesForChatClient, IList? aiContextProviderMessages, IList? chatMessageStoreMessages) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); var chatClient = this.ChatClient; @@ -398,6 +401,7 @@ private async Task RunCoreAsync RunCoreAsyncOptional parameters for agent invocation. /// The to monitor for cancellation requests. The default is . /// A tuple containing the thread, chat options, and thread messages. - private async Task<(ChatClientAgentThread AgentThread, ChatOptions? ChatOptions, List InputMessagesForChatClient, IList? AIContextProviderMessages)> PrepareThreadAndMessagesAsync( + private async Task + <( + ChatClientAgentThread AgentThread, + ChatOptions? ChatOptions, + List InputMessagesForChatClient, + IList? AIContextProviderMessages, + IList? ChatMessageStoreMessages + )> PrepareThreadAndMessagesAsync( AgentThread? thread, IEnumerable inputMessages, AgentRunOptions? runOptions, @@ -637,6 +648,7 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider List inputMessagesForChatClient = []; IList? aiContextProviderMessages = null; + IList? chatMessageStoreMessages = null; // Populate the thread messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) @@ -644,7 +656,10 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider // Add any existing messages from the thread to the messages to be sent to the chat client. if (typedThread.MessageStore is not null) { - inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); + var invokingContext = new ChatMessageStore.InvokingContext(inputMessages); + var storeMessages = await typedThread.MessageStore.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); + inputMessagesForChatClient.AddRange(storeMessages); + chatMessageStoreMessages = storeMessages as IList ?? storeMessages.ToList(); } // If we have an AIContextProvider, we should get context from it, and update our @@ -698,7 +713,7 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider chatOptions.ConversationId = typedThread.ConversationId; } - return (typedThread, chatOptions, inputMessagesForChatClient, aiContextProviderMessages); + return (typedThread, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatMessageStoreMessages); } private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, string? responseConversationId) @@ -725,7 +740,13 @@ private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, } } - private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread thread, IEnumerable newMessages, CancellationToken cancellationToken) + private static Task NotifyMessageStoreOfFailureAsync( + ChatClientAgentThread thread, + Exception ex, + IEnumerable requestMessages, + IEnumerable? chatMessageStoreMessages, + IEnumerable? aiContextProviderMessages, + CancellationToken cancellationToken) { var messageStore = thread.MessageStore; @@ -733,7 +754,38 @@ private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread t // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (messageStore is not null) { - return messageStore.AddMessagesAsync(newMessages, cancellationToken); + var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages!) + { + AIContextProviderMessages = aiContextProviderMessages, + InvokeException = ex + }; + + return messageStore.InvokedAsync(invokedContext, cancellationToken).AsTask(); + } + + return Task.CompletedTask; + } + + private static Task NotifyMessageStoreOfNewMessagesAsync( + ChatClientAgentThread thread, + IEnumerable requestMessages, + IEnumerable? chatMessageStoreMessages, + IEnumerable? aiContextProviderMessages, + IEnumerable responseMessages, + CancellationToken cancellationToken) + { + var messageStore = thread.MessageStore; + + // Only notify the message store if we have one. + // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. + if (messageStore is not null) + { + var invokedContext = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages!) + { + AIContextProviderMessages = aiContextProviderMessages, + ResponseMessages = responseMessages + }; + return messageStore.InvokedAsync(invokedContext, cancellationToken).AsTask(); } return Task.CompletedTask; diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index 72c0b14ae2..2bec0b366e 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -39,7 +39,12 @@ public async Task> GetChatHistoryAsync(AgentThread thread) { var typedThread = (ChatClientAgentThread)thread; - return typedThread.MessageStore is null ? [] : (await typedThread.MessageStore.GetMessagesAsync()).ToList(); + if (typedThread.MessageStore is null) + { + return []; + } + + return (await typedThread.MessageStore.InvokingAsync(new([]))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index 883b317f5e..ddb015eb17 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -48,7 +48,12 @@ public async Task> GetChatHistoryAsync(AgentThread thread) return await this.GetChatHistoryFromResponsesChainAsync(chatClientThread.ConversationId); } - return chatClientThread.MessageStore is null ? [] : (await chatClientThread.MessageStore.GetMessagesAsync()).ToList(); + if (chatClientThread.MessageStore is null) + { + return []; + } + + return (await chatClientThread.MessageStore.InvokingAsync(new([]))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreMessageFilterTests.cs new file mode 100644 index 0000000000..ab10c377ae --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreMessageFilterTests.cs @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class ChatMessageStoreMessageFilterTests +{ + [Fact] + public void Constructor_WithNullInnerStore_ThrowsArgumentNullException() + { + // Arrange, Act & Assert + Assert.Throws(() => new ChatMessageStoreMessageFilter(null!)); + } + + [Fact] + public void Constructor_WithOnlyInnerStore_Throws() + { + // Arrange + var innerStoreMock = new Mock(); + + // Act & Assert + Assert.Throws(() => new ChatMessageStoreMessageFilter(innerStoreMock.Object)); + } + + [Fact] + public void Constructor_WithAllParameters_CreatesInstance() + { + // Arrange + var innerStoreMock = new Mock(); + + IEnumerable InvokingFilter(IEnumerable msgs) => msgs; + ChatMessageStore.InvokedContext InvokedFilter(ChatMessageStore.InvokedContext ctx) => ctx; + + // Act + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, InvokingFilter, InvokedFilter); + + // Assert + Assert.NotNull(filter); + } + + [Fact] + public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerStoreMessagesAsync() + { + // Arrange + var innerStoreMock = new Mock(); + var expectedMessages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + }; + var context = new ChatMessageStore.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + + innerStoreMock + .Setup(s => s.InvokingAsync(context, It.IsAny())) + .ReturnsAsync(expectedMessages); + + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, x => x, x => x); + + // Act + var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); + + // Assert + Assert.Equal(2, result.Count); + Assert.Equal("Hello", result[0].Text); + Assert.Equal("Hi there!", result[1].Text); + innerStoreMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + } + + [Fact] + public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() + { + // Arrange + var innerStoreMock = new Mock(); + var innerMessages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!"), + new(ChatRole.User, "How are you?") + }; + var context = new ChatMessageStore.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + + innerStoreMock + .Setup(s => s.InvokingAsync(context, It.IsAny())) + .ReturnsAsync(innerMessages); + + // Filter to only user messages + IEnumerable InvokingFilter(IEnumerable msgs) => msgs.Where(m => m.Role == ChatRole.User); + + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, InvokingFilter); + + // Act + var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); + + // Assert + Assert.Equal(2, result.Count); + Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role)); + innerStoreMock.Verify(s => s.InvokingAsync(context, It.IsAny()), Times.Once); + } + + [Fact] + public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() + { + // Arrange + var innerStoreMock = new Mock(); + var innerMessages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + }; + var context = new ChatMessageStore.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + + innerStoreMock + .Setup(s => s.InvokingAsync(context, It.IsAny())) + .ReturnsAsync(innerMessages); + + // Filter that transforms messages + IEnumerable InvokingFilter(IEnumerable msgs) => + msgs.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")); + + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, InvokingFilter); + + // Act + var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); + + // Assert + Assert.Equal(2, result.Count); + Assert.Equal("[FILTERED] Hello", result[0].Text); + Assert.Equal("[FILTERED] Hi there!", result[1].Text); + } + + [Fact] + public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() + { + // Arrange + var innerStoreMock = new Mock(); + var requestMessages = new List { new(ChatRole.User, "Hello") }; + var chatMessageStoreMessages = new List { new(ChatRole.System, "System") }; + var responseMessages = new List { new(ChatRole.Assistant, "Response") }; + var context = new ChatMessageStore.InvokedContext(requestMessages, chatMessageStoreMessages) + { + ResponseMessages = responseMessages + }; + + ChatMessageStore.InvokedContext? capturedContext = null; + innerStoreMock + .Setup(s => s.InvokedAsync(It.IsAny(), It.IsAny())) + .Callback((ctx, ct) => capturedContext = ctx) + .Returns(default(ValueTask)); + + // Filter that modifies the context + ChatMessageStore.InvokedContext InvokedFilter(ChatMessageStore.InvokedContext ctx) + { + var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); + return new ChatMessageStore.InvokedContext(modifiedRequestMessages, ctx.ChatMessageStoreMessages) + { + ResponseMessages = ctx.ResponseMessages, + AIContextProviderMessages = ctx.AIContextProviderMessages, + InvokeException = ctx.InvokeException + }; + } + + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, invokedMessagesFilter: InvokedFilter); + + // Act + await filter.InvokedAsync(context, CancellationToken.None); + + // Assert + Assert.NotNull(capturedContext); + Assert.Single(capturedContext.RequestMessages); + Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text); + innerStoreMock.Verify(s => s.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public void Serialize_DelegatesToInnerStore() + { + // Arrange + var innerStoreMock = new Mock(); + var expectedJson = JsonSerializer.SerializeToElement("data", TestJsonSerializerContext.Default.String); + + innerStoreMock + .Setup(s => s.Serialize(It.IsAny())) + .Returns(expectedJson); + + var filter = new ChatMessageStoreMessageFilter(innerStoreMock.Object, x => x, x => x); + + // Act + var result = filter.Serialize(); + + // Assert + Assert.Equal(expectedJson.GetRawText(), result.GetRawText()); + innerStoreMock.Verify(s => s.Serialize(null), Times.Once); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreTests.cs index 4100b20f5a..883941458c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageStoreTests.cs @@ -78,11 +78,11 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() private sealed class TestChatMessageStore : ChatMessageStore { - public override Task> GetMessagesAsync(CancellationToken cancellationToken = default) - => Task.FromResult>([]); + public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(Array.Empty()); - public override Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) - => Task.CompletedTask; + public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => default; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs index 824fb62f6d..43bfacca79 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs @@ -47,34 +47,54 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() } [Fact] - public async Task AddMessagesAsyncAddsMessagesAndReturnsNullThreadIdAsync() + public async Task InvokedAsyncAddsMessagesAsync() { - var store = new InMemoryChatMessageStore(); - var messages = new List + var requestMessages = new List + { + new(ChatRole.User, "Hello") + }; + var responseMessages = new List { - new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; + var messageStoreMessages = new List() + { + new(ChatRole.System, "original instructions") + }; + var aiContextProviderMessages = new List() + { + new(ChatRole.System, "additional context") + }; - await store.AddMessagesAsync(messages, CancellationToken.None); + var store = new InMemoryChatMessageStore(); + store.Add(messageStoreMessages[0]); + var context = new ChatMessageStore.InvokedContext(requestMessages, messageStoreMessages) + { + AIContextProviderMessages = aiContextProviderMessages, + ResponseMessages = responseMessages + }; + await store.InvokedAsync(context, CancellationToken.None); - Assert.Equal(2, store.Count); - Assert.Equal("Hello", store[0].Text); - Assert.Equal("Hi there!", store[1].Text); + Assert.Equal(4, store.Count); + Assert.Equal("original instructions", store[0].Text); + Assert.Equal("Hello", store[1].Text); + Assert.Equal("additional context", store[2].Text); + Assert.Equal("Hi there!", store[3].Text); } [Fact] - public async Task AddMessagesAsyncWithEmptyDoesNotFailAsync() + public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var store = new InMemoryChatMessageStore(); - await store.AddMessagesAsync([], CancellationToken.None); + var context = new ChatMessageStore.InvokedContext([], []); + await store.InvokedAsync(context, CancellationToken.None); Assert.Empty(store); } [Fact] - public async Task GetMessagesAsyncReturnsAllMessagesAsync() + public async Task InvokingAsyncReturnsAllMessagesAsync() { var store = new InMemoryChatMessageStore { @@ -82,7 +102,8 @@ public async Task GetMessagesAsyncReturnsAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Test2") }; - var result = (await store.GetMessagesAsync(CancellationToken.None)).ToList(); + var context = new ChatMessageStore.InvokingContext([]); + var result = (await store.InvokingAsync(context, CancellationToken.None)).ToList(); Assert.Equal(2, result.Count); Assert.Contains(result, m => m.Text == "Test1"); @@ -157,24 +178,25 @@ public async Task SerializeAndDeserializeWorksWithExperimentalContentTypesAsync( } [Fact] - public async Task AddMessagesAsyncWithEmptyMessagesDoesNotChangeStoreAsync() + public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeStoreAsync() { var store = new InMemoryChatMessageStore(); var messages = new List(); - await store.AddMessagesAsync(messages, CancellationToken.None); + var context = new ChatMessageStore.InvokedContext(messages, []); + await store.InvokedAsync(context, CancellationToken.None); Assert.Empty(store); } [Fact] - public async Task AddMessagesAsync_WithNullMessages_ThrowsArgumentNullExceptionAsync() + public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() { // Arrange var store = new InMemoryChatMessageStore(); // Act & Assert - await Assert.ThrowsAsync(() => store.AddMessagesAsync(null!, CancellationToken.None)); + await Assert.ThrowsAsync(() => store.InvokedAsync(null!, CancellationToken.None).AsTask()); } [Fact] @@ -498,7 +520,8 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var store = new InMemoryChatMessageStore(reducerMock.Object, InMemoryChatMessageStore.ChatReducerTriggerEvent.AfterMessageAdded); // Act - await store.AddMessagesAsync(originalMessages, CancellationToken.None); + var context = new ChatMessageStore.InvokedContext(originalMessages, []); + await store.InvokedAsync(context, CancellationToken.None); // Assert Assert.Single(store); @@ -526,10 +549,15 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe .ReturnsAsync(reducedMessages); var store = new InMemoryChatMessageStore(reducerMock.Object, InMemoryChatMessageStore.ChatReducerTriggerEvent.BeforeMessagesRetrieval); - await store.AddMessagesAsync(originalMessages, CancellationToken.None); + // Add messages directly to the store for this test + foreach (var msg in originalMessages) + { + store.Add(msg); + } // Act - var result = (await store.GetMessagesAsync(CancellationToken.None)).ToList(); + var invokingContext = new ChatMessageStore.InvokingContext(Array.Empty()); + var result = (await store.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert Assert.Single(result); @@ -551,7 +579,8 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var store = new InMemoryChatMessageStore(reducerMock.Object, InMemoryChatMessageStore.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - await store.AddMessagesAsync(originalMessages, CancellationToken.None); + var context = new ChatMessageStore.InvokedContext(originalMessages, []); + await store.InvokedAsync(context, CancellationToken.None); // Assert Assert.Single(store); @@ -576,7 +605,8 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu }; // Act - var result = (await store.GetMessagesAsync(CancellationToken.None)).ToList(); + var invokingContext = new ChatMessageStore.InvokingContext(Array.Empty()); + var result = (await store.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert Assert.Single(result); diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs index 3dbd3ec367..9410e68f1b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatMessageStoreTests.cs @@ -202,11 +202,11 @@ public void Constructor_WithEmptyConversationId_ShouldThrowArgumentException() #endregion - #region AddMessagesAsync Tests + #region InvokedAsync Tests [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task AddMessagesAsync_WithSingleMessage_ShouldAddMessageAsync() + public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); @@ -214,14 +214,20 @@ public async Task AddMessagesAsync_WithSingleMessage_ShouldAddMessageAsync() using var store = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); + var context = new ChatMessageStore.InvokedContext([message], []) + { + ResponseMessages = [] + }; + // Act - await store.AddMessagesAsync([message]); + await store.InvokedAsync(context); // Wait a moment for eventual consistency await Task.Delay(100); // Assert - var messages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var messages = await store.InvokingAsync(invokingContext); var messageList = messages.ToList(); // Simple assertion - if this fails, we know the deserialization is the issue @@ -256,7 +262,7 @@ public async Task AddMessagesAsync_WithSingleMessage_ShouldAddMessageAsync() } string rawJson = rawResults.Count > 0 ? Newtonsoft.Json.JsonConvert.SerializeObject(rawResults[0], Newtonsoft.Json.Formatting.Indented) : "null"; - Assert.Fail($"GetMessagesAsync returned 0 messages, but direct count query found {count} items for conversation {conversationId}. Raw document: {rawJson}"); + Assert.Fail($"InvokingAsync returned 0 messages, but direct count query found {count} items for conversation {conversationId}. Raw document: {rawJson}"); } Assert.Single(messageList); @@ -266,45 +272,63 @@ public async Task AddMessagesAsync_WithSingleMessage_ShouldAddMessageAsync() [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task AddMessagesAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() + public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); var conversationId = Guid.NewGuid().ToString(); using var store = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); - var messages = new[] + var requestMessages = new[] { new ChatMessage(ChatRole.User, "First message"), new ChatMessage(ChatRole.Assistant, "Second message"), new ChatMessage(ChatRole.User, "Third message") }; + var aiContextProviderMessages = new[] + { + new ChatMessage(ChatRole.System, "System context message") + }; + var responseMessages = new[] + { + new ChatMessage(ChatRole.Assistant, "Response message") + }; + + var context = new ChatMessageStore.InvokedContext(requestMessages, []) + { + AIContextProviderMessages = aiContextProviderMessages, + ResponseMessages = responseMessages + }; // Act - await store.AddMessagesAsync(messages); + await store.InvokedAsync(context); // Assert - var retrievedMessages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var retrievedMessages = await store.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); - Assert.Equal(3, messageList.Count); + Assert.Equal(5, messageList.Count); Assert.Equal("First message", messageList[0].Text); Assert.Equal("Second message", messageList[1].Text); Assert.Equal("Third message", messageList[2].Text); + Assert.Equal("System context message", messageList[3].Text); + Assert.Equal("Response message", messageList[4].Text); } #endregion - #region GetMessagesAsync Tests + #region InvokingAsync Tests [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task GetMessagesAsync_WithNoMessages_ShouldReturnEmptyAsync() + public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); using var store = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); // Act - var messages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var messages = await store.InvokingAsync(invokingContext); // Assert Assert.Empty(messages); @@ -312,7 +336,7 @@ public async Task GetMessagesAsync_WithNoMessages_ShouldReturnEmptyAsync() [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task GetMessagesAsync_WithConversationIsolation_ShouldOnlyReturnMessagesForConversationAsync() + public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessagesForConversationAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); @@ -322,12 +346,18 @@ public async Task GetMessagesAsync_WithConversationIsolation_ShouldOnlyReturnMes using var store1 = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - await store1.AddMessagesAsync([new ChatMessage(ChatRole.User, "Message for conversation 1")]); - await store2.AddMessagesAsync([new ChatMessage(ChatRole.User, "Message for conversation 2")]); + var context1 = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + + await store1.InvokedAsync(context1); + await store2.InvokedAsync(context2); // Act - var messages1 = await store1.GetMessagesAsync(); - var messages2 = await store2.GetMessagesAsync(); + var invokingContext1 = new ChatMessageStore.InvokingContext([]); + var invokingContext2 = new ChatMessageStore.InvokingContext([]); + + var messages1 = await store1.InvokingAsync(invokingContext1); + var messages2 = await store2.InvokingAsync(invokingContext2); // Assert var messageList1 = messages1.ToList(); @@ -361,16 +391,18 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - await originalStore.AddMessagesAsync(messages); + var invokedContext = new ChatMessageStore.InvokedContext(messages, []); + await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var retrievedMessages = await originalStore.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); // Act 3: Create new store instance for same conversation (test persistence) using var newStore = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); - var persistedMessages = await newStore.GetMessagesAsync(); + var persistedMessages = await newStore.InvokingAsync(invokingContext); var persistedList = persistedMessages.ToList(); // Assert final state @@ -502,7 +534,7 @@ public void Constructor_WithHierarchicalWhitespaceSessionId_ShouldThrowArgumentE [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task AddMessagesAsync_WithHierarchicalPartitioning_ShouldAddMessageWithMetadataAsync() + public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWithMetadataAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); @@ -513,14 +545,17 @@ public async Task AddMessagesAsync_WithHierarchicalPartitioning_ShouldAddMessage using var store = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); + var context = new ChatMessageStore.InvokedContext([message], []); + // Act - await store.AddMessagesAsync([message]); + await store.InvokedAsync(context); // Wait a moment for eventual consistency await Task.Delay(100); // Assert - var messages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var messages = await store.InvokingAsync(invokingContext); var messageList = messages.ToList(); Assert.Single(messageList); @@ -551,7 +586,7 @@ public async Task AddMessagesAsync_WithHierarchicalPartitioning_ShouldAddMessage [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task AddMessagesAsync_WithHierarchicalMultipleMessages_ShouldAddAllMessagesAsync() + public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); @@ -567,14 +602,17 @@ public async Task AddMessagesAsync_WithHierarchicalMultipleMessages_ShouldAddAll new ChatMessage(ChatRole.User, "Third hierarchical message") }; + var context = new ChatMessageStore.InvokedContext(messages, []); + // Act - await store.AddMessagesAsync(messages); + await store.InvokedAsync(context); // Wait a moment for eventual consistency await Task.Delay(100); // Assert - var retrievedMessages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var retrievedMessages = await store.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(3, messageList.Count); @@ -585,7 +623,7 @@ public async Task AddMessagesAsync_WithHierarchicalMultipleMessages_ShouldAddAll [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task GetMessagesAsync_WithHierarchicalPartitionIsolation_ShouldIsolateMessagesByUserIdAsync() + public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolateMessagesByUserIdAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); @@ -599,17 +637,23 @@ public async Task GetMessagesAsync_WithHierarchicalPartitionIsolation_ShouldIsol using var store2 = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - await store1.AddMessagesAsync([new ChatMessage(ChatRole.User, "Message from user 1")]); - await store2.AddMessagesAsync([new ChatMessage(ChatRole.User, "Message from user 2")]); + var context1 = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []); + + await store1.InvokedAsync(context1); + await store2.InvokedAsync(context2); // Wait a moment for eventual consistency await Task.Delay(100); // Act & Assert - var messages1 = await store1.GetMessagesAsync(); + var invokingContext1 = new ChatMessageStore.InvokingContext([]); + var invokingContext2 = new ChatMessageStore.InvokingContext([]); + + var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); - var messages2 = await store2.GetMessagesAsync(); + var messages2 = await store2.InvokingAsync(invokingContext2); var messageList2 = messages2.ToList(); // With true hierarchical partitioning, each user sees only their own messages @@ -630,7 +674,9 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser const string SessionId = "session-serialize"; using var originalStore = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - await originalStore.AddMessagesAsync([new ChatMessage(ChatRole.User, "Test serialization message")]); + + var context = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []); + await originalStore.InvokedAsync(context); // Act - Serialize the store state var serializedState = originalStore.Serialize(); @@ -647,7 +693,8 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser await Task.Delay(100); // Assert - The deserialized store should have the same functionality - var messages = await deserializedStore.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var messages = await deserializedStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); Assert.Single(messageList); @@ -670,17 +717,22 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalStore = new CosmosChatMessageStore(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - await simpleStore.AddMessagesAsync([new ChatMessage(ChatRole.User, "Simple partitioning message")]); - await hierarchicalStore.AddMessagesAsync([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); + var simpleContext = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatMessageStore.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + + await simpleStore.InvokedAsync(simpleContext); + await hierarchicalStore.InvokedAsync(hierarchicalContext); // Wait a moment for eventual consistency await Task.Delay(100); // Act & Assert - var simpleMessages = await simpleStore.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + + var simpleMessages = await simpleStore.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); - var hierarchicalMessages = await hierarchicalStore.GetMessagesAsync(); + var hierarchicalMessages = await hierarchicalStore.InvokingAsync(invokingContext); var hierarchicalMessageList = hierarchicalMessages.ToList(); // Each should only see its own messages since they use different containers @@ -707,14 +759,17 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); await Task.Delay(10); // Small delay to ensure different timestamps } - await store.AddMessagesAsync(messages); + + var context = new ChatMessageStore.InvokedContext(messages, []); + await store.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - Set max to 5 and retrieve store.MaxMessagesToRetrieve = 5; - var retrievedMessages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var retrievedMessages = await store.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); // Assert - Should get the 5 most recent messages (6-10) in ascending order @@ -742,13 +797,16 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() { messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - await store.AddMessagesAsync(messages); + + var context = new ChatMessageStore.InvokedContext(messages, []); + await store.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var retrievedMessages = await store.GetMessagesAsync(); + var invokingContext = new ChatMessageStore.InvokingContext([]); + var retrievedMessages = await store.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); // Assert - Should get all 10 messages diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 6e9d952b57..5850bc56ba 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -502,6 +502,12 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); Mock mockChatMessageStore = new(); + mockChatMessageStore.Setup(s => s.InvokingAsync( + It.IsAny(), + It.IsAny())).ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + mockChatMessageStore.Setup(s => s.InvokedAsync( + It.IsAny(), + It.IsAny())).Returns(new ValueTask()); Mock> mockFactory = new(); mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); @@ -518,7 +524,58 @@ public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversati // Assert Assert.IsType(thread!.MessageStore, exactMatch: false); - mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); + mockService.Verify( + x => x.GetResponseAsync( + It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), + It.IsAny(), + It.IsAny()), + Times.Once); + mockChatMessageStore.Verify(s => s.InvokingAsync( + It.Is(x => x.RequestMessages.Count() == 1), + It.IsAny()), + Times.Once); + mockChatMessageStore.Verify(s => s.InvokedAsync( + It.Is(x => x.RequestMessages.Count() == 1 && x.ChatMessageStoreMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), + It.IsAny()), + Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync notifies the ChatMessageStore on failure. + /// + [Fact] + public async Task RunAsyncNotifiesChatMessageStoreOnFailureAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).Throws(new InvalidOperationException("Test Error")); + + Mock mockChatMessageStore = new(); + + Mock> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatOptions = new() { Instructions = "test instructions" }, + ChatMessageStoreFactory = mockFactory.Object + }); + + // Act + ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); + + // Assert + Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.InvokedAsync( + It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), + It.IsAny()), + Times.Once); mockFactory.Verify(f => f(It.IsAny()), Times.Once); } @@ -610,7 +667,7 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "base function"); Assert.Contains(capturedTools, t => t.Name == "context provider function"); - // Verify that the thread was updated with the input, ai context and response messages + // Verify that the thread was updated with the ai context provider, input and response messages var messageStore = Assert.IsType(thread!.MessageStore); Assert.Equal(3, messageStore.Count); Assert.Equal("user message", messageStore[0].Text); @@ -2067,7 +2124,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "base function"); Assert.Contains(capturedTools, t => t.Name == "context provider function"); - // Verify that the thread was updated with the input, ai context and response messages + // Verify that the thread was updated with the input, ai context provider, and response messages var messageStore = Assert.IsType(thread!.MessageStore); Assert.Equal(3, messageStore.Count); Assert.Equal("user message", messageStore[0].Text); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 583a0815ca..3bc28ee12f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -339,7 +339,7 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync // Create a mock message store that would normally provide messages var mockMessageStore = new Mock(); mockMessageStore - .Setup(ms => ms.GetMessagesAsync(It.IsAny())) + .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) .ReturnsAsync([new(ChatRole.User, "Message from message store")]); // Create a mock AI context provider that would normally provide context @@ -383,7 +383,7 @@ public async Task RunAsyncSkipsThreadMessagePopulationWithContinuationTokenAsync // Verify that message store was never called due to continuation token mockMessageStore.Verify( - ms => ms.GetMessagesAsync(It.IsAny()), + ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), Times.Never); // Verify that AI context provider was never called due to continuation token @@ -401,7 +401,7 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT // Create a mock message store that would normally provide messages var mockMessageStore = new Mock(); mockMessageStore - .Setup(ms => ms.GetMessagesAsync(It.IsAny())) + .Setup(ms => ms.InvokingAsync(It.IsAny(), It.IsAny())) .ReturnsAsync([new(ChatRole.User, "Message from message store")]); // Create a mock AI context provider that would normally provide context @@ -446,7 +446,7 @@ public async Task RunStreamingAsyncSkipsThreadMessagePopulationWithContinuationT // Verify that message store was never called due to continuation token mockMessageStore.Verify( - ms => ms.GetMessagesAsync(It.IsAny()), + ms => ms.InvokingAsync(It.IsAny(), It.IsAny()), Times.Never); // Verify that AI context provider was never called due to continuation token diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index 656d310ddf..0fb9745d2d 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -32,7 +32,12 @@ public async Task> GetChatHistoryAsync(AgentThread thread) { var typedThread = (ChatClientAgentThread)thread; - return typedThread.MessageStore is null ? [] : (await typedThread.MessageStore.GetMessagesAsync()).ToList(); + if (typedThread.MessageStore is null) + { + return []; + } + + return (await typedThread.MessageStore.InvokingAsync(new([]))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index c6c84db569..c57e1c460d 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -50,7 +50,12 @@ public async Task> GetChatHistoryAsync(AgentThread thread) return [.. previousMessages, responseMessage]; } - return typedThread.MessageStore is null ? [] : (await typedThread.MessageStore.GetMessagesAsync()).ToList(); + if (typedThread.MessageStore is null) + { + return []; + } + + return (await typedThread.MessageStore.InvokingAsync(new([]))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item)