diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index d04d9bb9fb..049322785c 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -289,6 +289,7 @@ public override async IAsyncEnumerable RunStreamingAsync public override AgentThread GetNewThread() => new ChatClientAgentThread { + MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }), AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) }; @@ -316,6 +317,34 @@ public AgentThread GetNewThread(string conversationId) AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) }; + /// + /// Creates a new agent thread instance using an existing to continue a conversation. + /// + /// The instance to use for managing the conversation's message history. + /// + /// A new instance configured to work with the provided . + /// + /// + /// + /// This method creates threads that do not support server-side conversation storage. + /// Some AI services require server-side conversation storage to function properly, and creating a thread + /// with a may not be compatible with these services. + /// + /// + /// Where a service requires server-side conversation storage, use . + /// + /// + /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, + /// the thread will throw an exception to indicate that it cannot continue using the provided . + /// + /// + public AgentThread GetNewThread(ChatMessageStore chatMessageStore) + => new ChatClientAgentThread() + { + MessageStore = Throw.IfNull(chatMessageStore), + AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + }; + /// public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index e9d458aba4..920f9f82ee 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -459,10 +459,10 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat } /// - /// Verify that RunAsync doesn't use the ChatMessageStore factory when the chat client returns a conversation id. + /// Verify that RunAsync uses the default InMemoryChatMessageStore when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -470,9 +470,42 @@ public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByCha s => s.GetResponseAsync( It.IsAny>(), It.IsAny(), - It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + ChatClientAgent agent = new(mockService.Object, options: new() + { + Instructions = "test instructions", + }); + + // Act + ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + await agent.RunAsync([new(ChatRole.User, "test")], thread); + + // Assert + var messageStore = Assert.IsType(thread!.MessageStore); + Assert.Equal(2, messageStore.Count); + Assert.Equal("test", messageStore[0].Text); + Assert.Equal("response", messageStore[1].Text); + } + + /// + /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. + /// + [Fact] + public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversationIdReturnedByChatClientAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + Mock mockChatMessageStore = new(); + Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", @@ -484,8 +517,36 @@ public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByCha await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert - Assert.Equal("ConvId", thread!.ConversationId); - mockFactory.Verify(f => f(It.IsAny()), Times.Never); + Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync throws when a ChatMessageStore Factory is provided and the chat client returns a conversation id. + /// + [Fact] + public async Task RunAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversationIdReturnedByChatClientAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); + Mock> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + ChatClientAgent agent = new(mockService.Object, options: new() + { + Instructions = "test instructions", + ChatMessageStoreFactory = mockFactory.Object + }); + + // Act & Assert + ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + var exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } /// @@ -1914,10 +1975,10 @@ public async Task RunStreamingAsyncUsesChatMessageStoreWhenNoConversationIdRetur } /// - /// Verify that RunStreamingAsync doesn't use the ChatMessageStore factory when the chat client returns a conversation id. + /// Verify that RunStreamingAsync throws when a ChatMessageStore factory is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunStreamingAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByChatClientAsync() + public async Task RunStreamingAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -1939,13 +2000,10 @@ public async Task RunStreamingAsyncIgnoresChatMessageStoreWhenConversationIdRetu ChatMessageStoreFactory = mockFactory.Object }); - // Act + // Act & Assert ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; - await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync(); - - // Assert - Assert.Equal("ConvId", thread!.ConversationId); - mockFactory.Verify(f => f(It.IsAny()), Times.Never); + var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync()); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } /// @@ -2074,37 +2132,6 @@ await Assert.ThrowsAsync(async () => #endregion - #region GetNewThread Tests - - [Fact] - public void GetNewThreadUsesAIContextProviderFactoryIfProvided() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - Instructions = "Test instructions", - AIContextProviderFactory = _ => - { - factoryCalled = true; - return mockContextProvider.Object; - } - }); - - // Act - var thread = agent.GetNewThread(); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(thread); - var typedThread = (ChatClientAgentThread)thread; - Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); - } - - #endregion - #region Background Responses Tests [Theory] diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs new file mode 100644 index 0000000000..1fd9a71b98 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +/// +/// Contains unit tests for the ChatClientAgent.DeserializeThread methods. +/// +public class ChatClientAgent_DeserializeThreadTests +{ + [Fact] + public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + factoryCalled = true; + return mockContextProvider.Object; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + "aiContextProviderState": ["CP1"] + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var thread = agent.DeserializeThread(json); + + // Assert + Assert.True(factoryCalled, "AIContextProviderFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + factoryCalled = true; + return mockMessageStore.Object; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + "storeState": { } + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var thread = agent.DeserializeThread(json); + + // Assert + Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs new file mode 100644 index 0000000000..43e0bef8bc --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +/// +/// Contains unit tests for the ChatClientAgent.GetNewThread methods. +/// +public class ChatClientAgent_GetNewThreadTests +{ + [Fact] + public void GetNewThread_UsesAIContextProviderFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + factoryCalled = true; + return mockContextProvider.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "AIContextProviderFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void GetNewThread_UsesChatMessageStoreFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + factoryCalled = true; + return mockMessageStore.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesChatMessageStore_FromTypedOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var thread = agent.GetNewThread(mockMessageStore.Object); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesConversationId_FromTypedOverload() + { + // Arrange + var mockChatClient = new Mock(); + const string TestConversationId = "test_conversation_id"; + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var thread = agent.GetNewThread(TestConversationId); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Equal(TestConversationId, typedThread.ConversationId); + } +}