diff --git a/eng/packages/General.props b/eng/packages/General.props index b66f1a4ffa8..441a30afa73 100644 --- a/eng/packages/General.props +++ b/eng/packages/General.props @@ -29,6 +29,7 @@ + diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index 5875491f919..603e31c0e5b 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -21,7 +21,6 @@ - diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs new file mode 100644 index 00000000000..029eeae47a1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a strategy capable of selecting a reduced set of tools for a chat request. +/// +/// +/// A tool reduction strategy is invoked prior to sending a request to an underlying , +/// enabling scenarios where a large tool catalog must be trimmed to fit provider limits or to improve model +/// tool selection quality. +/// +/// The implementation should return a non- enumerable. Returning the original +/// instance indicates no change. Returning a different enumerable indicates +/// the caller may replace the existing tool list. +/// +/// +[Experimental("MEAI001")] +public interface IToolReductionStrategy +{ + /// + /// Selects the tools that should be included for a specific request. + /// + /// The chat messages for the request. This is an to avoid premature materialization. + /// The chat options for the request (may be ). + /// A token to observe cancellation. + /// + /// A (possibly reduced) enumerable of instances. Must never be . + /// Returning the same instance referenced by . signals no change. + /// + Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs index 554918b0a8e..8217ea49da5 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs @@ -6,6 +6,8 @@ using Microsoft.Shared.Collections; using Microsoft.Shared.Diagnostics; +#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14 + namespace Microsoft.Extensions.AI; /// Provides context for an in-flight function invocation. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index fb821c984df..c1cff1ab554 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -284,7 +284,7 @@ public override async Task GetResponseAsync( bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set int consecutiveErrorCount = 0; - (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name + (Dictionary? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name if (HasAnyApprovalContent(originalMessages)) { @@ -424,7 +424,7 @@ public override async IAsyncEnumerable GetStreamingResponseA List updates = []; // updates from the current response int consecutiveErrorCount = 0; - (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name + (Dictionary? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name // This is a synthetic ID since we're generating the tool messages instead of getting them from // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to @@ -624,7 +624,13 @@ public override async IAsyncEnumerable GetStreamingResponseA AddUsageTags(activity, totalUsage); } - private static ChatResponseUpdate ConvertToolResultMessageToUpdate(ChatMessage message, string? conversationId, string? messageId) => + /// + /// Converts a tool result into a for streaming scenarios. + /// + /// The tool result message. + /// The conversation ID. + /// The message ID. + internal static ChatResponseUpdate ConvertToolResultMessageToUpdate(ChatMessage message, string? conversationId, string? messageId) => new() { AdditionalProperties = message.AdditionalProperties, @@ -662,7 +668,7 @@ private static void AddUsageTags(Activity? activity, UsageDetails? usage) /// The most recent response being handled. /// A list of all response messages received up until this point. /// Whether the previous iteration's response had a conversation ID. - private static void FixupHistories( + internal static void FixupHistories( IEnumerable originalMessages, ref IEnumerable messages, [NotNull] ref List? augmentedHistory, @@ -722,26 +728,51 @@ private static void FixupHistories( /// The lists of tools to combine into a single dictionary. Tools from later lists are preferred /// over tools from earlier lists if they have the same name. /// - private static (Dictionary? ToolMap, bool AnyRequireApproval) CreateToolsMap(params ReadOnlySpan?> toolLists) + /// The to monitor for cancellation requests. + private static async ValueTask<(Dictionary? ToolMap, bool AnyRequireApproval)> CreateToolsMapAsync(IList?[] toolLists, CancellationToken cancellationToken) { Dictionary? map = null; bool anyRequireApproval = false; foreach (var toolList in toolLists) { - if (toolList?.Count is int count && count > 0) + if (toolList is not null) + { + map ??= []; + var anyInListRequireApproval = await AddToolListAsync(map, toolList, cancellationToken).ConfigureAwait(false); + anyRequireApproval |= anyInListRequireApproval; + } + } + + return (map, anyRequireApproval); + + static async ValueTask AddToolListAsync(Dictionary map, IEnumerable tools, CancellationToken cancellationToken) + { +#if NET + if (tools.TryGetNonEnumeratedCount(out var count) && count == 0) { - map ??= new(StringComparer.Ordinal); - for (int i = 0; i < count; i++) + return false; + } +#endif + var anyRequireApproval = false; + + foreach (var tool in tools) + { + if (tool is AIToolGroup toolGroup) + { + var nestedTools = await toolGroup.GetToolsAsync(cancellationToken).ConfigureAwait(false); + var nestedToolsRequireApproval = await AddToolListAsync(map, nestedTools, cancellationToken).ConfigureAwait(false); + anyRequireApproval |= nestedToolsRequireApproval; + } + else { - AITool tool = toolList[i]; anyRequireApproval |= tool.GetService() is not null; map[tool.Name] = tool; } } - } - return (map, anyRequireApproval); + return anyRequireApproval; + } } /// diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index 36e6bb00562..54cbcc99754 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -44,6 +44,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs new file mode 100644 index 00000000000..1706861b5a1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a logical grouping of tools that can be dynamically expanded. +/// +/// +/// +/// A is an that supplies an ordered list of instances +/// via the method. This enables grouping tools together for organizational purposes +/// and allows for dynamic tool selection based on context. +/// +/// +/// Tool groups can be used independently or in conjunction with to implement +/// hierarchical tool selection, where groups are initially collapsed and can be expanded on demand. +/// +/// +[Experimental("MEAI001")] +public abstract class AIToolGroup : AITool +{ + private readonly string _name; + private readonly string _description; + + /// Initializes a new instance of the class. + /// Group name (identifier used by the expansion function). + /// Human readable description of the group. + /// is . + protected AIToolGroup(string name, string description) + { + _name = Throw.IfNull(name); + _description = Throw.IfNull(description); + } + + /// Gets the group name. + public override string Name => _name; + + /// Gets the group description. + public override string Description => _description; + + /// Creates a tool group with a static list of tools. + /// Group name (identifier used by the expansion function). + /// Human readable description of the group. + /// Ordered tools contained in the group. + /// An instance containing the specified tools. + /// or is . + public static AIToolGroup Create(string name, string description, IReadOnlyList tools) + { + _ = Throw.IfNull(name); + _ = Throw.IfNull(tools); + return new StaticAIToolGroup(name, description, tools); + } + + /// + /// Asynchronously retrieves the ordered list of tools belonging to this group. + /// + /// The to monitor for cancellation requests. + /// A representing the asynchronous operation, containing the ordered list of tools in the group. + /// + /// The returned list may contain other instances, enabling hierarchical tool organization. + /// Implementations should ensure the returned list is stable and deterministic for a given group instance. + /// + public abstract ValueTask> GetToolsAsync(CancellationToken cancellationToken = default); + + /// A tool group implementation that returns a static list of tools. + private sealed class StaticAIToolGroup : AIToolGroup + { + private readonly IReadOnlyList _tools; + + public StaticAIToolGroup(string name, string description, IReadOnlyList tools) + : base(name, description) + { + _tools = tools; + } + + public override ValueTask> GetToolsAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + return new ValueTask>(_tools); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs new file mode 100644 index 00000000000..6c79a30cbba --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Builder extensions for . +[Experimental("MEAI001")] +public static class ChatClientBuilderToolGroupingExtensions +{ + /// Adds tool grouping middleware to the pipeline. + /// Chat client builder. + /// Configuration delegate. + /// The builder for chaining. + /// Should appear before tool reduction and function invocation middleware. + public static ChatClientBuilder UseToolGrouping(this ChatClientBuilder builder, Action? configure = null) + { + _ = Throw.IfNull(builder); + var options = new ToolGroupingOptions(); + configure?.Invoke(options); + return builder.Use(inner => new ToolGroupingChatClient(inner, options)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs new file mode 100644 index 00000000000..684528f795d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs @@ -0,0 +1,516 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable IDE0058 // Expression value is never used + +/// +/// A chat client that enables tool groups (see ) to be dynamically expanded. +/// +/// +/// +/// On each request, this chat client initially presents a minimal tool surface consisting of: (a) a function +/// returning the current list of available groups plus (b) a synthetic expansion function plus (c) tools in +/// that are not instances. +/// If the model calls the expansion function with a valid group name, the +/// client issues another request with that group's tools visible. +/// Only one group may be expanded per top-level request, and by default at most three expansion loops are performed. +/// +/// +/// This client should typically appear in the pipeline before tool reduction middleware and function invocation +/// middleware. Example order: .UseToolGrouping(...).UseToolReduction(...).UseFunctionInvocation(). +/// +/// +[Experimental("MEAI001")] +public sealed class ToolGroupingChatClient : DelegatingChatClient +{ + private const string ExpansionFunctionGroupNameParameter = "groupName"; + private static readonly Delegate _expansionFunctionDelegate = static string (string groupName) + => throw new InvalidOperationException("The tool expansion function should not be invoked directly."); + + private readonly int _maxExpansionsPerRequest; + private readonly AIFunctionDeclaration _expansionFunction; + private readonly string _listGroupsFunctionName; + private readonly string _listGroupsFunctionDescription; + + /// Initializes a new instance of the class. + /// Inner client. + /// Grouping options. + public ToolGroupingChatClient(IChatClient innerClient, ToolGroupingOptions options) + : base(innerClient) + { + _ = Throw.IfNull(options); + + _maxExpansionsPerRequest = options.MaxExpansionsPerRequest; + _listGroupsFunctionName = options.ListGroupsFunctionName; + _listGroupsFunctionDescription = options.ListGroupsFunctionDescription + ?? "Returns the list of available tool groups that can be expanded."; + + var expansionFunctionName = options.ExpansionFunctionName; + var expansionDescription = options.ExpansionFunctionDescription + ?? $"Expands a tool group to make its tools available. Use the '{_listGroupsFunctionName}' function to see available groups."; + + _expansionFunction = AIFunctionFactory.Create( + method: _expansionFunctionDelegate, + name: expansionFunctionName, + description: expansionDescription).AsDeclarationOnly(); + } + + /// + public override async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + var toolGroups = ExtractToolGroups(options); + if (toolGroups is not { Count: > 0 }) + { + // If there are no tool groups, then tool expansion isn't possible. + // We'll just call directly through to the inner chat client. + return await base.GetResponseAsync(messages, options, cancellationToken); + } + + // Copy the original messages in order to avoid enumerating the original messages multiple times. + // The IEnumerable can represent an arbitrary amount of work. + List originalMessages = [.. messages]; + messages = originalMessages; + + // Build top-level groups dictionary + var topLevelToolGroupsByName = toolGroups.ToDictionary(g => g.Name, StringComparer.Ordinal); + + // Track the currently-expanded group and all its constituent tools + AIToolGroup? expandedGroup = null; + List? expandedGroupToolGroups = null; // tool groups within the currently-expanded tool group + List? expandedGroupTools = null; // non-group tools within the currently-expanded tool group + + // Create the "list groups" function. Its behavior is controlled by values captured in the lambda below. + var listGroupsFunction = AIFunctionFactory.Create( + method: () => CreateListGroupsResult(expandedGroup, toolGroups, expandedGroupToolGroups), + name: _listGroupsFunctionName, + description: _listGroupsFunctionDescription); + + // Construct new chat options containing ungrouped tools and utility functions. + List baseTools = ComputeBaseTools(options, listGroupsFunction); + ChatOptions modifiedOptions = options?.Clone() ?? new(); + modifiedOptions.Tools = baseTools; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + List? expansionRequests = null; // expansion requests that need responding to in the current turn + UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response + bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set + List? modifiedTools = null; // the modified tools list containing the current tool group + + for (var expansionIterationCount = 0; ; expansionIterationCount++) + { + expansionRequests?.Clear(); + + // Make the call to the inner client. + response = await base.GetResponseAsync(messages, modifiedOptions, cancellationToken).ConfigureAwait(false); + if (response is null) + { + Throw.InvalidOperationException("Inner client returned null ChatResponse."); + } + + // Any expansions to perform? If yes, ensure we're tracking that work in expansionRequests. + bool requiresExpansion = + expansionIterationCount < _maxExpansionsPerRequest && + CopyExpansionRequests(response.Messages, ref expansionRequests); + + if (!requiresExpansion && expansionIterationCount == 0) + { + // Fast path: no function calling work required + return response; + } + + // Track aggregate details from the response + (responseMessages ??= []).AddRange(response.Messages); + if (response.Usage is not null) + { + if (totalUsage is not null) + { + totalUsage.Add(response.Usage); + } + else + { + totalUsage = response.Usage; + } + } + + if (!requiresExpansion) + { + // No more work to do. + break; + } + + // Prepare the history for the next iteration. + FunctionInvokingChatClient.FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); + + expandedGroupTools ??= []; + expandedGroupToolGroups ??= []; + (var addedMessages, expandedGroup) = await ProcessExpansionsAsync( + expansionRequests!, + topLevelToolGroupsByName, + expandedGroupTools, + expandedGroupToolGroups, + expandedGroup, + cancellationToken); + + augmentedHistory.AddRange(addedMessages); + responseMessages.AddRange(addedMessages); + + (modifiedTools ??= []).Clear(); + modifiedTools.AddRange(baseTools); + modifiedTools.AddRange(expandedGroupTools); + modifiedOptions.Tools = modifiedTools; + modifiedOptions.ConversationId = response.ConversationId; + } + + Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); + response.Messages = responseMessages!; + response.Usage = totalUsage; + + return response; + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + var toolGroups = ExtractToolGroups(options); + if (toolGroups is not { Count: > 0 }) + { + // No tool groups, just call through + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + + yield break; + } + + List originalMessages = [.. messages]; + messages = originalMessages; + + // Build top-level groups dictionary + var topLevelToolGroupsByName = toolGroups.ToDictionary(g => g.Name, StringComparer.Ordinal); + + // Track the currently-expanded group and all its constituent tools + AIToolGroup? expandedGroup = null; + List? expandedGroupToolGroups = null; // tool groups within the currently-expanded tool group + List? expandedGroupTools = null; // non-group tools within the currently-expanded tool group + + // Create the "list groups" function. Its behavior is controlled by values captured in the lambda below. + var listGroupsFunction = AIFunctionFactory.Create( + method: () => CreateListGroupsResult(expandedGroup, toolGroups, expandedGroupToolGroups), + name: _listGroupsFunctionName, + description: _listGroupsFunctionDescription); + + // Construct new chat options containing ungrouped tools and utility functions. + List baseTools = ComputeBaseTools(options, listGroupsFunction); + ChatOptions modifiedOptions = options?.Clone() ?? new(); + modifiedOptions.Tools = baseTools; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + List? expansionRequests = null; // expansion requests that need responding to in the current turn + bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set + List updates = []; // collected updates from the inner client for the current iteration + List? modifiedTools = null; + string toolMessageId = Guid.NewGuid().ToString("N"); // stable id for synthetic tool result updates emitted per iteration + + for (int expansionIterationCount = 0; ; expansionIterationCount++) + { + // Reset any state accumulated from the prior iteration before calling the inner client again. + updates.Clear(); + expansionRequests?.Clear(); + + await foreach (var update in base.GetStreamingResponseAsync(messages, modifiedOptions, cancellationToken).ConfigureAwait(false)) + { + if (update is null) + { + Throw.InvalidOperationException("Inner client returned null ChatResponseUpdate."); + } + + updates.Add(update); + + _ = CopyExpansionRequests(update.Contents, ref expansionRequests); + + yield return update; + } + + if (expansionIterationCount >= _maxExpansionsPerRequest || expansionRequests is not { Count: > 0 }) + { + // We've either hit the expansion iteration limit or no expansion function calls were made, + // so we're done streaming the response. + break; + } + + // Materialize the collected updates into a ChatResponse so the rest of the logic can share code paths + // with the non-streaming implementation. + var response = updates.ToChatResponse(); + (responseMessages ??= []).AddRange(response.Messages); + + // Prepare the history for the next iteration. + FunctionInvokingChatClient.FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); + + // Add the responses from the group expansions into the augmented history and also into the tracked + // list of response messages. + expandedGroupTools ??= []; + expandedGroupToolGroups ??= []; + (var addedMessages, expandedGroup) = await ProcessExpansionsAsync( + expansionRequests!, + topLevelToolGroupsByName, + expandedGroupTools, + expandedGroupToolGroups, + expandedGroup, + cancellationToken); + + augmentedHistory!.AddRange(addedMessages); + responseMessages.AddRange(addedMessages); + + // Surface the expansion results to the caller as additional streaming updates. + foreach (var message in addedMessages) + { + yield return FunctionInvokingChatClient.ConvertToolResultMessageToUpdate(message, response.ConversationId, toolMessageId); + } + + // If a valid group was requested for expansion, and it does not match the currently-expanded group, + // update the tools list to contain the newly-expanded tool group. + (modifiedTools ??= []).Clear(); + modifiedTools.AddRange(baseTools); + modifiedTools.AddRange(expandedGroupTools); + modifiedOptions.Tools = modifiedTools; + modifiedOptions.ConversationId = response.ConversationId; + } + } + + /// Extracts instances from the provided options. + private static List? ExtractToolGroups(ChatOptions? options) + { + if (options?.Tools is not { Count: > 0 }) + { + return null; + } + + List? groups = null; + foreach (var tool in options.Tools) + { + if (tool is AIToolGroup group) + { + (groups ??= []).Add(group); + } + } + + return groups; + } + + /// Creates a function that returns the list of available groups. + private static string CreateListGroupsResult( + AIToolGroup? expandedToolGroup, + List topLevelGroups, + List? nestedGroups) + { + var allToolGroups = nestedGroups is null + ? topLevelGroups + : topLevelGroups.Concat(nestedGroups); + + allToolGroups = allToolGroups.Where(g => g != expandedToolGroup); + + if (!allToolGroups.Any()) + { + return "No tool groups are currently available."; + } + + var sb = new StringBuilder(); + sb.Append("Available tool groups:"); + AppendAIToolList(sb, allToolGroups); + return sb.ToString(); + } + + /// Processes expansion requests and returns messages to add, termination flag, and updated group state. + private static async Task<(IList messagesToAdd, AIToolGroup? expandedGroup)> ProcessExpansionsAsync( + List expansionRequests, + Dictionary topLevelGroupsByName, + List expandedGroupTools, + List expandedGroupToolGroups, + AIToolGroup? expandedGroup, + CancellationToken cancellationToken) + { + Debug.Assert(expansionRequests.Count != 0, "Expected at least one expansion request."); + + var contents = new List(expansionRequests.Count); + + foreach (var expansionRequest in expansionRequests) + { + if (expansionRequest.Arguments is not { Count: > 0 } arguments || + !arguments.TryGetValue(ExpansionFunctionGroupNameParameter, out var groupNameArg) || + groupNameArg is null) + { + contents.Add(new FunctionResultContent( + callId: expansionRequest.CallId, + result: "No group name was specified; ignoring expansion request.")); + continue; + } + + bool TryGetValidToolGroup(string groupName, [NotNullWhen(true)] out AIToolGroup? group) + { + if (topLevelGroupsByName.TryGetValue(groupName, out group)) + { + return true; + } + + group = expandedGroupToolGroups.FirstOrDefault(g => string.Equals(g.Name, groupName, StringComparison.Ordinal)); + return group is not null; + } + + var groupName = groupNameArg.ToString(); + if (groupName is null || !TryGetValidToolGroup(groupName, out var group)) + { + contents.Add(new FunctionResultContent( + callId: expansionRequest.CallId, + result: $"The specific group name '{groupName}' was invalid; ignoring expansion request.")); + continue; + } + + if (group == expandedGroup) + { + contents.Add(new FunctionResultContent( + callId: expansionRequest.CallId, + result: $"Ignoring duplicate expansion of group '{groupName}'.")); + continue; + } + + // Expand the group + expandedGroup = group; + var groupTools = await group.GetToolsAsync(cancellationToken).ConfigureAwait(false); + + expandedGroupTools.Clear(); + expandedGroupToolGroups.Clear(); + + foreach (var tool in groupTools) + { + if (tool is AIToolGroup toolGroup) + { + expandedGroupToolGroups.Add(toolGroup); + } + else + { + expandedGroupTools.Add(tool); + } + } + + // Build success message + var sb = new StringBuilder(); + sb.Append("Successfully expanded group '"); + sb.Append(groupName); + sb.Append("'."); + + if (expandedGroupTools.Count > 0) + { + sb.Append(" Only this group's tools are now available:"); + AppendAIToolList(sb, expandedGroupTools); + } + + if (expandedGroupToolGroups.Count > 0) + { + sb.AppendLine(); + sb.Append("Additional groups available for expansion:"); + AppendAIToolList(sb, expandedGroupToolGroups); + } + + contents.Add(new FunctionResultContent( + callId: expansionRequest.CallId, + result: sb.ToString())); + } + + return (messagesToAdd: [new ChatMessage(ChatRole.Tool, contents)], expandedGroup); + } + + /// Appends a formatted list of AI tools to the specified . + private static void AppendAIToolList(StringBuilder sb, IEnumerable tools) + { + foreach (var tool in tools) + { + sb.AppendLine(); + sb.Append("- "); + sb.Append(tool.Name); + sb.Append(": "); + sb.Append(tool.Description); + } + } + + /// Copies expansion requests from messages. + private bool CopyExpansionRequests(IList messages, [NotNullWhen(true)] ref List? expansionRequests) + { + var any = false; + foreach (var message in messages) + { + any |= CopyExpansionRequests(message.Contents, ref expansionRequests); + } + + return any; + } + + /// Copies expansion requests from contents. + private bool CopyExpansionRequests( + IList contents, + [NotNullWhen(true)] ref List? expansionRequests) + { + var any = false; + foreach (var content in contents) + { + if (content is FunctionCallContent functionCall && + string.Equals(functionCall.Name, _expansionFunction.Name, StringComparison.Ordinal)) + { + (expansionRequests ??= []).Add(functionCall); + any = true; + } + } + + return any; + } + + /// + /// Generates a list of base AI tools by combining the default expansion function with additional tools specified in + /// the provided chat options, excluding any tools that are grouped. + /// + private List ComputeBaseTools(ChatOptions? options, AIFunction listGroupsFunction) + { + List baseTools = [listGroupsFunction, _expansionFunction]; + + foreach (var tool in options?.Tools ?? []) + { + if (tool is not AIToolGroup) + { + if (string.Equals(tool.Name, _expansionFunction.Name, StringComparison.Ordinal) || + string.Equals(tool.Name, listGroupsFunction.Name, StringComparison.Ordinal)) + { + throw new InvalidOperationException( + $"The group expansion tool with name '{tool.Name}' collides with a registered tool of the same name."); + } + + baseTools.Add(tool); + } + } + + return baseTools; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs new file mode 100644 index 00000000000..fe6854c2378 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14 + +namespace Microsoft.Extensions.AI; + +/// Options controlling tool grouping / expansion behavior. +[Experimental("MEAI001")] +public sealed class ToolGroupingOptions +{ + private const string DefaultExpansionFunctionName = "__expand_tool_group"; + private const string DefaultListGroupsFunctionName = "__list_tool_groups"; + + private string _expansionFunctionName = DefaultExpansionFunctionName; + private string? _expansionFunctionDescription; + private string _listGroupsFunctionName = DefaultListGroupsFunctionName; + private string? _listGroupsFunctionDescription; + private int _maxExpansionsPerRequest = 3; + + /// Gets or sets the name of the synthetic expansion function tool. + public string ExpansionFunctionName + { + get => _expansionFunctionName; + set => _expansionFunctionName = Throw.IfNull(value); + } + + /// Gets or sets the description of the synthetic expansion function tool. + public string? ExpansionFunctionDescription + { + get => _expansionFunctionDescription; + set => _expansionFunctionDescription = value; + } + + /// Gets or sets the name of the synthetic list groups function tool. + public string ListGroupsFunctionName + { + get => _listGroupsFunctionName; + set => _listGroupsFunctionName = Throw.IfNull(value); + } + + /// Gets or sets the description of the synthetic list groups function tool. + public string? ListGroupsFunctionDescription + { + get => _listGroupsFunctionDescription; + set => _listGroupsFunctionDescription = value; + } + + /// Gets or sets the maximum number of expansions allowed within a single request. + /// Defaults to 3. + public int MaxExpansionsPerRequest + { + get => _maxExpansionsPerRequest; + set => _maxExpansionsPerRequest = Throw.IfLessThan(value, 1); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs new file mode 100644 index 00000000000..5a644267328 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Extension methods for adding tool reduction middleware to a chat client pipeline. +[Experimental("MEAI001")] +public static class ChatClientBuilderToolReductionExtensions +{ + /// + /// Adds tool reduction to the chat client pipeline using the specified . + /// + /// The chat client builder. + /// The reduction strategy. + /// The original builder for chaining. + /// If or is . + /// + /// This should typically appear in the pipeline before function invocation middleware so that only the reduced tools + /// are exposed to the underlying provider. + /// + public static ChatClientBuilder UseToolReduction(this ChatClientBuilder builder, IToolReductionStrategy strategy) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(strategy); + + return builder.Use(inner => new ToolReducingChatClient(inner, strategy)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs new file mode 100644 index 00000000000..f9e4c60995a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -0,0 +1,330 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Numerics.Tensors; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14 + +/// +/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context. +/// +/// +/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current +/// conversation content each request. It then selects the top toolLimit tools by similarity. +/// +[Experimental("MEAI001")] +public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy +{ + private readonly ConditionalWeakTable> _toolEmbeddingsCache = new(); + private readonly IEmbeddingGenerator> _embeddingGenerator; + private readonly int _toolLimit; + + private Func _toolEmbeddingTextSelector = static t => + { + if (string.IsNullOrWhiteSpace(t.Name)) + { + return t.Description; + } + + if (string.IsNullOrWhiteSpace(t.Description)) + { + return t.Name; + } + + return t.Name + Environment.NewLine + t.Description; + }; + + private Func, ValueTask> _messagesEmbeddingTextSelector = static messages => + { + var sb = new StringBuilder(); + foreach (var message in messages) + { + var contents = message.Contents; + for (var i = 0; i < contents.Count; i++) + { + string text; + switch (contents[i]) + { + case TextContent content: + text = content.Text; + break; + case TextReasoningContent content: + text = content.Text; + break; + default: + continue; + } + + _ = sb.AppendLine(text); + } + } + + return new ValueTask(sb.ToString()); + }; + + private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); + + private Func _isRequiredTool = static _ => false; + + /// + /// Initializes a new instance of the class. + /// + /// Embedding generator used to produce embeddings. + /// Maximum number of tools to return, excluding required tools. Must be greater than zero. + public EmbeddingToolReductionStrategy( + IEmbeddingGenerator> embeddingGenerator, + int toolLimit) + { + _embeddingGenerator = Throw.IfNull(embeddingGenerator); + _toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0); + } + + /// + /// Gets or sets the selector used to generate a single text string from a tool. + /// + /// + /// Defaults to: Name + "\n" + Description (omitting empty parts). + /// + public Func ToolEmbeddingTextSelector + { + get => _toolEmbeddingTextSelector; + set => _toolEmbeddingTextSelector = Throw.IfNull(value); + } + + /// + /// Gets or sets the selector used to generate a single text string from a collection of chat messages for + /// embedding purposes. + /// + public Func, ValueTask> MessagesEmbeddingTextSelector + { + get => _messagesEmbeddingTextSelector; + set => _messagesEmbeddingTextSelector = Throw.IfNull(value); + } + + /// + /// Gets or sets a similarity function applied to (query, tool) embedding vectors. + /// + /// + /// Defaults to cosine similarity. + /// + public Func, ReadOnlyMemory, float> Similarity + { + get => _similarity; + set => _similarity = Throw.IfNull(value); + } + + /// + /// Gets or sets a function that determines whether a tool is required (always included). + /// + /// + /// If this returns , the tool is included regardless of ranking and does not count against + /// the configured non-required tool limit. A tool explicitly named by (when + /// is non-null) is also treated as required, independent + /// of this delegate's result. + /// + public Func IsRequiredTool + { + get => _isRequiredTool; + set => _isRequiredTool = Throw.IfNull(value); + } + + /// + /// Gets or sets a value indicating whether to preserve original ordering of selected tools. + /// If (default), tools are ordered by descending similarity. + /// If , the top-N tools by similarity are re-emitted in their original order. + /// + public bool PreserveOriginalOrdering { get; set; } + + /// + public async Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + if (options?.Tools is not { Count: > 0 } tools) + { + // Prefer the original tools list reference if possible. + // This allows ToolReducingChatClient to avoid unnecessarily copying ChatOptions. + // When no reduction is performed. + return options?.Tools ?? []; + } + + Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero."); + + if (tools.Count <= _toolLimit) + { + // Since the total number of tools doesn't exceed the configured tool limit, + // there's no need to determine which tools are optional, i.e., subject to reduction. + // We can return the original tools list early. + return tools; + } + + var toolRankingInfoArray = ArrayPool.Shared.Rent(tools.Count); + try + { + var toolRankingInfoMemory = toolRankingInfoArray.AsMemory(start: 0, length: tools.Count); + + // We allocate tool rankings in a contiguous chunk of memory, but partition them such that + // required tools come first and are immediately followed by optional tools. + // This allows us to separately rank optional tools by similarity score, but then later re-order + // the top N tools (including required tools) to preserve their original relative order. + var (requiredTools, optionalTools) = PartitionToolRankings(toolRankingInfoMemory, tools, options.ToolMode); + + if (optionalTools.Length <= _toolLimit) + { + // There aren't enough optional tools to require reduction, so we'll return the original + // tools list. + return tools; + } + + // Build query text from recent messages. + var queryText = await MessagesEmbeddingTextSelector(messages).ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(queryText)) + { + // We couldn't build a meaningful query, likely because the message list was empty. + // We'll just return the original tools list. + return tools; + } + + var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false); + + // Compute and populate similarity scores in the tool ranking info. + await ComputeSimilarityScoresAsync(optionalTools, queryEmbedding, cancellationToken); + + var topTools = toolRankingInfoMemory.Slice(start: 0, length: requiredTools.Length + _toolLimit); +#if NET + optionalTools.Span.Sort(AIToolRankingInfo.CompareByDescendingSimilarityScore); + if (PreserveOriginalOrdering) + { + topTools.Span.Sort(AIToolRankingInfo.CompareByOriginalIndex); + } +#else + Array.Sort(toolRankingInfoArray, index: requiredTools.Length, length: optionalTools.Length, AIToolRankingInfo.CompareByDescendingSimilarityScore); + if (PreserveOriginalOrdering) + { + Array.Sort(toolRankingInfoArray, index: 0, length: topTools.Length, AIToolRankingInfo.CompareByOriginalIndex); + } +#endif + return ToToolList(topTools.Span); + + static List ToToolList(ReadOnlySpan toolInfo) + { + var result = new List(capacity: toolInfo.Length); + foreach (var info in toolInfo) + { + result.Add(info.Tool); + } + + return result; + } + } + finally + { + ArrayPool.Shared.Return(toolRankingInfoArray); + } + } + + private (Memory RequiredTools, Memory OptionalTools) PartitionToolRankings( + Memory toolRankingInfo, IList tools, ChatToolMode? toolMode) + { + // Always include a tool if its name matches the required function name. + var requiredFunctionName = (toolMode as RequiredChatToolMode)?.RequiredFunctionName; + var nextRequiredToolIndex = 0; + var nextOptionalToolIndex = tools.Count - 1; + for (var i = 0; i < toolRankingInfo.Length; i++) + { + var tool = tools[i]; + var isRequiredByToolMode = requiredFunctionName is not null && string.Equals(requiredFunctionName, tool.Name, StringComparison.Ordinal); + var toolIndex = isRequiredByToolMode || IsRequiredTool(tool) + ? nextRequiredToolIndex++ + : nextOptionalToolIndex--; + toolRankingInfo.Span[toolIndex] = new AIToolRankingInfo(tool, originalIndex: i); + } + + return ( + RequiredTools: toolRankingInfo.Slice(0, nextRequiredToolIndex), + OptionalTools: toolRankingInfo.Slice(nextRequiredToolIndex)); + } + + private async Task ComputeSimilarityScoresAsync(Memory toolInfo, Embedding queryEmbedding, CancellationToken cancellationToken) + { + var anyCacheMisses = false; + List cacheMissToolEmbeddingTexts = null!; + List cacheMissToolInfoIndexes = null!; + for (var i = 0; i < toolInfo.Length; i++) + { + ref var info = ref toolInfo.Span[i]; + if (_toolEmbeddingsCache.TryGetValue(info.Tool, out var toolEmbedding)) + { + info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector); + } + else + { + if (!anyCacheMisses) + { + anyCacheMisses = true; + cacheMissToolEmbeddingTexts = []; + cacheMissToolInfoIndexes = []; + } + + var text = ToolEmbeddingTextSelector(info.Tool); + cacheMissToolEmbeddingTexts.Add(text); + cacheMissToolInfoIndexes.Add(i); + } + } + + if (!anyCacheMisses) + { + // There were no cache misses; no more work to do. + return; + } + + var uncachedEmbeddings = await _embeddingGenerator.GenerateAsync(cacheMissToolEmbeddingTexts, cancellationToken: cancellationToken).ConfigureAwait(false); + if (uncachedEmbeddings.Count != cacheMissToolEmbeddingTexts.Count) + { + throw new InvalidOperationException($"Expected {cacheMissToolEmbeddingTexts.Count} embeddings, got {uncachedEmbeddings.Count}."); + } + + for (var i = 0; i < uncachedEmbeddings.Count; i++) + { + var toolInfoIndex = cacheMissToolInfoIndexes[i]; + var toolEmbedding = uncachedEmbeddings[i]; + ref var info = ref toolInfo.Span[toolInfoIndex]; + info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector); + _toolEmbeddingsCache.Add(info.Tool, toolEmbedding); + } + } + + private struct AIToolRankingInfo(AITool tool, int originalIndex) + { + public static readonly Comparer CompareByDescendingSimilarityScore + = Comparer.Create(static (a, b) => + { + var result = b.SimilarityScore.CompareTo(a.SimilarityScore); + return result != 0 + ? result + : a.OriginalIndex.CompareTo(b.OriginalIndex); // Stabilize ties. + }); + + public static readonly Comparer CompareByOriginalIndex + = Comparer.Create(static (a, b) => a.OriginalIndex.CompareTo(b.OriginalIndex)); + + public AITool Tool { get; } = tool; + public int OriginalIndex { get; } = originalIndex; + public float SimilarityScore { get; set; } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs new file mode 100644 index 00000000000..6a5d6d925fc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that applies a tool reduction strategy before invoking the inner client. +/// +/// +/// Insert this into a pipeline (typically before function invocation middleware) to automatically +/// reduce the tool list carried on for each request. +/// +[Experimental("MEAI001")] +public sealed class ToolReducingChatClient : DelegatingChatClient +{ + private readonly IToolReductionStrategy _strategy; + + /// + /// Initializes a new instance of the class. + /// + /// The inner client. + /// The tool reduction strategy to apply. + /// Thrown if any argument is . + public ToolReducingChatClient(IChatClient innerClient, IToolReductionStrategy strategy) + : base(innerClient) + { + _strategy = Throw.IfNull(strategy); + } + + /// + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false); + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + private async Task ApplyReductionAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken) + { + // If there are no options or no tools, skip. + if (options?.Tools is not { Count: > 0 }) + { + return options; + } + + var reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false); + + // If strategy returned the same list instance (or reference equality), assume no change. + if (ReferenceEquals(reduced, options.Tools)) + { + return options; + } + + // Materialize and compare counts; if unchanged and tools have identical ordering and references, keep original. + if (reduced is not IList reducedList) + { + reducedList = reduced.ToList(); + } + + // Clone options to avoid mutating a possibly shared instance. + var cloned = options.Clone(); + cloned.Tools = reducedList; + return cloned; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 448de8d11df..06a16abc2a4 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -41,6 +41,8 @@ protected ChatClientIntegrationTests() protected IChatClient? ChatClient { get; } + protected IEmbeddingGenerator>? EmbeddingGenerator { get; private set; } + public void Dispose() { ChatClient?.Dispose(); @@ -49,6 +51,13 @@ public void Dispose() protected abstract IChatClient? CreateChatClient(); + /// + /// Optionally supplies an embedding generator for integration tests that exercise + /// embedding-based components (e.g., tool reduction). Default returns null and + /// tests depending on embeddings will skip if not overridden. + /// + protected virtual IEmbeddingGenerator>? CreateEmbeddingGenerator() => null; + [ConditionalFact] public virtual async Task GetResponseAsync_SingleRequestMessage() { @@ -1346,6 +1355,398 @@ public virtual async Task SummarizingChatReducer_CustomPrompt() Assert.Contains("5", response.Text); } + [ConditionalFact] + public virtual async Task ToolGrouping_LlmExpandsGroupAndInvokesTool_NonStreaming() + { + SkipIfNotEnabled(); + + const string TravelToken = "TRAVEL-PLAN-TOKEN-4826"; + + List> availableToolsPerInvocation = []; + string expansionFunctionName = "__expand_tool_group"; // Default expansion function name + + var generateItinerary = AIFunctionFactory.Create( + (string city) => $"{TravelToken}::{city.ToUpperInvariant()}::DAY-PLAN", + new AIFunctionFactoryOptions + { + Name = "GenerateItinerary", + Description = "Produces a detailed itinerary token. Always repeat the returned token verbatim in your summary." + }); + + var summarizePacking = AIFunctionFactory.Create( + () => "Pack light layers and comfortable shoes.", + new AIFunctionFactoryOptions + { + Name = "PackingSummary", + Description = "Provides a short packing reminder." + }); + + using var client = ChatClient! + .AsBuilder() + .UseToolGrouping() + .Use((messages, options, next, cancellationToken) => + { + if (options?.Tools is { Count: > 0 } tools) + { + availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]); + } + else + { + availableToolsPerInvocation.Add([]); + } + + return next(messages, options, cancellationToken); + }) + .UseFunctionInvocation() + .Build(); + + List messages = + [ + new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user."), + new(ChatRole.User, "Plan a two-day cultural trip to Rome and tell me the itinerary token."), + ]; + + var response = await client.GetResponseAsync(messages, new ChatOptions + { + Tools = [AIToolGroup.Create("TravelUtilities", "Travel planning helpers that generate itinerary tokens.", [generateItinerary, summarizePacking])] + }); + + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelUtilities'", StringComparison.OrdinalIgnoreCase) >= 0)); + + Assert.Contains(response.Messages, message => + message.Contents.OfType().Any(content => + string.Equals(content.Name, "GenerateItinerary", StringComparison.Ordinal))); + + var nonStreamingText = response.Text ?? string.Empty; + Assert.Contains(TravelToken.ToUpperInvariant(), nonStreamingText.ToUpperInvariant()); + + Assert.NotEmpty(availableToolsPerInvocation); + var firstInvocationTools = availableToolsPerInvocation[0]; + Assert.Contains(expansionFunctionName, firstInvocationTools); + Assert.DoesNotContain(generateItinerary.Name, firstInvocationTools); + + var expandedInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains(generateItinerary.Name)); + Assert.True(expandedInvocationIndex >= 0, "GenerateItinerary was never exposed to the model."); + Assert.True(expandedInvocationIndex > 0, "GenerateItinerary was visible before the expansion function executed."); + } + + [ConditionalFact] + public virtual async Task ToolGrouping_LlmExpandsGroupAndInvokesTool_Streaming() + { + SkipIfNotEnabled(); + + const string LodgingToken = "LODGING-TOKEN-3895"; + + List> availableToolsPerInvocation = []; + string expansionFunctionName = "__expand_tool_group"; // Default expansion function name + + var suggestLodging = AIFunctionFactory.Create( + (string city, string budget) => $"{LodgingToken}::{city.ToUpperInvariant()}::{budget}", + new AIFunctionFactoryOptions + { + Name = "SuggestLodging", + Description = "Returns hotel recommendations along with a lodging token. Repeat the token verbatim in your narrative." + }); + + using var client = ChatClient! + .AsBuilder() + .UseToolGrouping() + .Use((messages, options, next, cancellationToken) => + { + if (options?.Tools is { Count: > 0 } tools) + { + availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]); + } + else + { + availableToolsPerInvocation.Add([]); + } + + return next(messages, options, cancellationToken); + }) + .UseFunctionInvocation() + .Build(); + + List messages = + [ + new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user."), + new(ChatRole.User, "We're visiting Paris with a nightly budget of 150 USD. Stream the suggestions as they arrive and repeat the lodging token."), + ]; + + var response = await client.GetStreamingResponseAsync(messages, new ChatOptions + { + Tools = [AIToolGroup.Create("TravelUtilities", "Travel helpers used for lodging recommendations.", [suggestLodging])] + }).ToChatResponseAsync(); + + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelUtilities'", StringComparison.OrdinalIgnoreCase) >= 0)); + + Assert.Contains(response.Messages, message => + message.Contents.OfType().Any(content => + string.Equals(content.Name, "SuggestLodging", StringComparison.Ordinal))); + + var streamingText = response.Text ?? string.Empty; + Assert.Contains(LodgingToken.ToUpperInvariant(), streamingText.ToUpperInvariant()); + + Assert.NotEmpty(availableToolsPerInvocation); + var firstInvocationTools = availableToolsPerInvocation[0]; + Assert.Contains(expansionFunctionName, firstInvocationTools); + Assert.DoesNotContain(suggestLodging.Name, firstInvocationTools); + + var expandedInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains(suggestLodging.Name)); + Assert.True(expandedInvocationIndex >= 0, "SuggestLodging was never exposed to the model."); + Assert.True(expandedInvocationIndex > 0, "SuggestLodging was visible before the expansion function executed."); + } + + [ConditionalFact] + public virtual async Task ToolGrouping_NestedGroups_LlmExpandsHierarchyAndInvokesTool() + { + SkipIfNotEnabled(); + + const string BookingToken = "BOOKING-CONFIRMED-7291"; + + List> availableToolsPerInvocation = []; + string expansionFunctionName = "__expand_tool_group"; // Default expansion function name + + // Leaf-level tools in nested groups + var bookFlight = AIFunctionFactory.Create( + (string origin, string destination, string date) => $"{BookingToken}::FLIGHT::{origin}-{destination}::{date}", + new AIFunctionFactoryOptions + { + Name = "BookFlight", + Description = "Books a flight and returns a booking confirmation token. Always repeat the token verbatim." + }); + + var bookHotel = AIFunctionFactory.Create( + (string city, string checkIn) => $"{BookingToken}::HOTEL::{city}::{checkIn}", + new AIFunctionFactoryOptions + { + Name = "BookHotel", + Description = "Books a hotel and returns a booking confirmation token. Always repeat the token verbatim." + }); + + var getCurrency = AIFunctionFactory.Create( + (string country) => $"The currency in {country} is EUR.", + new AIFunctionFactoryOptions + { + Name = "GetCurrency", + Description = "Gets currency information for a country." + }); + + // Create nested group structure: TravelServices -> Booking -> FlightBooking + var flightBookingGroup = AIToolGroup.Create("FlightBooking", "Flight booking services", [bookFlight]); + var bookingGroup = AIToolGroup.Create("Booking", "All booking services including flights and hotels", [flightBookingGroup, bookHotel]); + var travelServicesGroup = AIToolGroup.Create("TravelServices", "Complete travel services including booking and information", [bookingGroup, getCurrency]); + + using var client = ChatClient! + .AsBuilder() + .UseToolGrouping(options => + { + options.MaxExpansionsPerRequest = 5; + }) + .Use((messages, options, next, cancellationToken) => + { + if (options?.Tools is { Count: > 0 } tools) + { + availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]); + } + else + { + availableToolsPerInvocation.Add([]); + } + + return next(messages, options, cancellationToken); + }) + .UseFunctionInvocation() + .Build(); + + List messages = + [ + new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user. Explore nested groups to find the right tool."), + new(ChatRole.User, "I need to book a flight from Seattle to Paris for December 15th, 2025. Please provide the booking token."), + ]; + + var response = await client.GetResponseAsync(messages, new ChatOptions + { + Tools = [travelServicesGroup] + }); + + // Verify the nested group expansions occurred + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelServices'", StringComparison.OrdinalIgnoreCase) >= 0)); + + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'Booking'", StringComparison.OrdinalIgnoreCase) >= 0)); + + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'FlightBooking'", StringComparison.OrdinalIgnoreCase) >= 0)); + + // Verify the actual tool was invoked + Assert.Contains(response.Messages, message => + message.Contents.OfType().Any(content => + string.Equals(content.Name, "BookFlight", StringComparison.Ordinal))); + + // Verify the booking token appears in the response + var responseText = response.Text ?? string.Empty; + Assert.Contains(BookingToken, responseText); + Assert.Contains("SEATTLE", responseText.ToUpperInvariant()); + Assert.Contains("PARIS", responseText.ToUpperInvariant()); + + // Verify progressive expansion: first only expansion function, then groups appear, finally leaf tools + Assert.NotEmpty(availableToolsPerInvocation); + var firstInvocationTools = availableToolsPerInvocation[0]; + Assert.Contains(expansionFunctionName, firstInvocationTools); + Assert.DoesNotContain("BookFlight", firstInvocationTools); + Assert.DoesNotContain("Booking", firstInvocationTools); + + // Verify BookFlight was not available until after all necessary expansions + var bookFlightInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains("BookFlight")); + Assert.True(bookFlightInvocationIndex >= 0, "BookFlight was never exposed to the model."); + Assert.True(bookFlightInvocationIndex > 2, "BookFlight was visible before completing the nested expansion hierarchy."); + } + + [ConditionalFact] + public virtual async Task ToolGrouping_MultipleNestedGroups_LlmSelectsCorrectPathAndInvokesTool() + { + SkipIfNotEnabled(); + + const string DiagnosticToken = "DIAGNOSTIC-REPORT-5483"; + + List> availableToolsPerInvocation = []; + string expansionFunctionName = "__expand_tool_group"; // Default expansion function name + + // Healthcare nested tools + var runBloodTest = AIFunctionFactory.Create( + (string patientId) => $"{DiagnosticToken}::BLOOD-TEST::{patientId}::COMPLETE", + new AIFunctionFactoryOptions + { + Name = "RunBloodTest", + Description = "Orders a blood test and returns a diagnostic token. Always repeat the token verbatim." + }); + + var scheduleXRay = AIFunctionFactory.Create( + (string patientId, string bodyPart) => $"X-Ray scheduled for {bodyPart}", + new AIFunctionFactoryOptions + { + Name = "ScheduleXRay", + Description = "Schedules an X-ray appointment." + }); + + // Financial nested tools + var processPayment = AIFunctionFactory.Create( + (string accountId, decimal amount) => $"Payment of ${amount} processed for account {accountId}", + new AIFunctionFactoryOptions + { + Name = "ProcessPayment", + Description = "Processes a payment transaction." + }); + + var generateInvoice = AIFunctionFactory.Create( + (string customerId) => $"Invoice generated for customer {customerId}", + new AIFunctionFactoryOptions + { + Name = "GenerateInvoice", + Description = "Generates an invoice document." + }); + + // Create two separate nested hierarchies + var diagnosticsGroup = AIToolGroup.Create("Diagnostics", "Medical diagnostic services", [runBloodTest]); + var imagingGroup = AIToolGroup.Create("Imaging", "Medical imaging services", [scheduleXRay]); + var healthcareGroup = AIToolGroup.Create("Healthcare", "All healthcare services including diagnostics and imaging", [diagnosticsGroup, imagingGroup]); + + var paymentsGroup = AIToolGroup.Create("Payments", "Payment processing services", [processPayment]); + var billingGroup = AIToolGroup.Create("Billing", "Billing and invoicing services", [generateInvoice]); + var financialGroup = AIToolGroup.Create("Financial", "All financial services including payments and billing", [paymentsGroup, billingGroup]); + + using var client = ChatClient! + .AsBuilder() + .UseToolGrouping(options => + { + options.MaxExpansionsPerRequest = 5; + }) + .Use((messages, options, next, cancellationToken) => + { + if (options?.Tools is { Count: > 0 } tools) + { + availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]); + } + else + { + availableToolsPerInvocation.Add([]); + } + + return next(messages, options, cancellationToken); + }) + .UseFunctionInvocation() + .Build(); + + List messages = + [ + new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user. Navigate through nested groups to find the appropriate tool."), + new(ChatRole.User, "Patient P-42 needs a blood test ordered. Please provide the diagnostic token."), + ]; + + var response = await client.GetResponseAsync(messages, new ChatOptions + { + Tools = [healthcareGroup, financialGroup] + }); + + // Verify the correct nested path was taken (Healthcare -> Diagnostics) + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'Healthcare'", StringComparison.OrdinalIgnoreCase) >= 0)); + + Assert.Contains(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'Diagnostics'", StringComparison.OrdinalIgnoreCase) >= 0)); + + // Verify the incorrect path was NOT taken (Financial group should not be expanded) + Assert.DoesNotContain(response.Messages, message => + message.Role == ChatRole.Tool && + message.Contents.OfType().Any(content => + content.Result?.ToString()?.IndexOf("Successfully expanded group 'Financial'", StringComparison.OrdinalIgnoreCase) >= 0)); + + // Verify the correct leaf tool was invoked + Assert.Contains(response.Messages, message => + message.Contents.OfType().Any(content => + string.Equals(content.Name, "RunBloodTest", StringComparison.Ordinal))); + + // Verify wrong tools were not invoked + Assert.DoesNotContain(response.Messages, message => + message.Contents.OfType().Any(content => + string.Equals(content.Name, "ProcessPayment", StringComparison.Ordinal) || + string.Equals(content.Name, "GenerateInvoice", StringComparison.Ordinal))); + + // Verify the diagnostic token appears in the response + var responseText = response.Text ?? string.Empty; + Assert.Contains(DiagnosticToken, responseText); + Assert.Contains("P-42", responseText); + + // Verify progressive expansion behavior + Assert.NotEmpty(availableToolsPerInvocation); + var firstInvocationTools = availableToolsPerInvocation[0]; + Assert.Contains(expansionFunctionName, firstInvocationTools); + Assert.DoesNotContain("RunBloodTest", firstInvocationTools); + + // Verify RunBloodTest only became available after the correct nested expansions + var runBloodTestInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains("RunBloodTest")); + Assert.True(runBloodTestInvocationIndex >= 0, "RunBloodTest was never exposed to the model."); + Assert.True(runBloodTestInvocationIndex > 1, "RunBloodTest was visible before completing necessary nested expansions."); + } + private sealed class TestSummarizingChatClient : IChatClient { private IChatClient _summarizerChatClient; @@ -1395,6 +1796,343 @@ public void Dispose() } } + [ConditionalFact] + public virtual async Task ToolReduction_DynamicSelection_RespectsConversationHistory() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + // Limit to 2 so that, once the conversation references both weather and translation, + // both tools can be included even if the latest user turn only mentions one of them. + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2); + + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns weather forecast and temperature for a given city." + }); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates text between human languages." + }); + + var mathTool = AIFunctionFactory.Create( + () => 42, + new AIFunctionFactoryOptions + { + Name = "SolveMath", + Description = "Solves basic math problems." + }); + + var allTools = new List { weatherTool, translateTool, mathTool }; + + IList? firstTurnTools = null; + IList? secondTurnTools = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .Use(async (messages, options, next, ct) => + { + // Capture the (possibly reduced) tool list for each turn. + if (firstTurnTools is null) + { + firstTurnTools = options?.Tools; + } + else + { + secondTurnTools ??= options?.Tools; + } + + await next(messages, options, ct); + }) + .UseFunctionInvocation() + .Build(); + + // Maintain chat history across turns. + List history = []; + + // Turn 1: Ask a weather question. + history.Add(new ChatMessage(ChatRole.User, "What will the weather be in Seattle tomorrow?")); + var firstResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(firstResponse); // Append assistant reply. + + Assert.NotNull(firstTurnTools); + Assert.Contains(firstTurnTools, t => t.Name == "GetWeatherForecast"); + + // Turn 2: Ask a translation question. Even though only translation is mentioned now, + // conversation history still contains a weather request. Expect BOTH weather + translation tools. + history.Add(new ChatMessage(ChatRole.User, "Please translate 'good evening' into French.")); + var secondResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(secondResponse); + + Assert.NotNull(secondTurnTools); + Assert.Equal(2, secondTurnTools.Count); // Should have filled both slots with the two relevant domains. + Assert.Contains(secondTurnTools, t => t.Name == "GetWeatherForecast"); + Assert.Contains(secondTurnTools, t => t.Name == "TranslateText"); + + // Ensure unrelated tool was excluded. + Assert.DoesNotContain(secondTurnTools, t => t.Name == "SolveMath"); + } + + [ConditionalFact] + public virtual async Task ToolReduction_RequireSpecificToolPreservedAndOrdered() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + // Limit would normally reduce to 1, but required tool plus another should remain. + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates phrases between languages." + }); + + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns forecast data for a city." + }); + + var tools = new List { translateTool, weatherTool }; + + IList? captured = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .UseFunctionInvocation() + .Use((messages, options, next, ct) => + { + captured = options?.Tools; + return next(messages, options, ct); + }) + .Build(); + + var history = new List + { + new(ChatRole.User, "What will the weather be like in Redmond next week?") + }; + + var response = await client.GetResponseAsync(history, new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.RequireSpecific(translateTool.Name) + }); + history.AddMessages(response); + + Assert.NotNull(captured); + Assert.Equal(2, captured!.Count); + Assert.Equal("TranslateText", captured[0].Name); // Required should appear first. + Assert.Equal("GetWeatherForecast", captured[1].Name); + } + + [ConditionalFact] + public virtual async Task ToolReduction_ToolRemovedAfterFirstUse_NotInvokedAgain() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + int weatherInvocationCount = 0; + + var weatherTool = AIFunctionFactory.Create( + () => + { + weatherInvocationCount++; + return "Sunny and dry."; + }, + new AIFunctionFactoryOptions + { + Name = "GetWeather", + Description = "Gets the weather forecast for a given location." + }); + + // Strategy exposes tools only on the first request, then removes them. + var removalStrategy = new RemoveToolAfterFirstUseStrategy(); + + IList? firstTurnTools = null; + IList? secondTurnTools = null; + + using var client = ChatClient! + .AsBuilder() + // Place capture immediately after reduction so it's invoked exactly once per user request. + .UseToolReduction(removalStrategy) + .Use((messages, options, next, ct) => + { + if (firstTurnTools is null) + { + firstTurnTools = options?.Tools; + } + else + { + secondTurnTools ??= options?.Tools; + } + + return next(messages, options, ct); + }) + .UseFunctionInvocation() + .Build(); + + List history = []; + + // Turn 1 + history.Add(new ChatMessage(ChatRole.User, "What's the weather like tomorrow in Seattle?")); + var firstResponse = await client.GetResponseAsync(history, new ChatOptions + { + Tools = [weatherTool], + ToolMode = ChatToolMode.RequireAny + }); + history.AddMessages(firstResponse); + + Assert.Equal(1, weatherInvocationCount); + Assert.NotNull(firstTurnTools); + Assert.Contains(firstTurnTools!, t => t.Name == "GetWeather"); + + // Turn 2 (tool removed by strategy even though caller supplies it again) + history.Add(new ChatMessage(ChatRole.User, "And what about next week?")); + var secondResponse = await client.GetResponseAsync(history, new ChatOptions + { + Tools = [weatherTool] + }); + history.AddMessages(secondResponse); + + Assert.Equal(1, weatherInvocationCount); // Not invoked again. + Assert.NotNull(secondTurnTools); + Assert.Empty(secondTurnTools!); // Strategy removed the tool set. + + // Response text shouldn't just echo the tool's stub output. + Assert.DoesNotContain("Sunny and dry.", secondResponse.Text, StringComparison.OrdinalIgnoreCase); + } + + [ConditionalFact] + public virtual async Task ToolReduction_MessagesEmbeddingTextSelector_UsesChatClientToAnalyzeConversation() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + // Create tools for different domains. + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns weather forecast and temperature for a given city." + }); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates text between human languages." + }); + + var mathTool = AIFunctionFactory.Create( + () => 42, + new AIFunctionFactoryOptions + { + Name = "SolveMath", + Description = "Solves basic math problems." + }); + + var allTools = new List { weatherTool, translateTool, mathTool }; + + // Track the analysis result from the chat client used in the selector. + string? capturedAnalysis = null; + + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2) + { + // Use a chat client to analyze the conversation and extract relevant tool categories. + MessagesEmbeddingTextSelector = async messages => + { + var conversationText = string.Join("\n", messages.Select(m => $"{m.Role}: {m.Text}")); + + var analysisPrompt = $""" + Analyze the following conversation and identify what kinds of tools would be most helpful. + Focus on the key topics and tasks being discussed. + Respond with a brief summary of the relevant tool categories (e.g., "weather", "translation", "math"). + + Conversation: + {conversationText} + + Relevant tool categories: + """; + + var response = await ChatClient!.GetResponseAsync(analysisPrompt); + capturedAnalysis = response.Text; + + // Return the analysis as the query text for embedding-based tool selection. + return capturedAnalysis; + } + }; + + IList? selectedTools = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .Use(async (messages, options, next, ct) => + { + selectedTools = options?.Tools; + await next(messages, options, ct); + }) + .UseFunctionInvocation() + .Build(); + + // Conversation that clearly indicates weather-related needs. + List history = []; + history.Add(new ChatMessage(ChatRole.User, "What will the weather be like in London tomorrow?")); + + var response = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(response); + + // Verify that the chat client was used to analyze the conversation. + Assert.NotNull(capturedAnalysis); + Assert.True( + capturedAnalysis.IndexOf("weather", StringComparison.OrdinalIgnoreCase) >= 0 || + capturedAnalysis.IndexOf("forecast", StringComparison.OrdinalIgnoreCase) >= 0, + $"Expected analysis to mention weather or forecast: {capturedAnalysis}"); + + // Verify that the tool selection was influenced by the analysis. + Assert.NotNull(selectedTools); + Assert.True(selectedTools.Count <= 2, $"Expected at most 2 tools, got {selectedTools.Count}"); + Assert.Contains(selectedTools, t => t.Name == "GetWeatherForecast"); + } + + // Test-only custom strategy: include tools on first request, then remove them afterward. + private sealed class RemoveToolAfterFirstUseStrategy : IToolReductionStrategy + { + private bool _used; + + public Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default) + { + if (!_used && options?.Tools is { Count: > 0 }) + { + _used = true; + // Returning the same instance signals no change. + return Task.FromResult>(options.Tools); + } + + // After first use, remove all tools. + return Task.FromResult>(Array.Empty()); + } + } + [MemberNotNull(nameof(ChatClient))] protected void SkipIfNotEnabled() { @@ -1405,4 +2143,15 @@ protected void SkipIfNotEnabled() throw new SkipTestException("Client is not enabled."); } } + + [MemberNotNull(nameof(EmbeddingGenerator))] + protected void EnsureEmbeddingGenerator() + { + EmbeddingGenerator ??= CreateEmbeddingGenerator(); + + if (EmbeddingGenerator is null) + { + throw new SkipTestException("Embedding generator is not enabled."); + } + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs new file mode 100644 index 00000000000..7352814e91c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs @@ -0,0 +1,1301 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ToolGroupingTests +{ + private const string DefaultExpansionFunctionName = "__expand_tool_group"; + private const string DefaultListGroupsFunctionName = "__list_tool_groups"; + + [Fact] + public async Task ToolGroupingChatClient_Collapsed_IncludesUtilityAndUngroupedToolsOnly() + { + var ungrouped = new SimpleTool("Basic", "basic"); + var groupedA = new SimpleTool("A1", "a1"); + var groupedB = new SimpleTool("B1", "b1"); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "hello")], + Options = new ChatOptions { Tools = [ungrouped, AIToolGroup.Create("GroupA", "Group A", [groupedA]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] }, + ConfigureToolGroupingOptions = options => { }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "Hi"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("Hi", testResult.Response.Text); + + AssertResponse(result); + AssertResponse(streamingResult); + + void AssertObservedTools(List?> observedTools) + { + var tools = Assert.Single(observedTools); + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.Contains(tools, t => t.Name == DefaultListGroupsFunctionName); + Assert.DoesNotContain(tools, t => t.Name == groupedA.Name); + Assert.DoesNotContain(tools, t => t.Name == groupedB.Name); + } + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + } + + [Fact] + public async Task ToolGroupingChatClient_ExpansionLoop_ExpandsSingleGroup() + { + var groupedA1 = new SimpleTool("A1", "a1"); + var groupedA2 = new SimpleTool("A2", "a2"); + var groupedB = new SimpleTool("B1", "b1"); + var ungrouped = new SimpleTool("Common", "c"); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [ungrouped, AIToolGroup.Create("GroupA", "Group A", [groupedA1, groupedA2]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] }, + ConfigureToolGroupingOptions = options => + { + options.MaxExpansionsPerRequest = 1; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call1", "GroupA"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "Done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("Done", testResult.Response.Text); + + AssertResponse(result); + AssertResponse(streamingResult); + + void AssertObservedTools(List?> observed) => Assert.Collection(observed, + tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.DoesNotContain(tools, t => t.Name == groupedA1.Name); + Assert.DoesNotContain(tools, t => t.Name == groupedB.Name); + }, + tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.Contains(tools, t => t.Name == groupedA1.Name); + Assert.Contains(tools, t => t.Name == groupedA2.Name); + Assert.DoesNotContain(tools, t => t.Name == groupedB.Name); + }); + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + + AssertContainsResultMessage(result.Response, "Successfully expanded group 'GroupA'"); + AssertContainsResultMessage(streamingResult.Response, "Successfully expanded group 'GroupA'"); + } + + [Fact] + public async Task ToolGroupingChatClient_NoGroups_BypassesMiddleware() + { + var tool = new SimpleTool("Standalone", "s"); + + ToolGroupingTestScenario CreateScenario(ChatOptions options, List observedOptions, List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "hello")], + Options = options, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "ok"), + AssertInvocation = ctx => + { + observedOptions.Add(ctx.Options); + observedTools.Add(ctx.Options?.Tools?.ToList()); + } + } + ] + }; + + List observedOptionsNonStreaming = []; + List?> observedToolsNonStreaming = []; + ChatOptions nonStreamingOptions = new() { Tools = [tool] }; + var result = await InvokeAndAssertAsync(CreateScenario(nonStreamingOptions, observedOptionsNonStreaming, observedToolsNonStreaming)); + + List observedOptionsStreaming = []; + List?> observedToolsStreaming = []; + ChatOptions streamingOptions = new() { Tools = [tool] }; + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(streamingOptions, observedOptionsStreaming, observedToolsStreaming)); + + void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("ok", testResult.Response.Text); + + AssertResponse(result); + AssertResponse(streamingResult); + + static void AssertObservedOptions(ChatOptions expected, List observed) => + Assert.Same(expected, Assert.Single(observed)); + + static void AssertObservedTools(List?> observed) + { + var tools = Assert.Single(observed); + Assert.NotNull(tools); + Assert.DoesNotContain(tools!, t => t.Name == DefaultExpansionFunctionName); + Assert.DoesNotContain(tools!, t => t.Name == DefaultListGroupsFunctionName); + } + + AssertObservedOptions(nonStreamingOptions, observedOptionsNonStreaming); + AssertObservedOptions(streamingOptions, observedOptionsStreaming); + + AssertObservedTools(observedToolsNonStreaming); + AssertObservedTools(observedToolsStreaming); + } + + [Fact] + public async Task ToolGroupingChatClient_InvalidGroupRequest_ReturnsResultMessage() + { + var groupedA = new SimpleTool("A1", "a1"); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = options => + { + options.MaxExpansionsPerRequest = 2; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "Unknown" })]) + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "Oops!"), + } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "was invalid; ignoring expansion request"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_MissingGroupName_ReturnsNotice() + { + var groupedA = new SimpleTool("A1", "a1"); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = options => + { + options.MaxExpansionsPerRequest = 2; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName)]) + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "Oops!"), + } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "No group name was specified"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_GroupNameReadsJsonElement() + { + var groupedA = new SimpleTool("A1", "a1"); + var jsonValue = JsonDocument.Parse("\"GroupA\"").RootElement; + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName, new Dictionary { ["groupName"] = jsonValue })]) + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done") + } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_MultipleValidExpansions_LastWins() + { + var groupedA = new SimpleTool("A1", "a1"); + var groupedB = new SimpleTool("B1", "b1"); + var alwaysOn = new SimpleTool("Common", "c"); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [alwaysOn, AIToolGroup.Create("GroupA", "Group A", [groupedA]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] }, + ConfigureToolGroupingOptions = options => + { + options.MaxExpansionsPerRequest = 2; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [ + new FunctionCallContent("call1", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" }), + new FunctionCallContent("call2", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupB" }) + ]), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedToolsNonStreaming = []; + var result = await InvokeAndAssertAsync(CreateScenario(observedToolsNonStreaming)); + + List?> observedToolsStreaming = []; + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedToolsStreaming)); + + void AssertObservedTools(List?> observed) => Assert.Collection(observed, + tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == alwaysOn.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + }, + tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == alwaysOn.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.DoesNotContain(tools, t => t.Name == groupedA.Name); + Assert.Contains(tools, t => t.Name == groupedB.Name); + }); + + AssertObservedTools(observedToolsNonStreaming); + AssertObservedTools(observedToolsStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) + { + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'"); + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupB'"); + } + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_DuplicateExpansionSameIteration_Reported() + { + var groupedA = new SimpleTool("A1", "a1"); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = options => + { + options.MaxExpansionsPerRequest = 2; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [ + new FunctionCallContent("call1", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" }), + new FunctionCallContent("call2", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" }) + ]) + }, + new DownstreamTurn { ResponseMessage = new(ChatRole.Assistant, "done") } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "Ignoring duplicate expansion"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_ReexpandingSameGroupDoesNotTerminateLoop() + { + var groupedA = new SimpleTool("A1", "a1"); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "GroupA") }, + new DownstreamTurn { ResponseMessage = CreateExpansionCall("call2", "GroupA") }, + new DownstreamTurn { ResponseMessage = new ChatMessage(ChatRole.Assistant, "Oops!") } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "Ignoring duplicate expansion of group 'GroupA'."); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_PropagatesConversationIdBetweenIterations() + { + var groupedA = new SimpleTool("A1", "a1"); + + ToolGroupingTestScenario CreateScenario(List observedConversationIds) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "go")], + Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call1", "GroupA"), + ConversationId = "conv-1" + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedConversationIds.Add(ctx.Options?.ConversationId) + } + ] + }; + + List observedNonStreaming = []; + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + + List observedStreaming = []; + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + static void AssertConversationIds(List observed) => Assert.Equal("conv-1", Assert.Single(observed)); + + AssertConversationIds(observedNonStreaming); + AssertConversationIds(observedStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("done", testResult.Response.Text); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_NestedGroups_ExpandsParentThenChild() + { + var nestedTool1 = new SimpleTool("NestedTool1", "nested tool 1"); + var nestedTool2 = new SimpleTool("NestedTool2", "nested tool 2"); + var parentTool = new SimpleTool("ParentTool", "parent tool"); + var ungrouped = new SimpleTool("Ungrouped", "ungrouped"); + + var nestedGroup = AIToolGroup.Create("NestedGroup", "Nested group", [nestedTool1, nestedTool2]); + var parentGroup = AIToolGroup.Create("ParentGroup", "Parent group", [parentTool, nestedGroup]); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "expand parent then nested")], + Options = new ChatOptions { Tools = [ungrouped, parentGroup] }, + ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 2, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call1", "ParentGroup"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call2", "NestedGroup"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertObservedTools(List?> observed) => Assert.Collection(observed, + tools => + { + // First iteration: collapsed state + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.DoesNotContain(tools, t => t.Name == parentTool.Name); + Assert.DoesNotContain(tools, t => t.Name == nestedTool1.Name); + }, + tools => + { + // Second iteration: parent expanded + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.Contains(tools, t => t.Name == parentTool.Name); + Assert.DoesNotContain(tools, t => t.Name == nestedTool1.Name); + + // NestedGroup should NOT be in tools list (only actual tools) + Assert.DoesNotContain(tools, t => t.Name == "NestedGroup"); + }, + tools => + { + // Third iteration: nested group expanded + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungrouped.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + Assert.Contains(tools, t => t.Name == nestedTool1.Name); + Assert.Contains(tools, t => t.Name == nestedTool2.Name); + Assert.DoesNotContain(tools, t => t.Name == parentTool.Name); + }); + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) + { + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'ParentGroup'"); + AssertContainsResultMessage(testResult.Response, "Additional groups available for expansion"); + AssertContainsResultMessage(testResult.Response, "- NestedGroup:"); + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'NestedGroup'"); + } + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_NestedGroups_CannotExpandNestedFromDifferentParent() + { + var toolA = new SimpleTool("ToolA", "tool a"); + var nestedATool = new SimpleTool("NestedATool", "nested a tool"); + var nestedA = AIToolGroup.Create("NestedA", "Nested A", [nestedATool]); + var groupA = AIToolGroup.Create("GroupA", "Group A", [toolA, nestedA]); + + var toolB = new SimpleTool("ToolB", "tool b"); + var nestedBTool = new SimpleTool("NestedBTool", "nested b tool"); + var nestedB = AIToolGroup.Create("NestedB", "Nested B", [nestedBTool]); + var groupB = AIToolGroup.Create("GroupB", "Group B", [toolB, nestedB]); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "try to expand wrong nested group")], + Options = new ChatOptions { Tools = [groupA, groupB] }, + ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 2, + Turns = + [ + new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "GroupA") }, + new DownstreamTurn + { + // Try to expand NestedB (belongs to unexpanded GroupB) - should fail + ResponseMessage = CreateExpansionCall("call2", "NestedB") + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "Oops!"), + } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + void AssertResponse(ToolGroupingTestResult testResult) + { + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'"); + AssertContainsResultMessage(testResult.Response, "group name 'NestedB' was invalid"); + } + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_NestedGroups_MultiLevelNesting() + { + var deeplyNestedTool = new SimpleTool("DeeplyNestedTool", "deeply nested tool"); + var deeplyNested = AIToolGroup.Create("DeeplyNested", "Deeply nested group", [deeplyNestedTool]); + + var nestedTool = new SimpleTool("NestedTool", "nested tool"); + var nested = AIToolGroup.Create("Nested", "Nested group", [nestedTool, deeplyNested]); + + var topTool = new SimpleTool("TopTool", "top tool"); + var topGroup = AIToolGroup.Create("TopGroup", "Top group", [topTool, nested]); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "three level nesting")], + Options = new ChatOptions { Tools = [topGroup] }, + ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 3, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call1", "TopGroup"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call2", "Nested"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call3", "DeeplyNested"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertObservedTools(List?> observed) => Assert.Collection(observed, + tools => + { + // Collapsed: only expansion function + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == DefaultExpansionFunctionName); + Assert.DoesNotContain(tools, t => t.Name == topTool.Name); + }, + tools => + { + // TopGroup expanded: topTool + Nested available + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == topTool.Name); + Assert.DoesNotContain(tools, t => t.Name == nestedTool.Name); + }, + tools => + { + // Nested expanded: nestedTool + DeeplyNested available + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == nestedTool.Name); + Assert.DoesNotContain(tools, t => t.Name == topTool.Name); + Assert.DoesNotContain(tools, t => t.Name == deeplyNestedTool.Name); + }, + tools => + { + // DeeplyNested expanded: only deeplyNestedTool + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == deeplyNestedTool.Name); + Assert.DoesNotContain(tools, t => t.Name == nestedTool.Name); + }); + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) + { + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'TopGroup'"); + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'Nested'"); + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'DeeplyNested'"); + } + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_ToolNameCollision_WithExpansionFunction() + { + var collisionTool = new SimpleTool(DefaultExpansionFunctionName, "collision"); + var group = AIToolGroup.Create("Group", "group", [new SimpleTool("Tool", "tool")]); + + using TestChatClient inner = new(); + using IChatClient client = inner.AsBuilder().UseToolGrouping(_ => { }).Build(); + + var options = new ChatOptions { Tools = [collisionTool, group] }; + + var exception = await Assert.ThrowsAsync(async () => + { + try + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + catch (NotSupportedException) + { + // Inner client throws NotSupportedException, but we should hit InvalidOperationException first + throw; + } + }); + + Assert.Contains(DefaultExpansionFunctionName, exception.Message); + Assert.Contains("collides", exception.Message); + } + + [Fact] + public async Task ToolGroupingChatClient_ToolNameCollision_WithListGroupsFunction() + { + var collisionTool = new SimpleTool(DefaultListGroupsFunctionName, "collision"); + var group = AIToolGroup.Create("Group", "group", [new SimpleTool("Tool", "tool")]); + + using TestChatClient inner = new(); + using IChatClient client = inner.AsBuilder().UseToolGrouping(_ => { }).Build(); + + var options = new ChatOptions { Tools = [collisionTool, group] }; + + var exception = await Assert.ThrowsAsync(async () => + { + try + { + await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options); + } + catch (NotSupportedException) + { + throw; + } + }); + + Assert.Contains(DefaultListGroupsFunctionName, exception.Message); + Assert.Contains("collides", exception.Message); + } + + [Fact] + public async Task ToolGroupingChatClient_DynamicToolGroup_GetToolsAsyncCalled() + { + var tool = new SimpleTool("DynamicTool", "dynamic tool"); + bool getToolsAsyncCalled = false; + + var dynamicGroup = new DynamicToolGroup( + "DynamicGroup", + "Dynamic group", + async ct => + { + getToolsAsyncCalled = true; + await Task.Yield(); + return [tool]; + }); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "expand dynamic group")], + Options = new ChatOptions { Tools = [dynamicGroup] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "DynamicGroup") }, + new DownstreamTurn { ResponseMessage = new(ChatRole.Assistant, "done") } + ] + }; + + var result = await InvokeAndAssertAsync(CreateScenario()); + + Assert.True(getToolsAsyncCalled, "GetToolsAsync should have been called"); + AssertContainsResultMessage(result.Response, "Successfully expanded group 'DynamicGroup'"); + + // Reset for streaming test + getToolsAsyncCalled = false; + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario()); + + Assert.True(getToolsAsyncCalled, "GetToolsAsync should have been called in streaming"); + AssertContainsResultMessage(streamingResult.Response, "Successfully expanded group 'DynamicGroup'"); + } + + [Fact] + public async Task ToolGroupingChatClient_DynamicToolGroup_ThrowsException() + { + var dynamicGroup = new DynamicToolGroup( + "FailingGroup", + "Failing group", + ct => throw new InvalidOperationException("Simulated failure")); + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "expand failing group")], + Options = new ChatOptions { Tools = [dynamicGroup] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "FailingGroup") } + ] + }; + + await Assert.ThrowsAsync(async () => + await InvokeAndAssertAsync(CreateScenario())); + + await Assert.ThrowsAsync(async () => + await InvokeAndAssertStreamingAsync(CreateScenario())); + } + + [Fact] + public async Task ToolGroupingChatClient_EmptyGroupExpansion_ReturnsNoTools() + { + var emptyGroup = AIToolGroup.Create("EmptyGroup", "Empty group", []); + var ungroupedTool = new SimpleTool("Ungrouped", "ungrouped"); + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "expand empty group")], + Options = new ChatOptions { Tools = [ungroupedTool, emptyGroup] }, + ConfigureToolGroupingOptions = _ => { }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = CreateExpansionCall("call1", "EmptyGroup"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertObservedTools(List?> observed) => Assert.Collection(observed, + tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungroupedTool.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + }, + tools => + { + // After expanding empty group, only ungrouped tool + expansion/list functions remain + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == ungroupedTool.Name); + Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName); + + // No group-specific tools should be added + Assert.Equal(3, tools.Count); // ungrouped + expansion + list + }); + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'EmptyGroup'"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_CustomExpansionFunctionName() + { + var tool = new SimpleTool("Tool", "tool"); + var group = AIToolGroup.Create("Group", "group", [tool]); + + const string CustomExpansionName = "my_custom_expand"; + + ToolGroupingTestScenario CreateScenario(List?> observedTools) => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "test custom name")], + Options = new ChatOptions { Tools = [group] }, + ConfigureToolGroupingOptions = options => + { + options.ExpansionFunctionName = CustomExpansionName; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new ChatMessage(ChatRole.Assistant, + [new FunctionCallContent("call1", CustomExpansionName, new Dictionary { ["groupName"] = "Group" })]), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + }, + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "done"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + List?> observedNonStreaming = []; + List?> observedStreaming = []; + + var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming)); + var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming)); + + void AssertObservedTools(List?> observed) + { + Assert.Equal(2, observed.Count); + + // All iterations should have custom expansion function + Assert.All(observed, tools => + { + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == CustomExpansionName); + Assert.DoesNotContain(tools, t => t.Name == DefaultExpansionFunctionName); + }); + } + + AssertObservedTools(observedNonStreaming); + AssertObservedTools(observedStreaming); + + void AssertResponse(ToolGroupingTestResult testResult) => + AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'Group'"); + + AssertResponse(result); + AssertResponse(streamingResult); + } + + [Fact] + public async Task ToolGroupingChatClient_CustomExpansionFunctionDescription() + { + var tool = new SimpleTool("Tool", "tool"); + var group = AIToolGroup.Create("Group", "group", [tool]); + + const string CustomDescription = "Use this custom function to expand a tool group"; + + List?> observedTools = []; + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "test custom description")], + Options = new ChatOptions { Tools = [group] }, + ConfigureToolGroupingOptions = options => + { + options.ExpansionFunctionDescription = CustomDescription; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "ok"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + await InvokeAndAssertAsync(CreateScenario()); + + var tools = Assert.Single(observedTools); + Assert.NotNull(tools); + var expansionTool = tools!.FirstOrDefault(t => t.Name == DefaultExpansionFunctionName); + Assert.NotNull(expansionTool); + Assert.Equal(CustomDescription, expansionTool!.Description); + } + + [Fact] + public async Task ToolGroupingChatClient_CustomListGroupsFunctionName() + { + var tool = new SimpleTool("Tool", "tool"); + var group = AIToolGroup.Create("Group", "group", [tool]); + + const string CustomListName = "my_list_groups"; + + List?> observedTools = []; + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "test custom list name")], + Options = new ChatOptions { Tools = [group] }, + ConfigureToolGroupingOptions = options => + { + options.ListGroupsFunctionName = CustomListName; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "ok"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + await InvokeAndAssertAsync(CreateScenario()); + + var tools = Assert.Single(observedTools); + Assert.NotNull(tools); + Assert.Contains(tools!, t => t.Name == CustomListName); + Assert.DoesNotContain(tools, t => t.Name == DefaultListGroupsFunctionName); + } + + [Fact] + public async Task ToolGroupingChatClient_CustomListGroupsFunctionDescription() + { + var tool = new SimpleTool("Tool", "tool"); + var group = AIToolGroup.Create("Group", "group", [tool]); + + const string CustomDescription = "Custom description for listing groups"; + + List?> observedTools = []; + + ToolGroupingTestScenario CreateScenario() => new() + { + InitialMessages = [new ChatMessage(ChatRole.User, "test custom list description")], + Options = new ChatOptions { Tools = [group] }, + ConfigureToolGroupingOptions = options => + { + options.ListGroupsFunctionDescription = CustomDescription; + }, + Turns = + [ + new DownstreamTurn + { + ResponseMessage = new(ChatRole.Assistant, "ok"), + AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()), + } + ] + }; + + await InvokeAndAssertAsync(CreateScenario()); + + var tools = Assert.Single(observedTools); + Assert.NotNull(tools); + var listTool = tools!.FirstOrDefault(t => t.Name == DefaultListGroupsFunctionName); + Assert.NotNull(listTool); + Assert.Equal(CustomDescription, listTool!.Description); + } + + private sealed class DynamicToolGroup : AIToolGroup + { + private readonly Func>> _getToolsFunc; + + public DynamicToolGroup(string name, string description, Func>> getToolsFunc) + : base(name, description) + { + _getToolsFunc = getToolsFunc; + } + + public override async ValueTask> GetToolsAsync(CancellationToken cancellationToken = default) + { + var tools = await _getToolsFunc(cancellationToken); + return tools.ToList(); + } + } + + private static async Task InvokeAndAssertAsync(ToolGroupingTestScenario scenario) + { + if (scenario.InitialMessages.Count == 0) + { + throw new InvalidOperationException("Scenario must include at least one initial message."); + } + + List turns = scenario.Turns; + long expectedTotalTokenCounts = 0; + int iteration = 0; + + using TestChatClient inner = new(); + + inner.GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materialized = messages.ToList(); + if (iteration >= turns.Count) + { + throw new InvalidOperationException("Unexpected additional iteration."); + } + + var turn = turns[iteration]; + turn.AssertInvocation?.Invoke(new ToolGroupingInvocationContext(iteration, materialized, options)); + + UsageDetails usage = CreateRandomUsage(); + expectedTotalTokenCounts += usage.InputTokenCount!.Value; + + var response = new ChatResponse(turn.ResponseMessage) + { + Usage = usage, + ConversationId = turn.ConversationId, + }; + iteration++; + return Task.FromResult(response); + }; + + using IChatClient client = inner.AsBuilder().UseToolGrouping(scenario.ConfigureToolGroupingOptions).Build(); + + var request = new EnumeratedOnceEnumerable(scenario.InitialMessages); + ChatResponse response = await client.GetResponseAsync(request, scenario.Options, CancellationToken.None); + + Assert.Equal(turns.Count, iteration); + + // Usage should be aggregated over all responses, including AdditionalUsage + var actualUsage = response.Usage!; + Assert.Equal(expectedTotalTokenCounts, actualUsage.InputTokenCount); + Assert.Equal(expectedTotalTokenCounts, actualUsage.OutputTokenCount); + Assert.Equal(expectedTotalTokenCounts, actualUsage.TotalTokenCount); + Assert.Equal(2, actualUsage.AdditionalCounts!.Count); + Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["firstValue"]); + Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["secondValue"]); + + return new ToolGroupingTestResult(response); + } + + private static async Task InvokeAndAssertStreamingAsync(ToolGroupingTestScenario scenario) + { + if (scenario.InitialMessages.Count == 0) + { + throw new InvalidOperationException("Scenario must include at least one initial message."); + } + + List turns = scenario.Turns; + int iteration = 0; + + using TestChatClient inner = new(); + + inner.GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => + { + var materialized = messages.ToList(); + if (iteration >= turns.Count) + { + throw new InvalidOperationException("Unexpected additional iteration."); + } + + var turn = turns[iteration]; + turn.AssertInvocation?.Invoke(new ToolGroupingInvocationContext(iteration, materialized, options)); + + var response = new ChatResponse(turn.ResponseMessage) + { + ConversationId = turn.ConversationId, + }; + iteration++; + return YieldAsync(response.ToChatResponseUpdates()); + }; + + using IChatClient client = inner.AsBuilder().UseToolGrouping(scenario.ConfigureToolGroupingOptions).Build(); + + var request = new EnumeratedOnceEnumerable(scenario.InitialMessages); + ChatResponse response = await client.GetStreamingResponseAsync(request, scenario.Options, CancellationToken.None).ToChatResponseAsync(); + + Assert.Equal(turns.Count, iteration); + + return new ToolGroupingTestResult(response); + } + + private static UsageDetails CreateRandomUsage() + { + // We'll set the same random number on all the properties so that, when determining the + // correct sum in tests, we only have to total the values once + var value = new Random().Next(100); + return new UsageDetails + { + InputTokenCount = value, + OutputTokenCount = value, + TotalTokenCount = value, + AdditionalCounts = new() { ["firstValue"] = value, ["secondValue"] = value }, + }; + } + + private static ChatMessage CreateExpansionCall(string callId, string groupName) => + new(ChatRole.Assistant, [new FunctionCallContent(callId, DefaultExpansionFunctionName, new Dictionary { ["groupName"] = groupName })]); + + private static void AssertContainsResultMessage(ChatResponse response, string substring) + { + var toolMessages = response.Messages.Where(m => m.Role == ChatRole.Tool).ToList(); + Assert.NotEmpty(toolMessages); + Assert.Contains(toolMessages.SelectMany(m => m.Contents.OfType()), r => + { + var text = r.Result?.ToString() ?? string.Empty; + return text.Contains(substring); + }); + } + + private static async IAsyncEnumerable YieldAsync(IEnumerable updates) + { + foreach (var update in updates) + { + yield return update; + await Task.Yield(); + } + } + + private sealed class SimpleTool : AITool + { + private readonly string _name; + private readonly string _description; + + public SimpleTool(string name, string description) + { + _name = name; + _description = description; + } + + public override string Name => _name; + public override string Description => _description; + } + + private sealed class TestChatClient : IChatClient + { + public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + if (GetResponseAsyncCallback is null) + { + throw new NotSupportedException(); + } + + return GetResponseAsyncCallback(messages, options, cancellationToken); + } + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + if (GetStreamingResponseAsyncCallback is null) + { + throw new NotSupportedException(); + } + + return GetStreamingResponseAsyncCallback(messages, options, cancellationToken); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() + { + // No-op + } + } + + private sealed class ToolGroupingTestScenario + { + public List InitialMessages { get; init; } = []; + public Action ConfigureToolGroupingOptions { get; init; } = _ => { }; + public List Turns { get; init; } = []; + public ChatOptions? Options { get; init; } + } + + private sealed class DownstreamTurn + { + public ChatMessage ResponseMessage { get; init; } = new(ChatRole.Assistant, string.Empty); + public string? ConversationId { get; init; } + public Action? AssertInvocation { get; init; } + } + + private sealed record ToolGroupingInvocationContext(int Iteration, IReadOnlyList Messages, ChatOptions? Options); + + private sealed record ToolGroupingTestResult(ChatResponse Response); + + private sealed class EnumeratedOnceEnumerable : IEnumerable + { + private readonly IEnumerable _items; + private bool _enumerated; + + public EnumeratedOnceEnumerable(IEnumerable items) + { + _items = items; + } + + public IEnumerator GetEnumerator() + { + if (_enumerated) + { + throw new InvalidOperationException("Sequence may only be enumerated once."); + } + + _enumerated = true; + return _items.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs new file mode 100644 index 00000000000..96c9adc6311 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -0,0 +1,663 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ToolReductionTests +{ + [Fact] + public void EmbeddingToolReductionStrategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0)); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_NoReduction_WhenToolsBelowLimit() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 5); + + var tools = CreateTools("Weather", "Math"); + var options = new ChatOptions { Tools = tools }; + + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Tell me about weather") }, + options); + + Assert.Same(tools, result); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_NoReduction_WhenOptionalToolsBelowLimit() + { + // 1 required + 2 optional, limit = 2 (optional count == limit) => original list returned + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2) + { + IsRequiredTool = t => t.Name == "Req" + }; + + var tools = CreateTools("Req", "Opt1", "Opt2"); + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "anything") }, + new ChatOptions { Tools = tools }); + + Assert.Same(tools, result); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools("Weather", "Translate", "Math", "Jokes"); + var options = new ChatOptions { Tools = tools }; + + var messages = new[] + { + new ChatMessage(ChatRole.User, "Can you do some weather math for forecasting?") + }; + + var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Contains(reduced, t => t.Name == "Weather"); + Assert.Contains(reduced, t => t.Name == "Math"); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_PreserveOriginalOrdering_ReordersAfterSelection() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2) + { + PreserveOriginalOrdering = true + }; + + var tools = CreateTools("Math", "Translate", "Weather"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Explain weather math please") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Equal("Math", reduced[0].Name); + Assert.Equal("Weather", reduced[1].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("Weather", "Math", "Jokes"); + var messages = new[] { new ChatMessage(ChatRole.User, "weather") }; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + int afterFirst = gen.TotalValueInputs; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + int afterSecond = gen.TotalValueInputs; + + // +1 for second query embedding only + Assert.Equal(afterFirst + 1, afterSecond); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_OptionsNullOrNoTools_ReturnsEmptyOrOriginal() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var empty = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "anything") }, null); + Assert.Empty(empty); + + var options = new ChatOptions { Tools = [] }; + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather") }, options); + Assert.Same(options.Tools, result); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrdering() + { + using var gen = new VectorBasedTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + Similarity = (q, t) => -t.Span[0] + }; + + var highTool = new SimpleTool("HighScore", "alpha"); + var lowTool = new SimpleTool("LowScore", "beta"); + gen.VectorSelector = text => text.Contains("alpha") ? 10f : 1f; + + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Pick something") }, + new ChatOptions { Tools = [highTool, lowTool] })).ToList(); + + Assert.Single(reduced); + Assert.Equal("LowScore", reduced[0].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_TieDeterminism_PrefersLowerOriginalIndex() + { + // Generator returns identical vectors so similarity ties; we expect original order preserved + using var gen = new ConstantEmbeddingGenerator(3); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools("T1", "T2", "T3", "T4"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "any") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Equal("T1", reduced[0].Name); + Assert.Equal("T2", reduced[1].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyDescription_UsesNameOnly() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + + var target = new SimpleTool("ComputeSum", description: ""); + var filler = new SimpleTool("Other", "Unrelated"); + _ = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = [target, filler] }); + + Assert.Contains("ComputeSum", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyName_UsesDescriptionOnly() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + + var target = new SimpleTool("", description: "Translates between languages."); + var filler = new SimpleTool("Other", "Unrelated"); + _ = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "translate") }, + new ChatOptions { Tools = [target, filler] }); + + Assert.Contains("Translates between languages.", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextSelector_Applied() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1) + { + ToolEmbeddingTextSelector = t => $"NAME:{t.Name}|DESC:{t.Description}" + }; + + var target = new SimpleTool("WeatherTool", "Gets forecast."); + var filler = new SimpleTool("Other", "Irrelevant"); + _ = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather") }, + new ChatOptions { Tools = [target, filler] }); + + Assert.Contains("NAME:WeatherTool|DESC:Gets forecast.", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_CustomFiltersMessages() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("Weather", "Math", "Translate"); + + var messages = new[] + { + new ChatMessage(ChatRole.User, "Please tell me the weather tomorrow."), + new ChatMessage(ChatRole.Assistant, "Sure, I can help."), + new ChatMessage(ChatRole.User, "Now instead solve a math problem.") + }; + + strategy.MessagesEmbeddingTextSelector = msgs => new ValueTask(msgs.LastOrDefault()?.Text ?? string.Empty); + + var reduced = (await strategy.SelectToolsForRequestAsync( + messages, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Single(reduced); + Assert.Equal("Math", reduced[0].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_InvokedOnce() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("Weather", "Math"); + int invocationCount = 0; + + strategy.MessagesEmbeddingTextSelector = msgs => + { + invocationCount++; + return new ValueTask(string.Join("\n", msgs.Select(m => m.Text))); + }; + + _ = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather and math") }, + new ChatOptions { Tools = tools }); + + Assert.Equal(1, invocationCount); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_IncludesReasoningContent() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + var tools = CreateTools("Weather", "Math"); + + var reasoningLine = "Thinking about the best way to get tomorrow's forecast..."; + var answerLine = "Tomorrow will be sunny."; + var userLine = "What's the weather tomorrow?"; + + var messages = new[] + { + new ChatMessage(ChatRole.User, userLine), + new ChatMessage(ChatRole.Assistant, + [ + new TextReasoningContent(reasoningLine), + new TextContent(answerLine) + ]) + }; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + string queryInput = recorder.Inputs[0]; + + Assert.Contains(userLine, queryInput); + Assert.Contains(reasoningLine, queryInput); + Assert.Contains(answerLine, queryInput); + + var userIndex = queryInput.IndexOf(userLine, StringComparison.Ordinal); + var reasoningIndex = queryInput.IndexOf(reasoningLine, StringComparison.Ordinal); + var answerIndex = queryInput.IndexOf(answerLine, StringComparison.Ordinal); + Assert.True(userIndex >= 0 && reasoningIndex > userIndex && answerIndex > reasoningIndex); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_SkipsNonTextContent() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + var tools = CreateTools("Alpha", "Beta"); + + var textOnly = "Provide translation."; + var messages = new[] + { + new ChatMessage(ChatRole.User, + [ + new DataContent(new byte[] { 1, 2, 3 }, "application/octet-stream"), + new TextContent(textOnly) + ]) + }; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + var queryInput = recorder.Inputs[0]; + Assert.Contains(textOnly, queryInput); + Assert.DoesNotContain("application/octet-stream", queryInput, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_RequiredToolAlwaysIncluded() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name == "Core" + }; + + var tools = CreateTools("Core", "Weather", "Math"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); // required + one optional (limit=1) + Assert.Contains(reduced, t => t.Name == "Core"); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MultipleRequiredTools_ExceedLimit_AllRequiredIncluded() + { + // 3 required, limit=1 => expect 3 required + 1 ranked optional = 4 total + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name.StartsWith("R", StringComparison.Ordinal) + }; + + var tools = CreateTools("R1", "R2", "R3", "Weather", "Math"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather math") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(4, reduced.Count); + Assert.Equal(3, reduced.Count(t => t.Name.StartsWith("R"))); + } + + [Fact] + public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + var tools = CreateTools("Weather", "Math", "Translate", "Jokes"); + + IList? observedTools = null; + + using var inner = new TestChatClient + { + GetResponseAsyncCallback = (messages, options, ct) => + { + observedTools = options?.Tools; + return Task.FromResult(new ChatResponse()); + } + }; + + using var client = inner.AsBuilder().UseToolReduction(strategy).Build(); + + await client.GetResponseAsync( + new[] { new ChatMessage(ChatRole.User, "weather math please") }, + new ChatOptions { Tools = tools }); + + Assert.NotNull(observedTools); + Assert.Equal(2, observedTools!.Count); + Assert.Contains(observedTools, t => t.Name == "Weather"); + Assert.Contains(observedTools, t => t.Name == "Math"); + } + + [Fact] + public async Task ToolReducingChatClient_ReducesTools_ForStreaming() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + var tools = CreateTools("Weather", "Math"); + + IList? observedTools = null; + + using var inner = new TestChatClient + { + GetStreamingResponseAsyncCallback = (messages, options, ct) => + { + observedTools = options?.Tools; + return EmptyAsyncEnumerable(); + } + }; + + using var client = inner.AsBuilder().UseToolReduction(strategy).Build(); + + await foreach (var _ in client.GetStreamingResponseAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = tools })) + { + // Consume + } + + Assert.NotNull(observedTools); + Assert.Single(observedTools!); + Assert.Equal("Math", observedTools![0].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction() + { + // Arrange: more tools than limit so we'd normally reduce, but query is empty -> return full list unchanged. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("ToolA", "ToolB", "ToolC"); + var options = new ChatOptions { Tools = tools }; + + // Empty / whitespace message text produces empty query. + var messages = new[] { new ChatMessage(ChatRole.User, " ") }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, options); + + // Assert: same reference (no reduction), and generator not invoked at all. + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction_WithRequiredTool() + { + // Arrange: required tool + optional tools; still should return original set when query is empty. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name == "Req" + }; + + var tools = CreateTools("Req", "Optional1", "Optional2"); + var options = new ChatOptions { Tools = tools }; + + var messages = new[] { new ChatMessage(ChatRole.User, " ") }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, options); + + // Assert + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_ViaCustomMessagesSelector_NoReduction() + { + // Arrange: force empty query through custom selector returning whitespace. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + MessagesEmbeddingTextSelector = _ => new ValueTask(" ") + }; + + var tools = CreateTools("One", "Two"); + var messages = new[] + { + new ChatMessage(ChatRole.User, "This content will be ignored by custom selector.") + }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + // Assert: no reduction and no embeddings generated. + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + + private static List CreateTools(params string[] names) => + names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList(); + +#pragma warning disable CS1998 + private static async IAsyncEnumerable EmptyAsyncEnumerable() + { + yield break; + } +#pragma warning restore CS1998 + + private sealed class SimpleTool : AITool + { + private readonly string _name; + private readonly string _description; + + public SimpleTool(string name, string description) + { + _name = name; + _description = description; + } + + public override string Name => _name; + public override string Description => _description; + } + + /// + /// Deterministic embedding generator producing sparse keyword indicator vectors. + /// Each dimension corresponds to a known keyword. Cosine similarity then reflects + /// pure keyword overlap (non-overlapping keywords contribute nothing), avoiding + /// false ties for tools unrelated to the query. + /// + private sealed class DeterministicTestEmbeddingGenerator : IEmbeddingGenerator> + { + private static readonly string[] _keywords = + [ + "weather","forecast","temperature","math","calculate","sum","translate","language","joke" + ]; + + // +1 bias dimension (last) to avoid zero magnitude vectors when no keywords present. + private static int VectorLength => _keywords.Length + 1; + + public int TotalValueInputs { get; private set; } + + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var list = new List>(); + + foreach (var v in values) + { + TotalValueInputs++; + var vec = new float[VectorLength]; + if (!string.IsNullOrWhiteSpace(v)) + { + var lower = v.ToLowerInvariant(); + for (int i = 0; i < _keywords.Length; i++) + { + if (lower.Contains(_keywords[i])) + { + vec[i] = 1f; + } + } + } + + vec[^1] = 1f; // bias + list.Add(new Embedding(vec)); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() + { + // No-op + } + } + + private sealed class RecordingEmbeddingGenerator : IEmbeddingGenerator> + { + public List Inputs { get; } = new(); + + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var v in values) + { + Inputs.Add(v); + + // Basic 2-dim vector (length encodes a bit of variability) + list.Add(new Embedding(new float[] { v.Length, 1f })); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } + + private sealed class VectorBasedTestEmbeddingGenerator : IEmbeddingGenerator> + { + public Func VectorSelector { get; set; } = _ => 1f; + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var v in values) + { + list.Add(new Embedding(new float[] { VectorSelector(v), 1f })); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } + + private sealed class ConstantEmbeddingGenerator : IEmbeddingGenerator> + { + private readonly float[] _vector; + public ConstantEmbeddingGenerator(int dims) + { + _vector = Enumerable.Repeat(1f, dims).ToArray(); + } + + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var _ in values) + { + list.Add(new Embedding(_vector)); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } + + private sealed class TestChatClient : IChatClient + { + public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + (GetResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken); + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + (GetStreamingResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs index 6322e3d6b64..a9e08a58e52 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs @@ -8,4 +8,8 @@ public class OpenAIChatClientIntegrationTests : ChatClientIntegrationTests protected override IChatClient? CreateChatClient() => IntegrationTestHelpers.GetOpenAIClient() ?.GetChatClient(TestRunnerConfiguration.Instance["OpenAI:ChatModel"] ?? "gpt-4o-mini").AsIChatClient(); + + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOpenAIClient() + ?.GetEmbeddingClient(TestRunnerConfiguration.Instance["OpenAI:EmbeddingModel"] ?? "text-embedding-3-small").AsIEmbeddingGenerator(); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 2308a921ab3..e6aef075770 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -1232,6 +1232,49 @@ public async Task ClonesChatOptionsAndResetContinuationTokenForBackgroundRespons Assert.Null(actualChatOptions!.ContinuationToken); } + [Fact] + public async Task ToolGroups_GetExpandedAutomatically() + { + var innerGroup = AIToolGroup.Create( + "InnerGroup", + "Inner group of tools", + new List + { + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + }); + + var outerGroup = AIToolGroup.Create( + "OuterGroup", + "Outer group of tools", + new List + { + AIFunctionFactory.Create(() => "Result 1", "Func1"), + innerGroup, + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + }); + + ChatOptions options = new() + { + Tools = [outerGroup] + }; + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + await InvokeAndAssertAsync(options, plan); + + await InvokeAndAssertStreamingAsync(options, plan); + } + private sealed class CustomSynchronizationContext : SynchronizationContext { public override void Post(SendOrPostCallback d, object? state)