Skip to content

Commit

Permalink
.Net: Return result of the function executed before termination for s…
Browse files Browse the repository at this point in the history
…treaming API (#6428)

### Motivation, Context and Description
Fixes: #6404

Today, the SK chat completion streaming API does not return the result
of a function executed before termination, whereas the non-streaming
version does return the result. This PR resolves this issue by returning
the result of a function executed before termination was requested.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
SergeyMenshykh authored May 29, 2024
1 parent 246b843 commit 02145d9
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,47 @@ public async Task ValidateGetChatMessageContentsWithAutoFunctionInvocationFilter
Assert.Contains("GetWeather", invokedFunctions);
}

[Fact]
public async Task ValidateGetStreamingChatMessageContentWithAutoFunctionInvocationFilterTerminateAsync()
{
// Arrange
var client = this.CreateMistralClientStreaming("mistral-tiny", "https://api.mistral.ai/v1/chat/completions", "chat_completions_streaming_function_call_response.txt");

var kernel = new Kernel();
kernel.Plugins.AddFromType<WeatherPlugin>();

var filter = new FakeAutoFunctionFilter(async (context, next) =>
{
await next(context);
context.Terminate = true;
});
kernel.AutoFunctionInvocationFilters.Add(filter);

var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions };
var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};

List<StreamingKernelContent> streamingContent = [];

// Act
await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel))
{
streamingContent.Add(item);
}

// Assert
// Results of function invoked before termination should be returned
Assert.Equal(3, streamingContent.Count);

var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
Assert.NotNull(lastMessageContent);

Assert.Equal("12°C\nWind: 11 KMPH\nHumidity: 48%\nMostly cloudy", lastMessageContent.Content);
Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
}

[Theory]
[InlineData("system", "System Content")]
[InlineData("user", "User Content")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,17 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes

this.AddResponseMessage(chatRequest, chatHistory, toolCall, result: stringResult, errorMessage: null);

// If filter requested termination, breaking request iteration loop.
// If filter requested termination, returning latest function result and breaking request iteration loop.
if (invocationContext.Terminate)
{
if (this._logger.IsEnabled(LogLevel.Debug))
{
this._logger.LogDebug("Filter requested termination of automatic function invocation.");
}

var lastChatMessage = chatHistory.Last();

yield return new StreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content);
yield break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -859,14 +859,17 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC

AddResponseMessage(chatOptions, chat, streamedRole, toolCall, metadata, stringResult, errorMessage: null, this.Logger);

// If filter requested termination, breaking request iteration loop.
// If filter requested termination, returning latest function result and breaking request iteration loop.
if (invocationContext.Terminate)
{
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Filter requested termination of automatic function invocation.");
}

var lastChatMessage = chat.Last();

yield return new OpenAIStreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content);
yield break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ public async Task PostFilterCanTerminateOperationAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();

// Act
await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
}));
Expand All @@ -507,6 +507,13 @@ public async Task PostFilterCanTerminateOperationAsync()
Assert.Equal(0, secondFunctionInvocations);
Assert.Equal([0], requestSequenceNumbers);
Assert.Equal([0], functionSequenceNumbers);

// Results of function invoked before termination should be returned
var lastMessageContent = result.GetValue<ChatMessageContent>();
Assert.NotNull(lastMessageContent);

Assert.Equal("function1-value", lastMessageContent.Content);
Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
}

[Fact]
Expand Down Expand Up @@ -538,15 +545,28 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()

var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };

List<StreamingKernelContent> streamingContent = [];

// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
{ }
{
streamingContent.Add(item);
}

// Assert
Assert.Equal(1, firstFunctionInvocations);
Assert.Equal(0, secondFunctionInvocations);
Assert.Equal([0], requestSequenceNumbers);
Assert.Equal([0], functionSequenceNumbers);

// Results of function invoked before termination should be returned
Assert.Equal(3, streamingContent.Count);

var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
Assert.NotNull(lastMessageContent);

Assert.Equal("function1-value", lastMessageContent.Content);
Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
}

public void Dispose()
Expand Down

0 comments on commit 02145d9

Please sign in to comment.