From 455626df5864074af805bf39b80d6e8976bdd1eb Mon Sep 17 00:00:00 2001 From: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:30:54 +0100 Subject: [PATCH] Try a few different techniques and compare outcomes --- .../OpenAI_ReasonedFunctionCalling.cs | 141 +++++++++++++++++- .../samples/InternalUtilities/BaseTest.cs | 2 + 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ReasonedFunctionCalling.cs b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ReasonedFunctionCalling.cs index 48001a17d2e0..28d5d8506a1f 100644 --- a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ReasonedFunctionCalling.cs +++ b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ReasonedFunctionCalling.cs @@ -8,12 +8,15 @@ namespace ChatCompletion; /// -/// +/// Samples showing how to get the LLM to provide the reason is using function calling. /// public sealed class OpenAI_ReasonedFunctionCalling(ITestOutputHelper output) : BaseTest(output) { + /// + /// Using the system prompt to explain function calls doesn't work work with gpt-4o. + /// [Fact] - public async Task AskingAssistantToExplainFunctionCallsAsync() + public async Task UseSystemPromptToExplainFunctionCallsAsync() { // Create a kernel with MistralAI chat completion and WeatherPlugin Kernel kernel = CreateKernelWithPlugin(); @@ -26,6 +29,27 @@ public async Task AskingAssistantToExplainFunctionCallsAsync() new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel); + chatHistory.Add(result); + Console.WriteLine(result); + } + + /// + /// Asking the model to explain function calls after execution works well but may be too late depending on your use case. + /// + [Fact] + public async Task AskAssistantToExplainFunctionCallsAfterExecutionAsync() + { + // Create a kernel with MistralAI chat completion and WeatherPlugin + Kernel kernel = CreateKernelWithPlugin(); + var service = kernel.GetRequiredService(); + + // Invoke chat prompt with auto invocation of functions enabled + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") + }; + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; var result1 = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel); chatHistory.Add(result1); Console.WriteLine(result1); @@ -35,7 +59,118 @@ public async Task AskingAssistantToExplainFunctionCallsAsync() Console.WriteLine(result2); } - public sealed class WeatherPlugin + /// + /// Asking the model to explain function calls in response to each function call can work but the model may also + /// get confused and treat the request to explain the function calls as an error response from the function calls. + /// + [Fact] + public async Task AskAssistantToExplainFunctionCallsBeforeExecutionAsync() + { + // Create a kernel with MistralAI chat completion and WeatherPlugin + Kernel kernel = CreateKernelWithPlugin(); + kernel.AutoFunctionInvocationFilters.Add(new RespondExplainFunctionInvocationFilter()); + var service = kernel.GetRequiredService(); + + // Invoke chat prompt with auto invocation of functions enabled + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") + }; + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel); + chatHistory.Add(result); + Console.WriteLine(result); + } + + /// + /// Asking to the model to explain function calls using a separate conversation i.e. chat history seems to provide the + /// best results. This may be because the model can focus on explaining the function calls without being confused by other + /// messages in the chat history. + /// + [Fact] + public async Task QueryAssistantToExplainFunctionCallsBeforeExecutionAsync() + { + // Create a kernel with MistralAI chat completion and WeatherPlugin + Kernel kernel = CreateKernelWithPlugin(); + kernel.AutoFunctionInvocationFilters.Add(new QueryExplainFunctionInvocationFilter(this.Output)); + var service = kernel.GetRequiredService(); + + // Invoke chat prompt with auto invocation of functions enabled + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") + }; + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel); + chatHistory.Add(result); + Console.WriteLine(result); + } + + /// + /// This will respond to function call requests and ask the model to explain why it is + /// calling the function(s). It is only suitable for transient use because it stores information about the functions that have been + /// called for a single chat history. + /// + private sealed class RespondExplainFunctionInvocationFilter : IAutoFunctionInvocationFilter + { + private readonly HashSet _functionNames = []; + + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Get the function calls for which we need an explanation + var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last()); + var needExplanation = 0; + foreach (var functionCall in functionCalls) + { + var functionName = $"{functionCall.PluginName}-{functionCall.FunctionName}"; + if (_functionNames.Add(functionName)) + { + needExplanation++; + } + } + + if (needExplanation > 0) + { + // Create a response asking why these functions are being called + context.Result = new FunctionResult(context.Result, $"Provide an explanation why you are calling function {string.Join(',', _functionNames)} and try again"); + return; + } + + // Invoke the functions + await next(context); + } + } + + /// + /// This uses the currently available to query the model + /// to find out what certain functions are being called. + /// + private sealed class QueryExplainFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + private readonly ITestOutputHelper _output = output; + + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Invoke the model to explain why the functions are being called + var message = context.ChatHistory[^2]; + var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last()); + var functionNames = functionCalls.Select(fc => $"{fc.PluginName}-{fc.FunctionName}").ToList(); + var service = context.Kernel.GetRequiredService(); + + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, $"Provide an explanation why these functions: {string.Join(',', functionNames)} need to be called to answer this query: {message.Content}") + }; + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.EnableKernelFunctions }; + var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, context.Kernel); + this._output.WriteLine(result); + + // Invoke the functions + await next(context); + } + } + + private sealed class WeatherPlugin { [KernelFunction] [Description("Get the current weather in a given location.")] diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs index f1d84d0eb22b..c846fc87c463 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs @@ -114,6 +114,7 @@ protected override async Task SendAsync(HttpRequestMessage { var content = await request.Content.ReadAsStringAsync(cancellationToken); string formattedContent = JsonSerializer.Serialize(JsonSerializer.Deserialize(content), s_jsonSerializerOptions); + this._output.WriteLine("=== REQUEST ==="); this._output.WriteLine(formattedContent); this._output.WriteLine(string.Empty); } @@ -125,6 +126,7 @@ protected override async Task SendAsync(HttpRequestMessage { // Log the response details var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); + this._output.WriteLine("=== RESPONSE ==="); this._output.WriteLine(responseContent); this._output.WriteLine(string.Empty); }