Skip to content

Commit

Permalink
.Net: Added support for multiple chat and text results from Kernel (#…
Browse files Browse the repository at this point in the history
…6704)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

**Note**: This PR changes the behavior when AI connector returns
multiple results, when using `kernel.InvokeAsync`. It won't throw an
exception anymore, but instead it will return multiple results. The
behavior for single result is not changed.

Fixes: #6434

When executing prompt function using kernel and asking for multiple
results per prompt, we will get an error:
```csharp
var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
{
    ResultsPerPrompt = 3
});

var result = await this._kernel.InvokePromptAsync("Hi, can you help me today?", arguments); // this will throw an exception
```

Current `KernelFunctionFromPrompt` implementation expects only single
result from AI connector, while its API can return multiple results per
prompt/request.

This PR updates `KernelFunctionFromPrompt` in a following way:
1. If AI connector returns single item - the behavior will be the same
as it is today, `FunctionResult` will contain instance of that item, so
it's possible to get its properties, use `ToString()` etc.
2. If AI connector returns multiple items - all items will be returned
in collection to the caller, and this collection needs to be handled
appropriately (by using loop or accessing specific item by index).

One of the examples shows how to select one result, in case if we invoke
prompt function inside another prompt function using prompt template
engine. In this case, filter can be registered, which will get multiple
results produced by function, select one of them and return it back to
the prompt rendering operation.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk committed Jun 18, 2024
1 parent d30250f commit 4b8a526
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,60 +1,133 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;

namespace ChatCompletion;

// The following example shows how to use Semantic Kernel with streaming Multiple Results Chat Completion.
/// <summary>
/// The following example shows how to use Semantic Kernel with multiple chat completion results.
/// </summary>
public class OpenAI_ChatCompletionMultipleChoices(ITestOutputHelper output) : BaseTest(output)
{
/// <summary>
/// Example with multiple chat completion results using <see cref="Kernel"/>.
/// </summary>
[Fact]
public Task AzureOpenAIMultiChatCompletionAsync()
public async Task MultipleChatCompletionResultsUsingKernelAsync()
{
Console.WriteLine("======== Azure OpenAI - Multiple Chat Completion ========");
var kernel = Kernel
.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
.Build();

var chatCompletionService = new AzureOpenAIChatCompletionService(
deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName,
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
modelId: TestConfiguration.AzureOpenAI.ChatModelId);
// Execution settings with configured ResultsPerPrompt property.
var executionSettings = new OpenAIPromptExecutionSettings { MaxTokens = 200, ResultsPerPrompt = 3 };

return ChatCompletionAsync(chatCompletionService);
var contents = await kernel.InvokePromptAsync<IReadOnlyList<KernelContent>>("Write a paragraph about why AI is awesome", new(executionSettings));

foreach (var content in contents!)
{
Console.Write(content.ToString() ?? string.Empty);
Console.WriteLine("\n-------------\n");
}
}

/// <summary>
/// Example with multiple chat completion results using <see cref="IChatCompletionService"/>.
/// </summary>
[Fact]
public Task OpenAIMultiChatCompletionAsync()
public async Task MultipleChatCompletionResultsUsingChatCompletionServiceAsync()
{
Console.WriteLine("======== Open AI - Multiple Chat Completion ========");
var kernel = Kernel
.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
.Build();

// Execution settings with configured ResultsPerPrompt property.
var executionSettings = new OpenAIPromptExecutionSettings { MaxTokens = 200, ResultsPerPrompt = 3 };

var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Write a paragraph about why AI is awesome");

var chatCompletionService = new OpenAIChatCompletionService(
TestConfiguration.OpenAI.ChatModelId,
TestConfiguration.OpenAI.ApiKey);
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();

return ChatCompletionAsync(chatCompletionService);
foreach (var chatMessageContent in await chatCompletionService.GetChatMessageContentsAsync(chatHistory, executionSettings))
{
Console.Write(chatMessageContent.Content ?? string.Empty);
Console.WriteLine("\n-------------\n");
}
}

private async Task ChatCompletionAsync(IChatCompletionService chatCompletionService)
/// <summary>
/// This example shows how to handle multiple results in case if prompt template contains a call to another prompt function.
/// <see cref="FunctionResultSelectionFilter"/> is used for result selection.
/// </summary>
[Fact]
public async Task MultipleChatCompletionResultsInPromptTemplateAsync()
{
var executionSettings = new OpenAIPromptExecutionSettings()
{
MaxTokens = 200,
FrequencyPenalty = 0,
PresencePenalty = 0,
Temperature = 1,
TopP = 0.5,
ResultsPerPrompt = 2,
};
var kernel = Kernel
.CreateBuilder()
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
apiKey: TestConfiguration.OpenAI.ApiKey)
.Build();

var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Write one paragraph about why AI is awesome");
var executionSettings = new OpenAIPromptExecutionSettings { MaxTokens = 200, ResultsPerPrompt = 3 };

// Initializing a function with execution settings for multiple results.
// We ask AI to write one paragraph, but in execution settings we specified that we want 3 different results for this request.
var function = KernelFunctionFactory.CreateFromPrompt("Write a paragraph about why AI is awesome", executionSettings, "GetParagraph");
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]);

foreach (var chatMessageChoice in await chatCompletionService.GetChatMessageContentsAsync(chatHistory, executionSettings))
kernel.Plugins.Add(plugin);

// Add function result selection filter.
kernel.FunctionInvocationFilters.Add(new FunctionResultSelectionFilter(this.Output));

// Inside our main request, we call MyPlugin.GetParagraph function for text summarization.
// Taking into account that MyPlugin.GetParagraph function produces 3 results, for text summarization we need to choose only one of them.
// Registered filter will be invoked during execution, which will select and return only 1 result, and this result will be inserted in our main request for summarization.
var result = await kernel.InvokePromptAsync("Summarize this text: {{MyPlugin.GetParagraph}}");

// It's possible to check what prompt was rendered for our main request.
Console.WriteLine($"Rendered prompt: '{result.RenderedPrompt}'");

// Output:
// Rendered prompt: 'Summarize this text: AI is awesome because...'
}

/// <summary>
/// Example of filter which is responsible for result selection in case if some function produces multiple results.
/// </summary>
private sealed class FunctionResultSelectionFilter(ITestOutputHelper output) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
Console.Write(chatMessageChoice.Content ?? string.Empty);
Console.WriteLine("\n-------------\n");
}
await next(context);

// Selection logic for function which is expected to produce multiple results.
if (context.Function.Name == "GetParagraph")
{
// Get multiple results from function invocation
var contents = context.Result.GetValue<IReadOnlyList<KernelContent>>()!;

Console.WriteLine();
output.WriteLine("Multiple results:");

foreach (var content in contents)
{
output.WriteLine(content.ToString());
}

// Select first result for correct prompt rendering
var selectedContent = contents[0];
context.Result = new FunctionResult(context.Function, selectedContent, context.Kernel.Culture, selectedContent.Metadata);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,38 @@ public async Task AzureOpenAIInvokePromptTestAsync()

// Assert
Assert.Contains("Pike Place", actual.GetValue<string>(), StringComparison.OrdinalIgnoreCase);
Assert.NotNull(actual.Metadata);
}

[Fact]
public async Task AzureOpenAIInvokePromptWithMultipleResultsTestAsync()
{
// Arrange
this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
Kernel target = builder.Build();

var prompt = "Where is the most famous fish market in Seattle, Washington, USA?";

var executionSettings = new OpenAIPromptExecutionSettings() { MaxTokens = 150, ResultsPerPrompt = 3 };

// Act
FunctionResult actual = await target.InvokePromptAsync(prompt, new(executionSettings));

// Assert
Assert.Null(actual.Metadata);

var chatMessageContents = actual.GetValue<IReadOnlyList<ChatMessageContent>>();

Assert.NotNull(chatMessageContents);
Assert.Equal(executionSettings.ResultsPerPrompt, chatMessageContents.Count);

foreach (var chatMessageContent in chatMessageContents)
{
Assert.NotNull(chatMessageContent.Metadata);
Assert.Contains("Pike Place", chatMessageContent.Content, StringComparison.OrdinalIgnoreCase);
}
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static class ChatCompletionServiceExtensions
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of different chat results generated by the remote model</returns>
/// <returns>Single chat message content generated by the remote model.</returns>
public static async Task<ChatMessageContent> GetChatMessageContentAsync(
this IChatCompletionService chatCompletionService,
string prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static class TextGenerationExtensions
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of different text results generated by the remote model</returns>
/// <returns>Single text content generated by the remote model.</returns>
public static async Task<TextContent> GetTextContentAsync(
this ITextGenerationService textGenerationService,
string prompt,
Expand All @@ -34,15 +34,15 @@ public static class TextGenerationExtensions
.Single();

/// <summary>
/// Get a single text generation result for the standardized prompt and settings.
/// Get a text generation results for the standardized prompt and settings.
/// </summary>
/// <param name="textGenerationService">Text generation service</param>
/// <param name="prompt">The standardized prompt input.</param>
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of different text results generated by the remote model</returns>
internal static async Task<TextContent> GetTextContentWithDefaultParserAsync(
internal static async Task<IReadOnlyList<TextContent>> GetTextContentsWithDefaultParserAsync(
this ITextGenerationService textGenerationService,
string prompt,
PromptExecutionSettings? executionSettings = null,
Expand All @@ -52,12 +52,14 @@ public static class TextGenerationExtensions
if (textGenerationService is IChatCompletionService chatCompletion
&& ChatPromptParser.TryParse(prompt, out var chatHistory))
{
var chatMessage = await chatCompletion.GetChatMessageContentAsync(chatHistory, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
return new TextContent(chatMessage.Content, chatMessage.ModelId, chatMessage.InnerContent, chatMessage.Encoding, chatMessage.Metadata);
var chatMessages = await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
return chatMessages
.Select(chatMessage => new TextContent(chatMessage.Content, chatMessage.ModelId, chatMessage.InnerContent, chatMessage.Encoding, chatMessage.Metadata))
.ToArray();
}

// When using against text generations, the prompt will be used as is.
return await textGenerationService.GetTextContentAsync(prompt, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
return await textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,38 +123,29 @@ internal sealed class KernelFunctionFromPrompt : KernelFunction
{
this.AddDefaultValues(arguments);

var result = await this.RenderPromptAsync(kernel, arguments, cancellationToken).ConfigureAwait(false);
var promptRenderingResult = await this.RenderPromptAsync(kernel, arguments, cancellationToken).ConfigureAwait(false);

#pragma warning disable CS0612 // Events are deprecated
if (result.RenderedEventArgs?.Cancel is true)
if (promptRenderingResult.RenderedEventArgs?.Cancel is true)
{
throw new OperationCanceledException($"A {nameof(Kernel)}.{nameof(Kernel.PromptRendered)} event handler requested cancellation after prompt rendering.");
}
#pragma warning restore CS0612 // Events are deprecated

// Return function result if it was set in prompt filter.
if (result.FunctionResult is not null)
if (promptRenderingResult.FunctionResult is not null)
{
result.FunctionResult.RenderedPrompt = result.RenderedPrompt;
return result.FunctionResult;
promptRenderingResult.FunctionResult.RenderedPrompt = promptRenderingResult.RenderedPrompt;
return promptRenderingResult.FunctionResult;
}

if (result.AIService is IChatCompletionService chatCompletion)
return promptRenderingResult.AIService switch
{
var chatContent = await chatCompletion.GetChatMessageContentAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
this.CaptureUsageDetails(chatContent.ModelId, chatContent.Metadata, this._logger);
return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata) { RenderedPrompt = result.RenderedPrompt };
}

if (result.AIService is ITextGenerationService textGeneration)
{
var textContent = await textGeneration.GetTextContentWithDefaultParserAsync(result.RenderedPrompt, result.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
this.CaptureUsageDetails(textContent.ModelId, textContent.Metadata, this._logger);
return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata) { RenderedPrompt = result.RenderedPrompt };
}

// The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector.
throw new NotSupportedException($"The AI service {result.AIService.GetType()} is not supported. Supported services are {typeof(IChatCompletionService)} and {typeof(ITextGenerationService)}");
IChatCompletionService chatCompletion => await this.GetChatCompletionResultAsync(chatCompletion, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false),
ITextGenerationService textGeneration => await this.GetTextGenerationResultAsync(textGeneration, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false),
// The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector.
_ => throw new NotSupportedException($"The AI service {promptRenderingResult.AIService.GetType()} is not supported. Supported services are {typeof(IChatCompletionService)} and {typeof(ITextGenerationService)}")
};
}

/// <inheritdoc/>
Expand Down Expand Up @@ -449,5 +440,67 @@ private void CaptureUsageDetails(string? modelId, IReadOnlyDictionary<string, ob
}
}

private async Task<FunctionResult> GetChatCompletionResultAsync(
IChatCompletionService chatCompletion,
Kernel kernel,
PromptRenderingResult promptRenderingResult,
CancellationToken cancellationToken)
{
var chatContents = await chatCompletion.GetChatMessageContentsAsync(
promptRenderingResult.RenderedPrompt,
promptRenderingResult.ExecutionSettings,
kernel,
cancellationToken).ConfigureAwait(false);

if (chatContents is { Count: 0 })
{
return new FunctionResult(this, culture: kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

// Usage details are global and duplicated for each chat message content, use first one to get usage information
var chatContent = chatContents[0];
this.CaptureUsageDetails(chatContent.ModelId, chatContent.Metadata, this._logger);

// If collection has one element, return single result
if (chatContents.Count == 1)
{
return new FunctionResult(this, chatContent, kernel.Culture, chatContent.Metadata) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

// Otherwise, return multiple results
return new FunctionResult(this, chatContents, kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

private async Task<FunctionResult> GetTextGenerationResultAsync(
ITextGenerationService textGeneration,
Kernel kernel,
PromptRenderingResult promptRenderingResult,
CancellationToken cancellationToken)
{
var textContents = await textGeneration.GetTextContentsWithDefaultParserAsync(
promptRenderingResult.RenderedPrompt,
promptRenderingResult.ExecutionSettings,
kernel,
cancellationToken).ConfigureAwait(false);

if (textContents is { Count: 0 })
{
return new FunctionResult(this, culture: kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

// Usage details are global and duplicated for each text content, use first one to get usage information
var textContent = textContents[0];
this.CaptureUsageDetails(textContent.ModelId, textContent.Metadata, this._logger);

// If collection has one element, return single result
if (textContents.Count == 1)
{
return new FunctionResult(this, textContent, kernel.Culture, textContent.Metadata) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

// Otherwise, return multiple results
return new FunctionResult(this, textContents, kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt };
}

#endregion
}
Loading

0 comments on commit 4b8a526

Please sign in to comment.