diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs index cbafeddc3f4e..0394f7590b24 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs @@ -397,6 +397,47 @@ public async Task ValidateGetChatMessageContentsWithAutoFunctionInvocationFilter Assert.Contains("GetWeather", invokedFunctions); } + [Fact] + public async Task ValidateGetStreamingChatMessageContentWithAutoFunctionInvocationFilterTerminateAsync() + { + // Arrange + var client = this.CreateMistralClientStreaming("mistral-tiny", "https://api.mistral.ai/v1/chat/completions", "chat_completions_streaming_function_call_response.txt"); + + var kernel = new Kernel(); + kernel.Plugins.AddFromType(); + + var filter = new FakeAutoFunctionFilter(async (context, next) => + { + await next(context); + context.Terminate = true; + }); + kernel.AutoFunctionInvocationFilters.Add(filter); + + var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions }; + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") + }; + + List streamingContent = []; + + // Act + await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel)) + { + streamingContent.Add(item); + } + + // Assert + // Results of function invoked before termination should be returned + Assert.Equal(3, streamingContent.Count); + + var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; + Assert.NotNull(lastMessageContent); + + Assert.Equal("12°C\nWind: 11 KMPH\nHumidity: 48%\nMostly cloudy", lastMessageContent.Content); + Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); + } + [Theory] [InlineData("system", "System Content")] [InlineData("user", "User Content")] diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 78c9e6dce33f..cdd9c33f4789 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -447,7 +447,7 @@ internal async IAsyncEnumerable GetStreamingChatMes this.AddResponseMessage(chatRequest, chatHistory, toolCall, result: stringResult, errorMessage: null); - // If filter requested termination, breaking request iteration loop. + // If filter requested termination, returning latest function result and breaking request iteration loop. if (invocationContext.Terminate) { if (this._logger.IsEnabled(LogLevel.Debug)) @@ -455,6 +455,9 @@ internal async IAsyncEnumerable GetStreamingChatMes this._logger.LogDebug("Filter requested termination of automatic function invocation."); } + var lastChatMessage = chatHistory.Last(); + + yield return new StreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content); yield break; } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 60124db2c1e9..b985c529764c 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -859,7 +859,7 @@ internal async IAsyncEnumerable GetStreamingC AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, stringResult, errorMessage: null, this.Logger); - // If filter requested termination, breaking request iteration loop. + // If filter requested termination, returning latest function result and breaking request iteration loop. if (invocationContext.Terminate) { if (this.Logger.IsEnabled(LogLevel.Debug)) @@ -867,6 +867,9 @@ internal async IAsyncEnumerable GetStreamingC this.Logger.LogDebug("Filter requested termination of automatic function invocation."); } + var lastChatMessage = chat.Last(); + + yield return new OpenAIStreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content); yield break; } diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/FunctionCalling/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/FunctionCalling/AutoFunctionInvocationFilterTests.cs index b16bf02b6cb0..1151ea41bc9b 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/FunctionCalling/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/FunctionCalling/AutoFunctionInvocationFilterTests.cs @@ -497,7 +497,7 @@ public async Task PostFilterCanTerminateOperationAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); // Act - await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions })); @@ -507,6 +507,13 @@ public async Task PostFilterCanTerminateOperationAsync() Assert.Equal(0, secondFunctionInvocations); Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + var lastMessageContent = result.GetValue(); + Assert.NotNull(lastMessageContent); + + Assert.Equal("function1-value", lastMessageContent.Content); + Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); } [Fact] @@ -538,15 +545,28 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + List streamingContent = []; + // Act await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) - { } + { + streamingContent.Add(item); + } // Assert Assert.Equal(1, firstFunctionInvocations); Assert.Equal(0, secondFunctionInvocations); Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + Assert.Equal(3, streamingContent.Count); + + var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; + Assert.NotNull(lastMessageContent); + + Assert.Equal("function1-value", lastMessageContent.Content); + Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); } public void Dispose()