diff --git a/eng/packages/General.props b/eng/packages/General.props
index b66f1a4ffa8..441a30afa73 100644
--- a/eng/packages/General.props
+++ b/eng/packages/General.props
@@ -29,6 +29,7 @@
+
diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props
index 5875491f919..603e31c0e5b 100644
--- a/eng/packages/TestOnly.props
+++ b/eng/packages/TestOnly.props
@@ -21,7 +21,6 @@
-
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs
new file mode 100644
index 00000000000..029eeae47a1
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs
@@ -0,0 +1,41 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// Represents a strategy capable of selecting a reduced set of tools for a chat request.
+///
+///
+/// A tool reduction strategy is invoked prior to sending a request to an underlying ,
+/// enabling scenarios where a large tool catalog must be trimmed to fit provider limits or to improve model
+/// tool selection quality.
+///
+/// The implementation should return a non- enumerable. Returning the original
+/// instance indicates no change. Returning a different enumerable indicates
+/// the caller may replace the existing tool list.
+///
+///
+[Experimental("MEAI001")]
+public interface IToolReductionStrategy
+{
+ ///
+ /// Selects the tools that should be included for a specific request.
+ ///
+ /// The chat messages for the request. This is an to avoid premature materialization.
+ /// The chat options for the request (may be ).
+ /// A token to observe cancellation.
+ ///
+ /// A (possibly reduced) enumerable of instances. Must never be .
+ /// Returning the same instance referenced by . signals no change.
+ ///
+ Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default);
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
index 554918b0a8e..8217ea49da5 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
@@ -6,6 +6,8 @@
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
+#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14
+
namespace Microsoft.Extensions.AI;
/// Provides context for an in-flight function invocation.
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
index fb821c984df..c1cff1ab554 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
@@ -284,7 +284,7 @@ public override async Task GetResponseAsync(
bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set
int consecutiveErrorCount = 0;
- (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name
+ (Dictionary? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name
if (HasAnyApprovalContent(originalMessages))
{
@@ -424,7 +424,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
List updates = []; // updates from the current response
int consecutiveErrorCount = 0;
- (Dictionary? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name
+ (Dictionary? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name
// This is a synthetic ID since we're generating the tool messages instead of getting them from
// the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to
@@ -624,7 +624,13 @@ public override async IAsyncEnumerable GetStreamingResponseA
AddUsageTags(activity, totalUsage);
}
- private static ChatResponseUpdate ConvertToolResultMessageToUpdate(ChatMessage message, string? conversationId, string? messageId) =>
+ ///
+ /// Converts a tool result into a for streaming scenarios.
+ ///
+ /// The tool result message.
+ /// The conversation ID.
+ /// The message ID.
+ internal static ChatResponseUpdate ConvertToolResultMessageToUpdate(ChatMessage message, string? conversationId, string? messageId) =>
new()
{
AdditionalProperties = message.AdditionalProperties,
@@ -662,7 +668,7 @@ private static void AddUsageTags(Activity? activity, UsageDetails? usage)
/// The most recent response being handled.
/// A list of all response messages received up until this point.
/// Whether the previous iteration's response had a conversation ID.
- private static void FixupHistories(
+ internal static void FixupHistories(
IEnumerable originalMessages,
ref IEnumerable messages,
[NotNull] ref List? augmentedHistory,
@@ -722,26 +728,51 @@ private static void FixupHistories(
/// The lists of tools to combine into a single dictionary. Tools from later lists are preferred
/// over tools from earlier lists if they have the same name.
///
- private static (Dictionary? ToolMap, bool AnyRequireApproval) CreateToolsMap(params ReadOnlySpan?> toolLists)
+ /// The to monitor for cancellation requests.
+ private static async ValueTask<(Dictionary? ToolMap, bool AnyRequireApproval)> CreateToolsMapAsync(IList?[] toolLists, CancellationToken cancellationToken)
{
Dictionary? map = null;
bool anyRequireApproval = false;
foreach (var toolList in toolLists)
{
- if (toolList?.Count is int count && count > 0)
+ if (toolList is not null)
+ {
+ map ??= [];
+ var anyInListRequireApproval = await AddToolListAsync(map, toolList, cancellationToken).ConfigureAwait(false);
+ anyRequireApproval |= anyInListRequireApproval;
+ }
+ }
+
+ return (map, anyRequireApproval);
+
+ static async ValueTask AddToolListAsync(Dictionary map, IEnumerable tools, CancellationToken cancellationToken)
+ {
+#if NET
+ if (tools.TryGetNonEnumeratedCount(out var count) && count == 0)
{
- map ??= new(StringComparer.Ordinal);
- for (int i = 0; i < count; i++)
+ return false;
+ }
+#endif
+ var anyRequireApproval = false;
+
+ foreach (var tool in tools)
+ {
+ if (tool is AIToolGroup toolGroup)
+ {
+ var nestedTools = await toolGroup.GetToolsAsync(cancellationToken).ConfigureAwait(false);
+ var nestedToolsRequireApproval = await AddToolListAsync(map, nestedTools, cancellationToken).ConfigureAwait(false);
+ anyRequireApproval |= nestedToolsRequireApproval;
+ }
+ else
{
- AITool tool = toolList[i];
anyRequireApproval |= tool.GetService() is not null;
map[tool.Name] = tool;
}
}
- }
- return (map, anyRequireApproval);
+ return anyRequireApproval;
+ }
}
///
diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
index 36e6bb00562..54cbcc99754 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
@@ -44,6 +44,7 @@
+
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs
new file mode 100644
index 00000000000..1706861b5a1
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/AIToolGroup.cs
@@ -0,0 +1,90 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// Represents a logical grouping of tools that can be dynamically expanded.
+///
+///
+///
+/// A is an that supplies an ordered list of instances
+/// via the method. This enables grouping tools together for organizational purposes
+/// and allows for dynamic tool selection based on context.
+///
+///
+/// Tool groups can be used independently or in conjunction with to implement
+/// hierarchical tool selection, where groups are initially collapsed and can be expanded on demand.
+///
+///
+[Experimental("MEAI001")]
+public abstract class AIToolGroup : AITool
+{
+ private readonly string _name;
+ private readonly string _description;
+
+ /// Initializes a new instance of the class.
+ /// Group name (identifier used by the expansion function).
+ /// Human readable description of the group.
+ /// is .
+ protected AIToolGroup(string name, string description)
+ {
+ _name = Throw.IfNull(name);
+ _description = Throw.IfNull(description);
+ }
+
+ /// Gets the group name.
+ public override string Name => _name;
+
+ /// Gets the group description.
+ public override string Description => _description;
+
+ /// Creates a tool group with a static list of tools.
+ /// Group name (identifier used by the expansion function).
+ /// Human readable description of the group.
+ /// Ordered tools contained in the group.
+ /// An instance containing the specified tools.
+ /// or is .
+ public static AIToolGroup Create(string name, string description, IReadOnlyList tools)
+ {
+ _ = Throw.IfNull(name);
+ _ = Throw.IfNull(tools);
+ return new StaticAIToolGroup(name, description, tools);
+ }
+
+ ///
+ /// Asynchronously retrieves the ordered list of tools belonging to this group.
+ ///
+ /// The to monitor for cancellation requests.
+ /// A representing the asynchronous operation, containing the ordered list of tools in the group.
+ ///
+ /// The returned list may contain other instances, enabling hierarchical tool organization.
+ /// Implementations should ensure the returned list is stable and deterministic for a given group instance.
+ ///
+ public abstract ValueTask> GetToolsAsync(CancellationToken cancellationToken = default);
+
+ /// A tool group implementation that returns a static list of tools.
+ private sealed class StaticAIToolGroup : AIToolGroup
+ {
+ private readonly IReadOnlyList _tools;
+
+ public StaticAIToolGroup(string name, string description, IReadOnlyList tools)
+ : base(name, description)
+ {
+ _tools = tools;
+ }
+
+ public override ValueTask> GetToolsAsync(CancellationToken cancellationToken = default)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ return new ValueTask>(_tools);
+ }
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs
new file mode 100644
index 00000000000..6c79a30cbba
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ChatClientBuilderToolGroupingExtensions.cs
@@ -0,0 +1,26 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+/// Builder extensions for .
+[Experimental("MEAI001")]
+public static class ChatClientBuilderToolGroupingExtensions
+{
+ /// Adds tool grouping middleware to the pipeline.
+ /// Chat client builder.
+ /// Configuration delegate.
+ /// The builder for chaining.
+ /// Should appear before tool reduction and function invocation middleware.
+ public static ChatClientBuilder UseToolGrouping(this ChatClientBuilder builder, Action? configure = null)
+ {
+ _ = Throw.IfNull(builder);
+ var options = new ToolGroupingOptions();
+ configure?.Invoke(options);
+ return builder.Use(inner => new ToolGroupingChatClient(inner, options));
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs
new file mode 100644
index 00000000000..684528f795d
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingChatClient.cs
@@ -0,0 +1,516 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+#pragma warning disable S103 // Lines should not be too long
+#pragma warning disable IDE0058 // Expression value is never used
+
+///
+/// A chat client that enables tool groups (see ) to be dynamically expanded.
+///
+///
+///
+/// On each request, this chat client initially presents a minimal tool surface consisting of: (a) a function
+/// returning the current list of available groups plus (b) a synthetic expansion function plus (c) tools in
+/// that are not instances.
+/// If the model calls the expansion function with a valid group name, the
+/// client issues another request with that group's tools visible.
+/// Only one group may be expanded per top-level request, and by default at most three expansion loops are performed.
+///
+///
+/// This client should typically appear in the pipeline before tool reduction middleware and function invocation
+/// middleware. Example order: .UseToolGrouping(...).UseToolReduction(...).UseFunctionInvocation().
+///
+///
+[Experimental("MEAI001")]
+public sealed class ToolGroupingChatClient : DelegatingChatClient
+{
+ private const string ExpansionFunctionGroupNameParameter = "groupName";
+ private static readonly Delegate _expansionFunctionDelegate = static string (string groupName)
+ => throw new InvalidOperationException("The tool expansion function should not be invoked directly.");
+
+ private readonly int _maxExpansionsPerRequest;
+ private readonly AIFunctionDeclaration _expansionFunction;
+ private readonly string _listGroupsFunctionName;
+ private readonly string _listGroupsFunctionDescription;
+
+ /// Initializes a new instance of the class.
+ /// Inner client.
+ /// Grouping options.
+ public ToolGroupingChatClient(IChatClient innerClient, ToolGroupingOptions options)
+ : base(innerClient)
+ {
+ _ = Throw.IfNull(options);
+
+ _maxExpansionsPerRequest = options.MaxExpansionsPerRequest;
+ _listGroupsFunctionName = options.ListGroupsFunctionName;
+ _listGroupsFunctionDescription = options.ListGroupsFunctionDescription
+ ?? "Returns the list of available tool groups that can be expanded.";
+
+ var expansionFunctionName = options.ExpansionFunctionName;
+ var expansionDescription = options.ExpansionFunctionDescription
+ ?? $"Expands a tool group to make its tools available. Use the '{_listGroupsFunctionName}' function to see available groups.";
+
+ _expansionFunction = AIFunctionFactory.Create(
+ method: _expansionFunctionDelegate,
+ name: expansionFunctionName,
+ description: expansionDescription).AsDeclarationOnly();
+ }
+
+ ///
+ public override async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ var toolGroups = ExtractToolGroups(options);
+ if (toolGroups is not { Count: > 0 })
+ {
+ // If there are no tool groups, then tool expansion isn't possible.
+ // We'll just call directly through to the inner chat client.
+ return await base.GetResponseAsync(messages, options, cancellationToken);
+ }
+
+ // Copy the original messages in order to avoid enumerating the original messages multiple times.
+ // The IEnumerable can represent an arbitrary amount of work.
+ List originalMessages = [.. messages];
+ messages = originalMessages;
+
+ // Build top-level groups dictionary
+ var topLevelToolGroupsByName = toolGroups.ToDictionary(g => g.Name, StringComparer.Ordinal);
+
+ // Track the currently-expanded group and all its constituent tools
+ AIToolGroup? expandedGroup = null;
+ List? expandedGroupToolGroups = null; // tool groups within the currently-expanded tool group
+ List? expandedGroupTools = null; // non-group tools within the currently-expanded tool group
+
+ // Create the "list groups" function. Its behavior is controlled by values captured in the lambda below.
+ var listGroupsFunction = AIFunctionFactory.Create(
+ method: () => CreateListGroupsResult(expandedGroup, toolGroups, expandedGroupToolGroups),
+ name: _listGroupsFunctionName,
+ description: _listGroupsFunctionDescription);
+
+ // Construct new chat options containing ungrouped tools and utility functions.
+ List baseTools = ComputeBaseTools(options, listGroupsFunction);
+ ChatOptions modifiedOptions = options?.Clone() ?? new();
+ modifiedOptions.Tools = baseTools;
+
+ List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
+ ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned
+ List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response
+ List? expansionRequests = null; // expansion requests that need responding to in the current turn
+ UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response
+ bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set
+ List? modifiedTools = null; // the modified tools list containing the current tool group
+
+ for (var expansionIterationCount = 0; ; expansionIterationCount++)
+ {
+ expansionRequests?.Clear();
+
+ // Make the call to the inner client.
+ response = await base.GetResponseAsync(messages, modifiedOptions, cancellationToken).ConfigureAwait(false);
+ if (response is null)
+ {
+ Throw.InvalidOperationException("Inner client returned null ChatResponse.");
+ }
+
+ // Any expansions to perform? If yes, ensure we're tracking that work in expansionRequests.
+ bool requiresExpansion =
+ expansionIterationCount < _maxExpansionsPerRequest &&
+ CopyExpansionRequests(response.Messages, ref expansionRequests);
+
+ if (!requiresExpansion && expansionIterationCount == 0)
+ {
+ // Fast path: no function calling work required
+ return response;
+ }
+
+ // Track aggregate details from the response
+ (responseMessages ??= []).AddRange(response.Messages);
+ if (response.Usage is not null)
+ {
+ if (totalUsage is not null)
+ {
+ totalUsage.Add(response.Usage);
+ }
+ else
+ {
+ totalUsage = response.Usage;
+ }
+ }
+
+ if (!requiresExpansion)
+ {
+ // No more work to do.
+ break;
+ }
+
+ // Prepare the history for the next iteration.
+ FunctionInvokingChatClient.FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId);
+
+ expandedGroupTools ??= [];
+ expandedGroupToolGroups ??= [];
+ (var addedMessages, expandedGroup) = await ProcessExpansionsAsync(
+ expansionRequests!,
+ topLevelToolGroupsByName,
+ expandedGroupTools,
+ expandedGroupToolGroups,
+ expandedGroup,
+ cancellationToken);
+
+ augmentedHistory.AddRange(addedMessages);
+ responseMessages.AddRange(addedMessages);
+
+ (modifiedTools ??= []).Clear();
+ modifiedTools.AddRange(baseTools);
+ modifiedTools.AddRange(expandedGroupTools);
+ modifiedOptions.Tools = modifiedTools;
+ modifiedOptions.ConversationId = response.ConversationId;
+ }
+
+ Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
+ response.Messages = responseMessages!;
+ response.Usage = totalUsage;
+
+ return response;
+ }
+
+ ///
+ public override async IAsyncEnumerable GetStreamingResponseAsync(
+ IEnumerable messages,
+ ChatOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ var toolGroups = ExtractToolGroups(options);
+ if (toolGroups is not { Count: > 0 })
+ {
+ // No tool groups, just call through
+ await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
+ {
+ yield return update;
+ }
+
+ yield break;
+ }
+
+ List originalMessages = [.. messages];
+ messages = originalMessages;
+
+ // Build top-level groups dictionary
+ var topLevelToolGroupsByName = toolGroups.ToDictionary(g => g.Name, StringComparer.Ordinal);
+
+ // Track the currently-expanded group and all its constituent tools
+ AIToolGroup? expandedGroup = null;
+ List? expandedGroupToolGroups = null; // tool groups within the currently-expanded tool group
+ List? expandedGroupTools = null; // non-group tools within the currently-expanded tool group
+
+ // Create the "list groups" function. Its behavior is controlled by values captured in the lambda below.
+ var listGroupsFunction = AIFunctionFactory.Create(
+ method: () => CreateListGroupsResult(expandedGroup, toolGroups, expandedGroupToolGroups),
+ name: _listGroupsFunctionName,
+ description: _listGroupsFunctionDescription);
+
+ // Construct new chat options containing ungrouped tools and utility functions.
+ List baseTools = ComputeBaseTools(options, listGroupsFunction);
+ ChatOptions modifiedOptions = options?.Clone() ?? new();
+ modifiedOptions.Tools = baseTools;
+
+ List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
+ List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response
+ List? expansionRequests = null; // expansion requests that need responding to in the current turn
+ bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set
+ List updates = []; // collected updates from the inner client for the current iteration
+ List? modifiedTools = null;
+ string toolMessageId = Guid.NewGuid().ToString("N"); // stable id for synthetic tool result updates emitted per iteration
+
+ for (int expansionIterationCount = 0; ; expansionIterationCount++)
+ {
+ // Reset any state accumulated from the prior iteration before calling the inner client again.
+ updates.Clear();
+ expansionRequests?.Clear();
+
+ await foreach (var update in base.GetStreamingResponseAsync(messages, modifiedOptions, cancellationToken).ConfigureAwait(false))
+ {
+ if (update is null)
+ {
+ Throw.InvalidOperationException("Inner client returned null ChatResponseUpdate.");
+ }
+
+ updates.Add(update);
+
+ _ = CopyExpansionRequests(update.Contents, ref expansionRequests);
+
+ yield return update;
+ }
+
+ if (expansionIterationCount >= _maxExpansionsPerRequest || expansionRequests is not { Count: > 0 })
+ {
+ // We've either hit the expansion iteration limit or no expansion function calls were made,
+ // so we're done streaming the response.
+ break;
+ }
+
+ // Materialize the collected updates into a ChatResponse so the rest of the logic can share code paths
+ // with the non-streaming implementation.
+ var response = updates.ToChatResponse();
+ (responseMessages ??= []).AddRange(response.Messages);
+
+ // Prepare the history for the next iteration.
+ FunctionInvokingChatClient.FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId);
+
+ // Add the responses from the group expansions into the augmented history and also into the tracked
+ // list of response messages.
+ expandedGroupTools ??= [];
+ expandedGroupToolGroups ??= [];
+ (var addedMessages, expandedGroup) = await ProcessExpansionsAsync(
+ expansionRequests!,
+ topLevelToolGroupsByName,
+ expandedGroupTools,
+ expandedGroupToolGroups,
+ expandedGroup,
+ cancellationToken);
+
+ augmentedHistory!.AddRange(addedMessages);
+ responseMessages.AddRange(addedMessages);
+
+ // Surface the expansion results to the caller as additional streaming updates.
+ foreach (var message in addedMessages)
+ {
+ yield return FunctionInvokingChatClient.ConvertToolResultMessageToUpdate(message, response.ConversationId, toolMessageId);
+ }
+
+ // If a valid group was requested for expansion, and it does not match the currently-expanded group,
+ // update the tools list to contain the newly-expanded tool group.
+ (modifiedTools ??= []).Clear();
+ modifiedTools.AddRange(baseTools);
+ modifiedTools.AddRange(expandedGroupTools);
+ modifiedOptions.Tools = modifiedTools;
+ modifiedOptions.ConversationId = response.ConversationId;
+ }
+ }
+
+ /// Extracts instances from the provided options.
+ private static List? ExtractToolGroups(ChatOptions? options)
+ {
+ if (options?.Tools is not { Count: > 0 })
+ {
+ return null;
+ }
+
+ List? groups = null;
+ foreach (var tool in options.Tools)
+ {
+ if (tool is AIToolGroup group)
+ {
+ (groups ??= []).Add(group);
+ }
+ }
+
+ return groups;
+ }
+
+ /// Creates a function that returns the list of available groups.
+ private static string CreateListGroupsResult(
+ AIToolGroup? expandedToolGroup,
+ List topLevelGroups,
+ List? nestedGroups)
+ {
+ var allToolGroups = nestedGroups is null
+ ? topLevelGroups
+ : topLevelGroups.Concat(nestedGroups);
+
+ allToolGroups = allToolGroups.Where(g => g != expandedToolGroup);
+
+ if (!allToolGroups.Any())
+ {
+ return "No tool groups are currently available.";
+ }
+
+ var sb = new StringBuilder();
+ sb.Append("Available tool groups:");
+ AppendAIToolList(sb, allToolGroups);
+ return sb.ToString();
+ }
+
+ /// Processes expansion requests and returns messages to add, termination flag, and updated group state.
+ private static async Task<(IList messagesToAdd, AIToolGroup? expandedGroup)> ProcessExpansionsAsync(
+ List expansionRequests,
+ Dictionary topLevelGroupsByName,
+ List expandedGroupTools,
+ List expandedGroupToolGroups,
+ AIToolGroup? expandedGroup,
+ CancellationToken cancellationToken)
+ {
+ Debug.Assert(expansionRequests.Count != 0, "Expected at least one expansion request.");
+
+ var contents = new List(expansionRequests.Count);
+
+ foreach (var expansionRequest in expansionRequests)
+ {
+ if (expansionRequest.Arguments is not { Count: > 0 } arguments ||
+ !arguments.TryGetValue(ExpansionFunctionGroupNameParameter, out var groupNameArg) ||
+ groupNameArg is null)
+ {
+ contents.Add(new FunctionResultContent(
+ callId: expansionRequest.CallId,
+ result: "No group name was specified; ignoring expansion request."));
+ continue;
+ }
+
+ bool TryGetValidToolGroup(string groupName, [NotNullWhen(true)] out AIToolGroup? group)
+ {
+ if (topLevelGroupsByName.TryGetValue(groupName, out group))
+ {
+ return true;
+ }
+
+ group = expandedGroupToolGroups.FirstOrDefault(g => string.Equals(g.Name, groupName, StringComparison.Ordinal));
+ return group is not null;
+ }
+
+ var groupName = groupNameArg.ToString();
+ if (groupName is null || !TryGetValidToolGroup(groupName, out var group))
+ {
+ contents.Add(new FunctionResultContent(
+ callId: expansionRequest.CallId,
+ result: $"The specific group name '{groupName}' was invalid; ignoring expansion request."));
+ continue;
+ }
+
+ if (group == expandedGroup)
+ {
+ contents.Add(new FunctionResultContent(
+ callId: expansionRequest.CallId,
+ result: $"Ignoring duplicate expansion of group '{groupName}'."));
+ continue;
+ }
+
+ // Expand the group
+ expandedGroup = group;
+ var groupTools = await group.GetToolsAsync(cancellationToken).ConfigureAwait(false);
+
+ expandedGroupTools.Clear();
+ expandedGroupToolGroups.Clear();
+
+ foreach (var tool in groupTools)
+ {
+ if (tool is AIToolGroup toolGroup)
+ {
+ expandedGroupToolGroups.Add(toolGroup);
+ }
+ else
+ {
+ expandedGroupTools.Add(tool);
+ }
+ }
+
+ // Build success message
+ var sb = new StringBuilder();
+ sb.Append("Successfully expanded group '");
+ sb.Append(groupName);
+ sb.Append("'.");
+
+ if (expandedGroupTools.Count > 0)
+ {
+ sb.Append(" Only this group's tools are now available:");
+ AppendAIToolList(sb, expandedGroupTools);
+ }
+
+ if (expandedGroupToolGroups.Count > 0)
+ {
+ sb.AppendLine();
+ sb.Append("Additional groups available for expansion:");
+ AppendAIToolList(sb, expandedGroupToolGroups);
+ }
+
+ contents.Add(new FunctionResultContent(
+ callId: expansionRequest.CallId,
+ result: sb.ToString()));
+ }
+
+ return (messagesToAdd: [new ChatMessage(ChatRole.Tool, contents)], expandedGroup);
+ }
+
+ /// Appends a formatted list of AI tools to the specified .
+ private static void AppendAIToolList(StringBuilder sb, IEnumerable tools)
+ {
+ foreach (var tool in tools)
+ {
+ sb.AppendLine();
+ sb.Append("- ");
+ sb.Append(tool.Name);
+ sb.Append(": ");
+ sb.Append(tool.Description);
+ }
+ }
+
+ /// Copies expansion requests from messages.
+ private bool CopyExpansionRequests(IList messages, [NotNullWhen(true)] ref List? expansionRequests)
+ {
+ var any = false;
+ foreach (var message in messages)
+ {
+ any |= CopyExpansionRequests(message.Contents, ref expansionRequests);
+ }
+
+ return any;
+ }
+
+ /// Copies expansion requests from contents.
+ private bool CopyExpansionRequests(
+ IList contents,
+ [NotNullWhen(true)] ref List? expansionRequests)
+ {
+ var any = false;
+ foreach (var content in contents)
+ {
+ if (content is FunctionCallContent functionCall &&
+ string.Equals(functionCall.Name, _expansionFunction.Name, StringComparison.Ordinal))
+ {
+ (expansionRequests ??= []).Add(functionCall);
+ any = true;
+ }
+ }
+
+ return any;
+ }
+
+ ///
+ /// Generates a list of base AI tools by combining the default expansion function with additional tools specified in
+ /// the provided chat options, excluding any tools that are grouped.
+ ///
+ private List ComputeBaseTools(ChatOptions? options, AIFunction listGroupsFunction)
+ {
+ List baseTools = [listGroupsFunction, _expansionFunction];
+
+ foreach (var tool in options?.Tools ?? [])
+ {
+ if (tool is not AIToolGroup)
+ {
+ if (string.Equals(tool.Name, _expansionFunction.Name, StringComparison.Ordinal) ||
+ string.Equals(tool.Name, listGroupsFunction.Name, StringComparison.Ordinal))
+ {
+ throw new InvalidOperationException(
+ $"The group expansion tool with name '{tool.Name}' collides with a registered tool of the same name.");
+ }
+
+ baseTools.Add(tool);
+ }
+ }
+
+ return baseTools;
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs
new file mode 100644
index 00000000000..fe6854c2378
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolGrouping/ToolGroupingOptions.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14
+
+namespace Microsoft.Extensions.AI;
+
+/// Options controlling tool grouping / expansion behavior.
+[Experimental("MEAI001")]
+public sealed class ToolGroupingOptions
+{
+ private const string DefaultExpansionFunctionName = "__expand_tool_group";
+ private const string DefaultListGroupsFunctionName = "__list_tool_groups";
+
+ private string _expansionFunctionName = DefaultExpansionFunctionName;
+ private string? _expansionFunctionDescription;
+ private string _listGroupsFunctionName = DefaultListGroupsFunctionName;
+ private string? _listGroupsFunctionDescription;
+ private int _maxExpansionsPerRequest = 3;
+
+ /// Gets or sets the name of the synthetic expansion function tool.
+ public string ExpansionFunctionName
+ {
+ get => _expansionFunctionName;
+ set => _expansionFunctionName = Throw.IfNull(value);
+ }
+
+ /// Gets or sets the description of the synthetic expansion function tool.
+ public string? ExpansionFunctionDescription
+ {
+ get => _expansionFunctionDescription;
+ set => _expansionFunctionDescription = value;
+ }
+
+ /// Gets or sets the name of the synthetic list groups function tool.
+ public string ListGroupsFunctionName
+ {
+ get => _listGroupsFunctionName;
+ set => _listGroupsFunctionName = Throw.IfNull(value);
+ }
+
+ /// Gets or sets the description of the synthetic list groups function tool.
+ public string? ListGroupsFunctionDescription
+ {
+ get => _listGroupsFunctionDescription;
+ set => _listGroupsFunctionDescription = value;
+ }
+
+ /// Gets or sets the maximum number of expansions allowed within a single request.
+ /// Defaults to 3.
+ public int MaxExpansionsPerRequest
+ {
+ get => _maxExpansionsPerRequest;
+ set => _maxExpansionsPerRequest = Throw.IfLessThan(value, 1);
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs
new file mode 100644
index 00000000000..5a644267328
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs
@@ -0,0 +1,32 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+/// Extension methods for adding tool reduction middleware to a chat client pipeline.
+[Experimental("MEAI001")]
+public static class ChatClientBuilderToolReductionExtensions
+{
+ ///
+ /// Adds tool reduction to the chat client pipeline using the specified .
+ ///
+ /// The chat client builder.
+ /// The reduction strategy.
+ /// The original builder for chaining.
+ /// If or is .
+ ///
+ /// This should typically appear in the pipeline before function invocation middleware so that only the reduced tools
+ /// are exposed to the underlying provider.
+ ///
+ public static ChatClientBuilder UseToolReduction(this ChatClientBuilder builder, IToolReductionStrategy strategy)
+ {
+ _ = Throw.IfNull(builder);
+ _ = Throw.IfNull(strategy);
+
+ return builder.Use(inner => new ToolReducingChatClient(inner, strategy));
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs
new file mode 100644
index 00000000000..f9e4c60995a
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs
@@ -0,0 +1,330 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Numerics.Tensors;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14
+
+///
+/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context.
+///
+///
+/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current
+/// conversation content each request. It then selects the top toolLimit tools by similarity.
+///
+[Experimental("MEAI001")]
+public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy
+{
+ private readonly ConditionalWeakTable> _toolEmbeddingsCache = new();
+ private readonly IEmbeddingGenerator> _embeddingGenerator;
+ private readonly int _toolLimit;
+
+ private Func _toolEmbeddingTextSelector = static t =>
+ {
+ if (string.IsNullOrWhiteSpace(t.Name))
+ {
+ return t.Description;
+ }
+
+ if (string.IsNullOrWhiteSpace(t.Description))
+ {
+ return t.Name;
+ }
+
+ return t.Name + Environment.NewLine + t.Description;
+ };
+
+ private Func, ValueTask> _messagesEmbeddingTextSelector = static messages =>
+ {
+ var sb = new StringBuilder();
+ foreach (var message in messages)
+ {
+ var contents = message.Contents;
+ for (var i = 0; i < contents.Count; i++)
+ {
+ string text;
+ switch (contents[i])
+ {
+ case TextContent content:
+ text = content.Text;
+ break;
+ case TextReasoningContent content:
+ text = content.Text;
+ break;
+ default:
+ continue;
+ }
+
+ _ = sb.AppendLine(text);
+ }
+ }
+
+ return new ValueTask(sb.ToString());
+ };
+
+ private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span);
+
+ private Func _isRequiredTool = static _ => false;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// Embedding generator used to produce embeddings.
+ /// Maximum number of tools to return, excluding required tools. Must be greater than zero.
+ public EmbeddingToolReductionStrategy(
+ IEmbeddingGenerator> embeddingGenerator,
+ int toolLimit)
+ {
+ _embeddingGenerator = Throw.IfNull(embeddingGenerator);
+ _toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0);
+ }
+
+ ///
+ /// Gets or sets the selector used to generate a single text string from a tool.
+ ///
+ ///
+ /// Defaults to: Name + "\n" + Description (omitting empty parts).
+ ///
+ public Func ToolEmbeddingTextSelector
+ {
+ get => _toolEmbeddingTextSelector;
+ set => _toolEmbeddingTextSelector = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets the selector used to generate a single text string from a collection of chat messages for
+ /// embedding purposes.
+ ///
+ public Func, ValueTask> MessagesEmbeddingTextSelector
+ {
+ get => _messagesEmbeddingTextSelector;
+ set => _messagesEmbeddingTextSelector = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a similarity function applied to (query, tool) embedding vectors.
+ ///
+ ///
+ /// Defaults to cosine similarity.
+ ///
+ public Func, ReadOnlyMemory, float> Similarity
+ {
+ get => _similarity;
+ set => _similarity = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a function that determines whether a tool is required (always included).
+ ///
+ ///
+ /// If this returns , the tool is included regardless of ranking and does not count against
+ /// the configured non-required tool limit. A tool explicitly named by (when
+ /// is non-null) is also treated as required, independent
+ /// of this delegate's result.
+ ///
+ public Func IsRequiredTool
+ {
+ get => _isRequiredTool;
+ set => _isRequiredTool = Throw.IfNull(value);
+ }
+
+ ///
+ /// Gets or sets a value indicating whether to preserve original ordering of selected tools.
+ /// If (default), tools are ordered by descending similarity.
+ /// If , the top-N tools by similarity are re-emitted in their original order.
+ ///
+ public bool PreserveOriginalOrdering { get; set; }
+
+ ///
+ public async Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ if (options?.Tools is not { Count: > 0 } tools)
+ {
+ // Prefer the original tools list reference if possible.
+ // This allows ToolReducingChatClient to avoid unnecessarily copying ChatOptions.
+ // When no reduction is performed.
+ return options?.Tools ?? [];
+ }
+
+ Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero.");
+
+ if (tools.Count <= _toolLimit)
+ {
+ // Since the total number of tools doesn't exceed the configured tool limit,
+ // there's no need to determine which tools are optional, i.e., subject to reduction.
+ // We can return the original tools list early.
+ return tools;
+ }
+
+ var toolRankingInfoArray = ArrayPool.Shared.Rent(tools.Count);
+ try
+ {
+ var toolRankingInfoMemory = toolRankingInfoArray.AsMemory(start: 0, length: tools.Count);
+
+ // We allocate tool rankings in a contiguous chunk of memory, but partition them such that
+ // required tools come first and are immediately followed by optional tools.
+ // This allows us to separately rank optional tools by similarity score, but then later re-order
+ // the top N tools (including required tools) to preserve their original relative order.
+ var (requiredTools, optionalTools) = PartitionToolRankings(toolRankingInfoMemory, tools, options.ToolMode);
+
+ if (optionalTools.Length <= _toolLimit)
+ {
+ // There aren't enough optional tools to require reduction, so we'll return the original
+ // tools list.
+ return tools;
+ }
+
+ // Build query text from recent messages.
+ var queryText = await MessagesEmbeddingTextSelector(messages).ConfigureAwait(false);
+ if (string.IsNullOrWhiteSpace(queryText))
+ {
+ // We couldn't build a meaningful query, likely because the message list was empty.
+ // We'll just return the original tools list.
+ return tools;
+ }
+
+ var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false);
+
+ // Compute and populate similarity scores in the tool ranking info.
+ await ComputeSimilarityScoresAsync(optionalTools, queryEmbedding, cancellationToken);
+
+ var topTools = toolRankingInfoMemory.Slice(start: 0, length: requiredTools.Length + _toolLimit);
+#if NET
+ optionalTools.Span.Sort(AIToolRankingInfo.CompareByDescendingSimilarityScore);
+ if (PreserveOriginalOrdering)
+ {
+ topTools.Span.Sort(AIToolRankingInfo.CompareByOriginalIndex);
+ }
+#else
+ Array.Sort(toolRankingInfoArray, index: requiredTools.Length, length: optionalTools.Length, AIToolRankingInfo.CompareByDescendingSimilarityScore);
+ if (PreserveOriginalOrdering)
+ {
+ Array.Sort(toolRankingInfoArray, index: 0, length: topTools.Length, AIToolRankingInfo.CompareByOriginalIndex);
+ }
+#endif
+ return ToToolList(topTools.Span);
+
+ static List ToToolList(ReadOnlySpan toolInfo)
+ {
+ var result = new List(capacity: toolInfo.Length);
+ foreach (var info in toolInfo)
+ {
+ result.Add(info.Tool);
+ }
+
+ return result;
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(toolRankingInfoArray);
+ }
+ }
+
+ private (Memory RequiredTools, Memory OptionalTools) PartitionToolRankings(
+ Memory toolRankingInfo, IList tools, ChatToolMode? toolMode)
+ {
+ // Always include a tool if its name matches the required function name.
+ var requiredFunctionName = (toolMode as RequiredChatToolMode)?.RequiredFunctionName;
+ var nextRequiredToolIndex = 0;
+ var nextOptionalToolIndex = tools.Count - 1;
+ for (var i = 0; i < toolRankingInfo.Length; i++)
+ {
+ var tool = tools[i];
+ var isRequiredByToolMode = requiredFunctionName is not null && string.Equals(requiredFunctionName, tool.Name, StringComparison.Ordinal);
+ var toolIndex = isRequiredByToolMode || IsRequiredTool(tool)
+ ? nextRequiredToolIndex++
+ : nextOptionalToolIndex--;
+ toolRankingInfo.Span[toolIndex] = new AIToolRankingInfo(tool, originalIndex: i);
+ }
+
+ return (
+ RequiredTools: toolRankingInfo.Slice(0, nextRequiredToolIndex),
+ OptionalTools: toolRankingInfo.Slice(nextRequiredToolIndex));
+ }
+
+ private async Task ComputeSimilarityScoresAsync(Memory toolInfo, Embedding queryEmbedding, CancellationToken cancellationToken)
+ {
+ var anyCacheMisses = false;
+ List cacheMissToolEmbeddingTexts = null!;
+ List cacheMissToolInfoIndexes = null!;
+ for (var i = 0; i < toolInfo.Length; i++)
+ {
+ ref var info = ref toolInfo.Span[i];
+ if (_toolEmbeddingsCache.TryGetValue(info.Tool, out var toolEmbedding))
+ {
+ info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
+ }
+ else
+ {
+ if (!anyCacheMisses)
+ {
+ anyCacheMisses = true;
+ cacheMissToolEmbeddingTexts = [];
+ cacheMissToolInfoIndexes = [];
+ }
+
+ var text = ToolEmbeddingTextSelector(info.Tool);
+ cacheMissToolEmbeddingTexts.Add(text);
+ cacheMissToolInfoIndexes.Add(i);
+ }
+ }
+
+ if (!anyCacheMisses)
+ {
+ // There were no cache misses; no more work to do.
+ return;
+ }
+
+ var uncachedEmbeddings = await _embeddingGenerator.GenerateAsync(cacheMissToolEmbeddingTexts, cancellationToken: cancellationToken).ConfigureAwait(false);
+ if (uncachedEmbeddings.Count != cacheMissToolEmbeddingTexts.Count)
+ {
+ throw new InvalidOperationException($"Expected {cacheMissToolEmbeddingTexts.Count} embeddings, got {uncachedEmbeddings.Count}.");
+ }
+
+ for (var i = 0; i < uncachedEmbeddings.Count; i++)
+ {
+ var toolInfoIndex = cacheMissToolInfoIndexes[i];
+ var toolEmbedding = uncachedEmbeddings[i];
+ ref var info = ref toolInfo.Span[toolInfoIndex];
+ info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
+ _toolEmbeddingsCache.Add(info.Tool, toolEmbedding);
+ }
+ }
+
+ private struct AIToolRankingInfo(AITool tool, int originalIndex)
+ {
+ public static readonly Comparer CompareByDescendingSimilarityScore
+ = Comparer.Create(static (a, b) =>
+ {
+ var result = b.SimilarityScore.CompareTo(a.SimilarityScore);
+ return result != 0
+ ? result
+ : a.OriginalIndex.CompareTo(b.OriginalIndex); // Stabilize ties.
+ });
+
+ public static readonly Comparer CompareByOriginalIndex
+ = Comparer.Create(static (a, b) => a.OriginalIndex.CompareTo(b.OriginalIndex));
+
+ public AITool Tool { get; } = tool;
+ public int OriginalIndex { get; } = originalIndex;
+ public float SimilarityScore { get; set; }
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs
new file mode 100644
index 00000000000..6a5d6d925fc
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs
@@ -0,0 +1,89 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// A delegating chat client that applies a tool reduction strategy before invoking the inner client.
+///
+///
+/// Insert this into a pipeline (typically before function invocation middleware) to automatically
+/// reduce the tool list carried on for each request.
+///
+[Experimental("MEAI001")]
+public sealed class ToolReducingChatClient : DelegatingChatClient
+{
+ private readonly IToolReductionStrategy _strategy;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The inner client.
+ /// The tool reduction strategy to apply.
+ /// Thrown if any argument is .
+ public ToolReducingChatClient(IChatClient innerClient, IToolReductionStrategy strategy)
+ : base(innerClient)
+ {
+ _strategy = Throw.IfNull(strategy);
+ }
+
+ ///
+ public override async Task GetResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false);
+ return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ public override async IAsyncEnumerable GetStreamingResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false);
+
+ await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
+ {
+ yield return update;
+ }
+ }
+
+ private async Task ApplyReductionAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken)
+ {
+ // If there are no options or no tools, skip.
+ if (options?.Tools is not { Count: > 0 })
+ {
+ return options;
+ }
+
+ var reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false);
+
+ // If strategy returned the same list instance (or reference equality), assume no change.
+ if (ReferenceEquals(reduced, options.Tools))
+ {
+ return options;
+ }
+
+ // Materialize and compare counts; if unchanged and tools have identical ordering and references, keep original.
+ if (reduced is not IList reducedList)
+ {
+ reducedList = reduced.ToList();
+ }
+
+ // Clone options to avoid mutating a possibly shared instance.
+ var cloned = options.Clone();
+ cloned.Tools = reducedList;
+ return cloned;
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
index 448de8d11df..06a16abc2a4 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
@@ -41,6 +41,8 @@ protected ChatClientIntegrationTests()
protected IChatClient? ChatClient { get; }
+ protected IEmbeddingGenerator>? EmbeddingGenerator { get; private set; }
+
public void Dispose()
{
ChatClient?.Dispose();
@@ -49,6 +51,13 @@ public void Dispose()
protected abstract IChatClient? CreateChatClient();
+ ///
+ /// Optionally supplies an embedding generator for integration tests that exercise
+ /// embedding-based components (e.g., tool reduction). Default returns null and
+ /// tests depending on embeddings will skip if not overridden.
+ ///
+ protected virtual IEmbeddingGenerator>? CreateEmbeddingGenerator() => null;
+
[ConditionalFact]
public virtual async Task GetResponseAsync_SingleRequestMessage()
{
@@ -1346,6 +1355,398 @@ public virtual async Task SummarizingChatReducer_CustomPrompt()
Assert.Contains("5", response.Text);
}
+ [ConditionalFact]
+ public virtual async Task ToolGrouping_LlmExpandsGroupAndInvokesTool_NonStreaming()
+ {
+ SkipIfNotEnabled();
+
+ const string TravelToken = "TRAVEL-PLAN-TOKEN-4826";
+
+ List> availableToolsPerInvocation = [];
+ string expansionFunctionName = "__expand_tool_group"; // Default expansion function name
+
+ var generateItinerary = AIFunctionFactory.Create(
+ (string city) => $"{TravelToken}::{city.ToUpperInvariant()}::DAY-PLAN",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GenerateItinerary",
+ Description = "Produces a detailed itinerary token. Always repeat the returned token verbatim in your summary."
+ });
+
+ var summarizePacking = AIFunctionFactory.Create(
+ () => "Pack light layers and comfortable shoes.",
+ new AIFunctionFactoryOptions
+ {
+ Name = "PackingSummary",
+ Description = "Provides a short packing reminder."
+ });
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolGrouping()
+ .Use((messages, options, next, cancellationToken) =>
+ {
+ if (options?.Tools is { Count: > 0 } tools)
+ {
+ availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]);
+ }
+ else
+ {
+ availableToolsPerInvocation.Add([]);
+ }
+
+ return next(messages, options, cancellationToken);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List messages =
+ [
+ new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user."),
+ new(ChatRole.User, "Plan a two-day cultural trip to Rome and tell me the itinerary token."),
+ ];
+
+ var response = await client.GetResponseAsync(messages, new ChatOptions
+ {
+ Tools = [AIToolGroup.Create("TravelUtilities", "Travel planning helpers that generate itinerary tokens.", [generateItinerary, summarizePacking])]
+ });
+
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelUtilities'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ Assert.Contains(response.Messages, message =>
+ message.Contents.OfType().Any(content =>
+ string.Equals(content.Name, "GenerateItinerary", StringComparison.Ordinal)));
+
+ var nonStreamingText = response.Text ?? string.Empty;
+ Assert.Contains(TravelToken.ToUpperInvariant(), nonStreamingText.ToUpperInvariant());
+
+ Assert.NotEmpty(availableToolsPerInvocation);
+ var firstInvocationTools = availableToolsPerInvocation[0];
+ Assert.Contains(expansionFunctionName, firstInvocationTools);
+ Assert.DoesNotContain(generateItinerary.Name, firstInvocationTools);
+
+ var expandedInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains(generateItinerary.Name));
+ Assert.True(expandedInvocationIndex >= 0, "GenerateItinerary was never exposed to the model.");
+ Assert.True(expandedInvocationIndex > 0, "GenerateItinerary was visible before the expansion function executed.");
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolGrouping_LlmExpandsGroupAndInvokesTool_Streaming()
+ {
+ SkipIfNotEnabled();
+
+ const string LodgingToken = "LODGING-TOKEN-3895";
+
+ List> availableToolsPerInvocation = [];
+ string expansionFunctionName = "__expand_tool_group"; // Default expansion function name
+
+ var suggestLodging = AIFunctionFactory.Create(
+ (string city, string budget) => $"{LodgingToken}::{city.ToUpperInvariant()}::{budget}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "SuggestLodging",
+ Description = "Returns hotel recommendations along with a lodging token. Repeat the token verbatim in your narrative."
+ });
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolGrouping()
+ .Use((messages, options, next, cancellationToken) =>
+ {
+ if (options?.Tools is { Count: > 0 } tools)
+ {
+ availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]);
+ }
+ else
+ {
+ availableToolsPerInvocation.Add([]);
+ }
+
+ return next(messages, options, cancellationToken);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List messages =
+ [
+ new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user."),
+ new(ChatRole.User, "We're visiting Paris with a nightly budget of 150 USD. Stream the suggestions as they arrive and repeat the lodging token."),
+ ];
+
+ var response = await client.GetStreamingResponseAsync(messages, new ChatOptions
+ {
+ Tools = [AIToolGroup.Create("TravelUtilities", "Travel helpers used for lodging recommendations.", [suggestLodging])]
+ }).ToChatResponseAsync();
+
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelUtilities'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ Assert.Contains(response.Messages, message =>
+ message.Contents.OfType().Any(content =>
+ string.Equals(content.Name, "SuggestLodging", StringComparison.Ordinal)));
+
+ var streamingText = response.Text ?? string.Empty;
+ Assert.Contains(LodgingToken.ToUpperInvariant(), streamingText.ToUpperInvariant());
+
+ Assert.NotEmpty(availableToolsPerInvocation);
+ var firstInvocationTools = availableToolsPerInvocation[0];
+ Assert.Contains(expansionFunctionName, firstInvocationTools);
+ Assert.DoesNotContain(suggestLodging.Name, firstInvocationTools);
+
+ var expandedInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains(suggestLodging.Name));
+ Assert.True(expandedInvocationIndex >= 0, "SuggestLodging was never exposed to the model.");
+ Assert.True(expandedInvocationIndex > 0, "SuggestLodging was visible before the expansion function executed.");
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolGrouping_NestedGroups_LlmExpandsHierarchyAndInvokesTool()
+ {
+ SkipIfNotEnabled();
+
+ const string BookingToken = "BOOKING-CONFIRMED-7291";
+
+ List> availableToolsPerInvocation = [];
+ string expansionFunctionName = "__expand_tool_group"; // Default expansion function name
+
+ // Leaf-level tools in nested groups
+ var bookFlight = AIFunctionFactory.Create(
+ (string origin, string destination, string date) => $"{BookingToken}::FLIGHT::{origin}-{destination}::{date}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "BookFlight",
+ Description = "Books a flight and returns a booking confirmation token. Always repeat the token verbatim."
+ });
+
+ var bookHotel = AIFunctionFactory.Create(
+ (string city, string checkIn) => $"{BookingToken}::HOTEL::{city}::{checkIn}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "BookHotel",
+ Description = "Books a hotel and returns a booking confirmation token. Always repeat the token verbatim."
+ });
+
+ var getCurrency = AIFunctionFactory.Create(
+ (string country) => $"The currency in {country} is EUR.",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetCurrency",
+ Description = "Gets currency information for a country."
+ });
+
+ // Create nested group structure: TravelServices -> Booking -> FlightBooking
+ var flightBookingGroup = AIToolGroup.Create("FlightBooking", "Flight booking services", [bookFlight]);
+ var bookingGroup = AIToolGroup.Create("Booking", "All booking services including flights and hotels", [flightBookingGroup, bookHotel]);
+ var travelServicesGroup = AIToolGroup.Create("TravelServices", "Complete travel services including booking and information", [bookingGroup, getCurrency]);
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolGrouping(options =>
+ {
+ options.MaxExpansionsPerRequest = 5;
+ })
+ .Use((messages, options, next, cancellationToken) =>
+ {
+ if (options?.Tools is { Count: > 0 } tools)
+ {
+ availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]);
+ }
+ else
+ {
+ availableToolsPerInvocation.Add([]);
+ }
+
+ return next(messages, options, cancellationToken);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List messages =
+ [
+ new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user. Explore nested groups to find the right tool."),
+ new(ChatRole.User, "I need to book a flight from Seattle to Paris for December 15th, 2025. Please provide the booking token."),
+ ];
+
+ var response = await client.GetResponseAsync(messages, new ChatOptions
+ {
+ Tools = [travelServicesGroup]
+ });
+
+ // Verify the nested group expansions occurred
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'TravelServices'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'Booking'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'FlightBooking'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ // Verify the actual tool was invoked
+ Assert.Contains(response.Messages, message =>
+ message.Contents.OfType().Any(content =>
+ string.Equals(content.Name, "BookFlight", StringComparison.Ordinal)));
+
+ // Verify the booking token appears in the response
+ var responseText = response.Text ?? string.Empty;
+ Assert.Contains(BookingToken, responseText);
+ Assert.Contains("SEATTLE", responseText.ToUpperInvariant());
+ Assert.Contains("PARIS", responseText.ToUpperInvariant());
+
+ // Verify progressive expansion: first only expansion function, then groups appear, finally leaf tools
+ Assert.NotEmpty(availableToolsPerInvocation);
+ var firstInvocationTools = availableToolsPerInvocation[0];
+ Assert.Contains(expansionFunctionName, firstInvocationTools);
+ Assert.DoesNotContain("BookFlight", firstInvocationTools);
+ Assert.DoesNotContain("Booking", firstInvocationTools);
+
+ // Verify BookFlight was not available until after all necessary expansions
+ var bookFlightInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains("BookFlight"));
+ Assert.True(bookFlightInvocationIndex >= 0, "BookFlight was never exposed to the model.");
+ Assert.True(bookFlightInvocationIndex > 2, "BookFlight was visible before completing the nested expansion hierarchy.");
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolGrouping_MultipleNestedGroups_LlmSelectsCorrectPathAndInvokesTool()
+ {
+ SkipIfNotEnabled();
+
+ const string DiagnosticToken = "DIAGNOSTIC-REPORT-5483";
+
+ List> availableToolsPerInvocation = [];
+ string expansionFunctionName = "__expand_tool_group"; // Default expansion function name
+
+ // Healthcare nested tools
+ var runBloodTest = AIFunctionFactory.Create(
+ (string patientId) => $"{DiagnosticToken}::BLOOD-TEST::{patientId}::COMPLETE",
+ new AIFunctionFactoryOptions
+ {
+ Name = "RunBloodTest",
+ Description = "Orders a blood test and returns a diagnostic token. Always repeat the token verbatim."
+ });
+
+ var scheduleXRay = AIFunctionFactory.Create(
+ (string patientId, string bodyPart) => $"X-Ray scheduled for {bodyPart}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "ScheduleXRay",
+ Description = "Schedules an X-ray appointment."
+ });
+
+ // Financial nested tools
+ var processPayment = AIFunctionFactory.Create(
+ (string accountId, decimal amount) => $"Payment of ${amount} processed for account {accountId}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "ProcessPayment",
+ Description = "Processes a payment transaction."
+ });
+
+ var generateInvoice = AIFunctionFactory.Create(
+ (string customerId) => $"Invoice generated for customer {customerId}",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GenerateInvoice",
+ Description = "Generates an invoice document."
+ });
+
+ // Create two separate nested hierarchies
+ var diagnosticsGroup = AIToolGroup.Create("Diagnostics", "Medical diagnostic services", [runBloodTest]);
+ var imagingGroup = AIToolGroup.Create("Imaging", "Medical imaging services", [scheduleXRay]);
+ var healthcareGroup = AIToolGroup.Create("Healthcare", "All healthcare services including diagnostics and imaging", [diagnosticsGroup, imagingGroup]);
+
+ var paymentsGroup = AIToolGroup.Create("Payments", "Payment processing services", [processPayment]);
+ var billingGroup = AIToolGroup.Create("Billing", "Billing and invoicing services", [generateInvoice]);
+ var financialGroup = AIToolGroup.Create("Financial", "All financial services including payments and billing", [paymentsGroup, billingGroup]);
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolGrouping(options =>
+ {
+ options.MaxExpansionsPerRequest = 5;
+ })
+ .Use((messages, options, next, cancellationToken) =>
+ {
+ if (options?.Tools is { Count: > 0 } tools)
+ {
+ availableToolsPerInvocation.Add([.. tools.Select(static tool => tool.Name)]);
+ }
+ else
+ {
+ availableToolsPerInvocation.Add([]);
+ }
+
+ return next(messages, options, cancellationToken);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List messages =
+ [
+ new(ChatRole.System, "You are a helpful assistant. Use the available tools to assist the user. Navigate through nested groups to find the appropriate tool."),
+ new(ChatRole.User, "Patient P-42 needs a blood test ordered. Please provide the diagnostic token."),
+ ];
+
+ var response = await client.GetResponseAsync(messages, new ChatOptions
+ {
+ Tools = [healthcareGroup, financialGroup]
+ });
+
+ // Verify the correct nested path was taken (Healthcare -> Diagnostics)
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'Healthcare'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ Assert.Contains(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'Diagnostics'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ // Verify the incorrect path was NOT taken (Financial group should not be expanded)
+ Assert.DoesNotContain(response.Messages, message =>
+ message.Role == ChatRole.Tool &&
+ message.Contents.OfType().Any(content =>
+ content.Result?.ToString()?.IndexOf("Successfully expanded group 'Financial'", StringComparison.OrdinalIgnoreCase) >= 0));
+
+ // Verify the correct leaf tool was invoked
+ Assert.Contains(response.Messages, message =>
+ message.Contents.OfType().Any(content =>
+ string.Equals(content.Name, "RunBloodTest", StringComparison.Ordinal)));
+
+ // Verify wrong tools were not invoked
+ Assert.DoesNotContain(response.Messages, message =>
+ message.Contents.OfType().Any(content =>
+ string.Equals(content.Name, "ProcessPayment", StringComparison.Ordinal) ||
+ string.Equals(content.Name, "GenerateInvoice", StringComparison.Ordinal)));
+
+ // Verify the diagnostic token appears in the response
+ var responseText = response.Text ?? string.Empty;
+ Assert.Contains(DiagnosticToken, responseText);
+ Assert.Contains("P-42", responseText);
+
+ // Verify progressive expansion behavior
+ Assert.NotEmpty(availableToolsPerInvocation);
+ var firstInvocationTools = availableToolsPerInvocation[0];
+ Assert.Contains(expansionFunctionName, firstInvocationTools);
+ Assert.DoesNotContain("RunBloodTest", firstInvocationTools);
+
+ // Verify RunBloodTest only became available after the correct nested expansions
+ var runBloodTestInvocationIndex = availableToolsPerInvocation.FindIndex(tools => tools.Contains("RunBloodTest"));
+ Assert.True(runBloodTestInvocationIndex >= 0, "RunBloodTest was never exposed to the model.");
+ Assert.True(runBloodTestInvocationIndex > 1, "RunBloodTest was visible before completing necessary nested expansions.");
+ }
+
private sealed class TestSummarizingChatClient : IChatClient
{
private IChatClient _summarizerChatClient;
@@ -1395,6 +1796,343 @@ public void Dispose()
}
}
+ [ConditionalFact]
+ public virtual async Task ToolReduction_DynamicSelection_RespectsConversationHistory()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Limit to 2 so that, once the conversation references both weather and translation,
+ // both tools can be included even if the latest user turn only mentions one of them.
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2);
+
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns weather forecast and temperature for a given city."
+ });
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates text between human languages."
+ });
+
+ var mathTool = AIFunctionFactory.Create(
+ () => 42,
+ new AIFunctionFactoryOptions
+ {
+ Name = "SolveMath",
+ Description = "Solves basic math problems."
+ });
+
+ var allTools = new List { weatherTool, translateTool, mathTool };
+
+ IList? firstTurnTools = null;
+ IList? secondTurnTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .Use(async (messages, options, next, ct) =>
+ {
+ // Capture the (possibly reduced) tool list for each turn.
+ if (firstTurnTools is null)
+ {
+ firstTurnTools = options?.Tools;
+ }
+ else
+ {
+ secondTurnTools ??= options?.Tools;
+ }
+
+ await next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ // Maintain chat history across turns.
+ List history = [];
+
+ // Turn 1: Ask a weather question.
+ history.Add(new ChatMessage(ChatRole.User, "What will the weather be in Seattle tomorrow?"));
+ var firstResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(firstResponse); // Append assistant reply.
+
+ Assert.NotNull(firstTurnTools);
+ Assert.Contains(firstTurnTools, t => t.Name == "GetWeatherForecast");
+
+ // Turn 2: Ask a translation question. Even though only translation is mentioned now,
+ // conversation history still contains a weather request. Expect BOTH weather + translation tools.
+ history.Add(new ChatMessage(ChatRole.User, "Please translate 'good evening' into French."));
+ var secondResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(secondResponse);
+
+ Assert.NotNull(secondTurnTools);
+ Assert.Equal(2, secondTurnTools.Count); // Should have filled both slots with the two relevant domains.
+ Assert.Contains(secondTurnTools, t => t.Name == "GetWeatherForecast");
+ Assert.Contains(secondTurnTools, t => t.Name == "TranslateText");
+
+ // Ensure unrelated tool was excluded.
+ Assert.DoesNotContain(secondTurnTools, t => t.Name == "SolveMath");
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_RequireSpecificToolPreservedAndOrdered()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Limit would normally reduce to 1, but required tool plus another should remain.
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1);
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates phrases between languages."
+ });
+
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns forecast data for a city."
+ });
+
+ var tools = new List { translateTool, weatherTool };
+
+ IList? captured = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .UseFunctionInvocation()
+ .Use((messages, options, next, ct) =>
+ {
+ captured = options?.Tools;
+ return next(messages, options, ct);
+ })
+ .Build();
+
+ var history = new List
+ {
+ new(ChatRole.User, "What will the weather be like in Redmond next week?")
+ };
+
+ var response = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = tools,
+ ToolMode = ChatToolMode.RequireSpecific(translateTool.Name)
+ });
+ history.AddMessages(response);
+
+ Assert.NotNull(captured);
+ Assert.Equal(2, captured!.Count);
+ Assert.Equal("TranslateText", captured[0].Name); // Required should appear first.
+ Assert.Equal("GetWeatherForecast", captured[1].Name);
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_ToolRemovedAfterFirstUse_NotInvokedAgain()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ int weatherInvocationCount = 0;
+
+ var weatherTool = AIFunctionFactory.Create(
+ () =>
+ {
+ weatherInvocationCount++;
+ return "Sunny and dry.";
+ },
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeather",
+ Description = "Gets the weather forecast for a given location."
+ });
+
+ // Strategy exposes tools only on the first request, then removes them.
+ var removalStrategy = new RemoveToolAfterFirstUseStrategy();
+
+ IList? firstTurnTools = null;
+ IList? secondTurnTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ // Place capture immediately after reduction so it's invoked exactly once per user request.
+ .UseToolReduction(removalStrategy)
+ .Use((messages, options, next, ct) =>
+ {
+ if (firstTurnTools is null)
+ {
+ firstTurnTools = options?.Tools;
+ }
+ else
+ {
+ secondTurnTools ??= options?.Tools;
+ }
+
+ return next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ List history = [];
+
+ // Turn 1
+ history.Add(new ChatMessage(ChatRole.User, "What's the weather like tomorrow in Seattle?"));
+ var firstResponse = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = [weatherTool],
+ ToolMode = ChatToolMode.RequireAny
+ });
+ history.AddMessages(firstResponse);
+
+ Assert.Equal(1, weatherInvocationCount);
+ Assert.NotNull(firstTurnTools);
+ Assert.Contains(firstTurnTools!, t => t.Name == "GetWeather");
+
+ // Turn 2 (tool removed by strategy even though caller supplies it again)
+ history.Add(new ChatMessage(ChatRole.User, "And what about next week?"));
+ var secondResponse = await client.GetResponseAsync(history, new ChatOptions
+ {
+ Tools = [weatherTool]
+ });
+ history.AddMessages(secondResponse);
+
+ Assert.Equal(1, weatherInvocationCount); // Not invoked again.
+ Assert.NotNull(secondTurnTools);
+ Assert.Empty(secondTurnTools!); // Strategy removed the tool set.
+
+ // Response text shouldn't just echo the tool's stub output.
+ Assert.DoesNotContain("Sunny and dry.", secondResponse.Text, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [ConditionalFact]
+ public virtual async Task ToolReduction_MessagesEmbeddingTextSelector_UsesChatClientToAnalyzeConversation()
+ {
+ SkipIfNotEnabled();
+ EnsureEmbeddingGenerator();
+
+ // Create tools for different domains.
+ var weatherTool = AIFunctionFactory.Create(
+ () => "Weather data",
+ new AIFunctionFactoryOptions
+ {
+ Name = "GetWeatherForecast",
+ Description = "Returns weather forecast and temperature for a given city."
+ });
+
+ var translateTool = AIFunctionFactory.Create(
+ () => "Translated text",
+ new AIFunctionFactoryOptions
+ {
+ Name = "TranslateText",
+ Description = "Translates text between human languages."
+ });
+
+ var mathTool = AIFunctionFactory.Create(
+ () => 42,
+ new AIFunctionFactoryOptions
+ {
+ Name = "SolveMath",
+ Description = "Solves basic math problems."
+ });
+
+ var allTools = new List { weatherTool, translateTool, mathTool };
+
+ // Track the analysis result from the chat client used in the selector.
+ string? capturedAnalysis = null;
+
+ var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2)
+ {
+ // Use a chat client to analyze the conversation and extract relevant tool categories.
+ MessagesEmbeddingTextSelector = async messages =>
+ {
+ var conversationText = string.Join("\n", messages.Select(m => $"{m.Role}: {m.Text}"));
+
+ var analysisPrompt = $"""
+ Analyze the following conversation and identify what kinds of tools would be most helpful.
+ Focus on the key topics and tasks being discussed.
+ Respond with a brief summary of the relevant tool categories (e.g., "weather", "translation", "math").
+
+ Conversation:
+ {conversationText}
+
+ Relevant tool categories:
+ """;
+
+ var response = await ChatClient!.GetResponseAsync(analysisPrompt);
+ capturedAnalysis = response.Text;
+
+ // Return the analysis as the query text for embedding-based tool selection.
+ return capturedAnalysis;
+ }
+ };
+
+ IList? selectedTools = null;
+
+ using var client = ChatClient!
+ .AsBuilder()
+ .UseToolReduction(strategy)
+ .Use(async (messages, options, next, ct) =>
+ {
+ selectedTools = options?.Tools;
+ await next(messages, options, ct);
+ })
+ .UseFunctionInvocation()
+ .Build();
+
+ // Conversation that clearly indicates weather-related needs.
+ List history = [];
+ history.Add(new ChatMessage(ChatRole.User, "What will the weather be like in London tomorrow?"));
+
+ var response = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools });
+ history.AddMessages(response);
+
+ // Verify that the chat client was used to analyze the conversation.
+ Assert.NotNull(capturedAnalysis);
+ Assert.True(
+ capturedAnalysis.IndexOf("weather", StringComparison.OrdinalIgnoreCase) >= 0 ||
+ capturedAnalysis.IndexOf("forecast", StringComparison.OrdinalIgnoreCase) >= 0,
+ $"Expected analysis to mention weather or forecast: {capturedAnalysis}");
+
+ // Verify that the tool selection was influenced by the analysis.
+ Assert.NotNull(selectedTools);
+ Assert.True(selectedTools.Count <= 2, $"Expected at most 2 tools, got {selectedTools.Count}");
+ Assert.Contains(selectedTools, t => t.Name == "GetWeatherForecast");
+ }
+
+ // Test-only custom strategy: include tools on first request, then remove them afterward.
+ private sealed class RemoveToolAfterFirstUseStrategy : IToolReductionStrategy
+ {
+ private bool _used;
+
+ public Task> SelectToolsForRequestAsync(
+ IEnumerable messages,
+ ChatOptions? options,
+ CancellationToken cancellationToken = default)
+ {
+ if (!_used && options?.Tools is { Count: > 0 })
+ {
+ _used = true;
+ // Returning the same instance signals no change.
+ return Task.FromResult>(options.Tools);
+ }
+
+ // After first use, remove all tools.
+ return Task.FromResult>(Array.Empty());
+ }
+ }
+
[MemberNotNull(nameof(ChatClient))]
protected void SkipIfNotEnabled()
{
@@ -1405,4 +2143,15 @@ protected void SkipIfNotEnabled()
throw new SkipTestException("Client is not enabled.");
}
}
+
+ [MemberNotNull(nameof(EmbeddingGenerator))]
+ protected void EnsureEmbeddingGenerator()
+ {
+ EmbeddingGenerator ??= CreateEmbeddingGenerator();
+
+ if (EmbeddingGenerator is null)
+ {
+ throw new SkipTestException("Embedding generator is not enabled.");
+ }
+ }
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs
new file mode 100644
index 00000000000..7352814e91c
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolGroupingTests.cs
@@ -0,0 +1,1301 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class ToolGroupingTests
+{
+ private const string DefaultExpansionFunctionName = "__expand_tool_group";
+ private const string DefaultListGroupsFunctionName = "__list_tool_groups";
+
+ [Fact]
+ public async Task ToolGroupingChatClient_Collapsed_IncludesUtilityAndUngroupedToolsOnly()
+ {
+ var ungrouped = new SimpleTool("Basic", "basic");
+ var groupedA = new SimpleTool("A1", "a1");
+ var groupedB = new SimpleTool("B1", "b1");
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "hello")],
+ Options = new ChatOptions { Tools = [ungrouped, AIToolGroup.Create("GroupA", "Group A", [groupedA]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] },
+ ConfigureToolGroupingOptions = options => { },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "Hi"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("Hi", testResult.Response.Text);
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+
+ void AssertObservedTools(List?> observedTools)
+ {
+ var tools = Assert.Single(observedTools);
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.Contains(tools, t => t.Name == DefaultListGroupsFunctionName);
+ Assert.DoesNotContain(tools, t => t.Name == groupedA.Name);
+ Assert.DoesNotContain(tools, t => t.Name == groupedB.Name);
+ }
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_ExpansionLoop_ExpandsSingleGroup()
+ {
+ var groupedA1 = new SimpleTool("A1", "a1");
+ var groupedA2 = new SimpleTool("A2", "a2");
+ var groupedB = new SimpleTool("B1", "b1");
+ var ungrouped = new SimpleTool("Common", "c");
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [ungrouped, AIToolGroup.Create("GroupA", "Group A", [groupedA1, groupedA2]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.MaxExpansionsPerRequest = 1;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call1", "GroupA"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "Done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("Done", testResult.Response.Text);
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+
+ void AssertObservedTools(List?> observed) => Assert.Collection(observed,
+ tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.DoesNotContain(tools, t => t.Name == groupedA1.Name);
+ Assert.DoesNotContain(tools, t => t.Name == groupedB.Name);
+ },
+ tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.Contains(tools, t => t.Name == groupedA1.Name);
+ Assert.Contains(tools, t => t.Name == groupedA2.Name);
+ Assert.DoesNotContain(tools, t => t.Name == groupedB.Name);
+ });
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+
+ AssertContainsResultMessage(result.Response, "Successfully expanded group 'GroupA'");
+ AssertContainsResultMessage(streamingResult.Response, "Successfully expanded group 'GroupA'");
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_NoGroups_BypassesMiddleware()
+ {
+ var tool = new SimpleTool("Standalone", "s");
+
+ ToolGroupingTestScenario CreateScenario(ChatOptions options, List observedOptions, List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "hello")],
+ Options = options,
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "ok"),
+ AssertInvocation = ctx =>
+ {
+ observedOptions.Add(ctx.Options);
+ observedTools.Add(ctx.Options?.Tools?.ToList());
+ }
+ }
+ ]
+ };
+
+ List observedOptionsNonStreaming = [];
+ List?> observedToolsNonStreaming = [];
+ ChatOptions nonStreamingOptions = new() { Tools = [tool] };
+ var result = await InvokeAndAssertAsync(CreateScenario(nonStreamingOptions, observedOptionsNonStreaming, observedToolsNonStreaming));
+
+ List observedOptionsStreaming = [];
+ List?> observedToolsStreaming = [];
+ ChatOptions streamingOptions = new() { Tools = [tool] };
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(streamingOptions, observedOptionsStreaming, observedToolsStreaming));
+
+ void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("ok", testResult.Response.Text);
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+
+ static void AssertObservedOptions(ChatOptions expected, List observed) =>
+ Assert.Same(expected, Assert.Single(observed));
+
+ static void AssertObservedTools(List?> observed)
+ {
+ var tools = Assert.Single(observed);
+ Assert.NotNull(tools);
+ Assert.DoesNotContain(tools!, t => t.Name == DefaultExpansionFunctionName);
+ Assert.DoesNotContain(tools!, t => t.Name == DefaultListGroupsFunctionName);
+ }
+
+ AssertObservedOptions(nonStreamingOptions, observedOptionsNonStreaming);
+ AssertObservedOptions(streamingOptions, observedOptionsStreaming);
+
+ AssertObservedTools(observedToolsNonStreaming);
+ AssertObservedTools(observedToolsStreaming);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_InvalidGroupRequest_ReturnsResultMessage()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.MaxExpansionsPerRequest = 2;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "Unknown" })])
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "Oops!"),
+ }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "was invalid; ignoring expansion request");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_MissingGroupName_ReturnsNotice()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.MaxExpansionsPerRequest = 2;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName)])
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "Oops!"),
+ }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "No group name was specified");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_GroupNameReadsJsonElement()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+ var jsonValue = JsonDocument.Parse("\"GroupA\"").RootElement;
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent(Guid.NewGuid().ToString("N"), DefaultExpansionFunctionName, new Dictionary { ["groupName"] = jsonValue })])
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done")
+ }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_MultipleValidExpansions_LastWins()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+ var groupedB = new SimpleTool("B1", "b1");
+ var alwaysOn = new SimpleTool("Common", "c");
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [alwaysOn, AIToolGroup.Create("GroupA", "Group A", [groupedA]), AIToolGroup.Create("GroupB", "Group B", [groupedB])] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.MaxExpansionsPerRequest = 2;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [
+ new FunctionCallContent("call1", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" }),
+ new FunctionCallContent("call2", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupB" })
+ ]),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedToolsNonStreaming = [];
+ var result = await InvokeAndAssertAsync(CreateScenario(observedToolsNonStreaming));
+
+ List?> observedToolsStreaming = [];
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedToolsStreaming));
+
+ void AssertObservedTools(List?> observed) => Assert.Collection(observed,
+ tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == alwaysOn.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ },
+ tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == alwaysOn.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.DoesNotContain(tools, t => t.Name == groupedA.Name);
+ Assert.Contains(tools, t => t.Name == groupedB.Name);
+ });
+
+ AssertObservedTools(observedToolsNonStreaming);
+ AssertObservedTools(observedToolsStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult)
+ {
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'");
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupB'");
+ }
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_DuplicateExpansionSameIteration_Reported()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.MaxExpansionsPerRequest = 2;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [
+ new FunctionCallContent("call1", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" }),
+ new FunctionCallContent("call2", DefaultExpansionFunctionName, new Dictionary { ["groupName"] = "GroupA" })
+ ])
+ },
+ new DownstreamTurn { ResponseMessage = new(ChatRole.Assistant, "done") }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "Ignoring duplicate expansion");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_ReexpandingSameGroupDoesNotTerminateLoop()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "GroupA") },
+ new DownstreamTurn { ResponseMessage = CreateExpansionCall("call2", "GroupA") },
+ new DownstreamTurn { ResponseMessage = new ChatMessage(ChatRole.Assistant, "Oops!") }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "Ignoring duplicate expansion of group 'GroupA'.");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_PropagatesConversationIdBetweenIterations()
+ {
+ var groupedA = new SimpleTool("A1", "a1");
+
+ ToolGroupingTestScenario CreateScenario(List observedConversationIds) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "go")],
+ Options = new ChatOptions { Tools = [AIToolGroup.Create("GroupA", "Group A", [groupedA])] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call1", "GroupA"),
+ ConversationId = "conv-1"
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedConversationIds.Add(ctx.Options?.ConversationId)
+ }
+ ]
+ };
+
+ List observedNonStreaming = [];
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+
+ List observedStreaming = [];
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ static void AssertConversationIds(List observed) => Assert.Equal("conv-1", Assert.Single(observed));
+
+ AssertConversationIds(observedNonStreaming);
+ AssertConversationIds(observedStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult) => Assert.Equal("done", testResult.Response.Text);
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_NestedGroups_ExpandsParentThenChild()
+ {
+ var nestedTool1 = new SimpleTool("NestedTool1", "nested tool 1");
+ var nestedTool2 = new SimpleTool("NestedTool2", "nested tool 2");
+ var parentTool = new SimpleTool("ParentTool", "parent tool");
+ var ungrouped = new SimpleTool("Ungrouped", "ungrouped");
+
+ var nestedGroup = AIToolGroup.Create("NestedGroup", "Nested group", [nestedTool1, nestedTool2]);
+ var parentGroup = AIToolGroup.Create("ParentGroup", "Parent group", [parentTool, nestedGroup]);
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "expand parent then nested")],
+ Options = new ChatOptions { Tools = [ungrouped, parentGroup] },
+ ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 2,
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call1", "ParentGroup"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call2", "NestedGroup"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertObservedTools(List?> observed) => Assert.Collection(observed,
+ tools =>
+ {
+ // First iteration: collapsed state
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.DoesNotContain(tools, t => t.Name == parentTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == nestedTool1.Name);
+ },
+ tools =>
+ {
+ // Second iteration: parent expanded
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.Contains(tools, t => t.Name == parentTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == nestedTool1.Name);
+
+ // NestedGroup should NOT be in tools list (only actual tools)
+ Assert.DoesNotContain(tools, t => t.Name == "NestedGroup");
+ },
+ tools =>
+ {
+ // Third iteration: nested group expanded
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungrouped.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ Assert.Contains(tools, t => t.Name == nestedTool1.Name);
+ Assert.Contains(tools, t => t.Name == nestedTool2.Name);
+ Assert.DoesNotContain(tools, t => t.Name == parentTool.Name);
+ });
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult)
+ {
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'ParentGroup'");
+ AssertContainsResultMessage(testResult.Response, "Additional groups available for expansion");
+ AssertContainsResultMessage(testResult.Response, "- NestedGroup:");
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'NestedGroup'");
+ }
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_NestedGroups_CannotExpandNestedFromDifferentParent()
+ {
+ var toolA = new SimpleTool("ToolA", "tool a");
+ var nestedATool = new SimpleTool("NestedATool", "nested a tool");
+ var nestedA = AIToolGroup.Create("NestedA", "Nested A", [nestedATool]);
+ var groupA = AIToolGroup.Create("GroupA", "Group A", [toolA, nestedA]);
+
+ var toolB = new SimpleTool("ToolB", "tool b");
+ var nestedBTool = new SimpleTool("NestedBTool", "nested b tool");
+ var nestedB = AIToolGroup.Create("NestedB", "Nested B", [nestedBTool]);
+ var groupB = AIToolGroup.Create("GroupB", "Group B", [toolB, nestedB]);
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "try to expand wrong nested group")],
+ Options = new ChatOptions { Tools = [groupA, groupB] },
+ ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 2,
+ Turns =
+ [
+ new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "GroupA") },
+ new DownstreamTurn
+ {
+ // Try to expand NestedB (belongs to unexpanded GroupB) - should fail
+ ResponseMessage = CreateExpansionCall("call2", "NestedB")
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "Oops!"),
+ }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ void AssertResponse(ToolGroupingTestResult testResult)
+ {
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'GroupA'");
+ AssertContainsResultMessage(testResult.Response, "group name 'NestedB' was invalid");
+ }
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_NestedGroups_MultiLevelNesting()
+ {
+ var deeplyNestedTool = new SimpleTool("DeeplyNestedTool", "deeply nested tool");
+ var deeplyNested = AIToolGroup.Create("DeeplyNested", "Deeply nested group", [deeplyNestedTool]);
+
+ var nestedTool = new SimpleTool("NestedTool", "nested tool");
+ var nested = AIToolGroup.Create("Nested", "Nested group", [nestedTool, deeplyNested]);
+
+ var topTool = new SimpleTool("TopTool", "top tool");
+ var topGroup = AIToolGroup.Create("TopGroup", "Top group", [topTool, nested]);
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "three level nesting")],
+ Options = new ChatOptions { Tools = [topGroup] },
+ ConfigureToolGroupingOptions = options => options.MaxExpansionsPerRequest = 3,
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call1", "TopGroup"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call2", "Nested"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call3", "DeeplyNested"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertObservedTools(List?> observed) => Assert.Collection(observed,
+ tools =>
+ {
+ // Collapsed: only expansion function
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == DefaultExpansionFunctionName);
+ Assert.DoesNotContain(tools, t => t.Name == topTool.Name);
+ },
+ tools =>
+ {
+ // TopGroup expanded: topTool + Nested available
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == topTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == nestedTool.Name);
+ },
+ tools =>
+ {
+ // Nested expanded: nestedTool + DeeplyNested available
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == nestedTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == topTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == deeplyNestedTool.Name);
+ },
+ tools =>
+ {
+ // DeeplyNested expanded: only deeplyNestedTool
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == deeplyNestedTool.Name);
+ Assert.DoesNotContain(tools, t => t.Name == nestedTool.Name);
+ });
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult)
+ {
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'TopGroup'");
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'Nested'");
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'DeeplyNested'");
+ }
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_ToolNameCollision_WithExpansionFunction()
+ {
+ var collisionTool = new SimpleTool(DefaultExpansionFunctionName, "collision");
+ var group = AIToolGroup.Create("Group", "group", [new SimpleTool("Tool", "tool")]);
+
+ using TestChatClient inner = new();
+ using IChatClient client = inner.AsBuilder().UseToolGrouping(_ => { }).Build();
+
+ var options = new ChatOptions { Tools = [collisionTool, group] };
+
+ var exception = await Assert.ThrowsAsync(async () =>
+ {
+ try
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+ catch (NotSupportedException)
+ {
+ // Inner client throws NotSupportedException, but we should hit InvalidOperationException first
+ throw;
+ }
+ });
+
+ Assert.Contains(DefaultExpansionFunctionName, exception.Message);
+ Assert.Contains("collides", exception.Message);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_ToolNameCollision_WithListGroupsFunction()
+ {
+ var collisionTool = new SimpleTool(DefaultListGroupsFunctionName, "collision");
+ var group = AIToolGroup.Create("Group", "group", [new SimpleTool("Tool", "tool")]);
+
+ using TestChatClient inner = new();
+ using IChatClient client = inner.AsBuilder().UseToolGrouping(_ => { }).Build();
+
+ var options = new ChatOptions { Tools = [collisionTool, group] };
+
+ var exception = await Assert.ThrowsAsync(async () =>
+ {
+ try
+ {
+ await client.GetResponseAsync([new ChatMessage(ChatRole.User, "test")], options);
+ }
+ catch (NotSupportedException)
+ {
+ throw;
+ }
+ });
+
+ Assert.Contains(DefaultListGroupsFunctionName, exception.Message);
+ Assert.Contains("collides", exception.Message);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_DynamicToolGroup_GetToolsAsyncCalled()
+ {
+ var tool = new SimpleTool("DynamicTool", "dynamic tool");
+ bool getToolsAsyncCalled = false;
+
+ var dynamicGroup = new DynamicToolGroup(
+ "DynamicGroup",
+ "Dynamic group",
+ async ct =>
+ {
+ getToolsAsyncCalled = true;
+ await Task.Yield();
+ return [tool];
+ });
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "expand dynamic group")],
+ Options = new ChatOptions { Tools = [dynamicGroup] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "DynamicGroup") },
+ new DownstreamTurn { ResponseMessage = new(ChatRole.Assistant, "done") }
+ ]
+ };
+
+ var result = await InvokeAndAssertAsync(CreateScenario());
+
+ Assert.True(getToolsAsyncCalled, "GetToolsAsync should have been called");
+ AssertContainsResultMessage(result.Response, "Successfully expanded group 'DynamicGroup'");
+
+ // Reset for streaming test
+ getToolsAsyncCalled = false;
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario());
+
+ Assert.True(getToolsAsyncCalled, "GetToolsAsync should have been called in streaming");
+ AssertContainsResultMessage(streamingResult.Response, "Successfully expanded group 'DynamicGroup'");
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_DynamicToolGroup_ThrowsException()
+ {
+ var dynamicGroup = new DynamicToolGroup(
+ "FailingGroup",
+ "Failing group",
+ ct => throw new InvalidOperationException("Simulated failure"));
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "expand failing group")],
+ Options = new ChatOptions { Tools = [dynamicGroup] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn { ResponseMessage = CreateExpansionCall("call1", "FailingGroup") }
+ ]
+ };
+
+ await Assert.ThrowsAsync(async () =>
+ await InvokeAndAssertAsync(CreateScenario()));
+
+ await Assert.ThrowsAsync(async () =>
+ await InvokeAndAssertStreamingAsync(CreateScenario()));
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_EmptyGroupExpansion_ReturnsNoTools()
+ {
+ var emptyGroup = AIToolGroup.Create("EmptyGroup", "Empty group", []);
+ var ungroupedTool = new SimpleTool("Ungrouped", "ungrouped");
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "expand empty group")],
+ Options = new ChatOptions { Tools = [ungroupedTool, emptyGroup] },
+ ConfigureToolGroupingOptions = _ => { },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = CreateExpansionCall("call1", "EmptyGroup"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertObservedTools(List?> observed) => Assert.Collection(observed,
+ tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungroupedTool.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+ },
+ tools =>
+ {
+ // After expanding empty group, only ungrouped tool + expansion/list functions remain
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == ungroupedTool.Name);
+ Assert.Contains(tools, t => t.Name == DefaultExpansionFunctionName);
+
+ // No group-specific tools should be added
+ Assert.Equal(3, tools.Count); // ungrouped + expansion + list
+ });
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'EmptyGroup'");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_CustomExpansionFunctionName()
+ {
+ var tool = new SimpleTool("Tool", "tool");
+ var group = AIToolGroup.Create("Group", "group", [tool]);
+
+ const string CustomExpansionName = "my_custom_expand";
+
+ ToolGroupingTestScenario CreateScenario(List?> observedTools) => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "test custom name")],
+ Options = new ChatOptions { Tools = [group] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.ExpansionFunctionName = CustomExpansionName;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new ChatMessage(ChatRole.Assistant,
+ [new FunctionCallContent("call1", CustomExpansionName, new Dictionary { ["groupName"] = "Group" })]),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ },
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "done"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ List?> observedNonStreaming = [];
+ List?> observedStreaming = [];
+
+ var result = await InvokeAndAssertAsync(CreateScenario(observedNonStreaming));
+ var streamingResult = await InvokeAndAssertStreamingAsync(CreateScenario(observedStreaming));
+
+ void AssertObservedTools(List?> observed)
+ {
+ Assert.Equal(2, observed.Count);
+
+ // All iterations should have custom expansion function
+ Assert.All(observed, tools =>
+ {
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == CustomExpansionName);
+ Assert.DoesNotContain(tools, t => t.Name == DefaultExpansionFunctionName);
+ });
+ }
+
+ AssertObservedTools(observedNonStreaming);
+ AssertObservedTools(observedStreaming);
+
+ void AssertResponse(ToolGroupingTestResult testResult) =>
+ AssertContainsResultMessage(testResult.Response, "Successfully expanded group 'Group'");
+
+ AssertResponse(result);
+ AssertResponse(streamingResult);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_CustomExpansionFunctionDescription()
+ {
+ var tool = new SimpleTool("Tool", "tool");
+ var group = AIToolGroup.Create("Group", "group", [tool]);
+
+ const string CustomDescription = "Use this custom function to expand a tool group";
+
+ List?> observedTools = [];
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "test custom description")],
+ Options = new ChatOptions { Tools = [group] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.ExpansionFunctionDescription = CustomDescription;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "ok"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ await InvokeAndAssertAsync(CreateScenario());
+
+ var tools = Assert.Single(observedTools);
+ Assert.NotNull(tools);
+ var expansionTool = tools!.FirstOrDefault(t => t.Name == DefaultExpansionFunctionName);
+ Assert.NotNull(expansionTool);
+ Assert.Equal(CustomDescription, expansionTool!.Description);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_CustomListGroupsFunctionName()
+ {
+ var tool = new SimpleTool("Tool", "tool");
+ var group = AIToolGroup.Create("Group", "group", [tool]);
+
+ const string CustomListName = "my_list_groups";
+
+ List?> observedTools = [];
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "test custom list name")],
+ Options = new ChatOptions { Tools = [group] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.ListGroupsFunctionName = CustomListName;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "ok"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ await InvokeAndAssertAsync(CreateScenario());
+
+ var tools = Assert.Single(observedTools);
+ Assert.NotNull(tools);
+ Assert.Contains(tools!, t => t.Name == CustomListName);
+ Assert.DoesNotContain(tools, t => t.Name == DefaultListGroupsFunctionName);
+ }
+
+ [Fact]
+ public async Task ToolGroupingChatClient_CustomListGroupsFunctionDescription()
+ {
+ var tool = new SimpleTool("Tool", "tool");
+ var group = AIToolGroup.Create("Group", "group", [tool]);
+
+ const string CustomDescription = "Custom description for listing groups";
+
+ List?> observedTools = [];
+
+ ToolGroupingTestScenario CreateScenario() => new()
+ {
+ InitialMessages = [new ChatMessage(ChatRole.User, "test custom list description")],
+ Options = new ChatOptions { Tools = [group] },
+ ConfigureToolGroupingOptions = options =>
+ {
+ options.ListGroupsFunctionDescription = CustomDescription;
+ },
+ Turns =
+ [
+ new DownstreamTurn
+ {
+ ResponseMessage = new(ChatRole.Assistant, "ok"),
+ AssertInvocation = ctx => observedTools.Add(ctx.Options?.Tools?.ToList()),
+ }
+ ]
+ };
+
+ await InvokeAndAssertAsync(CreateScenario());
+
+ var tools = Assert.Single(observedTools);
+ Assert.NotNull(tools);
+ var listTool = tools!.FirstOrDefault(t => t.Name == DefaultListGroupsFunctionName);
+ Assert.NotNull(listTool);
+ Assert.Equal(CustomDescription, listTool!.Description);
+ }
+
+ private sealed class DynamicToolGroup : AIToolGroup
+ {
+ private readonly Func>> _getToolsFunc;
+
+ public DynamicToolGroup(string name, string description, Func>> getToolsFunc)
+ : base(name, description)
+ {
+ _getToolsFunc = getToolsFunc;
+ }
+
+ public override async ValueTask> GetToolsAsync(CancellationToken cancellationToken = default)
+ {
+ var tools = await _getToolsFunc(cancellationToken);
+ return tools.ToList();
+ }
+ }
+
+ private static async Task InvokeAndAssertAsync(ToolGroupingTestScenario scenario)
+ {
+ if (scenario.InitialMessages.Count == 0)
+ {
+ throw new InvalidOperationException("Scenario must include at least one initial message.");
+ }
+
+ List turns = scenario.Turns;
+ long expectedTotalTokenCounts = 0;
+ int iteration = 0;
+
+ using TestChatClient inner = new();
+
+ inner.GetResponseAsyncCallback = (messages, options, cancellationToken) =>
+ {
+ var materialized = messages.ToList();
+ if (iteration >= turns.Count)
+ {
+ throw new InvalidOperationException("Unexpected additional iteration.");
+ }
+
+ var turn = turns[iteration];
+ turn.AssertInvocation?.Invoke(new ToolGroupingInvocationContext(iteration, materialized, options));
+
+ UsageDetails usage = CreateRandomUsage();
+ expectedTotalTokenCounts += usage.InputTokenCount!.Value;
+
+ var response = new ChatResponse(turn.ResponseMessage)
+ {
+ Usage = usage,
+ ConversationId = turn.ConversationId,
+ };
+ iteration++;
+ return Task.FromResult(response);
+ };
+
+ using IChatClient client = inner.AsBuilder().UseToolGrouping(scenario.ConfigureToolGroupingOptions).Build();
+
+ var request = new EnumeratedOnceEnumerable(scenario.InitialMessages);
+ ChatResponse response = await client.GetResponseAsync(request, scenario.Options, CancellationToken.None);
+
+ Assert.Equal(turns.Count, iteration);
+
+ // Usage should be aggregated over all responses, including AdditionalUsage
+ var actualUsage = response.Usage!;
+ Assert.Equal(expectedTotalTokenCounts, actualUsage.InputTokenCount);
+ Assert.Equal(expectedTotalTokenCounts, actualUsage.OutputTokenCount);
+ Assert.Equal(expectedTotalTokenCounts, actualUsage.TotalTokenCount);
+ Assert.Equal(2, actualUsage.AdditionalCounts!.Count);
+ Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["firstValue"]);
+ Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["secondValue"]);
+
+ return new ToolGroupingTestResult(response);
+ }
+
+ private static async Task InvokeAndAssertStreamingAsync(ToolGroupingTestScenario scenario)
+ {
+ if (scenario.InitialMessages.Count == 0)
+ {
+ throw new InvalidOperationException("Scenario must include at least one initial message.");
+ }
+
+ List turns = scenario.Turns;
+ int iteration = 0;
+
+ using TestChatClient inner = new();
+
+ inner.GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) =>
+ {
+ var materialized = messages.ToList();
+ if (iteration >= turns.Count)
+ {
+ throw new InvalidOperationException("Unexpected additional iteration.");
+ }
+
+ var turn = turns[iteration];
+ turn.AssertInvocation?.Invoke(new ToolGroupingInvocationContext(iteration, materialized, options));
+
+ var response = new ChatResponse(turn.ResponseMessage)
+ {
+ ConversationId = turn.ConversationId,
+ };
+ iteration++;
+ return YieldAsync(response.ToChatResponseUpdates());
+ };
+
+ using IChatClient client = inner.AsBuilder().UseToolGrouping(scenario.ConfigureToolGroupingOptions).Build();
+
+ var request = new EnumeratedOnceEnumerable(scenario.InitialMessages);
+ ChatResponse response = await client.GetStreamingResponseAsync(request, scenario.Options, CancellationToken.None).ToChatResponseAsync();
+
+ Assert.Equal(turns.Count, iteration);
+
+ return new ToolGroupingTestResult(response);
+ }
+
+ private static UsageDetails CreateRandomUsage()
+ {
+ // We'll set the same random number on all the properties so that, when determining the
+ // correct sum in tests, we only have to total the values once
+ var value = new Random().Next(100);
+ return new UsageDetails
+ {
+ InputTokenCount = value,
+ OutputTokenCount = value,
+ TotalTokenCount = value,
+ AdditionalCounts = new() { ["firstValue"] = value, ["secondValue"] = value },
+ };
+ }
+
+ private static ChatMessage CreateExpansionCall(string callId, string groupName) =>
+ new(ChatRole.Assistant, [new FunctionCallContent(callId, DefaultExpansionFunctionName, new Dictionary { ["groupName"] = groupName })]);
+
+ private static void AssertContainsResultMessage(ChatResponse response, string substring)
+ {
+ var toolMessages = response.Messages.Where(m => m.Role == ChatRole.Tool).ToList();
+ Assert.NotEmpty(toolMessages);
+ Assert.Contains(toolMessages.SelectMany(m => m.Contents.OfType()), r =>
+ {
+ var text = r.Result?.ToString() ?? string.Empty;
+ return text.Contains(substring);
+ });
+ }
+
+ private static async IAsyncEnumerable YieldAsync(IEnumerable updates)
+ {
+ foreach (var update in updates)
+ {
+ yield return update;
+ await Task.Yield();
+ }
+ }
+
+ private sealed class SimpleTool : AITool
+ {
+ private readonly string _name;
+ private readonly string _description;
+
+ public SimpleTool(string name, string description)
+ {
+ _name = name;
+ _description = description;
+ }
+
+ public override string Name => _name;
+ public override string Description => _description;
+ }
+
+ private sealed class TestChatClient : IChatClient
+ {
+ public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; }
+ public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; }
+
+ public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (GetResponseAsyncCallback is null)
+ {
+ throw new NotSupportedException();
+ }
+
+ return GetResponseAsyncCallback(messages, options, cancellationToken);
+ }
+
+ public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (GetStreamingResponseAsyncCallback is null)
+ {
+ throw new NotSupportedException();
+ }
+
+ return GetStreamingResponseAsyncCallback(messages, options, cancellationToken);
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class ToolGroupingTestScenario
+ {
+ public List InitialMessages { get; init; } = [];
+ public Action ConfigureToolGroupingOptions { get; init; } = _ => { };
+ public List Turns { get; init; } = [];
+ public ChatOptions? Options { get; init; }
+ }
+
+ private sealed class DownstreamTurn
+ {
+ public ChatMessage ResponseMessage { get; init; } = new(ChatRole.Assistant, string.Empty);
+ public string? ConversationId { get; init; }
+ public Action? AssertInvocation { get; init; }
+ }
+
+ private sealed record ToolGroupingInvocationContext(int Iteration, IReadOnlyList Messages, ChatOptions? Options);
+
+ private sealed record ToolGroupingTestResult(ChatResponse Response);
+
+ private sealed class EnumeratedOnceEnumerable : IEnumerable
+ {
+ private readonly IEnumerable _items;
+ private bool _enumerated;
+
+ public EnumeratedOnceEnumerable(IEnumerable items)
+ {
+ _items = items;
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ if (_enumerated)
+ {
+ throw new InvalidOperationException("Sequence may only be enumerated once.");
+ }
+
+ _enumerated = true;
+ return _items.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs
new file mode 100644
index 00000000000..96c9adc6311
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs
@@ -0,0 +1,663 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class ToolReductionTests
+{
+ [Fact]
+ public void EmbeddingToolReductionStrategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0));
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_NoReduction_WhenToolsBelowLimit()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 5);
+
+ var tools = CreateTools("Weather", "Math");
+ var options = new ChatOptions { Tools = tools };
+
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Tell me about weather") },
+ options);
+
+ Assert.Same(tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_NoReduction_WhenOptionalToolsBelowLimit()
+ {
+ // 1 required + 2 optional, limit = 2 (optional count == limit) => original list returned
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2)
+ {
+ IsRequiredTool = t => t.Name == "Req"
+ };
+
+ var tools = CreateTools("Req", "Opt1", "Opt2");
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "anything") },
+ new ChatOptions { Tools = tools });
+
+ Assert.Same(tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var tools = CreateTools("Weather", "Translate", "Math", "Jokes");
+ var options = new ChatOptions { Tools = tools };
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "Can you do some weather math for forecasting?")
+ };
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Contains(reduced, t => t.Name == "Weather");
+ Assert.Contains(reduced, t => t.Name == "Math");
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_PreserveOriginalOrdering_ReordersAfterSelection()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2)
+ {
+ PreserveOriginalOrdering = true
+ };
+
+ var tools = CreateTools("Math", "Translate", "Weather");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Explain weather math please") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Equal("Math", reduced[0].Name);
+ Assert.Equal("Weather", reduced[1].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math", "Jokes");
+ var messages = new[] { new ChatMessage(ChatRole.User, "weather") };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+ int afterFirst = gen.TotalValueInputs;
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+ int afterSecond = gen.TotalValueInputs;
+
+ // +1 for second query embedding only
+ Assert.Equal(afterFirst + 1, afterSecond);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_OptionsNullOrNoTools_ReturnsEmptyOrOriginal()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var empty = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "anything") }, null);
+ Assert.Empty(empty);
+
+ var options = new ChatOptions { Tools = [] };
+ var result = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather") }, options);
+ Assert.Same(options.Tools, result);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrdering()
+ {
+ using var gen = new VectorBasedTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ Similarity = (q, t) => -t.Span[0]
+ };
+
+ var highTool = new SimpleTool("HighScore", "alpha");
+ var lowTool = new SimpleTool("LowScore", "beta");
+ gen.VectorSelector = text => text.Contains("alpha") ? 10f : 1f;
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "Pick something") },
+ new ChatOptions { Tools = [highTool, lowTool] })).ToList();
+
+ Assert.Single(reduced);
+ Assert.Equal("LowScore", reduced[0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_TieDeterminism_PrefersLowerOriginalIndex()
+ {
+ // Generator returns identical vectors so similarity ties; we expect original order preserved
+ using var gen = new ConstantEmbeddingGenerator(3);
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+
+ var tools = CreateTools("T1", "T2", "T3", "T4");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "any") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count);
+ Assert.Equal("T1", reduced[0].Name);
+ Assert.Equal("T2", reduced[1].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyDescription_UsesNameOnly()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+
+ var target = new SimpleTool("ComputeSum", description: "");
+ var filler = new SimpleTool("Other", "Unrelated");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("ComputeSum", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyName_UsesDescriptionOnly()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+
+ var target = new SimpleTool("", description: "Translates between languages.");
+ var filler = new SimpleTool("Other", "Unrelated");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "translate") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("Translates between languages.", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextSelector_Applied()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1)
+ {
+ ToolEmbeddingTextSelector = t => $"NAME:{t.Name}|DESC:{t.Description}"
+ };
+
+ var target = new SimpleTool("WeatherTool", "Gets forecast.");
+ var filler = new SimpleTool("Other", "Irrelevant");
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather") },
+ new ChatOptions { Tools = [target, filler] });
+
+ Assert.Contains("NAME:WeatherTool|DESC:Gets forecast.", recorder.Inputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_CustomFiltersMessages()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math", "Translate");
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "Please tell me the weather tomorrow."),
+ new ChatMessage(ChatRole.Assistant, "Sure, I can help."),
+ new ChatMessage(ChatRole.User, "Now instead solve a math problem.")
+ };
+
+ strategy.MessagesEmbeddingTextSelector = msgs => new ValueTask(msgs.LastOrDefault()?.Text ?? string.Empty);
+
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ messages,
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Single(reduced);
+ Assert.Equal("Math", reduced[0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_InvokedOnce()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("Weather", "Math");
+ int invocationCount = 0;
+
+ strategy.MessagesEmbeddingTextSelector = msgs =>
+ {
+ invocationCount++;
+ return new ValueTask(string.Join("\n", msgs.Select(m => m.Text)));
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather and math") },
+ new ChatOptions { Tools = tools });
+
+ Assert.Equal(1, invocationCount);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_IncludesReasoningContent()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+ var tools = CreateTools("Weather", "Math");
+
+ var reasoningLine = "Thinking about the best way to get tomorrow's forecast...";
+ var answerLine = "Tomorrow will be sunny.";
+ var userLine = "What's the weather tomorrow?";
+
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, userLine),
+ new ChatMessage(ChatRole.Assistant,
+ [
+ new TextReasoningContent(reasoningLine),
+ new TextContent(answerLine)
+ ])
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ string queryInput = recorder.Inputs[0];
+
+ Assert.Contains(userLine, queryInput);
+ Assert.Contains(reasoningLine, queryInput);
+ Assert.Contains(answerLine, queryInput);
+
+ var userIndex = queryInput.IndexOf(userLine, StringComparison.Ordinal);
+ var reasoningIndex = queryInput.IndexOf(reasoningLine, StringComparison.Ordinal);
+ var answerIndex = queryInput.IndexOf(answerLine, StringComparison.Ordinal);
+ Assert.True(userIndex >= 0 && reasoningIndex > userIndex && answerIndex > reasoningIndex);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_SkipsNonTextContent()
+ {
+ using var recorder = new RecordingEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1);
+ var tools = CreateTools("Alpha", "Beta");
+
+ var textOnly = "Provide translation.";
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User,
+ [
+ new DataContent(new byte[] { 1, 2, 3 }, "application/octet-stream"),
+ new TextContent(textOnly)
+ ])
+ };
+
+ _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ var queryInput = recorder.Inputs[0];
+ Assert.Contains(textOnly, queryInput);
+ Assert.DoesNotContain("application/octet-stream", queryInput, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_RequiredToolAlwaysIncluded()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name == "Core"
+ };
+
+ var tools = CreateTools("Core", "Weather", "Math");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(2, reduced.Count); // required + one optional (limit=1)
+ Assert.Contains(reduced, t => t.Name == "Core");
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_MultipleRequiredTools_ExceedLimit_AllRequiredIncluded()
+ {
+ // 3 required, limit=1 => expect 3 required + 1 ranked optional = 4 total
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name.StartsWith("R", StringComparison.Ordinal)
+ };
+
+ var tools = CreateTools("R1", "R2", "R3", "Weather", "Math");
+ var reduced = (await strategy.SelectToolsForRequestAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather math") },
+ new ChatOptions { Tools = tools })).ToList();
+
+ Assert.Equal(4, reduced.Count);
+ Assert.Equal(3, reduced.Count(t => t.Name.StartsWith("R")));
+ }
+
+ [Fact]
+ public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2);
+ var tools = CreateTools("Weather", "Math", "Translate", "Jokes");
+
+ IList? observedTools = null;
+
+ using var inner = new TestChatClient
+ {
+ GetResponseAsyncCallback = (messages, options, ct) =>
+ {
+ observedTools = options?.Tools;
+ return Task.FromResult(new ChatResponse());
+ }
+ };
+
+ using var client = inner.AsBuilder().UseToolReduction(strategy).Build();
+
+ await client.GetResponseAsync(
+ new[] { new ChatMessage(ChatRole.User, "weather math please") },
+ new ChatOptions { Tools = tools });
+
+ Assert.NotNull(observedTools);
+ Assert.Equal(2, observedTools!.Count);
+ Assert.Contains(observedTools, t => t.Name == "Weather");
+ Assert.Contains(observedTools, t => t.Name == "Math");
+ }
+
+ [Fact]
+ public async Task ToolReducingChatClient_ReducesTools_ForStreaming()
+ {
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+ var tools = CreateTools("Weather", "Math");
+
+ IList? observedTools = null;
+
+ using var inner = new TestChatClient
+ {
+ GetStreamingResponseAsyncCallback = (messages, options, ct) =>
+ {
+ observedTools = options?.Tools;
+ return EmptyAsyncEnumerable();
+ }
+ };
+
+ using var client = inner.AsBuilder().UseToolReduction(strategy).Build();
+
+ await foreach (var _ in client.GetStreamingResponseAsync(
+ new[] { new ChatMessage(ChatRole.User, "math") },
+ new ChatOptions { Tools = tools }))
+ {
+ // Consume
+ }
+
+ Assert.NotNull(observedTools);
+ Assert.Single(observedTools!);
+ Assert.Equal("Math", observedTools![0].Name);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction()
+ {
+ // Arrange: more tools than limit so we'd normally reduce, but query is empty -> return full list unchanged.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1);
+
+ var tools = CreateTools("ToolA", "ToolB", "ToolC");
+ var options = new ChatOptions { Tools = tools };
+
+ // Empty / whitespace message text produces empty query.
+ var messages = new[] { new ChatMessage(ChatRole.User, " ") };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, options);
+
+ // Assert: same reference (no reduction), and generator not invoked at all.
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction_WithRequiredTool()
+ {
+ // Arrange: required tool + optional tools; still should return original set when query is empty.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ IsRequiredTool = t => t.Name == "Req"
+ };
+
+ var tools = CreateTools("Req", "Optional1", "Optional2");
+ var options = new ChatOptions { Tools = tools };
+
+ var messages = new[] { new ChatMessage(ChatRole.User, " ") };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, options);
+
+ // Assert
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ [Fact]
+ public async Task EmbeddingToolReductionStrategy_EmptyQuery_ViaCustomMessagesSelector_NoReduction()
+ {
+ // Arrange: force empty query through custom selector returning whitespace.
+ using var gen = new DeterministicTestEmbeddingGenerator();
+ var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1)
+ {
+ MessagesEmbeddingTextSelector = _ => new ValueTask(" ")
+ };
+
+ var tools = CreateTools("One", "Two");
+ var messages = new[]
+ {
+ new ChatMessage(ChatRole.User, "This content will be ignored by custom selector.")
+ };
+
+ // Act
+ var result = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools });
+
+ // Assert: no reduction and no embeddings generated.
+ Assert.Same(tools, result);
+ Assert.Equal(0, gen.TotalValueInputs);
+ }
+
+ private static List CreateTools(params string[] names) =>
+ names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList();
+
+#pragma warning disable CS1998
+ private static async IAsyncEnumerable EmptyAsyncEnumerable()
+ {
+ yield break;
+ }
+#pragma warning restore CS1998
+
+ private sealed class SimpleTool : AITool
+ {
+ private readonly string _name;
+ private readonly string _description;
+
+ public SimpleTool(string name, string description)
+ {
+ _name = name;
+ _description = description;
+ }
+
+ public override string Name => _name;
+ public override string Description => _description;
+ }
+
+ ///
+ /// Deterministic embedding generator producing sparse keyword indicator vectors.
+ /// Each dimension corresponds to a known keyword. Cosine similarity then reflects
+ /// pure keyword overlap (non-overlapping keywords contribute nothing), avoiding
+ /// false ties for tools unrelated to the query.
+ ///
+ private sealed class DeterministicTestEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ private static readonly string[] _keywords =
+ [
+ "weather","forecast","temperature","math","calculate","sum","translate","language","joke"
+ ];
+
+ // +1 bias dimension (last) to avoid zero magnitude vectors when no keywords present.
+ private static int VectorLength => _keywords.Length + 1;
+
+ public int TotalValueInputs { get; private set; }
+
+ public Task>> GenerateAsync(
+ IEnumerable values,
+ EmbeddingGenerationOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var list = new List>();
+
+ foreach (var v in values)
+ {
+ TotalValueInputs++;
+ var vec = new float[VectorLength];
+ if (!string.IsNullOrWhiteSpace(v))
+ {
+ var lower = v.ToLowerInvariant();
+ for (int i = 0; i < _keywords.Length; i++)
+ {
+ if (lower.Contains(_keywords[i]))
+ {
+ vec[i] = 1f;
+ }
+ }
+ }
+
+ vec[^1] = 1f; // bias
+ list.Add(new Embedding(vec));
+ }
+
+ return Task.FromResult(new GeneratedEmbeddings>(list));
+ }
+
+ public object? GetService(Type serviceType, object? serviceKey = null) => null;
+
+ public void Dispose()
+ {
+ // No-op
+ }
+ }
+
+ private sealed class RecordingEmbeddingGenerator : IEmbeddingGenerator>
+ {
+ public List