Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma warning disable CA2213 // Disposable fields should be disposed
#pragma warning disable S2219 // Runtime type checking should be simplified
#pragma warning disable S3353 // Unchanged local variables should be "const"
#pragma warning disable SA1204 // Static members should appear before non-static members

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -934,6 +935,15 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont
// There are functions to call but we have no tools, so we can't handle them.
// If we're configured to terminate on unknown call requests, do so now.
// Otherwise, ProcessFunctionCallsAsync will handle it by creating a NotFound response message.
if (TerminateOnUnknownCalls)
{
// Log each function call that would cause termination
foreach (var fcc in functionCalls)
{
LogFunctionNotFound(fcc.Name);
}
}

return TerminateOnUnknownCalls;
}

Expand All @@ -948,6 +958,7 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont
{
// The tool was found but it's not invocable. Regardless of TerminateOnUnknownCallRequests,
// we need to break out of the loop so that callers can handle all the call requests.
LogNonInvocableFunction(fcc.Name);
return true;
}
}
Expand All @@ -958,6 +969,7 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont
// creating a NotFound response message.
if (TerminateOnUnknownCalls)
{
LogFunctionNotFound(fcc.Name);
return true;
}
}
Expand Down Expand Up @@ -1063,6 +1075,8 @@ private void UpdateConsecutiveErrorCountOrThrow(IList<ChatMessage> added, ref in
consecutiveErrorCount++;
if (consecutiveErrorCount > MaximumConsecutiveErrorsPerRequest)
{
LogMaxConsecutiveErrorsExceeded(MaximumConsecutiveErrorsPerRequest);

var allExceptionsArray = added
.SelectMany(m => m.Contents.OfType<FunctionResultContent>())
.Select(frc => frc.Exception!)
Expand Down Expand Up @@ -1115,6 +1129,15 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
AIFunctionDeclaration? tool = FindTool(callContent.Name, options?.Tools, AdditionalTools);
if (tool is not AIFunction aiFunction)
{
if (tool is null)
{
LogFunctionNotFound(callContent.Name);
}
else
{
LogNonInvocableFunction(callContent.Name);
}

return new(terminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
}

Expand Down Expand Up @@ -1151,6 +1174,11 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
exception: e);
}

if (context.Terminate)
{
LogFunctionRequestedTermination(callContent.Name);
}

return new(
terminate: context.Terminate,
FunctionInvocationStatus.RanToCompletion,
Expand Down Expand Up @@ -1345,10 +1373,10 @@ private static bool CurrentActivityIsInvokeAgent
/// <summary>
/// 1. Remove all <see cref="FunctionApprovalRequestContent"/> and <see cref="FunctionApprovalResponseContent"/> from the <paramref name="originalMessages"/>.
/// 2. Recreate <see cref="FunctionCallContent"/> for any <see cref="FunctionApprovalResponseContent"/> that haven't been executed yet.
/// 3. Genreate failed <see cref="FunctionResultContent"/> for any rejected <see cref="FunctionApprovalResponseContent"/>.
/// 3. Generate failed <see cref="FunctionResultContent"/> for any rejected <see cref="FunctionApprovalResponseContent"/>.
/// 4. add all the new content items to <paramref name="originalMessages"/> and return them as the pre-invocation history.
/// </summary>
private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals) ProcessFunctionApprovalResponses(
private (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals) ProcessFunctionApprovalResponses(
List<ChatMessage> originalMessages, bool hasConversationId, string? toolMessageId, string? functionCallContentFallbackMessageId)
{
// Extract any approval responses where we need to execute or reject the function calls.
Expand Down Expand Up @@ -1399,7 +1427,7 @@ private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResult
/// We can then use the metadata from these messages when we re-create the FunctionCallContent messages/updates to return to the caller. This way, when we finally do return
/// the FuncionCallContent to users it's part of a message/update that contains the same metadata as originally returned to the downstream service.
/// </remarks>
private static (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections) ExtractAndRemoveApprovalRequestsAndResponses(
private (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections) ExtractAndRemoveApprovalRequestsAndResponses(
List<ChatMessage> messages)
{
Dictionary<string, ChatMessage>? allApprovalRequestsMessages = null;
Expand Down Expand Up @@ -1498,6 +1526,8 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval
continue;
}

LogProcessingApprovalResponse(approvalResponse.FunctionCall.Name, approvalResponse.Approved);

// Split the responses into approved and rejected.
ref List<ApprovalResultWithRequestMessage>? targetList = ref approvalResponse.Approved ? ref approvedFunctionCalls : ref rejectedFunctionCalls;

Expand All @@ -1516,10 +1546,12 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval
/// </summary>
/// <param name="rejections">Any rejected approval responses.</param>
/// <returns>The <see cref="AIContent"/> for the rejected function calls.</returns>
private static List<AIContent>? GenerateRejectedFunctionResults(List<ApprovalResultWithRequestMessage>? rejections) =>
private List<AIContent>? GenerateRejectedFunctionResults(List<ApprovalResultWithRequestMessage>? rejections) =>
rejections is { Count: > 0 } ?
rejections.ConvertAll(m =>
{
LogFunctionRejected(m.Response.FunctionCall.Name, m.Response.Reason);

string result = "Tool call invocation rejected.";
if (!string.IsNullOrWhiteSpace(m.Response.Reason))
{
Expand Down Expand Up @@ -1679,7 +1711,7 @@ private static bool TryReplaceFunctionCallsWithApprovalRequests(IList<AIContent>
/// Replaces all <see cref="FunctionCallContent"/> from <paramref name="messages"/> with <see cref="FunctionApprovalRequestContent"/>
/// if any one of them requires approval.
/// </summary>
private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
private IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
IList<ChatMessage> messages,
params ReadOnlySpan<IList<AITool>?> toolLists)
{
Expand Down Expand Up @@ -1721,6 +1753,7 @@ private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
message.Contents = [.. message.Contents];

var functionCall = (FunctionCallContent)message.Contents[contentIndex];
LogFunctionRequiresApproval(functionCall.Name);
message.Contents[contentIndex] = new FunctionApprovalRequestContent(functionCall.CallId, functionCall);
outputMessages[messageIndex] = message;

Expand Down Expand Up @@ -1785,6 +1818,27 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
[LoggerMessage(LogLevel.Debug, "Reached maximum iteration count of {MaximumIterationsPerRequest}. Stopping function invocation loop.")]
private partial void LogMaximumIterationsReached(int maximumIterationsPerRequest);

[LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' requires approval. Converting to approval request.")]
private partial void LogFunctionRequiresApproval(string functionName);

[LoggerMessage(LogLevel.Debug, "Processing approval response for '{FunctionName}'. Approved: {Approved}")]
private partial void LogProcessingApprovalResponse(string functionName, bool approved);

[LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' was rejected. Reason: {Reason}")]
private partial void LogFunctionRejected(string functionName, string? reason);

[LoggerMessage(LogLevel.Warning, "Maximum consecutive errors ({MaxErrors}) exceeded. Throwing aggregated exceptions.")]
private partial void LogMaxConsecutiveErrorsExceeded(int maxErrors);

[LoggerMessage(LogLevel.Warning, "Function '{FunctionName}' not found.")]
private partial void LogFunctionNotFound(string functionName);

[LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' is not invocable (declaration only). Terminating loop.")]
private partial void LogNonInvocableFunction(string functionName);

[LoggerMessage(LogLevel.Debug, "Function '{FunctionName}' requested termination of the processing loop.")]
private partial void LogFunctionRequestedTermination(string functionName);

/// <summary>Provides information about the invocation of a function call.</summary>
public sealed class FunctionInvocationResult
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2692,4 +2692,215 @@ public async Task RespectsChatOptionsToolsModificationsByFunctionTool_ReplaceWit

Assert.Equal(2, callCount);
}

[Fact]
public async Task LogsFunctionNotFound()
{
var collector = new FakeLogCollector();
ServiceCollection c = new();
c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "UnknownFunc")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Error: Requested function \"UnknownFunc\" not found.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILoggerFactory>()));

await InvokeAndAssertAsync(options, plan, null, configure, c.BuildServiceProvider());

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Function 'UnknownFunc' not found") && e.Level == LogLevel.Warning);
}

[Fact]
public async Task LogsNonInvocableFunction()
{
var collector = new FakeLogCollector();
ServiceCollection c = new();
c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var declarationOnly = AIFunctionFactory.Create(() => "Result 1", "Func1").AsDeclarationOnly();
var options = new ChatOptions
{
Tools = [declarationOnly]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Should not be produced")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

List<ChatMessage> expected = plan.Take(2).ToList();

Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILoggerFactory>()));

await InvokeAndAssertAsync(options, plan, expected, configure, c.BuildServiceProvider());

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Function 'Func1' is not invocable (declaration only)") && e.Level == LogLevel.Debug);
}

[Fact]
public async Task LogsFunctionRequestedTermination()
{
var collector = new FakeLogCollector();
ServiceCollection c = new();
c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() =>
{
var context = FunctionInvokingChatClient.CurrentContext!;
context.Terminate = true;
return "Terminated";
}, "TerminatingFunc")]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "TerminatingFunc")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Terminated")]),
];

Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILoggerFactory>()));

await InvokeAndAssertAsync(options, plan, null, configure, c.BuildServiceProvider());

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Function 'TerminatingFunc' requested termination of the processing loop") && e.Level == LogLevel.Debug);
}

[Fact]
public async Task LogsFunctionRequiresApproval()
{
var collector = new FakeLogCollector();
ServiceCollection c = new();
c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var approvalFunc = new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1"));
var options = new ChatOptions
{
Tools = [approvalFunc]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
];

// Expected output includes the user message and the approval request
List<ChatMessage> expected =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1"))
])
];

Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILoggerFactory>()));

await InvokeAndAssertAsync(options, plan, expected, configure, c.BuildServiceProvider());

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Function 'Func1' requires approval") && e.Level == LogLevel.Debug);
}

[Fact]
public async Task LogsProcessingApprovalResponse()
{
var collector = new FakeLogCollector();
using var loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var approvalFunc = new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1"));

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (messages, opts, ct) =>
Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "world")))
};

using var client = new FunctionInvokingChatClient(innerClient, loggerFactory);

var options = new ChatOptions { Tools = [approvalFunc] };

var messages = new List<ChatMessage>
{
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1"))
]),
new ChatMessage(ChatRole.User,
[
new FunctionApprovalResponseContent("callId1", true, new FunctionCallContent("callId1", "Func1"))
])
};

await client.GetResponseAsync(messages, options);

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Processing approval response for 'Func1'. Approved: True") && e.Level == LogLevel.Debug);
}

[Fact]
public async Task LogsFunctionRejected()
{
var collector = new FakeLogCollector();
using var loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(LogLevel.Debug));

var approvalFunc = new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1"));

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (messages, opts, ct) =>
Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "world")))
};

using var client = new FunctionInvokingChatClient(innerClient, loggerFactory);

var options = new ChatOptions { Tools = [approvalFunc] };

var messages = new List<ChatMessage>
{
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionApprovalRequestContent("callId1", new FunctionCallContent("callId1", "Func1"))
]),
new ChatMessage(ChatRole.User,
[
new FunctionApprovalResponseContent("callId1", false, new FunctionCallContent("callId1", "Func1")) { Reason = "User denied" }
])
};

await client.GetResponseAsync(messages, options);

var logs = collector.GetSnapshot();
Assert.Contains(logs, e => e.Message.Contains("Function 'Func1' was rejected. Reason: User denied") && e.Level == LogLevel.Debug);
}

// Note: LogMaxConsecutiveErrorsExceeded is exercised by the existing
// ContinuesWithFailingCallsUntilMaximumConsecutiveErrors test which triggers
// the threshold condition. The logging call is at line 1078 and will execute
// when MaximumConsecutiveErrorsPerRequest is exceeded.
}
Loading