Skip to content

Commit a9b3385

Browse files
authored
Fix bug to yield remaining buffered FCC (#6903)
1 parent 01354ad commit a9b3385

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
529529
.ToArray();
530530
}
531531

532-
if (approvalRequiredFunctions is not { Length: > 0 })
532+
if (approvalRequiredFunctions is not { Length: > 0 } || functionCallContents is not { Count: > 0 })
533533
{
534534
// If there are no function calls to make yet, or if none of the functions require approval at all,
535535
// we can yield the update as-is.
@@ -574,6 +574,14 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
574574
// or when we reach the end of the updates stream.
575575
}
576576

577+
// We need to yield any remaining updates that were not yielded while looping through the streamed updates.
578+
for (; lastYieldedUpdateIndex < updates.Count; lastYieldedUpdateIndex++)
579+
{
580+
var updateToYield = updates[lastYieldedUpdateIndex];
581+
yield return updateToYield;
582+
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
583+
}
584+
577585
// If there's nothing more to do, break out of the loop and allow the handling at the
578586
// end to configure the response with aggregated data from previous requests.
579587
if (iteration >= MaximumIterationsPerRequest ||

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientApprovalsTests.cs

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,69 @@ public async Task AlreadyExecutedApprovalsAreIgnoredAsync()
485485
await InvokeAndAssertStreamingAsync(options, input, downstreamClientOutput, output, expectedDownstreamClientInput);
486486
}
487487

488+
/// <summary>
489+
/// This verifies the following scenario:
490+
/// 1. We are streaming (also including non-streaming in the test for completeness).
491+
/// 2. There is one function that requires approval and one that does not.
492+
/// 3. We only get back FCC for the function that does not require approval.
493+
/// 4. This means that once we receive this FCC, we need to buffer all updates until the end, because we might receive more FCCs and some may require approval.
494+
/// 5. We then need to verify that we will still stream all updates once we reach the end, including the buffered FCC.
495+
/// </summary>
496+
[Fact]
497+
public async Task MixedApprovalRequiredToolsWithNonApprovalRequiringFunctionCallAsync()
498+
{
499+
var options = new ChatOptions
500+
{
501+
Tools =
502+
[
503+
new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => "Result 1", "Func1")),
504+
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
505+
]
506+
};
507+
508+
List<ChatMessage> input =
509+
[
510+
new ChatMessage(ChatRole.User, "hello"),
511+
];
512+
513+
Func<Queue<List<ChatMessage>>> expectedDownstreamClientInput = () => new Queue<List<ChatMessage>>(
514+
[
515+
new List<ChatMessage>
516+
{
517+
new ChatMessage(ChatRole.User, "hello"),
518+
},
519+
new List<ChatMessage>
520+
{
521+
new ChatMessage(ChatRole.User, "hello"),
522+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
523+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")])
524+
}
525+
]);
526+
527+
Func<Queue<List<ChatMessage>>> downstreamClientOutput = () => new Queue<List<ChatMessage>>(
528+
[
529+
new List<ChatMessage>
530+
{
531+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
532+
},
533+
new List<ChatMessage>
534+
{
535+
new ChatMessage(ChatRole.Assistant, "World again"),
536+
}
537+
]);
538+
539+
List<ChatMessage> output =
540+
[
541+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
542+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
543+
new ChatMessage(ChatRole.Assistant, "World again"),
544+
];
545+
546+
await InvokeAndAssertMultiRoundAsync(options, input, downstreamClientOutput(), output, expectedDownstreamClientInput());
547+
548+
await InvokeAndAssertStreamingMultiRoundAsync(options, input, downstreamClientOutput(), output, expectedDownstreamClientInput());
549+
}
550+
488551
[Fact]
489552
public async Task ApprovalRequestWithoutApprovalResponseThrowsAsync()
490553
{
@@ -781,14 +844,31 @@ async IAsyncEnumerable<ChatResponseUpdate> YieldInnerClientUpdates(
781844
}
782845
}
783846

784-
private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
847+
private static Task<List<ChatMessage>> InvokeAndAssertAsync(
785848
ChatOptions? options,
786849
List<ChatMessage> input,
787850
List<ChatMessage> downstreamClientOutput,
788851
List<ChatMessage> expectedOutput,
789852
List<ChatMessage>? expectedDownstreamClientInput = null,
790853
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
791854
AITool[]? additionalTools = null)
855+
=> InvokeAndAssertMultiRoundAsync(
856+
options,
857+
input,
858+
new Queue<List<ChatMessage>>(new[] { downstreamClientOutput }),
859+
expectedOutput,
860+
expectedDownstreamClientInput is null ? null : new Queue<List<ChatMessage>>(new[] { expectedDownstreamClientInput }),
861+
configurePipeline,
862+
additionalTools);
863+
864+
private static async Task<List<ChatMessage>> InvokeAndAssertMultiRoundAsync(
865+
ChatOptions? options,
866+
List<ChatMessage> input,
867+
Queue<List<ChatMessage>> downstreamClientOutput,
868+
List<ChatMessage> expectedOutput,
869+
Queue<List<ChatMessage>>? expectedDownstreamClientInput = null,
870+
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
871+
AITool[]? additionalTools = null)
792872
{
793873
Assert.NotEmpty(input);
794874

@@ -804,16 +884,17 @@ private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
804884
Assert.Equal(cts.Token, actualCancellationToken);
805885
if (expectedDownstreamClientInput is not null)
806886
{
807-
AssertExtensions.EqualMessageLists(expectedDownstreamClientInput, contents.ToList());
887+
AssertExtensions.EqualMessageLists(expectedDownstreamClientInput.Dequeue(), contents.ToList());
808888
}
809889

810890
await Task.Yield();
811891

812892
var usage = CreateRandomUsage();
813893
expectedTotalTokenCounts += usage.InputTokenCount!.Value;
814894

815-
downstreamClientOutput.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N"));
816-
return new ChatResponse(downstreamClientOutput) { Usage = usage };
895+
var output = downstreamClientOutput.Dequeue();
896+
output.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N"));
897+
return new ChatResponse(output) { Usage = usage };
817898
}
818899
};
819900

@@ -851,14 +932,31 @@ private static UsageDetails CreateRandomUsage()
851932
};
852933
}
853934

854-
private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
935+
private static Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
855936
ChatOptions? options,
856937
List<ChatMessage> input,
857938
List<ChatMessage> downstreamClientOutput,
858939
List<ChatMessage> expectedOutput,
859940
List<ChatMessage>? expectedDownstreamClientInput = null,
860941
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
861942
AITool[]? additionalTools = null)
943+
=> InvokeAndAssertStreamingMultiRoundAsync(
944+
options,
945+
input,
946+
new Queue<List<ChatMessage>>(new[] { downstreamClientOutput }),
947+
expectedOutput,
948+
expectedDownstreamClientInput is null ? null : new Queue<List<ChatMessage>>(new[] { expectedDownstreamClientInput }),
949+
configurePipeline,
950+
additionalTools);
951+
952+
private static async Task<List<ChatMessage>> InvokeAndAssertStreamingMultiRoundAsync(
953+
ChatOptions? options,
954+
List<ChatMessage> input,
955+
Queue<List<ChatMessage>> downstreamClientOutput,
956+
List<ChatMessage> expectedOutput,
957+
Queue<List<ChatMessage>>? expectedDownstreamClientInput = null,
958+
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
959+
AITool[]? additionalTools = null)
862960
{
863961
Assert.NotEmpty(input);
864962

@@ -873,11 +971,12 @@ private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
873971
Assert.Equal(cts.Token, actualCancellationToken);
874972
if (expectedDownstreamClientInput is not null)
875973
{
876-
AssertExtensions.EqualMessageLists(expectedDownstreamClientInput, contents.ToList());
974+
AssertExtensions.EqualMessageLists(expectedDownstreamClientInput.Dequeue(), contents.ToList());
877975
}
878976

879-
downstreamClientOutput.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N"));
880-
return YieldAsync(new ChatResponse(downstreamClientOutput).ToChatResponseUpdates());
977+
var output = downstreamClientOutput.Dequeue();
978+
output.ForEach(m => m.MessageId = Guid.NewGuid().ToString("N"));
979+
return YieldAsync(new ChatResponse(output).ToChatResponseUpdates());
881980
}
882981
};
883982

0 commit comments

Comments
 (0)