diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md index bd2643ec060..34c83f701c1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md @@ -3,6 +3,7 @@ ## NOT YET RELEASED - Added non-invocable `AIFunctionDeclaration` (base class for `AIFunction`), `AIFunctionFactory.CreateDeclaration`, and `AIFunction.AsDeclarationOnly`. +- Added `[Experimental]` support for user approval of function invocations via `ApprovalRequiredAIFunction`, `FunctionApprovalRequestContent`, and friends. ## 9.8.0 diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs index f9c7603d02a..e3ee10ad50a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs @@ -18,6 +18,13 @@ namespace Microsoft.Extensions.AI; [JsonDerivedType(typeof(TextReasoningContent), typeDiscriminator: "reasoning")] [JsonDerivedType(typeof(UriContent), typeDiscriminator: "uri")] [JsonDerivedType(typeof(UsageContent), typeDiscriminator: "usage")] + +// These should be added in once they're no longer [Experimental]. If they're included while still +// experimental, any JsonSerializerContext that includes AIContent will incur errors about using +// experimental types in its source generated files. +// [JsonDerivedType(typeof(FunctionApprovalRequestContent), typeDiscriminator: "functionApprovalRequest")] +// [JsonDerivedType(typeof(FunctionApprovalResponseContent), typeDiscriminator: "functionApprovalResponse")] + public class AIContent { /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalRequestContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalRequestContent.cs new file mode 100644 index 00000000000..d3ec7ab8f0b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalRequestContent.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a request for user approval of a function call. +/// +[Experimental("MEAI001")] +public sealed class FunctionApprovalRequestContent : UserInputRequestContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The ID that uniquely identifies the function approval request/response pair. + /// The function call that requires user approval. + /// is . + /// is empty or composed entirely of whitespace. + /// is . + public FunctionApprovalRequestContent(string id, FunctionCallContent functionCall) + : base(id) + { + FunctionCall = Throw.IfNull(functionCall); + } + + /// + /// Gets the function call that pre-invoke approval is required for. + /// + public FunctionCallContent FunctionCall { get; } + + /// + /// Creates a to indicate whether the function call is approved or rejected based on the value of . + /// + /// if the function call is approved; otherwise, . + /// The representing the approval response. + public FunctionApprovalResponseContent CreateResponse(bool approved) => new(Id, approved, FunctionCall); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalResponseContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalResponseContent.cs new file mode 100644 index 00000000000..948dc6a1347 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionApprovalResponseContent.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a response to a function approval request. +/// +[Experimental("MEAI001")] +public sealed class FunctionApprovalResponseContent : UserInputResponseContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The ID that uniquely identifies the function approval request/response pair. + /// if the function call is approved; otherwise, . + /// The function call that requires user approval. + /// is . + /// is empty or composed entirely of whitespace. + /// is . + public FunctionApprovalResponseContent(string id, bool approved, FunctionCallContent functionCall) + : base(id) + { + Approved = approved; + FunctionCall = Throw.IfNull(functionCall); + } + + /// + /// Gets a value indicating whether the user approved the request. + /// + public bool Approved { get; } + + /// + /// Gets the function call for which approval was requested. + /// + public FunctionCallContent FunctionCall { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputRequestContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputRequestContent.cs new file mode 100644 index 00000000000..c30cc3351df --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputRequestContent.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a request for user input. +/// +[Experimental("MEAI001")] +public class UserInputRequestContent : AIContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The ID that uniquely identifies the user input request/response pair. + /// is . + /// is empty or composed entirely of whitespace. + protected UserInputRequestContent(string id) + { + Id = Throw.IfNullOrWhitespace(id); + } + + /// + /// Gets the ID that uniquely identifies the user input request/response pair. + /// + public string Id { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputResponseContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputResponseContent.cs new file mode 100644 index 00000000000..2d436f99bb8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UserInputResponseContent.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents the response to a request for user input. +/// +[Experimental("MEAI001")] +public class UserInputResponseContent : AIContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The ID that uniquely identifies the user input request/response pair. + /// is . + /// is empty or composed entirely of whitespace. + protected UserInputResponseContent(string id) + { + Id = Throw.IfNullOrWhitespace(id); + } + + /// + /// Gets the ID that uniquely identifies the user input request/response pair. + /// + public string Id { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/ApprovalRequiredAIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/ApprovalRequiredAIFunction.cs new file mode 100644 index 00000000000..994e4660ac1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/ApprovalRequiredAIFunction.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents an that can be described to an AI service and invoked, but for which +/// the invoker should obtain user approval before the function is actually invoked. +/// +/// +/// This class simply augments an with an indication that approval is required before invocation. +/// It does not enforce the requirement for user approval; it is the responsibility of the invoker to obtain that approval before invoking the function. +/// +[Experimental("MEAI001")] +public sealed class ApprovalRequiredAIFunction : DelegatingAIFunction +{ + /// + /// Initializes a new instance of the class. + /// + /// The represented by this instance. + /// is . + public ApprovalRequiredAIFunction(AIFunction innerFunction) + : base(innerFunction) + { + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md index 6d0c1a3818d..a7c46b1d56d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md +++ b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md @@ -3,6 +3,7 @@ ## NOT YET RELEASED - Added `FunctionInvokingChatClient` support for non-invocable tools and `TerminateOnUnknownCalls` property. +- Added support to `FunctionInvokingChatClient` for user approval of function invocations. - Updated the Open Telemetry instrumentation to conform to the latest 1.37.0 draft specification of the Semantic Conventions for Generative AI systems. - Fixed `GetResponseAsync` to only look at the contents of the last message in the response. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index d503ef84630..e907078d535 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -14,10 +14,15 @@ using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; +#pragma warning disable CA1508 // Avoid dead conditional code #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test #pragma warning disable SA1202 // 'protected' members should come before 'private' members #pragma warning disable S107 // Methods should not have too many parameters +#pragma warning disable S907 // "goto" statement should not be used +#pragma warning disable S1659 // Multiple variables should not be declared on the same line +#pragma warning disable S3353 // Unchanged local variables should be "const" +#pragma warning disable IDE0031 // Use null propagation, suppressed until repo updates to C# 14 #pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14 namespace Microsoft.Extensions.AI; @@ -28,14 +33,33 @@ namespace Microsoft.Extensions.AI; /// /// /// -/// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a that it sends back to the inner client. This loop -/// is repeated until there are no more function calls to make, or until another stop condition is met, -/// such as hitting . +/// When this client receives a in a chat response from its inner +/// , it responds by invoking the corresponding defined +/// in (or in ), producing a +/// that it sends back to the inner client. This loop is repeated until there are no more function calls to make, or until +/// another stop condition is met, such as hitting . /// /// -/// The provided implementation of is thread-safe for concurrent use so long as the +/// If a requested function is an but not an , the +/// will not attempt to invoke it, and instead allow that +/// to pass back out to the caller. It is then that caller's responsibility to create the appropriate +/// for that call and send it back as part of a subsequent request. +/// +/// +/// Further, if a requested function is an , the will not +/// attempt to invoke it directly. Instead, it will replace that with a +/// that wraps the and indicates that the function requires approval before it can be invoked. The caller is then +/// responsible for responding to that approval request by sending a corresponding in a subsequent +/// request. The will then process that approval response and invoke the function as appropriate. +/// +/// +/// Due to the nature of interactions with an underlying , if any is received +/// for a function that requires approval, all received in that same response will also require approval, +/// even if they were not instances. If this is a concern, consider requesting that multiple tool call +/// requests not be made in a single response, by setting to . +/// +/// +/// A instance is thread-safe for concurrent use so long as the /// instances employed as part of the supplied are also safe. /// The property can be used to control whether multiple function invocation /// requests as part of the same request are invocable concurrently, but even with that set to @@ -264,7 +288,6 @@ public override async Task GetResponseAsync( List originalMessages = [.. messages]; messages = originalMessages; - Dictionary? toolMap = null; // all available tools, indexed by name List? augmentedHistory = null; // the actual history of messages sent on turns other than the first ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response @@ -273,6 +296,35 @@ public override async Task GetResponseAsync( bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set int consecutiveErrorCount = 0; + (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name + + if (HasAnyApprovalContent(originalMessages)) + { + // A previous turn may have translated FunctionCallContents from the inner client into approval requests sent back to the caller, + // for any AIFunctions that were actually ApprovalRequiredAIFunctions. If the incoming chat messages include responses to those + // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message + // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if + // the inner client had returned them directly. + (responseMessages, var notInvokedApprovals) = ProcessFunctionApprovalResponses( + originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId: null, functionCallContentFallbackMessageId: null); + (IList? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = + await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: false, cancellationToken); + + if (invokedApprovedFunctionApprovalResponses is not null) + { + // Add any generated FRCs to the list we'll return to callers as part of the next response. + (responseMessages ??= []).AddRange(invokedApprovedFunctionApprovalResponses); + } + + if (shouldTerminate) + { + return new ChatResponse(responseMessages); + } + } + + // At this point, we've fully handled all approval responses that were part of the original messages, + // and we can now enter the main function calling loop. + for (int iteration = 0; ; iteration++) { functionCallContents?.Clear(); @@ -284,19 +336,31 @@ public override async Task GetResponseAsync( Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); } + // Before we do any function execution, make sure that any functions that require approval have been turned into + // approval requests so that they don't get executed here. + if (anyToolsRequireApproval) + { + Debug.Assert(toolMap is not null, "anyToolsRequireApproval can only be true if there are tools"); + response.Messages = ReplaceFunctionCallsWithApprovalRequests(response.Messages, toolMap!); + } + // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. bool requiresFunctionInvocation = iteration < MaximumIterationsPerRequest && CopyFunctionCalls(response.Messages, ref functionCallContents); - if (requiresFunctionInvocation) - { - toolMap ??= CreateToolsDictionary(AdditionalTools, options?.Tools); - } - else if (iteration == 0) + if (!requiresFunctionInvocation && iteration == 0) { - // In a common case where we make a request and there's no function calling work required, - // fast path out by just returning the original response. + // In a common case where we make an initial request and there's no function calling work required, + // fast path out by just returning the original response. We may already have some messages + // in responseMessages from processing function approval responses, and we need to ensure + // those are included in the final response, too. + if (responseMessages is { Count: > 0 }) + { + responseMessages.AddRange(response.Messages); + response.Messages = responseMessages; + } + return response; } @@ -364,7 +428,7 @@ public override async IAsyncEnumerable GetStreamingResponseA List originalMessages = [.. messages]; messages = originalMessages; - Dictionary? toolMap = null; // all available tools, indexed by name + ApprovalRequiredAIFunction[]? approvalRequiredFunctions = null; // available tools that require approval List? augmentedHistory = null; // the actual history of messages sent on turns other than the first List? functionCallContents = null; // function call contents that need responding to in the current turn List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history @@ -372,11 +436,69 @@ public override async IAsyncEnumerable GetStreamingResponseA List updates = []; // updates from the current response int consecutiveErrorCount = 0; + (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name + + // This is a synthetic ID since we're generating the tool messages instead of getting them from + // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to + // use the same message ID for all of them within a given iteration, as this is a single logical + // message with multiple content items. We could also use different message IDs per tool content, + // but there's no benefit to doing so. + string toolMessageId = Guid.NewGuid().ToString("N"); + + if (HasAnyApprovalContent(originalMessages)) + { + // We also need a synthetic ID for the function call content for approved function calls + // where we don't know what the original message id of the function call was. + string functionCallContentFallbackMessageId = Guid.NewGuid().ToString("N"); + + // A previous turn may have translated FunctionCallContents from the inner client into approval requests sent back to the caller, + // for any AIFunctions that were actually ApprovalRequiredAIFunctions. If the incoming chat messages include responses to those + // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message + // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if + // the inner client had returned them directly. + var (preDownstreamCallHistory, notInvokedApprovals) = ProcessFunctionApprovalResponses( + originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId, functionCallContentFallbackMessageId); + if (preDownstreamCallHistory is not null) + { + foreach (var message in preDownstreamCallHistory) + { + yield return ConvertToolResultMessageToUpdate(message, options?.ConversationId, message.MessageId); + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + } + + // Invoke approved approval responses, which generates some additional FRC wrapped in ChatMessage. + (IList? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = + await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: true, cancellationToken); + + if (invokedApprovedFunctionApprovalResponses is not null) + { + foreach (var message in invokedApprovedFunctionApprovalResponses) + { + message.MessageId = toolMessageId; + yield return ConvertToolResultMessageToUpdate(message, options?.ConversationId, message.MessageId); + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + if (shouldTerminate) + { + yield break; + } + } + } + + // At this point, we've fully handled all approval responses that were part of the original messages, + // and we can now enter the main function calling loop. + for (int iteration = 0; ; iteration++) { updates.Clear(); functionCallContents?.Clear(); + bool hasApprovalRequiringFcc = false; + int lastApprovalCheckedFCCIndex = 0; + int lastYieldedUpdateIndex = 0; + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { if (update is null) @@ -401,18 +523,80 @@ public override async IAsyncEnumerable GetStreamingResponseA } } - yield return update; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + // We're streaming updates back to the caller. However, approvals requires extra handling. We should not yield any + // FunctionCallContents back to the caller if approvals might be required, because if any actually are, we need to convert + // all FunctionCallContents into approval requests, even those that don't require approval (we otherwise don't have a way + // to track the FCCs to a later turn, in particular when the conversation history is managed by the service / inner client). + // So, if there are no functions that need approval, we can yield updates with FCCs as they arrive. But if any FCC _might_ + // require approval (which just means that any AIFunction we can possibly invoke requires approval), then we need to hold off + // on yielding any FCCs until we know whether any of them actually require approval, which is either at the end of the stream + // or the first time we get an FCC that requires approval. At that point, we can yield all of the updates buffered thus far + // and anything further, replacing FCCs with approval if any required it, or yielding them as is. + if (anyToolsRequireApproval && approvalRequiredFunctions is null && functionCallContents is { Count: > 0 }) + { + approvalRequiredFunctions = + (options?.Tools ?? Enumerable.Empty()) + .Concat(AdditionalTools ?? Enumerable.Empty()) + .OfType() + .ToArray(); + } + + if (approvalRequiredFunctions is not { Length: > 0 }) + { + // If there are no function calls to make yet, or if none of the functions require approval at all, + // we can yield the update as-is. + lastYieldedUpdateIndex++; + yield return update; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + + continue; + } + + // There are function calls to make, some of which _may_ require approval. + Debug.Assert(functionCallContents is { Count: > 0 }, "Expected to have function call contents to check for approval requiring functions."); + Debug.Assert(approvalRequiredFunctions is { Length: > 0 }, "Expected to have approval requiring functions to check against function call contents."); + + // Check if any of the function call contents in this update requires approval. + (hasApprovalRequiringFcc, lastApprovalCheckedFCCIndex) = CheckForApprovalRequiringFCC( + functionCallContents, approvalRequiredFunctions!, hasApprovalRequiringFcc, lastApprovalCheckedFCCIndex); + if (hasApprovalRequiringFcc) + { + // If we've encountered a function call content that requires approval, + // we need to ask for approval for all functions, since we cannot mix and match. + // Convert all function call contents into approval requests from the last yielded update index + // and yield all those updates. + for (; lastYieldedUpdateIndex < updates.Count; lastYieldedUpdateIndex++) + { + var updateToYield = updates[lastYieldedUpdateIndex]; + if (TryReplaceFunctionCallsWithApprovalRequests(updateToYield.Contents, out var updatedContents)) + { + updateToYield.Contents = updatedContents; + } + + yield return updateToYield; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + continue; + } + + // We don't have any approval requiring function calls yet, but we may receive some in future + // so we cannot yield the updates yet. We'll just keep them in the updates list for later. + // We will yield the updates as soon as we receive a function call content that requires approval + // or when we reach the end of the updates stream. } // If there's nothing more to do, break out of the loop and allow the handling at the // end to configure the response with aggregated data from previous requests. if (iteration >= MaximumIterationsPerRequest || - ShouldTerminateLoopBasedOnHandleableFunctions(functionCallContents, toolMap ??= CreateToolsDictionary(AdditionalTools, options?.Tools))) + hasApprovalRequiringFcc || + ShouldTerminateLoopBasedOnHandleableFunctions(functionCallContents, toolMap)) { break; } + // We need to invoke functions. + // Reconstitute a response from the response updates. var response = updates.ToChatResponse(); (responseMessages ??= []).AddRange(response.Messages); @@ -425,31 +609,11 @@ public override async IAsyncEnumerable GetStreamingResponseA responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - // This is a synthetic ID since we're generating the tool messages instead of getting them from - // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to - // use the same message ID for all of them within a given iteration, as this is a single logical - // message with multiple content items. We could also use different message IDs per tool content, - // but there's no benefit to doing so. - string toolResponseId = Guid.NewGuid().ToString("N"); - // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages // includes all activities, including generated function results. foreach (var message in modeAndMessages.MessagesAdded) { - var toolResultUpdate = new ChatResponseUpdate - { - AdditionalProperties = message.AdditionalProperties, - AuthorName = message.AuthorName, - ConversationId = response.ConversationId, - CreatedAt = DateTimeOffset.UtcNow, - Contents = message.Contents, - RawRepresentation = message.RawRepresentation, - ResponseId = toolResponseId, - MessageId = toolResponseId, // See above for why this can be the same as ResponseId - Role = message.Role, - }; - - yield return toolResultUpdate; + yield return ConvertToolResultMessageToUpdate(message, response.ConversationId, toolMessageId); Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } @@ -464,6 +628,20 @@ public override async IAsyncEnumerable GetStreamingResponseA AddUsageTags(activity, totalUsage); } + private static ChatResponseUpdate ConvertToolResultMessageToUpdate(ChatMessage message, string? conversationId, string? messageId) => + new() + { + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + ConversationId = conversationId, + CreatedAt = DateTimeOffset.UtcNow, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + ResponseId = messageId, + MessageId = messageId, + Role = message.Role, + }; + /// Adds tags to for usage details in . private static void AddUsageTags(Activity? activity, UsageDetails? usage) { @@ -543,31 +721,39 @@ private static void FixupHistories( messages = augmentedHistory; } - /// Creates a dictionary mapping tool names to the corresponding tools. + /// Creates a mapping from tool names to the corresponding tools. /// /// The lists of tools to combine into a single dictionary. Tools from later lists are preferred /// over tools from earlier lists if they have the same name. /// - private static Dictionary? CreateToolsDictionary(params ReadOnlySpan?> toolLists) + private static (Dictionary? ToolMap, bool AnyRequireApproval) CreateToolsMap(params ReadOnlySpan?> toolLists) { - Dictionary? tools = null; + Dictionary? map = null; + bool anyRequireApproval = false; foreach (var toolList in toolLists) { if (toolList?.Count is int count && count > 0) { - tools ??= new(StringComparer.Ordinal); + map ??= new(StringComparer.Ordinal); for (int i = 0; i < count; i++) { AITool tool = toolList[i]; - tools[tool.Name] = tool; + anyRequireApproval |= tool is ApprovalRequiredAIFunction; + map[tool.Name] = tool; } } } - return tools; + return (map, anyRequireApproval); } + /// + /// Gets whether contains any or instances. + /// + private static bool HasAnyApprovalContent(List messages) => + messages.Any(static m => m.Contents.Any(static c => c is FunctionApprovalRequestContent or FunctionApprovalResponseContent)); + /// Copies any from to . private static bool CopyFunctionCalls( IList messages, [NotNullWhen(true)] ref List? functionCalls) @@ -758,7 +944,6 @@ select ProcessFunctionCallAsync( } } -#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection /// /// Updates the consecutive error count, and throws an exception if the count exceeds the maximum. /// @@ -767,24 +952,23 @@ select ProcessFunctionCallAsync( /// Thrown if the maximum consecutive error count is exceeded. private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) { - var allExceptions = added.SelectMany(m => m.Contents.OfType()) - .Select(frc => frc.Exception!) - .Where(e => e is not null); - - if (allExceptions.Any()) + if (added.Any(static m => m.Contents.Any(static c => c is FunctionResultContent { Exception: not null }))) { consecutiveErrorCount++; if (consecutiveErrorCount > MaximumConsecutiveErrorsPerRequest) { - var allExceptionsArray = allExceptions.ToArray(); + var allExceptionsArray = added + .SelectMany(m => m.Contents.OfType()) + .Select(frc => frc.Exception!) + .Where(e => e is not null) + .ToArray(); + if (allExceptionsArray.Length == 1) { ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); } - else - { - throw new AggregateException(allExceptionsArray); - } + + throw new AggregateException(allExceptionsArray); } } else @@ -792,14 +976,13 @@ private void UpdateConsecutiveErrorCountOrThrow(IList added, ref in consecutiveErrorCount = 0; } } -#pragma warning restore CA1851 /// /// Throws an exception if doesn't create any messages. /// private void ThrowIfNoFunctionResultsAdded(IList? messages) { - if (messages is null || messages.Count == 0) + if (messages is not { Count: > 0 }) { Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); } @@ -1011,6 +1194,384 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul context.Function.InvokeAsync(context.Arguments, cancellationToken); } + /// + /// 1. Remove all and from the . + /// 2. Recreate for any that haven't been executed yet. + /// 3. Genreate failed for any rejected . + /// 4. add all the new content items to and return them as the pre-invocation history. + /// + private static (List? preDownstreamCallHistory, List? approvals) ProcessFunctionApprovalResponses( + List originalMessages, bool hasConversationId, string? toolMessageId, string? functionCallContentFallbackMessageId) + { + // Extract any approval responses where we need to execute or reject the function calls. + // The original messages are also modified to remove all approval requests and responses. + var notInvokedResponses = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages); + + // Wrap the function call content in message(s). + ICollection? allPreDownstreamCallMessages = ConvertToFunctionCallContentMessages( + [.. notInvokedResponses.rejections ?? Enumerable.Empty(), .. notInvokedResponses.approvals ?? Enumerable.Empty()], + functionCallContentFallbackMessageId); + + // Generate failed function result contents for any rejected requests and wrap it in a message. + List? rejectedFunctionCallResults = GenerateRejectedFunctionResults(notInvokedResponses.rejections); + ChatMessage? rejectedPreDownstreamCallResultsMessage = rejectedFunctionCallResults is not null ? + new ChatMessage(ChatRole.Tool, rejectedFunctionCallResults) { MessageId = toolMessageId } : + null; + + // Add all the FCC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. + // Also, if we are not dealing with a service thread (i.e. we don't have a conversation ID), add them + // into the original messages list so that they are passed to the inner client and can be used to generate a result. + List? preDownstreamCallHistory = null; + if (allPreDownstreamCallMessages is not null) + { + preDownstreamCallHistory = [.. allPreDownstreamCallMessages]; + if (!hasConversationId) + { + originalMessages.AddRange(preDownstreamCallHistory); + } + } + + // Add all the FRC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. + // Also, add them into the original messages list so that they are passed to the inner client and can be used to generate a result. + if (rejectedPreDownstreamCallResultsMessage is not null) + { + (preDownstreamCallHistory ??= []).Add(rejectedPreDownstreamCallResultsMessage); + originalMessages.Add(rejectedPreDownstreamCallResultsMessage); + } + + return (preDownstreamCallHistory, notInvokedResponses.approvals); + } + + /// + /// This method extracts the approval requests and responses from the provided list of messages, + /// validates them, filters them to ones that require execution, and splits them into approved and rejected. + /// + /// + /// We return the messages containing the approval requests since these are the same messages that originally contained the FunctionCallContent from the downstream service. + /// We can then use the metadata from these messages when we re-create the FunctionCallContent messages/updates to return to the caller. This way, when we finally do return + /// the FuncionCallContent to users it's part of a message/update that contains the same metadata as originally returned to the downstream service. + /// + private static (List? approvals, List? rejections) ExtractAndRemoveApprovalRequestsAndResponses( + List messages) + { + Dictionary? allApprovalRequestsMessages = null; + List? allApprovalResponses = null; + HashSet? approvalRequestCallIds = null; + HashSet? functionResultCallIds = null; + + // 1st iteration, over all messages and content: + // - Build a list of all function call ids that are already executed. + // - Build a list of all function approval requests and responses. + // - Build a list of the content we want to keep (everything except approval requests and responses) and create a new list of messages for those. + // - Validate that we have an approval response for each approval request. + bool anyRemoved = false; + int i = 0; + for (; i < messages.Count; i++) + { + var message = messages[i]; + + List? keptContents = null; + + // Examine all content to populate our various collections. + for (int j = 0; j < message.Contents.Count; j++) + { + var content = message.Contents[j]; + switch (content) + { + case FunctionApprovalRequestContent farc: + // Validation: Capture each call id for each approval request to ensure later we have a matching response. + _ = (approvalRequestCallIds ??= []).Add(farc.FunctionCall.CallId); + (allApprovalRequestsMessages ??= []).Add(farc.Id, message); + break; + + case FunctionApprovalResponseContent farc: + // Validation: Remove the call id for each approval response, to check it off the list of requests we need responses for. + _ = approvalRequestCallIds?.Remove(farc.FunctionCall.CallId); + (allApprovalResponses ??= []).Add(farc); + break; + + case FunctionResultContent frc: + // Maintain a list of function calls that have already been invoked to avoid invoking them twice. + _ = (functionResultCallIds ??= []).Add(frc.CallId); + goto default; + + default: + // Content to keep. + (keptContents ??= []).Add(content); + break; + } + } + + // If any contents were filtered out, we need to either remove the message entirely (if no contents remain) or create a new message with the filtered contents. + if (keptContents?.Count != message.Contents.Count) + { + if (keptContents is { Count: > 0 }) + { + // Create a new replacement message to store the filtered contents. + var newMessage = message.Clone(); + newMessage.Contents = keptContents; + messages[i] = newMessage; + } + else + { + // Remove the message entirely since it has no contents left. Rather than doing an O(N) removal, which could possibly + // result in an O(N^2) overall operation, we mark the message as null and then do a single pass removal of all nulls after the loop. + anyRemoved = true; + messages[i] = null!; + } + } + } + + // Clean up any messages that were marked for removal during the iteration. + if (anyRemoved) + { + _ = messages.RemoveAll(static m => m is null); + } + + // Validation: If we got an approval for each request, we should have no call ids left. + if (approvalRequestCallIds is { Count: > 0 }) + { + Throw.InvalidOperationException( + $"FunctionApprovalRequestContent found with FunctionCall.CallId(s) '{string.Join(", ", approvalRequestCallIds)}' that have no matching FunctionApprovalResponseContent."); + } + + // 2nd iteration, over all approval responses: + // - Filter out any approval responses that already have a matching function result (i.e. already executed). + // - Find the matching function approval request for any response (where available). + // - Split the approval responses into two lists: approved and rejected, with their request messages (where available). + List? approvedFunctionCalls = null, rejectedFunctionCalls = null; + if (allApprovalResponses is { Count: > 0 }) + { + foreach (var approvalResponse in allApprovalResponses) + { + // Skip any approval responses that have already been processed. + if (functionResultCallIds?.Contains(approvalResponse.FunctionCall.CallId) is true) + { + continue; + } + + // Split the responses into approved and rejected. + ref List? targetList = ref approvalResponse.Approved ? ref approvedFunctionCalls : ref rejectedFunctionCalls; + + ChatMessage? requestMessage = null; + _ = allApprovalRequestsMessages?.TryGetValue(approvalResponse.FunctionCall.CallId, out requestMessage); + + (targetList ??= []).Add(new() { Response = approvalResponse, RequestMessage = requestMessage }); + } + } + + return (approvedFunctionCalls, rejectedFunctionCalls); + } + + /// + /// If we have any rejected approval responses, we need to generate failed function results for them. + /// + /// Any rejected approval responses. + /// The for the rejected function calls. + private static List? GenerateRejectedFunctionResults(List? rejections) => + rejections is { Count: > 0 } ? + rejections.ConvertAll(static m => (AIContent)new FunctionResultContent(m.Response.FunctionCall.CallId, "Error: Tool call invocation was rejected by user.")) : + null; + + /// + /// Extracts the from the provided to recreate the original function call messages. + /// The output messages tries to mimic the original messages that contained the , e.g. if the + /// had been split into separate messages, this method will recreate similarly split messages, each with their own . + /// + private static ICollection? ConvertToFunctionCallContentMessages( + List? resultWithRequestMessages, string? fallbackMessageId) + { + if (resultWithRequestMessages is not null) + { + ChatMessage? currentMessage = null; + Dictionary? messagesById = null; + + foreach (var resultWithRequestMessage in resultWithRequestMessages) + { + // Don't need to create a dictionary if we already have one or if it's the first iteration. + if (messagesById is null && currentMessage is not null + + // Everywhere we have no RequestMessage we use the fallbackMessageId, so in this case there is only one message. + && !(resultWithRequestMessage.RequestMessage is null && currentMessage.MessageId == fallbackMessageId) + + // Where we do have a RequestMessage, we can check if its message id differs from the current one. + && (resultWithRequestMessage.RequestMessage is not null && currentMessage.MessageId != resultWithRequestMessage.RequestMessage.MessageId)) + { + // The majority of the time, all FCC would be part of a single message, so no need to create a dictionary for this case. + // If we are dealing with multiple messages though, we need to keep track of them by their message ID. + messagesById = []; + messagesById[currentMessage.MessageId ?? string.Empty] = currentMessage; + } + + _ = messagesById?.TryGetValue(resultWithRequestMessage.RequestMessage?.MessageId ?? string.Empty, out currentMessage); + + if (currentMessage is null) + { + currentMessage = ConvertToFunctionCallContentMessage(resultWithRequestMessage, fallbackMessageId); + } + else + { + currentMessage.Contents.Add(resultWithRequestMessage.Response.FunctionCall); + } + + if (messagesById is not null) + { + messagesById[currentMessage.MessageId ?? string.Empty] = currentMessage; + } + } + + if (messagesById?.Values is ICollection cm) + { + return cm; + } + + if (currentMessage is not null) + { + return [currentMessage]; + } + } + + return null; + } + + /// + /// Takes the from the and wraps it in a + /// using the same message id that the was originally returned with from the downstream . + /// + private static ChatMessage ConvertToFunctionCallContentMessage(ApprovalResultWithRequestMessage resultWithRequestMessage, string? fallbackMessageId) + { + ChatMessage functionCallMessage = resultWithRequestMessage.RequestMessage?.Clone() ?? new() { Role = ChatRole.Assistant }; + functionCallMessage.Contents = [resultWithRequestMessage.Response.FunctionCall]; + functionCallMessage.MessageId ??= fallbackMessageId; + return functionCallMessage; + } + + /// + /// Check if any of the provided require approval. + /// Supports checking from a provided index up to the end of the list, to allow efficient incremental checking + /// when streaming. + /// + private static (bool hasApprovalRequiringFcc, int lastApprovalCheckedFCCIndex) CheckForApprovalRequiringFCC( + List? functionCallContents, + ApprovalRequiredAIFunction[] approvalRequiredFunctions, + bool hasApprovalRequiringFcc, + int lastApprovalCheckedFCCIndex) + { + // If we already found an approval requiring FCC, we can skip checking the rest. + if (hasApprovalRequiringFcc) + { + Debug.Assert(functionCallContents is not null, "functionCallContents must not be null here, since we have already encountered approval requiring functionCallContents"); + return (true, functionCallContents!.Count); + } + + if (functionCallContents is not null) + { + for (; lastApprovalCheckedFCCIndex < functionCallContents.Count; lastApprovalCheckedFCCIndex++) + { + var fcc = functionCallContents![lastApprovalCheckedFCCIndex]; + foreach (var arf in approvalRequiredFunctions) + { + if (arf.Name == fcc.Name) + { + hasApprovalRequiringFcc = true; + break; + } + } + } + } + + return (hasApprovalRequiringFcc, lastApprovalCheckedFCCIndex); + } + + /// + /// Replaces all with and ouputs a new list if any of them were replaced. + /// + /// true if any was replaced, false otherwise. + private static bool TryReplaceFunctionCallsWithApprovalRequests(IList content, out List? updatedContent) + { + updatedContent = null; + + if (content is { Count: > 0 }) + { + for (int i = 0; i < content.Count; i++) + { + if (content[i] is FunctionCallContent fcc) + { + updatedContent ??= [.. content]; // Clone the list if we haven't already + updatedContent[i] = new FunctionApprovalRequestContent(fcc.CallId, fcc); + } + } + } + + return updatedContent is not null; + } + + /// + /// Replaces all from with + /// if any one of them requires approval. + /// + private static IList ReplaceFunctionCallsWithApprovalRequests( + IList messages, + Dictionary toolMap) + { + var outputMessages = messages; + + bool anyApprovalRequired = false; + List<(int, int)>? allFunctionCallContentIndices = null; + + // Build a list of the indices of all FunctionCallContent items. + // Also check if any of them require approval. + for (int i = 0; i < messages.Count; i++) + { + var content = messages[i].Contents; + for (int j = 0; j < content.Count; j++) + { + if (content[j] is FunctionCallContent functionCall) + { + (allFunctionCallContentIndices ??= []).Add((i, j)); + + if (!anyApprovalRequired) + { + foreach (var t in toolMap) + { + if (t.Value is ApprovalRequiredAIFunction araf && araf.Name == functionCall.Name) + { + anyApprovalRequired = true; + break; + } + } + } + } + } + } + + // If any function calls were found, and any of them required approval, we should replace all of them with approval requests. + // This is because we do not have a way to deal with cases where some function calls require approval and others do not, so we just replace all of them. + if (anyApprovalRequired) + { + Debug.Assert(allFunctionCallContentIndices is not null, "We have already encountered function call contents that require approval."); + + // Clone the list so, we don't mutate the input. + outputMessages = [.. messages]; + int lastMessageIndex = -1; + + foreach (var (messageIndex, contentIndex) in allFunctionCallContentIndices!) + { + // Clone the message if we didn't already clone it in a previous iteration. + var message = lastMessageIndex != messageIndex ? outputMessages[messageIndex].Clone() : outputMessages[messageIndex]; + message.Contents = [.. message.Contents]; + + var functionCall = (FunctionCallContent)message.Contents[contentIndex]; + message.Contents[contentIndex] = new FunctionApprovalRequestContent(functionCall.CallId, functionCall); + outputMessages[messageIndex] = message; + + lastMessageIndex = messageIndex; + } + } + + return outputMessages; + } + private static TimeSpan GetElapsedTime(long startingTimestamp) => #if NET Stopwatch.GetElapsedTime(startingTimestamp); @@ -1018,6 +1579,33 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); #endif + /// + /// Execute the provided and return the resulting + /// wrapped in objects. + /// + private async Task<(IList? FunctionResultContentMessages, bool ShouldTerminate, int ConsecutiveErrorCount)> InvokeApprovedFunctionApprovalResponsesAsync( + List? notInvokedApprovals, + Dictionary? toolMap, + List originalMessages, + ChatOptions? options, + int consecutiveErrorCount, + bool isStreaming, + CancellationToken cancellationToken) + { + // Check if there are any function calls to do for any approved functions and execute them. + if (notInvokedApprovals is { Count: > 0 }) + { + // The FRC that is generated here is already added to originalMessages by ProcessFunctionCallsAsync. + var modeAndMessages = await ProcessFunctionCallsAsync( + originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, isStreaming, cancellationToken); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; + + return (modeAndMessages.MessagesAdded, modeAndMessages.ShouldTerminate, consecutiveErrorCount); + } + + return (null, false, consecutiveErrorCount); + } + [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] private partial void LogInvoking(string methodName); @@ -1084,4 +1672,10 @@ public enum FunctionInvocationStatus /// The function call failed with an exception. Exception, } + + private struct ApprovalResultWithRequestMessage + { + public FunctionApprovalResponseContent Response { get; set; } + public ChatMessage? RequestMessage { get; set; } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index fe431ca21e5..0d77680cbc4 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -14,7 +14,7 @@ $(TargetFrameworks);netstandard2.0 - $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 + $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253;MEAI001