diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index 220737be2749..90c18ad5aada 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -155,16 +155,16 @@ public static IServiceCollection AddOllamaChatCompletion( { var loggerFactory = serviceProvider.GetService(); - var builder = ((IChatClient)new OllamaApiClient(endpoint, modelId)) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + var ollamaClient = (IChatClient)new OllamaApiClient(endpoint, modelId); if (loggerFactory is not null) { - builder.UseLogging(loggerFactory); + ollamaClient.AsBuilder().UseLogging(loggerFactory).Build(); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return ollamaClient + .AsKernelFunctionInvokingChatClient(loggerFactory) + .AsChatCompletionService(); }); } @@ -190,16 +190,16 @@ public static IServiceCollection AddOllamaChatCompletion( var loggerFactory = serviceProvider.GetService(); - var builder = ((IChatClient)new OllamaApiClient(httpClient, modelId)) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + var ollamaClient = (IChatClient)new OllamaApiClient(httpClient, modelId); if (loggerFactory is not null) { - builder.UseLogging(loggerFactory); + ollamaClient.AsBuilder().UseLogging(loggerFactory).Build(); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return ollamaClient + .AsKernelFunctionInvokingChatClient(loggerFactory) + .AsChatCompletionService(); }); } @@ -231,15 +231,16 @@ public static IServiceCollection AddOllamaChatCompletion( } var builder = ((IChatClient)ollamaClient) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + .AsKernelFunctionInvokingChatClient(loggerFactory) + .AsBuilder(); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return builder.Build(serviceProvider) + .AsChatCompletionService(serviceProvider); }); } @@ -355,26 +356,4 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( } #endregion - - #region Private - - /// - /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current - /// asynchronous chain of execution. - /// - /// - /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that - /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself, - /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close - /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution. - /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in - /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that - /// prompt function could advertize itself as a candidate for auto-invocation. We don't want to outright block that, - /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent - /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made - /// configurable should need arise. - /// - private const int MaxInflightAutoInvokes = 128; - - #endregion } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index dd8d94c99824..bb22b67b49a8 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -736,7 +736,7 @@ public void Dispose() private static object? GetLastFunctionResultFromChatResponse(ChatResponse chatResponse) { Assert.NotEmpty(chatResponse.Messages); - var chatMessage = chatResponse.Messages[^1]; + var chatMessage = chatResponse.Messages.Where(m => m.Role == ChatRole.Tool).Last(); Assert.NotEmpty(chatMessage.Contents); Assert.Contains(chatMessage.Contents, c => c is Microsoft.Extensions.AI.FunctionResultContent); diff --git a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs index 28305a5ec8a4..cf07d36be1f5 100644 --- a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs +++ b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs @@ -13,6 +13,8 @@ using System.Threading.Tasks; using Microsoft.SemanticKernel.Http; +#pragma warning disable CA1859 // Use concrete types when possible for improved performance + namespace Microsoft.SemanticKernel.Plugins.OpenApi; /// diff --git a/dotnet/src/IntegrationTests/Plugins/Core/SessionsPythonPluginTests.cs b/dotnet/src/IntegrationTests/Plugins/Core/SessionsPythonPluginTests.cs index f76bff8901fd..994f986fb068 100644 --- a/dotnet/src/IntegrationTests/Plugins/Core/SessionsPythonPluginTests.cs +++ b/dotnet/src/IntegrationTests/Plugins/Core/SessionsPythonPluginTests.cs @@ -1,19 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Threading.Tasks; -using Xunit; -using Microsoft.SemanticKernel.Plugins.Core.CodeInterpreter; -using Microsoft.Extensions.Configuration; -using SemanticKernel.IntegrationTests.TestSettings; +using System.Collections.Generic; using System.Net.Http; -using Azure.Identity; +using System.Threading.Tasks; using Azure.Core; -using System.Collections.Generic; -using Microsoft.SemanticKernel; +using Azure.Identity; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Plugins.Core.CodeInterpreter; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; namespace SemanticKernel.IntegrationTests.Plugins.Core; diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index 88f3da9d6a53..55a300769812 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -51,7 +51,7 @@ internal sealed class FunctionCallsProcessor /// will be disabled. This is a safeguard against possible runaway execution if the model routinely re-requests /// the same function over and over. /// - private const int MaximumAutoInvokeAttempts = 128; + internal const int MaximumAutoInvokeAttempts = 128; /// Tracking for . /// diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index 68540a1c32d8..1179de5f99b7 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -14,7 +14,6 @@ namespace Microsoft.SemanticKernel.ChatCompletion; internal static class ChatOptionsExtensions { internal const string KernelKey = "AutoInvokingKernel"; - internal const string IsStreamingKey = "AutoInvokingIsStreaming"; internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent"; internal const string PromptExecutionSettingsKey = "AutoInvokingPromptExecutionSettings"; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index ea2dce48fc62..babef7736ebd 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -1,821 +1,39 @@ // Copyright (c) Microsoft. All rights reserved. using System; -#pragma warning restore IDE0073 // The file header does not match the required text using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Linq; -using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -#pragma warning disable IDE1006 // Naming Styles -#pragma warning disable IDE0009 // This -#pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test -#pragma warning disable SA1202 // 'protected' members should come before 'private' members - -// Modified source from 2025-04-07 -// https://raw.githubusercontent.com/dotnet/extensions/84d09b794d994435568adcbb85a981143d4f15cb/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs - namespace Microsoft.Extensions.AI; /// -/// A delegating chat client that invokes functions defined on . -/// Include this in a chat pipeline to resolve function calls automatically. +/// Specialization of that uses and supports . /// -/// -/// -/// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a that it sends back to the inner client. This loop -/// is repeated until there are no more function calls to make, or until another stop condition is met, -/// such as hitting . -/// -/// -/// The provided implementation of is thread-safe for concurrent use so long as the -/// instances employed as part of the supplied are also safe. -/// The property can be used to control whether multiple function invocation -/// requests as part of the same request are invocable concurrently, but even with that set to -/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those -/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific -/// ASP.NET web request should only be used as part of a single at a time, and only with -/// set to , in case the inner client decided to issue multiple -/// invocation requests to that same function. -/// -/// -public partial class KernelFunctionInvokingChatClient : DelegatingChatClient +internal sealed class KernelFunctionInvokingChatClient : FunctionInvokingChatClient { - /// The for the current function invocation. - private static readonly AsyncLocal _currentContext = new(); - - /// Optional services used for function invocation. - private readonly IServiceProvider? _functionInvocationServices; - - /// The logger to use for logging information about function invocation. - private readonly ILogger _logger; - - /// The to use for telemetry. - /// This component does not own the instance and should not dispose it. - private readonly ActivitySource? _activitySource; - - /// Maximum number of roundtrips allowed to the inner client. - private int _maximumIterationsPerRequest = 10; - - /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. - private int _maximumConsecutiveErrorsPerRequest = 3; - - /// - /// Initializes a new instance of the class. - /// - /// The underlying , or the next instance in a chain of clients. - /// An to use for logging information about function invocation. - /// An optional to use for resolving services required by the instances being invoked. - public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) - : base(innerClient) - { - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _activitySource = innerClient.GetService(); - _functionInvocationServices = functionInvocationServices; - } - - /// - /// Gets or sets the for the current function invocation. - /// - /// - /// This value flows across async calls. - /// - public static AutoFunctionInvocationContext? CurrentContext - { - get => _currentContext.Value; - protected set => _currentContext.Value = value; - } - - /// - /// Gets or sets a value indicating whether detailed exception information should be included - /// in the chat history when calling the underlying . - /// - /// - /// if the full exception message is added to the chat history - /// when calling the underlying . - /// if a generic error message is included in the chat history. - /// The default value is . - /// - /// - /// - /// Setting the value to prevents the underlying language model from disclosing - /// raw exception details to the end user, since it doesn't receive that information. Even in this - /// case, the raw object is available to application code by inspecting - /// the property. - /// - /// - /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However, it might - /// result in disclosing the raw exception information to external users, which can be a security - /// concern depending on the application scenario. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// whether detailed errors are provided during an in-flight request. - /// - /// - public bool IncludeDetailedErrors { get; set; } - - /// - /// Gets or sets a value indicating whether to allow concurrent invocation of functions. - /// - /// - /// if multiple function calls can execute in parallel. - /// if function calls are processed serially. - /// The default value is . - /// - /// - /// An individual response from the inner client might contain multiple function call requests. - /// By default, such function calls are processed serially. Set to - /// to enable concurrent invocation such that multiple function calls can execute in parallel. - /// - public bool AllowConcurrentInvocation { get; set; } - - /// - /// Gets or sets the maximum number of iterations per request. - /// - /// - /// The maximum number of iterations per request. - /// The default value is 10. - /// - /// - /// - /// Each request to this might end up making - /// multiple requests to the inner client. Each time the inner client responds with - /// a function call request, this client might perform that invocation and send the results - /// back to the inner client in a new request. This property limits the number of times - /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to how many iterations are allowed for an in-flight request. - /// - /// - public int MaximumIterationsPerRequest - { - get => _maximumIterationsPerRequest; - set - { - if (value < 1) - { - throw new ArgumentOutOfRangeException(nameof(value)); - } - - _maximumIterationsPerRequest = value; - } - } - - /// - /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. - /// - /// - /// The maximum number of consecutive iterations that are allowed to fail with an error. - /// The default value is 3. - /// - /// - /// - /// When function invocations fail with an exception, the - /// continues to make requests to the inner client, optionally supplying exception information (as - /// controlled by ). This allows the to - /// recover from errors by trying other function parameters that may succeed. - /// - /// - /// However, in case function invocations continue to produce exceptions, this property can be used to - /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be - /// rethrown to the caller. - /// - /// - /// If the value is set to zero, all function calling exceptions immediately terminate the function - /// invocation loop and the exception will be rethrown to the caller. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to how many iterations are allowed for an in-flight request. - /// - /// - public int MaximumConsecutiveErrorsPerRequest - { - get => _maximumConsecutiveErrorsPerRequest; - set - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "Argument less than minimum value 0"); - } - _maximumConsecutiveErrorsPerRequest = value; - } - } - /// - public override async Task GetResponseAsync( - IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - Verify.NotNull(messages); - - // A single request into this GetResponseAsync may result in multiple requests to the inner client. - // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); - - // Copy the original messages in order to avoid enumerating the original messages multiple times. - // The IEnumerable can represent an arbitrary amount of work. - List originalMessages = [.. messages]; - messages = originalMessages; - - List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned - List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response - UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response - List? functionCallContents = null; // function call contents that need responding to in the current turn - bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set - int consecutiveErrorCount = 0; - - for (int iteration = 0; ; iteration++) - { - functionCallContents?.Clear(); - - // Make the call to the inner client. - response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - if (response is null) - { - throw new InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); - } - - // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. - bool requiresFunctionInvocation = - options?.Tools is { Count: > 0 } && - iteration < MaximumIterationsPerRequest && - CopyFunctionCalls(response.Messages, ref functionCallContents); - - // In a common case where we make a request and there's no function calling work required, - // fast path out by just returning the original response. - if (iteration == 0 && !requiresFunctionInvocation) - { - return response; - } - - // Track aggregate details from the response, including all the response messages and usage details. - (responseMessages ??= []).AddRange(response.Messages); - if (response.Usage is not null) - { - if (totalUsage is not null) - { - totalUsage.Add(response.Usage); - } - else - { - totalUsage = response.Usage; - } - } - - // If there are no tools to call, or for any other reason we should stop, we're done. - // Break out of the loop and allow the handling at the end to configure the response - // with aggregated data from previous requests. - if (!requiresFunctionInvocation) - { - break; - } - - // Prepare the history for the next iteration. - FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - - // Prepare the options for the next auto function invocation iteration. - UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: false); - - // Add the responses from the function calls into the augmented history and also into the tracked - // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken).ConfigureAwait(false); - responseMessages.AddRange(modeAndMessages.MessagesAdded); - consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - - if (modeAndMessages.ShouldTerminate) - { - break; - } - - // Clear the auto function invocation options. - ClearOptionsForAutoFunctionInvocation(ref options); - - UpdateOptionsForNextIteration(ref options!, response.ChatThreadId); - } - - Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); - response.Messages = responseMessages!; - response.Usage = totalUsage; - - return response; - } - - /// - public override async IAsyncEnumerable GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Verify.NotNull(messages); - - // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. - // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - - // Copy the original messages in order to avoid enumerating the original messages multiple times. - // The IEnumerable can represent an arbitrary amount of work. - List originalMessages = [.. messages]; - messages = originalMessages; - - List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - List? functionCallContents = null; // function call contents that need responding to in the current turn - List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history - bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set - List updates = []; // updates from the current response - int consecutiveErrorCount = 0; - - for (int iteration = 0; ; iteration++) - { - updates.Clear(); - functionCallContents?.Clear(); - - await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - if (update is null) - { - throw new InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); - } - - updates.Add(update); - - _ = CopyFunctionCalls(update.Contents, ref functionCallContents); - - yield return update; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 - } - - // If there are no tools to call, or for any other reason we should stop, return the response. - if (functionCallContents is not { Count: > 0 } || - options?.Tools is not { Count: > 0 } || - iteration >= _maximumIterationsPerRequest) - { - break; - } - - // Reconstitute a response from the response updates. - var response = updates.ToChatResponse(); - (responseMessages ??= []).AddRange(response.Messages); - - // Prepare the history for the next iteration. - FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - - // Prepare the options for the next auto function invocation iteration. - UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true); - - // Process all the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false); - responseMessages.AddRange(modeAndMessages.MessagesAdded); - consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - - // Clear the auto function invocation options. - ClearOptionsForAutoFunctionInvocation(ref options); - - // This is a synthetic ID since we're generating the tool messages instead of getting them from - // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to - // use the same message ID for all of them within a given iteration, as this is a single logical - // message with multiple content items. We could also use different message IDs per tool content, - // but there's no benefit to doing so. - string toolResponseId = Guid.NewGuid().ToString("N"); - - // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // include all activity, including generated function results. - foreach (var message in modeAndMessages.MessagesAdded) - { - var toolResultUpdate = new ChatResponseUpdate - { - AdditionalProperties = message.AdditionalProperties, - AuthorName = message.AuthorName, - ChatThreadId = response.ChatThreadId, - CreatedAt = DateTimeOffset.UtcNow, - Contents = message.Contents, - RawRepresentation = message.RawRepresentation, - ResponseId = toolResponseId, - MessageId = toolResponseId, // See above for why this can be the same as ResponseId - Role = message.Role, - }; - - yield return toolResultUpdate; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 - } - - if (modeAndMessages.ShouldTerminate) - { - yield break; - } - - UpdateOptionsForNextIteration(ref options, response.ChatThreadId); - } - } - - /// Prepares the various chat message lists after a response from the inner client and before invoking functions. - /// The original messages provided by the caller. - /// The messages reference passed to the inner client. - /// The augmented history containing all the messages to be sent. - /// The most recent response being handled. - /// A list of all response messages received up until this point. - /// Whether the previous iteration's response had a thread id. - private static void FixupHistories( - IEnumerable originalMessages, - ref IEnumerable messages, - [NotNull] ref List? augmentedHistory, - ChatResponse response, - List allTurnsResponseMessages, - ref bool lastIterationHadThreadId) - { - // We're now going to need to augment the history with function result contents. - // That means we need a separate list to store the augmented history. - if (response.ChatThreadId is not null) - { - // The response indicates the inner client is tracking the history, so we don't want to send - // anything we've already sent or received. - if (augmentedHistory is not null) - { - augmentedHistory.Clear(); - } - else - { - augmentedHistory = []; - } - - lastIterationHadThreadId = true; - } - else if (lastIterationHadThreadId) - { - // In the very rare case where the inner client returned a response with a thread ID but then - // returned a subsequent response without one, we want to reconstitute the full history. To do that, - // we can populate the history with the original chat messages and then all the response - // messages up until this point, which includes the most recent ones. - augmentedHistory ??= []; - augmentedHistory.Clear(); - augmentedHistory.AddRange(originalMessages); - augmentedHistory.AddRange(allTurnsResponseMessages); - - lastIterationHadThreadId = false; - } - else - { - // If augmentedHistory is already non-null, then we've already populated it with everything up - // until this point (except for the most recent response). If it's null, we need to seed it with - // the chat history provided by the caller. - augmentedHistory ??= originalMessages.ToList(); - - // Now add the most recent response messages. - augmentedHistory.AddMessages(response); - - lastIterationHadThreadId = false; - } - - // Use the augmented history as the new set of messages to send. - messages = augmentedHistory; - } - - /// Copies any from to . - private static bool CopyFunctionCalls( - IList messages, [NotNullWhen(true)] ref List? functionCalls) - { - bool any = false; - int count = messages.Count; - for (int i = 0; i < count; i++) - { - any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); - } - - return any; - } - - /// Copies any from to . - private static bool CopyFunctionCalls( - IList content, [NotNullWhen(true)] ref List? functionCalls) + public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) + : base(innerClient, loggerFactory, functionInvocationServices) { - bool any = false; - int count = content.Count; - for (int i = 0; i < count; i++) - { - if (content[i] is FunctionCallContent functionCall) - { - (functionCalls ??= []).Add(functionCall); - any = true; - } - } - - return any; + this.MaximumIterationsPerRequest = 128; } - private static void UpdateOptionsForAutoFunctionInvocation(ref ChatOptions options, ChatMessageContent content, bool isStreaming) + private static void UpdateOptionsForAutoFunctionInvocation(ChatOptions options, ChatMessageContent content) { - if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) - { - throw new KernelException($"The reserved key name '{ChatOptionsExtensions.IsStreamingKey}' is already specified in the options. Avoid using this key name."); - } - if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) { - throw new KernelException($"The reserved key name '{ChatOptionsExtensions.ChatMessageContentKey}' is already specified in the options. Avoid using this key name."); + return; } options.AdditionalProperties ??= []; - - options.AdditionalProperties[ChatOptionsExtensions.IsStreamingKey] = isStreaming; options.AdditionalProperties[ChatOptionsExtensions.ChatMessageContentKey] = content; } - private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions options) - { - if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) - { - options.AdditionalProperties.Remove(ChatOptionsExtensions.IsStreamingKey); - } - - if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) - { - options.AdditionalProperties.Remove(ChatOptionsExtensions.ChatMessageContentKey); - } - } - - private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId) - { - if (options.ToolMode is RequiredChatToolMode) - { - // We have to reset the tool mode to be non-required after the first iteration, - // as otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - options.ChatThreadId = chatThreadId; - } - else if (options.ChatThreadId != chatThreadId) - { - // As with the other modes, ensure we've propagated the chat thread ID to the options. - // We only need to clone the options if we're actually mutating it. - options = options.Clone(); - options.ChatThreadId = chatThreadId; - } - } - - /// - /// Processes the function calls in the list. - /// - /// The current chat contents, inclusive of the function call contents being processed. - /// The options used for the response being processed. - /// The function call contents representing the functions to be invoked. - /// The iteration number of how many roundtrips have been made to the inner client. - /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. - /// Whether the function calls are being processed in a streaming context. - /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, - int iteration, int consecutiveErrorCount, bool isStreaming, CancellationToken cancellationToken) - { - // We must add a response for every tool call, regardless of whether we successfully executed it or not. - // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - - Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); - var shouldTerminate = false; - - var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; - - // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. - if (functionCallContents.Count == 1) - { - FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); - - IList added = CreateResponseMessages([result]); - ThrowIfNoFunctionResultsAdded(added); - UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); - - messages.AddRange(added); - return (result.ShouldTerminate, consecutiveErrorCount, added); - } - else - { - List results = []; - - var terminationRequested = false; - if (AllowConcurrentInvocation) - { - // Rather than awaiting each function before invoking the next, invoke all of them - // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, - // but if a function invocation completes asynchronously, its processing can overlap - // with the processing of other the other invocation invocations. - results.AddRange(await Task.WhenAll( - from i in Enumerable.Range(0, functionCallContents.Count) - select ProcessFunctionCallAsync( - messages, options, functionCallContents, - iteration, i, captureExceptions: true, isStreaming, cancellationToken)).ConfigureAwait(false)); - - terminationRequested = results.Any(r => r.ShouldTerminate); - } - else - { - // Invoke each function serially. - for (int i = 0; i < functionCallContents.Count; i++) - { - var result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, - iteration, i, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); - - results.Add(result); - - if (result.ShouldTerminate) - { - shouldTerminate = true; - terminationRequested = true; - break; - } - } - } - - IList added = CreateResponseMessages(results); - ThrowIfNoFunctionResultsAdded(added); - UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); - - messages.AddRange(added); - - if (!terminationRequested) - { - // If any function requested termination, we'll terminate. - shouldTerminate = false; - foreach (FunctionInvocationResult fir in results) - { - shouldTerminate = shouldTerminate || fir.ShouldTerminate; - } - } - - return (shouldTerminate, consecutiveErrorCount, added); - } - } - - private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) - { - var allExceptions = added.SelectMany(m => m.Contents.OfType()) - .Select(frc => frc.Exception!) - .Where(e => e is not null); - -#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection - if (allExceptions.Any()) - { - consecutiveErrorCount++; - if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) - { - var allExceptionsArray = allExceptions.ToArray(); - if (allExceptionsArray.Length == 1) - { - ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); - } - else - { - throw new AggregateException(allExceptionsArray); - } - } - } - else - { - consecutiveErrorCount = 0; - } -#pragma warning restore CA1851 // Possible multiple enumerations of 'IEnumerable' collection - } - - /// - /// Throws an exception if doesn't create any messages. - /// - private void ThrowIfNoFunctionResultsAdded(IList? messages) - { - if (messages is null || messages.Count == 0) - { - throw new InvalidOperationException($"{this.GetType().Name}.{nameof(this.CreateResponseMessages)} returned null or an empty collection of messages."); - } - } - - /// Processes the function call described in []. - /// The current chat contents, inclusive of the function call contents being processed. - /// The options used for the response being processed. - /// The function call contents representing all the functions being invoked. - /// The iteration number of how many roundtrips have been made to the inner client. - /// The 0-based index of the function being called out of . - /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. - /// Whether the function calls are being processed in a streaming context. - /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, - int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken) - { - var callContent = callContents[functionCallIndex]; - - // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. - AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); - if (function is null) - { - return new(shouldTerminate: false, FunctionInvokingChatClient.FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); - } - - if (callContent.Arguments is not null) - { - callContent.Arguments = new KernelArguments(callContent.Arguments); - } - - var context = new AutoFunctionInvocationContext(new() - { - Function = function, - Arguments = new(callContent.Arguments) { Services = _functionInvocationServices }, - - Messages = messages, - Options = options, - - CallContent = callContent, - Iteration = iteration, - FunctionCallIndex = functionCallIndex, - FunctionCount = callContents.Count, - }) - { IsStreaming = isStreaming }; - - object? result; - try - { - result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (!cancellationToken.IsCancellationRequested) - { - if (!captureExceptions) - { - throw; - } - - return new( - shouldTerminate: false, - FunctionInvokingChatClient.FunctionInvocationStatus.Exception, - callContent, - result: null, - exception: e); - } - - return new( - shouldTerminate: context.Terminate, - FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion, - callContent, - result, - exception: null); - } - - /// Creates one or more response messages for function invocation results. - /// Information about the function call invocations and results. - /// A list of all chat messages created from . - private IList CreateResponseMessages(List results) - { - var contents = new List(results.Count); - for (int i = 0; i < results.Count; i++) - { - contents.Add(CreateFunctionResultContent(results[i])); - } - - return [new(ChatRole.Tool, contents)]; - - FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) - { - Verify.NotNull(result); - - object? functionResult; - if (result.Status == FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion) - { - functionResult = result.Result ?? "Success: Function completed."; - } - else - { - string message = result.Status switch - { - FunctionInvokingChatClient.FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", - FunctionInvokingChatClient.FunctionInvocationStatus.Exception => "Error: Function failed.", - _ => "Error: Unknown error.", - }; - - if (IncludeDetailedErrors && result.Exception is not null) - { - message = $"{message} Exception: {result.Exception.Message}"; - } - - functionResult = message; - } - - return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; - } - } - /// /// Invokes the auto function invocation filters. /// @@ -858,166 +76,49 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( } } - /// Invokes the function asynchronously. - /// - /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. - /// - /// The to monitor for cancellation requests. The default is . - /// The result of the function invocation, or if the function invocation returned . - /// is . - private async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken) + /// + protected override async ValueTask InvokeFunctionAsync(Microsoft.Extensions.AI.FunctionInvocationContext context, CancellationToken cancellationToken) { - Verify.NotNull(context); - - using Activity? activity = _activitySource?.StartActivity(context.Function.Name); - - long startingTimestamp = 0; - if (_logger.IsEnabled(LogLevel.Debug)) + if (context.Options is null) { - startingTimestamp = Stopwatch.GetTimestamp(); - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.AIFunction.JsonSerializerOptions)); - } - else - { - LogInvoking(context.Function.Name); - } + return await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false); } object? result = null; - try - { - CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit - context = await this.OnAutoFunctionInvocationAsync( - context, - async (ctx) => - { - // Check if filter requested termination - if (ctx.Terminate) - { - return; - } - - // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any - // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, - // as the called function could in turn telling the model about itself as a possible candidate for invocation. - result = await context.AIFunction.InvokeAsync(new(context.Arguments), cancellationToken).ConfigureAwait(false); - ctx.Result = new FunctionResult(ctx.Function, result); - }).ConfigureAwait(false); - result = context.Result.GetValue(); - } - catch (Exception e) - { - if (activity is not null) - { - _ = activity.SetTag("error.type", e.GetType().FullName) - .SetStatus(ActivityStatusCode.Error, e.Message); - } - if (e is OperationCanceledException) - { - LogInvocationCanceled(context.Function.Name); - } - else - { - LogInvocationFailed(context.Function.Name, e); - } - - throw; - } - finally - { - if (_logger.IsEnabled(LogLevel.Debug)) - { - TimeSpan elapsed = GetElapsedTime(startingTimestamp); - - if (result is not null && _logger.IsEnabled(LogLevel.Trace)) - { - LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.AIFunction.JsonSerializerOptions)); - } - else + UpdateOptionsForAutoFunctionInvocation(context.Options, context.Messages.Last().ToChatMessageContent()); + var autoContext = new AutoFunctionInvocationContext(context.Options) + { + AIFunction = context.Function, + Arguments = new KernelArguments(context.Arguments) { Services = this.FunctionInvocationServices }, + Messages = context.Messages, + CallContent = context.CallContent, + Iteration = context.Iteration, + FunctionCallIndex = context.FunctionCallIndex, + FunctionCount = context.FunctionCount, + IsStreaming = context.IsStreaming + }; + + autoContext = await this.OnAutoFunctionInvocationAsync( + autoContext, + async (ctx) => + { + // Check if filter requested termination + if (ctx.Terminate) { - LogInvocationCompleted(context.Function.Name, elapsed); + return; } - } - } - - return result; - } - - /// Serializes as JSON for logging purposes. - private static string LoggingAsJson(T value, JsonSerializerOptions? options) - { - if (options?.TryGetTypeInfo(typeof(T), out var typeInfo) is true || - AIJsonUtilities.DefaultOptions.TryGetTypeInfo(typeof(T), out typeInfo)) - { -#pragma warning disable CA1031 // Do not catch general exception types - try - { - return JsonSerializer.Serialize(value, typeInfo); - } - catch - { - } -#pragma warning restore CA1031 // Do not catch general exception types - } - - // If we're unable to get a type info for the value, or if we fail to serialize, - // return an empty JSON object. We do not want lack of type info to disrupt application behavior with exceptions. - return "{}"; - } - - private static TimeSpan GetElapsedTime(long startingTimestamp) => -#if NET - Stopwatch.GetElapsedTime(startingTimestamp); -#else - new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); -#endif - - [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] - private partial void LogInvoking(string methodName); - - [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)] - private partial void LogInvokingSensitive(string methodName, string arguments); - - [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)] - private partial void LogInvocationCompleted(string methodName, TimeSpan duration); - - [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)] - private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result); - [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")] - private partial void LogInvocationCanceled(string methodName); + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + result = await autoContext.AIFunction.InvokeAsync(autoContext.Arguments, cancellationToken).ConfigureAwait(false); + ctx.Result = new FunctionResult(ctx.Function, result); + }).ConfigureAwait(false); + result = autoContext.Result.GetValue(); - [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")] - private partial void LogInvocationFailed(string methodName, Exception error); + context.Terminate = autoContext.Terminate; - /// Provides information about the invocation of a function call. - public sealed class FunctionInvocationResult - { - internal FunctionInvocationResult(bool shouldTerminate, FunctionInvokingChatClient.FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) - { - ShouldTerminate = shouldTerminate; - Status = status; - CallContent = callContent; - Result = result; - Exception = exception; - } - - /// Gets status about how the function invocation completed. - public FunctionInvokingChatClient.FunctionInvocationStatus Status { get; } - - /// Gets the function call content information associated with this invocation. - public FunctionCallContent CallContent { get; } - - /// Gets the result of the function call. - public object? Result { get; } - - /// Gets any exception the function call threw. - public Exception? Exception { get; } - - /// Gets a value indicating whether the caller should terminate the processing loop. - internal bool ShouldTerminate { get; } + return result; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index bc8dd0c3490c..b58aadb73ddc 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -13,33 +13,34 @@ namespace Microsoft.SemanticKernel; /// /// Class with data related to automatic function invocation. /// -public class AutoFunctionInvocationContext +public class AutoFunctionInvocationContext : Microsoft.Extensions.AI.FunctionInvocationContext { private ChatHistory? _chatHistory; private KernelFunction? _kernelFunction; - private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new(); /// /// Initializes a new instance of the class from an existing . /// - internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext) + internal AutoFunctionInvocationContext(ChatOptions options) { - Verify.NotNull(invocationContext); - Verify.NotNull(invocationContext.Options); + Verify.NotNull(options); // the ChatOptions must be provided with AdditionalProperties. - Verify.NotNull(invocationContext.Options.AdditionalProperties); + Verify.NotNull(options.AdditionalProperties); - invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); + // The ChatOptions must be provided with the kernel. + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); Verify.NotNull(kernel); - invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); + // The ChatOptions must be provided with the chat message content. + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); Verify.NotNull(chatMessageContent); - invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); - this.ExecutionSettings = executionSettings; - this._invocationContext = invocationContext; + // The ChatOptions can be provided with the execution settings. + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); + this.ExecutionSettings = executionSettings; + this.Options = options; this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; } @@ -64,7 +65,7 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatHistory); Verify.NotNull(chatMessageContent); - this._invocationContext.Options = new() + this.Options = new() { AdditionalProperties = new() { @@ -75,9 +76,9 @@ public AutoFunctionInvocationContext( this._kernelFunction = function; this._chatHistory = chatHistory; - this._invocationContext.Messages = chatHistory.ToChatMessageList(); - chatHistory.SetChatMessageHandlers(this._invocationContext.Messages); - this._invocationContext.Function = function.AsAIFunction(); + this.Messages = chatHistory.ToChatMessageList(); + chatHistory.SetChatMessageHandlers(this.Messages); + base.Function = function.AsAIFunction(); this.Result = result; } @@ -88,17 +89,25 @@ public AutoFunctionInvocationContext( public CancellationToken CancellationToken { get; init; } /// - /// Boolean flag which indicates whether a filter is invoked within streaming or non-streaming mode. + /// Gets the specialized version of associated with the operation. /// - public bool IsStreaming { get; init; } - - /// - /// Gets the arguments associated with the operation. - /// - public KernelArguments? Arguments + /// + /// Due to a clash with the as a type, this property hides + /// it to not break existing code that relies on the as a type. + /// + /// Attempting to access the property when the arguments is not a class. + public new KernelArguments? Arguments { - get => this._invocationContext.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null; - init => this._invocationContext.CallContent.Arguments = value; + get + { + if (base.Arguments is KernelArguments kernelArguments) + { + return kernelArguments; + } + + throw new InvalidOperationException($"The arguments provided in the initialization must be of type {nameof(KernelArguments)}."); + } + init => base.Arguments = value ?? new(); } /// @@ -106,8 +115,8 @@ public KernelArguments? Arguments /// public int RequestSequenceIndex { - get => this._invocationContext.Iteration; - init => this._invocationContext.Iteration = value; + get => this.Iteration; + init => this.Iteration = value; } /// @@ -115,19 +124,8 @@ public int RequestSequenceIndex /// public int FunctionSequenceIndex { - get => this._invocationContext.FunctionCallIndex; - init => this._invocationContext.FunctionCallIndex = value; - } - - /// Gets or sets the total number of function call requests within the iteration. - /// - /// The response from the underlying client might include multiple function call requests. - /// This count indicates how many there were. - /// - public int FunctionCount - { - get => this._invocationContext.FunctionCount; - init => this._invocationContext.FunctionCount = value; + get => this.FunctionCallIndex; + init => this.FunctionCallIndex = value; } /// @@ -135,13 +133,13 @@ public int FunctionCount /// public string? ToolCallId { - get => this._invocationContext.CallContent.CallId; + get => this.CallContent.CallId; init { - this._invocationContext.CallContent = new Microsoft.Extensions.AI.FunctionCallContent( + this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent( callId: value ?? string.Empty, - name: this._invocationContext.CallContent.Name, - arguments: this._invocationContext.CallContent.Arguments); + name: this.CallContent.Name, + arguments: this.CallContent.Arguments); } } @@ -149,40 +147,44 @@ public string? ToolCallId /// The chat message content associated with automatic function invocation. /// public ChatMessageContent ChatMessageContent - => (this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; + => (this.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; /// /// The execution settings associated with the operation. /// public PromptExecutionSettings? ExecutionSettings { - get => this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; + get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; init { - this._invocationContext.Options ??= new(); - this._invocationContext.Options.AdditionalProperties ??= []; - this._invocationContext.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; + this.Options ??= new(); + this.Options.AdditionalProperties ??= []; + this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; } } /// /// Gets the associated with automatic function invocation. /// - public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this._invocationContext.Messages); + public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages); /// /// Gets the with which this filter is associated. /// - public KernelFunction Function + /// + /// Due to a clash with the as a type, this property hides + /// it to not break existing code that relies on the as a type. + /// + public new KernelFunction Function { get { if (this._kernelFunction is null // If the schemas are different, // AIFunction reference potentially was modified and the kernel function should be regenerated. - || !IsSameSchema(this._kernelFunction, this._invocationContext.Function)) + || !IsSameSchema(this._kernelFunction, base.Function)) { - this._kernelFunction = this._invocationContext.Function.AsKernelFunction(); + this._kernelFunction = base.Function.AsKernelFunction(); } return this._kernelFunction; @@ -197,7 +199,7 @@ public Kernel Kernel get { Kernel? kernel = null; - this._invocationContext.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); + this.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); // To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null. return kernel!; @@ -209,30 +211,13 @@ public Kernel Kernel /// public FunctionResult Result { get; set; } - /// Gets or sets a value indicating whether to terminate the request. - /// - /// In response to a function call request, the function might be invoked, its result added to the chat contents, - /// and a new request issued to the wrapped client. If this property is set to , that subsequent request - /// will not be issued and instead the loop immediately terminated rather than continuing until there are no - /// more function call requests in responses. - /// - public bool Terminate - { - get => this._invocationContext.Terminate; - set => this._invocationContext.Terminate = value; - } - - /// Gets or sets the function call content information associated with this invocation. - internal Microsoft.Extensions.AI.FunctionCallContent CallContent - { - get => this._invocationContext.CallContent; - set => this._invocationContext.CallContent = value; - } - + /// + /// Gets or sets the with which this filter is associated. + /// internal AIFunction AIFunction { - get => this._invocationContext.Function; - set => this._invocationContext.Function = value; + get => base.Function; + set => base.Function = value; } private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction) diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs index eda736b3f583..419a12039049 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections; using System.Collections.Generic; using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; #pragma warning disable CA1710 // Identifiers should have correct suffix @@ -17,10 +17,8 @@ namespace Microsoft.SemanticKernel; /// A is a dictionary of argument names and values. It also carries a /// , accessible via the property. /// -public sealed class KernelArguments : IDictionary, IReadOnlyDictionary +public sealed class KernelArguments : AIFunctionArguments { - /// Dictionary of name/values for all the arguments in the instance. - private readonly Dictionary _arguments; private IReadOnlyDictionary? _executionSettings; /// @@ -28,8 +26,8 @@ public sealed class KernelArguments : IDictionary, IReadOnlyDic /// [JsonConstructor] public KernelArguments() + : base(StringComparer.OrdinalIgnoreCase) { - this._arguments = new(StringComparer.OrdinalIgnoreCase); } /// @@ -37,7 +35,7 @@ public KernelArguments() /// /// The prompt execution settings. public KernelArguments(PromptExecutionSettings? executionSettings) - : this(executionSettings is null ? null : [executionSettings]) + : this(executionSettings: executionSettings is null ? null : [executionSettings]) { } @@ -46,8 +44,8 @@ public KernelArguments(PromptExecutionSettings? executionSettings) /// /// The prompt execution settings. public KernelArguments(IEnumerable? executionSettings) + : base(StringComparer.OrdinalIgnoreCase) { - this._arguments = new(StringComparer.OrdinalIgnoreCase); if (executionSettings is not null) { var newExecutionSettings = new Dictionary(); @@ -80,10 +78,8 @@ public KernelArguments(IEnumerable? executionSettings) /// Otherwise, if the source is a , its are used. /// public KernelArguments(IDictionary source, Dictionary? executionSettings = null) + : base(source, StringComparer.OrdinalIgnoreCase) { - Verify.NotNull(source); - - this._arguments = new(source, StringComparer.OrdinalIgnoreCase); this.ExecutionSettings = executionSettings ?? (source as KernelArguments)?.ExecutionSettings; } @@ -115,37 +111,6 @@ public IReadOnlyDictionary? ExecutionSettings } } - /// - /// Gets the number of arguments contained in the . - /// - public int Count => this._arguments.Count; - - /// Adds the specified argument name and value to the . - /// The name of the argument to add. - /// The value of the argument to add. - /// is null. - /// An argument with the same name already exists in the . - public void Add(string name, object? value) - { - Verify.NotNull(name); - this._arguments.Add(name, value); - } - - /// Removes the argument value with the specified name from the . - /// The name of the argument value to remove. - /// is null. - public bool Remove(string name) - { - Verify.NotNull(name); - return this._arguments.Remove(name); - } - - /// Removes all arguments names and values from the . - /// - /// This does not affect the property. To clear it as well, set it to null. - /// - public void Clear() => this._arguments.Clear(); - /// Determines whether the contains an argument with the specified name. /// The name of the argument to locate. /// true if the arguments contains an argument with the specified named; otherwise, false. @@ -153,103 +118,9 @@ public bool Remove(string name) public bool ContainsName(string name) { Verify.NotNull(name); - return this._arguments.ContainsKey(name); - } - - /// Gets the value associated with the specified argument name. - /// The name of the argument value to get. - /// - /// When this method returns, contains the value associated with the specified name, - /// if the name is found; otherwise, null. - /// - /// true if the arguments contains an argument with the specified name; otherwise, false. - /// is null. - public bool TryGetValue(string name, out object? value) - { - Verify.NotNull(name); - return this._arguments.TryGetValue(name, out value); - } - - /// Gets or sets the value associated with the specified argument name. - /// The name of the argument value to get or set. - /// is null. - public object? this[string name] - { - get - { - Verify.NotNull(name); - return this._arguments[name]; - } - set - { - Verify.NotNull(name); - this._arguments[name] = value; - } - } - - /// Gets an of all of the arguments' names. - public ICollection Names => this._arguments.Keys; - - /// Gets an of all of the arguments' values. - public ICollection Values => this._arguments.Values; - - #region Interface implementations - /// - ICollection IDictionary.Keys => this._arguments.Keys; - - /// - IEnumerable IReadOnlyDictionary.Keys => this._arguments.Keys; - - /// - IEnumerable IReadOnlyDictionary.Values => this._arguments.Values; - - /// - bool ICollection>.IsReadOnly => false; - - /// - object? IReadOnlyDictionary.this[string key] => this._arguments[key]; - - /// - object? IDictionary.this[string key] - { - get => this._arguments[key]; - set => this._arguments[key] = value; + return base.ContainsKey(name); } - /// - void IDictionary.Add(string key, object? value) => this._arguments.Add(key, value); - - /// - bool IDictionary.ContainsKey(string key) => this._arguments.ContainsKey(key); - - /// - bool IDictionary.Remove(string key) => this._arguments.Remove(key); - - /// - bool IDictionary.TryGetValue(string key, out object? value) => this._arguments.TryGetValue(key, out value); - - /// - void ICollection>.Add(KeyValuePair item) => this._arguments.Add(item.Key, item.Value); - - /// - bool ICollection>.Contains(KeyValuePair item) => ((ICollection>)this._arguments).Contains(item); - - /// - void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)this._arguments).CopyTo(array, arrayIndex); - - /// - bool ICollection>.Remove(KeyValuePair item) => this._arguments.Remove(item.Key); - - /// - IEnumerator> IEnumerable>.GetEnumerator() => this._arguments.GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() => this._arguments.GetEnumerator(); - - /// - bool IReadOnlyDictionary.ContainsKey(string key) => this._arguments.ContainsKey(key); - - /// - bool IReadOnlyDictionary.TryGetValue(string key, out object? value) => this._arguments.TryGetValue(key, out value); - #endregion + /// Gets an of all of the arguments names. + public ICollection Names => this.Keys; } diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs new file mode 100644 index 000000000000..b3971d6472ca --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; + +namespace SemanticKernel.UnitTests.Filters.AutoFunctionInvocation; + +public class AutoFunctionInvocationContextTests +{ + [Fact] + public void ConstructorWithValidParametersCreatesInstance() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + // Assert + Assert.NotNull(context); + Assert.Same(kernel, context.Kernel); + Assert.Same(function, context.Function); + Assert.Same(result, context.Result); + Assert.Same(chatHistory, context.ChatHistory); + Assert.Same(chatMessageContent, context.ChatMessageContent); + } + + [Fact] + public void ConstructorWithNullKernelThrowsException() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + null!, + function, + result, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullFunctionThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + null!, + result, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullResultThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + null!, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullChatHistoryThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + result, + null!, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullChatMessageContentThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + null!)); + } + + [Fact] + public void PropertiesReturnCorrectValues() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + // Assert + Assert.Same(kernel, context.Kernel); + Assert.Same(function, context.Function); + Assert.Same(result, context.Result); + Assert.Same(chatHistory, context.ChatHistory); + Assert.Same(chatMessageContent, context.ChatMessageContent); + } + + [Fact] + public async Task AutoFunctionInvocationContextCanBeUsedInFilter() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + bool filterWasCalled = false; + + // Create a simple filter that just sets a flag + async Task FilterMethod(AutoFunctionInvocationContext ctx, Func next) + { + filterWasCalled = true; + Assert.Same(context, ctx); + await next(ctx); + } + + // Act + await FilterMethod(context, _ => Task.CompletedTask); + + // Assert + Assert.True(filterWasCalled); + } + + [Fact] + public void ExecutionSettingsCanBeSetAndRetrieved() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var executionSettings = new PromptExecutionSettings(); + + // Create options with execution settings + var additionalProperties = new AdditionalPropertiesDictionary + { + [ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings, + [ChatOptionsExtensions.KernelKey] = kernel, + [ChatOptionsExtensions.ChatMessageContentKey] = chatMessageContent + }; + + var options = new ChatOptions + { + AdditionalProperties = additionalProperties + }; + + // Act + var context = new AutoFunctionInvocationContext(options); + + // Assert + Assert.Same(executionSettings, context.ExecutionSettings); + } + + [Fact] + public async Task KernelFunctionCloneWithKernelUsesProvidedKernel() + { + // Arrange + var originalKernel = new Kernel(); + var newKernel = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Act + // Create AIFunctions with different kernels + var aiFunction1 = function.AsAIFunction(originalKernel); + var aiFunction2 = function.AsAIFunction(newKernel); + + // Invoke both functions + var args = new AIFunctionArguments(); + var result1 = await aiFunction1.InvokeAsync(args, default); + var result2 = await aiFunction2.InvokeAsync(args, default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.NotEqual(result1, result2); + Assert.Equal(originalKernel.GetHashCode().ToString(), result1.ToString()); + Assert.Equal(newKernel.GetHashCode().ToString(), result2.ToString()); + } + + // Let's simplify our approach and use a different testing strategy + [Fact] + public void ArgumentsPropertyHandlesKernelArguments() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Create KernelArguments and set them via the init property + var kernelArgs = new KernelArguments { ["test"] = "value" }; + + // Set the arguments via the init property + var contextWithArgs = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, contextWithArgs.Arguments); + } + + [Fact] + public void ArgumentsPropertyInitializesEmptyKernelArgumentsWhenSetToNull() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Set the arguments to null via the init property + var contextWithNullArgs = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = null + }; + + // Act & Assert + Assert.NotNull(contextWithNullArgs.Arguments); + Assert.IsType(contextWithNullArgs.Arguments); + Assert.Empty(contextWithNullArgs.Arguments); + } + + [Fact] + public void ArgumentsPropertyCanBeSetWithMultipleValues() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Create KernelArguments with multiple values + var kernelArgs = new KernelArguments + { + ["string"] = "value", + ["int"] = 42, + ["bool"] = true, + ["object"] = new object() + }; + + // Set the arguments via the init property + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, context.Arguments); + Assert.Equal(4, context.Arguments.Count); + Assert.Equal("value", context.Arguments["string"]); + Assert.Equal(42, context.Arguments["int"]); + Assert.Equal(true, context.Arguments["bool"]); + Assert.NotNull(context.Arguments["object"]); + } + + [Fact] + public void ArgumentsPropertyCanBeSetWithExecutionSettings() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var executionSettings = new PromptExecutionSettings(); + + // Create KernelArguments with execution settings + var kernelArgs = new KernelArguments(executionSettings) + { + ["test"] = "value" + }; + + // Set the arguments via the init property + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, context.Arguments); + Assert.Equal("value", context.Arguments["test"]); + Assert.Same(executionSettings, context.Arguments.ExecutionSettings?[PromptExecutionSettings.DefaultServiceId]); + } + + [Fact] + public void ArgumentsPropertyThrowsWhenBaseArgumentsIsNotKernelArguments() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + ((Microsoft.Extensions.AI.FunctionInvocationContext)context).Arguments = new AIFunctionArguments(); + + // Act & Assert + Assert.Throws(() => context.Arguments); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs new file mode 100644 index 000000000000..c6a0bc8dcb97 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.UnitTests.Functions; + +public class KernelFunctionCloneTests +{ + [Fact] + public async Task ClonedKernelFunctionUsesProvidedKernelWhenInvokingAsAIFunction() + { + // Arrange + var originalKernel = new Kernel(); + var newKernel = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Create an AIFunction from the KernelFunction with the original kernel + var aiFunction = function.AsAIFunction(originalKernel); + + // Act + // Clone the function and create a new AIFunction with the new kernel + var clonedFunction = function.Clone("TestPlugin"); + var clonedAIFunction = clonedFunction.AsAIFunction(newKernel); + + // Invoke both functions + var originalResult = await aiFunction.InvokeAsync(new AIFunctionArguments(), default); + var clonedResult = await clonedAIFunction.InvokeAsync(new AIFunctionArguments(), default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(originalResult); + Assert.NotNull(clonedResult); + Assert.NotEqual(originalResult, clonedResult); + Assert.Equal(originalKernel.GetHashCode().ToString(), originalResult.ToString()); + Assert.Equal(newKernel.GetHashCode().ToString(), clonedResult.ToString()); + } + + [Fact] + public async Task KernelAIFunctionUsesProvidedKernelWhenInvoking() + { + // Arrange + var kernel1 = new Kernel(); + var kernel2 = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Act + // Create AIFunctions with different kernels + var aiFunction1 = function.AsAIFunction(kernel1); + var aiFunction2 = function.AsAIFunction(kernel2); + + // Invoke both functions + var result1 = await aiFunction1.InvokeAsync(new AIFunctionArguments(), default); + var result2 = await aiFunction2.InvokeAsync(new AIFunctionArguments(), default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.NotEqual(result1, result2); + Assert.Equal(kernel1.GetHashCode().ToString(), result1.ToString()); + Assert.Equal(kernel2.GetHashCode().ToString(), result2.ToString()); + } + + [Fact] + public void AsAIFunctionStoresKernelForLaterUse() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + + // Act + var aiFunction = function.AsAIFunction(kernel); + + // Assert + // We can't directly access the private _kernel field, but we can verify it's used + // by checking that the AIFunction has the correct name format + Assert.Equal("TestFunction", aiFunction.Name); + } + + [Fact] + public void ClonePreservesMetadataButChangesPluginName() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod( + () => "Test", + "TestFunction", + "Test description"); + + // Act + var clonedFunction = function.Clone("NewPlugin"); + + // Assert + Assert.Equal("TestFunction", clonedFunction.Name); + Assert.Equal("NewPlugin", clonedFunction.PluginName); + Assert.Equal("Test description", clonedFunction.Description); + } +}