Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,14 @@ public override async Task<ChatResponse> GetResponseAsync(
}

// Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
bool requiresFunctionInvocation =
iteration < MaximumIterationsPerRequest &&
CopyFunctionCalls(response.Messages, ref functionCallContents);
// We also need to filter out any FCCs that already have a corresponding FRC in the response,
// as those have already been handled (e.g., by the inner client or an upstream FunctionInvokingChatClient).
bool requiresFunctionInvocation = false;
if (iteration < MaximumIterationsPerRequest && CopyFunctionCalls(response.Messages, ref functionCallContents))
{
RemoveAlreadyHandledFunctionCalls(response.Messages, functionCallContents);
requiresFunctionInvocation = functionCallContents!.Count > 0;
}

if (!requiresFunctionInvocation && iteration == 0)
{
Expand Down Expand Up @@ -550,41 +555,48 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
// Check if any of the function call contents in this update requires approval.
(hasApprovalRequiringFcc, lastApprovalCheckedFCCIndex) = CheckForApprovalRequiringFCC(
functionCallContents, approvalRequiredFunctions!, hasApprovalRequiringFcc, lastApprovalCheckedFCCIndex);
if (hasApprovalRequiringFcc)

// Even if we've found an approval-requiring FCC, we continue buffering updates
// so we can properly filter out FCCs that have matching FRCs before yielding.
// We will yield the updates as soon as we receive a function call content that requires approval
// or when we reach the end of the updates stream.
}

// Collect all FunctionResultContent CallIds to identify already-handled function calls.
HashSet<string>? handledCallIds = null;
for (int i = 0; i < updates.Count; i++)
{
IList<AIContent> contents = updates[i].Contents;
int contentCount = contents.Count;
for (int j = 0; j < contentCount; j++)
{
// If we've encountered a function call content that requires approval,
// we need to ask for approval for all functions, since we cannot mix and match.
// Convert all function call contents into approval requests from the last yielded update index
// and yield all those updates.
for (; lastYieldedUpdateIndex < updates.Count; lastYieldedUpdateIndex++)
if (contents[j] is FunctionResultContent frc)
{
var updateToYield = updates[lastYieldedUpdateIndex];
if (TryReplaceFunctionCallsWithApprovalRequests(updateToYield.Contents, out var updatedContents))
{
updateToYield.Contents = updatedContents;
}

yield return updateToYield;
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
_ = (handledCallIds ??= new(StringComparer.Ordinal)).Add(frc.CallId);
}

continue;
}

// We don't have any approval requiring function calls yet, but we may receive some in future
// so we cannot yield the updates yet. We'll just keep them in the updates list for later.
// We will yield the updates as soon as we receive a function call content that requires approval
// or when we reach the end of the updates stream.
}

// We need to yield any remaining updates that were not yielded while looping through the streamed updates.
// If we have approval-requiring FCCs, we need to convert them to approval requests,
// but only for FCCs that don't have matching FRCs.
for (; lastYieldedUpdateIndex < updates.Count; lastYieldedUpdateIndex++)
{
var updateToYield = updates[lastYieldedUpdateIndex];

if (hasApprovalRequiringFcc && TryReplaceFunctionCallsWithApprovalRequests(updateToYield.Contents, handledCallIds, out var updatedContents))
{
updateToYield.Contents = updatedContents;
}

yield return updateToYield;
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
}

// Filter out any FCCs that already have a corresponding FRC in the updates,
// as those have already been handled (e.g., by the inner client or an upstream FunctionInvokingChatClient).
RemoveAlreadyHandledFunctionCalls(updates, functionCallContents);

// If there's nothing more to do, break out of the loop and allow the handling at the
// end to configure the response with aggregated data from previous requests.
if (iteration >= MaximumIterationsPerRequest ||
Expand Down Expand Up @@ -785,6 +797,102 @@ private static bool CopyFunctionCalls(
return any;
}

/// <summary>
/// Removes any <see cref="FunctionCallContent"/> from <paramref name="functionCalls"/> that have a corresponding
/// <see cref="FunctionResultContent"/> with the same CallId in <paramref name="messages"/>.
/// </summary>
/// <remarks>
/// This handles scenarios where the inner <see cref="IChatClient"/> handles function invocation itself and returns
/// both the <see cref="FunctionCallContent"/> (indicating what was called) and the <see cref="FunctionResultContent"/>
/// (indicating the result). In such cases, the <see cref="FunctionInvokingChatClient"/> should not attempt to
/// re-invoke those functions.
/// </remarks>
private static void RemoveAlreadyHandledFunctionCalls(
IList<ChatMessage> messages, List<FunctionCallContent>? functionCalls)
{
if (functionCalls is not { Count: > 0 })
{
return;
}

// Collect all FunctionResultContent CallIds from the messages into a HashSet for O(1) lookup.
HashSet<string>? handledCallIds = null;
int messageCount = messages.Count;
for (int i = 0; i < messageCount; i++)
{
IList<AIContent> contents = messages[i].Contents;
int contentCount = contents.Count;
for (int j = 0; j < contentCount; j++)
{
if (contents[j] is FunctionResultContent frc)
{
_ = (handledCallIds ??= new(StringComparer.Ordinal)).Add(frc.CallId);
}
}
}

// If there are no FRCs, nothing to filter.
if (handledCallIds is null)
{
return;
}

// Remove any FCCs that have a matching FRC CallId.
// We iterate backwards to avoid index shifting issues when removing.
for (int i = functionCalls.Count - 1; i >= 0; i--)
{
if (handledCallIds.Contains(functionCalls[i].CallId))
{
functionCalls.RemoveAt(i);
}
}
}

/// <summary>
/// Removes any <see cref="FunctionCallContent"/> from <paramref name="functionCalls"/> that have a corresponding
/// <see cref="FunctionResultContent"/> with the same CallId in <paramref name="updates"/>.
/// </summary>
private static void RemoveAlreadyHandledFunctionCalls(
List<ChatResponseUpdate> updates, List<FunctionCallContent>? functionCalls)
{
if (functionCalls is not { Count: > 0 })
{
return;
}

// Collect all FunctionResultContent CallIds from the updates into a HashSet for O(1) lookup.
HashSet<string>? handledCallIds = null;
int updateCount = updates.Count;
for (int i = 0; i < updateCount; i++)
{
IList<AIContent> contents = updates[i].Contents;
int contentCount = contents.Count;
for (int j = 0; j < contentCount; j++)
{
if (contents[j] is FunctionResultContent frc)
{
_ = (handledCallIds ??= new(StringComparer.Ordinal)).Add(frc.CallId);
}
}
}

// If there are no FRCs, nothing to filter.
if (handledCallIds is null)
{
return;
}

// Remove any FCCs that have a matching FRC CallId.
// We iterate backwards to avoid index shifting issues when removing.
for (int i = functionCalls.Count - 1; i >= 0; i--)
{
if (handledCallIds.Contains(functionCalls[i].CallId))
{
functionCalls.RemoveAt(i);
}
}
}

private static void UpdateOptionsForNextIteration(ref ChatOptions? options, string? conversationId)
{
if (options is null)
Expand Down Expand Up @@ -1531,10 +1639,14 @@ private static (bool hasApprovalRequiringFcc, int lastApprovalCheckedFCCIndex) C
}

/// <summary>
/// Replaces all <see cref="FunctionCallContent"/> with <see cref="FunctionApprovalRequestContent"/> and ouputs a new list if any of them were replaced.
/// Replaces <see cref="FunctionCallContent"/> with <see cref="FunctionApprovalRequestContent"/> and outputs a new list if any of them were replaced.
/// Excludes any FCCs that have CallIds in <paramref name="excludedCallIds"/>.
/// </summary>
/// <returns>true if any <see cref="FunctionCallContent"/> was replaced, false otherwise.</returns>
private static bool TryReplaceFunctionCallsWithApprovalRequests(IList<AIContent> content, out List<AIContent>? updatedContent)
private static bool TryReplaceFunctionCallsWithApprovalRequests(
IList<AIContent> content,
HashSet<string>? excludedCallIds,
out List<AIContent>? updatedContent)
{
updatedContent = null;

Expand All @@ -1544,6 +1656,12 @@ private static bool TryReplaceFunctionCallsWithApprovalRequests(IList<AIContent>
{
if (content[i] is FunctionCallContent fcc)
{
// Skip FCCs that are in the excluded set (already have matching FRCs)
if (excludedCallIds?.Contains(fcc.CallId) is true)
{
continue;
}

updatedContent ??= [.. content]; // Clone the list if we haven't already
updatedContent[i] = new FunctionApprovalRequestContent(fcc.CallId, fcc);
}
Expand All @@ -1555,18 +1673,35 @@ private static bool TryReplaceFunctionCallsWithApprovalRequests(IList<AIContent>

/// <summary>
/// Replaces all <see cref="FunctionCallContent"/> from <paramref name="messages"/> with <see cref="FunctionApprovalRequestContent"/>
/// if any one of them requires approval.
/// if any one of them requires approval. Function calls that already have a corresponding <see cref="FunctionResultContent"/>
/// are not replaced, as they have already been handled.
/// </summary>
private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
IList<ChatMessage> messages,
Dictionary<string, AITool> toolMap)
{
var outputMessages = messages;

// First, collect all FunctionResultContent CallIds to identify already-handled function calls.
HashSet<string>? handledCallIds = null;
for (int i = 0; i < messages.Count; i++)
{
IList<AIContent> contents = messages[i].Contents;
int contentCount = contents.Count;
for (int j = 0; j < contentCount; j++)
{
if (contents[j] is FunctionResultContent frc)
{
_ = (handledCallIds ??= new(StringComparer.Ordinal)).Add(frc.CallId);
}
}
}

bool anyApprovalRequired = false;
List<(int, int)>? allFunctionCallContentIndices = null;

// Build a list of the indices of all FunctionCallContent items.
// Build a list of the indices of all FunctionCallContent items that need to be handled.
// Exclude any that have matching FunctionResultContent (already handled).
// Also check if any of them require approval.
for (int i = 0; i < messages.Count; i++)
{
Expand All @@ -1575,6 +1710,12 @@ private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
{
if (content[j] is FunctionCallContent functionCall)
{
// Skip FCCs that already have a matching FRC
if (handledCallIds?.Contains(functionCall.CallId) is true)
{
continue;
}

(allFunctionCallContentIndices ??= []).Add((i, j));

if (!anyApprovalRequired)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,23 +819,24 @@ async IAsyncEnumerable<ChatResponseUpdate> YieldInnerClientUpdates(
Assert.Equal("callId1", approvalRequest1.FunctionCall.CallId);
Assert.Equal("Func1", approvalRequest1.FunctionCall.Name);

// Third content should have been buffered, since we have not yet encountered a function call that requires approval.
Assert.Equal(4, updateYieldCount);
// Third content is now yielded after the full stream is collected
// (to properly filter FCCs that have matching FRCs).
Assert.Equal(5, updateYieldCount);
break;
case 3:
var approvalRequest2 = update.Contents.OfType<FunctionApprovalRequestContent>().First();
Assert.Equal("callId2", approvalRequest2.FunctionCall.CallId);
Assert.Equal("Func2", approvalRequest2.FunctionCall.Name);

// Fourth content can be yielded immediately, since it is the first function call that requires approval.
Assert.Equal(4, updateYieldCount);
// Fourth content is yielded after the full stream is collected.
Assert.Equal(5, updateYieldCount);
break;
case 4:
var approvalRequest3 = update.Contents.OfType<FunctionApprovalRequestContent>().First();
Assert.Equal("callId1", approvalRequest3.FunctionCall.CallId);
Assert.Equal("Func3", approvalRequest3.FunctionCall.Name);

// Fifth content can be yielded immediately, since we previously encountered a function call that requires approval.
// Fifth content is yielded after the full stream is collected.
Assert.Equal(5, updateYieldCount);
break;
}
Expand All @@ -844,6 +845,94 @@ async IAsyncEnumerable<ChatResponseUpdate> YieldInnerClientUpdates(
}
}

[Fact]
public async Task IgnoresApprovalRequiredFunctionCallsWithMatchingFunctionResults_NonStreaming()
{
// When an approval-required function has already been handled (FRC exists),
// it should not be converted to an approval request.
int funcInvocationCount = 0;

var options = new ChatOptions
{
Tools =
[
new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => { funcInvocationCount++; return "should not be invoked"; }, "Func1")),
]
};

using var innerClient = new TestChatClient
{
GetResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
{
// Inner client handles function calling itself and returns both FCC and FRC
var response = new ChatResponse(new ChatMessage(ChatRole.Assistant,
[
new FunctionCallContent("callId1", "Func1"),
new FunctionResultContent("callId1", result: "Already handled"),
new TextContent("Response based on result")
]));

return Task.FromResult(response);
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var result = await client.GetResponseAsync([new ChatMessage(ChatRole.User, "hello")], options);

// The function should NOT have been invoked
Assert.Equal(0, funcInvocationCount);

// The response should NOT contain approval requests since the function was already handled
Assert.DoesNotContain(result.Messages.SelectMany(m => m.Contents), c => c is FunctionApprovalRequestContent);

// Should contain the original FCC and FRC
Assert.Contains(result.Messages.SelectMany(m => m.Contents), c => c is FunctionCallContent);
Assert.Contains(result.Messages.SelectMany(m => m.Contents), c => c is FunctionResultContent);
}

[Fact]
public async Task IgnoresApprovalRequiredFunctionCallsWithMatchingFunctionResults_Streaming()
{
// When an approval-required function has already been handled (FRC exists),
// it should not be converted to an approval request.
int funcInvocationCount = 0;

var options = new ChatOptions
{
Tools =
[
new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => { funcInvocationCount++; return "should not be invoked"; }, "Func1")),
]
};

using var innerClient = new TestChatClient
{
GetStreamingResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
{
// Inner client handles function calling itself and returns both FCC and FRC in streaming updates
ChatResponseUpdate[] updates =
[
new() { Contents = [new FunctionCallContent("callId1", "Func1")], MessageId = "msg1", Role = ChatRole.Assistant },
new() { Contents = [new FunctionResultContent("callId1", result: "Already handled")], MessageId = "msg1", Role = ChatRole.Assistant },
new() { Contents = [new TextContent("Response based on result")], MessageId = "msg1", Role = ChatRole.Assistant }
];

return YieldAsync(updates);
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var result = await client.GetStreamingResponseAsync([new ChatMessage(ChatRole.User, "hello")], options).ToChatResponseAsync();

// The function should NOT have been invoked
Assert.Equal(0, funcInvocationCount);

// The response should NOT contain approval requests since the function was already handled
Assert.DoesNotContain(result.Messages.SelectMany(m => m.Contents), c => c is FunctionApprovalRequestContent);
}

private static Task<List<ChatMessage>> InvokeAndAssertAsync(
ChatOptions? options,
List<ChatMessage> input,
Expand Down
Loading
Loading