Skip to content

Commit

Permalink
Try a few different techniques and compare outcomes
Browse files Browse the repository at this point in the history
  • Loading branch information
markwallace-microsoft committed Jun 24, 2024
1 parent 5a93550 commit 455626d
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
namespace ChatCompletion;

/// <summary>
///
/// Samples showing how to get the LLM to provide the reason is using function calling.
/// </summary>
public sealed class OpenAI_ReasonedFunctionCalling(ITestOutputHelper output) : BaseTest(output)
{
/// <summary>
/// Using the system prompt to explain function calls doesn't work work with gpt-4o.
/// </summary>
[Fact]
public async Task AskingAssistantToExplainFunctionCallsAsync()
public async Task UseSystemPromptToExplainFunctionCallsAsync()
{
// Create a kernel with MistralAI chat completion and WeatherPlugin
Kernel kernel = CreateKernelWithPlugin<WeatherPlugin>();
Expand All @@ -26,6 +29,27 @@ public async Task AskingAssistantToExplainFunctionCallsAsync()
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);
chatHistory.Add(result);
Console.WriteLine(result);
}

/// <summary>
/// Asking the model to explain function calls after execution works well but may be too late depending on your use case.
/// </summary>
[Fact]
public async Task AskAssistantToExplainFunctionCallsAfterExecutionAsync()
{
// Create a kernel with MistralAI chat completion and WeatherPlugin
Kernel kernel = CreateKernelWithPlugin<WeatherPlugin>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Invoke chat prompt with auto invocation of functions enabled
var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result1 = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);
chatHistory.Add(result1);
Console.WriteLine(result1);
Expand All @@ -35,7 +59,118 @@ public async Task AskingAssistantToExplainFunctionCallsAsync()
Console.WriteLine(result2);
}

public sealed class WeatherPlugin
/// <summary>
/// Asking the model to explain function calls in response to each function call can work but the model may also
/// get confused and treat the request to explain the function calls as an error response from the function calls.
/// </summary>
[Fact]
public async Task AskAssistantToExplainFunctionCallsBeforeExecutionAsync()
{
// Create a kernel with MistralAI chat completion and WeatherPlugin
Kernel kernel = CreateKernelWithPlugin<WeatherPlugin>();
kernel.AutoFunctionInvocationFilters.Add(new RespondExplainFunctionInvocationFilter());
var service = kernel.GetRequiredService<IChatCompletionService>();

// Invoke chat prompt with auto invocation of functions enabled
var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);
chatHistory.Add(result);
Console.WriteLine(result);
}

/// <summary>
/// Asking to the model to explain function calls using a separate conversation i.e. chat history seems to provide the
/// best results. This may be because the model can focus on explaining the function calls without being confused by other
/// messages in the chat history.
/// </summary>
[Fact]
public async Task QueryAssistantToExplainFunctionCallsBeforeExecutionAsync()
{
// Create a kernel with MistralAI chat completion and WeatherPlugin
Kernel kernel = CreateKernelWithPlugin<WeatherPlugin>();
kernel.AutoFunctionInvocationFilters.Add(new QueryExplainFunctionInvocationFilter(this.Output));
var service = kernel.GetRequiredService<IChatCompletionService>();

// Invoke chat prompt with auto invocation of functions enabled
var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, kernel);
chatHistory.Add(result);
Console.WriteLine(result);
}

/// <summary>
/// This <see cref="IAutoFunctionInvocationFilter"/> will respond to function call requests and ask the model to explain why it is
/// calling the function(s). It is only suitable for transient use because it stores information about the functions that have been
/// called for a single chat history.
/// </summary>
private sealed class RespondExplainFunctionInvocationFilter : IAutoFunctionInvocationFilter
{
private readonly HashSet<string> _functionNames = [];

public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next)
{
// Get the function calls for which we need an explanation
var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last());
var needExplanation = 0;
foreach (var functionCall in functionCalls)
{
var functionName = $"{functionCall.PluginName}-{functionCall.FunctionName}";
if (_functionNames.Add(functionName))
{
needExplanation++;
}
}

if (needExplanation > 0)
{
// Create a response asking why these functions are being called
context.Result = new FunctionResult(context.Result, $"Provide an explanation why you are calling function {string.Join(',', _functionNames)} and try again");
return;
}

// Invoke the functions
await next(context);
}
}

/// <summary>
/// This <see cref="IAutoFunctionInvocationFilter"/> uses the currently available <see cref="IChatCompletionService"/> to query the model
/// to find out what certain functions are being called.
/// </summary>
private sealed class QueryExplainFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter
{
private readonly ITestOutputHelper _output = output;

public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next)
{
// Invoke the model to explain why the functions are being called
var message = context.ChatHistory[^2];
var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last());
var functionNames = functionCalls.Select(fc => $"{fc.PluginName}-{fc.FunctionName}").ToList();
var service = context.Kernel.GetRequiredService<IChatCompletionService>();

var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.User, $"Provide an explanation why these functions: {string.Join(',', functionNames)} need to be called to answer this query: {message.Content}")
};
var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.EnableKernelFunctions };
var result = await service.GetChatMessageContentAsync(chatHistory, executionSettings, context.Kernel);
this._output.WriteLine(result);

// Invoke the functions
await next(context);
}
}

private sealed class WeatherPlugin
{
[KernelFunction]
[Description("Get the current weather in a given location.")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
{
var content = await request.Content.ReadAsStringAsync(cancellationToken);
string formattedContent = JsonSerializer.Serialize(JsonSerializer.Deserialize<JsonElement>(content), s_jsonSerializerOptions);
this._output.WriteLine("=== REQUEST ===");
this._output.WriteLine(formattedContent);
this._output.WriteLine(string.Empty);
}
Expand All @@ -125,6 +126,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
{
// Log the response details
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken);
this._output.WriteLine("=== RESPONSE ===");
this._output.WriteLine(responseContent);
this._output.WriteLine(string.Empty);
}
Expand Down

0 comments on commit 455626d

Please sign in to comment.