diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs similarity index 82% rename from src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs index 4878239f35b..3f090a2ac3b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs @@ -2,11 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; #if !NET using System.Linq; #endif @@ -15,23 +17,26 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Collections; using Microsoft.Shared.Diagnostics; #pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S2333 // Redundant modifiers should not be used #pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields -#pragma warning disable SA1118 // Parameter should not span multiple lines -#pragma warning disable SA1500 // Braces for multi-line statements should not share line namespace Microsoft.Extensions.AI; -/// Provides factory methods for creating commonly used implementations of . +/// Provides factory methods for creating commonly-used implementations of . /// Invoke .NET functions using an AI model. public static partial class AIFunctionFactory { + // NOTE: + // Unlike most library code, AIFunctionFactory uses ConfigureAwait(true) rather than ConfigureAwait(false). This is to + // enable AIFunctionFactory to be used with methods that might be context-aware, such as those employing a UI framework. + /// Holds the default options instance used when creating function. private static readonly AIFunctionFactoryOptions _defaultOptions = new(); @@ -71,25 +76,6 @@ public static partial class AIFunctionFactory /// . /// /// - /// - /// - /// By default, parameters attributed with are resolved from the - /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided, - /// is allowed to be ; otherwise, - /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. - /// The handling of such parameters may be overridden via . - /// - /// - /// - /// - /// When the is constructed, it may be passed an via - /// . Any parameter that can be satisfied by that - /// according to will not be included in the generated JSON schema and will be resolved - /// from the provided to via , - /// rather than from the argument collection. The handling of such parameters may be overridden via - /// . - /// - /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -170,23 +156,6 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio /// optional or not. /// /// - /// - /// - /// By default, parameters attributed with are resolved from the - /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided, - /// is allowed to be ; otherwise, - /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. - /// - /// - /// - /// - /// When the is constructed, it may be passed an via - /// . Any parameter that can be satisfied by that - /// according to will not be included in the generated JSON schema and will be resolved - /// from the provided to via , - /// rather than from the argument collection. - /// - /// /// /// All other parameter types are bound from the dictionary passed into /// and are included in the generated JSON schema. @@ -270,25 +239,6 @@ public static AIFunction Create(Delegate method, string? name = null, string? de /// . /// /// - /// - /// - /// By default, parameters attributed with are resolved from the - /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided, - /// is allowed to be ; otherwise, - /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. - /// The handling of such parameters may be overridden via . - /// - /// - /// - /// - /// When the is constructed, it may be passed an via - /// . Any parameter that can be satisfied by that - /// according to will not be included in the generated JSON schema and will be resolved - /// from the provided to via , - /// rather than from the argument collection. The handling of such parameters may be overridden via - /// . - /// - /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -379,23 +329,6 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac /// optional or not. /// /// - /// - /// - /// By default, parameters attributed with are resolved from the - /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided, - /// is allowed to be ; otherwise, - /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. - /// - /// - /// - /// When the is constructed, it may be passed an via - /// . Any parameter that can be satisfied by that - /// according to will not be included in the generated JSON schema and will be resolved - /// from the provided to via , - /// rather than from the argument collection. - /// - /// - /// /// /// All other parameter types are bound from the dictionary passed into /// and are included in the generated JSON schema. @@ -447,10 +380,9 @@ public static AIFunction Create(MethodInfo method, object? target, string? name /// The instance method to be represented via the created . /// /// The to construct an instance of on which to invoke when - /// the resulting is invoked. If is provided, - /// will be used to construct the instance using those services; otherwise, - /// is used, utilizing the type's public parameterless constructor. - /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// the resulting is invoked. is used, + /// utilizing the type's public parameterless constructor. If an instance can't be constructed, an exception is + /// thrown during the function's invocation. /// /// Metadata to use to override defaults inferred from . /// The created for invoking . @@ -494,25 +426,6 @@ public static AIFunction Create(MethodInfo method, object? target, string? name /// . /// /// - /// - /// - /// By default, parameters attributed with are resolved from the - /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided, - /// is allowed to be ; otherwise, - /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. - /// The handling of such parameters may be overridden via . - /// - /// - /// - /// - /// When the is constructed, it may be passed an via - /// . Any parameter that can be satisfied by that - /// according to will not be included in the generated JSON schema and will be resolved - /// from the provided to via , - /// rather than from the argument collection. The handling of such parameters may be overridden via - /// . - /// - /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -627,6 +540,7 @@ private ReflectionAIFunction( { FunctionDescriptor = functionDescriptor; TargetType = targetType; + CreateInstance = options.CreateInstance; AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; } @@ -634,6 +548,8 @@ private ReflectionAIFunction( public object? Target { get; } [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] public Type? TargetType { get; } + public Func? CreateInstance { get; } + public override IReadOnlyDictionary AdditionalProperties { get; } public override string Name => FunctionDescriptor.Name; public override string Description => FunctionDescriptor.Description; @@ -654,9 +570,14 @@ private ReflectionAIFunction( Debug.Assert(target is null, "Expected target to be null when we have a non-null target type"); Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method"); - target = arguments.Services is { } services ? - ActivatorUtilities.CreateInstance(services, targetType!) : + target = CreateInstance is not null ? + CreateInstance(targetType, arguments) : Activator.CreateInstance(targetType); + if (target is null) + { + Throw.InvalidOperationException("Unable to create an instance of the target type."); + } + disposeTarget = true; } @@ -669,7 +590,7 @@ private ReflectionAIFunction( } return await FunctionDescriptor.ReturnParameterMarshaller( - ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken); + ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(true); } finally { @@ -677,7 +598,7 @@ private ReflectionAIFunction( { if (target is IAsyncDisposable ad) { - await ad.DisposeAsync(); + await ad.DisposeAsync().ConfigureAwait(true); } else if (target is IDisposable d) { @@ -709,7 +630,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu serializerOptions.MakeReadOnly(); ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); - DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions); + DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions); if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) { return descriptor; @@ -736,8 +657,6 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions } } - IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService(); - // Use that binding information to impact the schema generation. AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with { @@ -757,21 +676,6 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions return false; } - // If the parameter is attributed as [FromKeyedServices], exclude it, as we'll instead - // get its value from the IServiceProvider. - if (parameterInfo.GetCustomAttribute(inherit: true) is not null) - { - return false; - } - - // We assume that if the services used to create the function support a particular type, - // so too do the services that will be passed into InvokeAsync. This is the same basic assumption - // made in ASP.NET. - if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true) - { - return false; - } - // If there was an existing IncludeParameter delegate, now defer to it as we've // excluded everything we need to exclude. if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter) @@ -793,7 +697,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions options = default; } - ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService); + ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]); } // Get a marshaling delegate for the return value. @@ -863,8 +767,7 @@ static bool IsAsyncMethod(MethodInfo method) private static Func GetParameterMarshaller( JsonSerializerOptions serializerOptions, AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions, - ParameterInfo parameter, - IServiceProviderIsService? serviceProviderIsService) + ParameterInfo parameter) { if (string.IsNullOrWhiteSpace(parameter.Name)) { @@ -911,56 +814,6 @@ static bool IsAsyncMethod(MethodInfo method) }; } - // For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments. - if (parameter.GetCustomAttribute(inherit: true) is { } keyedAttr) - { - return (arguments, _) => - { - if ((arguments.Services as IKeyedServiceProvider)?.GetKeyedService(parameterType, keyedAttr.Key) is { } service) - { - return service; - } - - if (!parameter.HasDefaultValue) - { - if (arguments.Services is null) - { - ThrowNullServices(parameter.Name); - } - - Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'."); - } - - return parameter.DefaultValue; - }; - } - - // For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync - // via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether - // the parameter should be included in the schema. - if (serviceProviderIsService?.IsService(parameterType) is true) - { - return (arguments, _) => - { - if (arguments.Services?.GetService(parameterType) is { } service) - { - return service; - } - - if (!parameter.HasDefaultValue) - { - if (arguments.Services is null) - { - ThrowNullServices(parameter.Name); - } - - Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'."); - } - - return parameter.DefaultValue; - }; - } - // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found. JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType); @@ -1037,14 +890,14 @@ static void ThrowNullServices(string parameterName) => { return async (result, cancellationToken) => { - await ((Task)ThrowIfNullResult(result)); - return await marshalResult(null, null, cancellationToken); + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(true); + return await marshalResult(null, null, cancellationToken).ConfigureAwait(true); }; } return async static (result, _) => { - await ((Task)ThrowIfNullResult(result)); + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(true); return null; }; } @@ -1056,14 +909,14 @@ static void ThrowNullServices(string parameterName) => { return async (result, cancellationToken) => { - await ((ValueTask)ThrowIfNullResult(result)); - return await marshalResult(null, null, cancellationToken); + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(true); + return await marshalResult(null, null, cancellationToken).ConfigureAwait(true); }; } return async static (result, _) => { - await ((ValueTask)ThrowIfNullResult(result)); + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(true); return null; }; } @@ -1078,18 +931,18 @@ static void ThrowNullServices(string parameterName) => { return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)); + await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken); + return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(true); }; } returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)); + await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(true); }; } @@ -1104,9 +957,9 @@ static void ThrowNullServices(string parameterName) => return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task; + await task.ConfigureAwait(true); object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken); + return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(true); }; } @@ -1114,9 +967,9 @@ static void ThrowNullServices(string parameterName) => return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task; + await task.ConfigureAwait(true); object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(true); }; } } @@ -1140,7 +993,7 @@ static void ThrowNullServices(string parameterName) => // Serialize asynchronously to support potential IAsyncEnumerable responses. using PooledMemoryStream stream = new(); - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(true); Utf8JsonReader reader = new(stream.GetBuffer()); return JsonElement.ParseValue(ref reader); } @@ -1169,7 +1022,126 @@ private record struct DescriptorKey( string? Description, Func? GetBindParameterOptions, Func>? MarshalResult, - IServiceProvider? Services, AIJsonSchemaCreateOptions SchemaOptions); } + + /// + /// Removes characters from a .NET member name that shouldn't be used in an AI function name. + /// + /// The .NET member name that should be sanitized. + /// + /// Replaces non-alphanumeric characters in the identifier with the underscore character. + /// Primarily intended to remove characters produced by compiler-generated method name mangling. + /// + private static string SanitizeMemberName(string memberName) => + InvalidNameCharsRegex().Replace(memberName, "_"); + + /// Regex that flags any character other than ASCII digits or letters or the underscore. +#if NET + [GeneratedRegex("[^0-9A-Za-z_]")] + private static partial Regex InvalidNameCharsRegex(); +#else + private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; + private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); +#endif + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + throw; + } +#endif + } + + /// + /// Implements a simple write-only memory stream that uses pooled buffers. + /// + private sealed class PooledMemoryStream : Stream + { + private const int DefaultBufferSize = 4096; + private byte[] _buffer; + private int _position; + + public PooledMemoryStream(int initialCapacity = DefaultBufferSize) + { + _buffer = ArrayPool.Shared.Rent(initialCapacity); + _position = 0; + } + + public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + public override long Length => _position; + public override long Position + { + get => _position; + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + EnsureNotDisposed(); + EnsureCapacity(_position + count); + + Buffer.BlockCopy(buffer, offset, _buffer, _position, count); + _position += count; + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + protected override void Dispose(bool disposing) + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null!; + } + + base.Dispose(disposing); + } + + private void EnsureCapacity(int requiredCapacity) + { + if (requiredCapacity <= _buffer.Length) + { + return; + } + + int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); + byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); + Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); + + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + private void EnsureNotDisposed() + { + if (_buffer is null) + { + Throw(); + static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); + } + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs similarity index 87% rename from src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs index 80c8b485c59..80ff394359d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs @@ -107,14 +107,22 @@ public AIFunctionFactoryOptions() public Func>? MarshalResult { get; set; } /// - /// Gets or sets optional services used in the construction of the . + /// Gets or sets a delegate used with to create the receiver instance. /// /// - /// These services will be used to determine which parameters should be satisifed from dependency injection. As such, - /// what services are satisfied via this provider should match what's satisfied via the provider passed into - /// via . + /// + /// creates instances that invoke an + /// instance method on the specified . This delegate is used to create the instance of the type that will be used to invoke the method. + /// By default if is , is used. If + /// is non-, the delegate is invoked with the to be instantiated and the + /// provided to the method. + /// + /// + /// Each created instance will be used for a single invocation. If the object is or , it will + /// be disposed of after the invocation completes. + /// /// - public IServiceProvider? Services { get; set; } + public Func? CreateInstance { get; set; } /// Provides configuration options produced by the delegate. public readonly record struct ParameterBindingOptions diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index d7bc12a1a41..e35f8b87949 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -7,11 +7,13 @@ using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; #pragma warning disable SA1118 // Parameter should not span multiple lines +#pragma warning disable S2333 // Redundant modifiers should not be used namespace Microsoft.Extensions.AI; @@ -19,7 +21,7 @@ namespace Microsoft.Extensions.AI; /// Provides extension methods on that simplify working with structured output. /// /// Request a response with structured output. -public static class ChatClientStructuredOutputExtensions +public static partial class ChatClientStructuredOutputExtensions { private static readonly AIJsonSchemaCreateOptions _inferenceOptions = new() { @@ -197,7 +199,7 @@ public static async Task> GetResponseAsync( // the LLM backend is meant to do whatever's needed to explain the schema to the LLM. options.ResponseFormat = ChatResponseFormat.ForJsonSchema( schema, - schemaName: AIFunctionFactory.SanitizeMemberName(typeof(T).Name), + schemaName: SanitizeMemberName(typeof(T).Name), schemaDescription: typeof(T).GetCustomAttribute()?.Description); } else @@ -246,4 +248,24 @@ private static bool SchemaRepresentsObject(JsonElement schemaElement) _ => JsonValue.Create(element) }; } + + /// + /// Removes characters from a .NET member name that shouldn't be used in an AI function name. + /// + /// The .NET member name that should be sanitized. + /// + /// Replaces non-alphanumeric characters in the identifier with the underscore character. + /// Primarily intended to remove characters produced by compiler-generated method name mangling. + /// + private static string SanitizeMemberName(string memberName) => + InvalidNameCharsRegex().Replace(memberName, "_"); + + /// Regex that flags any character other than ASCII digits or letters or the underscore. +#if NET + [GeneratedRegex("[^0-9A-Za-z_]")] + private static partial Regex InvalidNameCharsRegex(); +#else + private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; + private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); +#endif } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs deleted file mode 100644 index cbafe78e5d3..00000000000 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.Utilities.cs +++ /dev/null @@ -1,137 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Buffers; -using System.IO; -using System.Reflection; -using System.Text.RegularExpressions; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Extensions.AI; - -public static partial class AIFunctionFactory -{ - /// - /// Removes characters from a .NET member name that shouldn't be used in an AI function name. - /// - /// The .NET member name that should be sanitized. - /// - /// Replaces non-alphanumeric characters in the identifier with the underscore character. - /// Primarily intended to remove characters produced by compiler-generated method name mangling. - /// - internal static string SanitizeMemberName(string memberName) - { - _ = Throw.IfNull(memberName); - return InvalidNameCharsRegex().Replace(memberName, "_"); - } - - /// Regex that flags any character other than ASCII digits or letters or the underscore. -#if NET - [GeneratedRegex("[^0-9A-Za-z_]")] - private static partial Regex InvalidNameCharsRegex(); -#else - private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; - private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); -#endif - - /// Invokes the MethodInfo with the specified target object and arguments. - private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) - { -#if NET - return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); -#else - try - { - return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); - } - catch (TargetInvocationException e) when (e.InnerException is not null) - { - // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions - // is ignored, the original exception will be wrapped in a TargetInvocationException. - // Unwrap it and throw that original exception, maintaining its stack information. - System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - throw; - } -#endif - } - - /// - /// Implements a simple write-only memory stream that uses pooled buffers. - /// - private sealed class PooledMemoryStream : Stream - { - private const int DefaultBufferSize = 4096; - private byte[] _buffer; - private int _position; - - public PooledMemoryStream(int initialCapacity = DefaultBufferSize) - { - _buffer = ArrayPool.Shared.Rent(initialCapacity); - _position = 0; - } - - public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); - public override bool CanWrite => true; - public override bool CanRead => false; - public override bool CanSeek => false; - public override long Length => _position; - public override long Position - { - get => _position; - set => throw new NotSupportedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - EnsureNotDisposed(); - EnsureCapacity(_position + count); - - Buffer.BlockCopy(buffer, offset, _buffer, _position, count); - _position += count; - } - - public override void Flush() - { - } - - public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - - protected override void Dispose(bool disposing) - { - if (_buffer is not null) - { - ArrayPool.Shared.Return(_buffer); - _buffer = null!; - } - - base.Dispose(disposing); - } - - private void EnsureCapacity(int requiredCapacity) - { - if (requiredCapacity <= _buffer.Length) - { - return; - } - - int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); - byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); - Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); - - ArrayPool.Shared.Return(_buffer); - _buffer = newBuffer; - } - - private void EnsureNotDisposed() - { - if (_buffer is null) - { - Throw(); - static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); - } - } - } -} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index 4b5ff9a0600..4f5037fc92d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -299,55 +299,6 @@ public async Task AIFunctionArguments_MissingServicesMayBeOptional() Assert.Equal("", result?.ToString()); } - [Fact] - public async Task IServiceProvider_ServicesInOptionsImpactsFunctionCreation() - { - ServiceCollection sc = new(); - sc.AddSingleton(new MyService(123)); - IServiceProvider sp = sc.BuildServiceProvider(); - - AIFunction func; - - // Services not provided to Create, non-optional argument - if (JsonSerializer.IsReflectionEnabledByDefault) - { - func = AIFunctionFactory.Create((MyService myService) => myService.Value); - Assert.Contains("myService", func.JsonSchema.ToString()); - await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new()).AsTask()); - await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new() { Services = sp }).AsTask()); - } - else - { - Assert.Throws(() => AIFunctionFactory.Create((MyService myService) => myService.Value)); - } - - // Services not provided to Create, optional argument - if (JsonSerializer.IsReflectionEnabledByDefault) - { - func = AIFunctionFactory.Create((MyService? myService = null) => myService?.Value ?? 456); - Assert.Contains("myService", func.JsonSchema.ToString()); - Assert.Contains("456", (await func.InvokeAsync(new()))?.ToString()); - Assert.Contains("456", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); - } - else - { - Assert.Throws(() => AIFunctionFactory.Create((MyService myService) => myService.Value)); - } - - // Services provided to Create, non-optional argument - func = AIFunctionFactory.Create((MyService myService) => myService.Value, new() { Services = sp }); - Assert.DoesNotContain("myService", func.JsonSchema.ToString()); - await Assert.ThrowsAsync("arguments.Services", () => func.InvokeAsync(new()).AsTask()); - await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new() { Services = new ServiceCollection().BuildServiceProvider() }).AsTask()); - Assert.Contains("123", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); - - // Services provided to Create, optional argument - func = AIFunctionFactory.Create((MyService? myService = null) => myService?.Value ?? 456, new() { Services = sp }); - Assert.DoesNotContain("myService", func.JsonSchema.ToString()); - Assert.Contains("456", (await func.InvokeAsync(new()))?.ToString()); - Assert.Contains("123", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); - } - [Fact] public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable() { @@ -364,6 +315,11 @@ public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable( typeof(MyFunctionTypeWithOneArg), new() { + CreateInstance = (type, arguments) => + { + Assert.NotNull(arguments.Services); + return ActivatorUtilities.CreateInstance(arguments.Services, type); + }, MarshalResult = (result, type, cancellationToken) => new ValueTask(result), }); @@ -398,7 +354,7 @@ public async Task Create_NoInstance_ThrowsWhenCantConstructInstance() typeof(MyFunctionTypeWithOneArg)); Assert.NotNull(func); - await Assert.ThrowsAsync(async () => await func.InvokeAsync(new() { Services = sp })); + await Assert.ThrowsAsync(async () => await func.InvokeAsync(new() { Services = sp })); } [Fact] @@ -485,13 +441,13 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider() sc.AddKeyedSingleton("key", service); IServiceProvider sp = sc.BuildServiceProvider(); - AIFunction f = AIFunctionFactory.Create(([FromKeyedServices("key")] MyService service, int myInteger) => service.Value + myInteger); + AIFunction f = AIFunctionFactory.Create(([FromKeyedServices("key")] MyService service, int myInteger) => service.Value + myInteger, + CreateKeyedServicesSupportOptions()); Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); - Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); - Assert.Contains("Services are required", e.Message); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); Assert.Contains("43", result?.ToString()); @@ -506,13 +462,13 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices() sc.AddSingleton(service); IServiceProvider sp = sc.BuildServiceProvider(); - AIFunction f = AIFunctionFactory.Create(([FromKeyedServices(null!)] MyService service, int myInteger) => service.Value + myInteger); + AIFunction f = AIFunctionFactory.Create(([FromKeyedServices(null!)] MyService service, int myInteger) => service.Value + myInteger, + CreateKeyedServicesSupportOptions()); Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); - Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); - Assert.Contains("Services are required", e.Message); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); Assert.Contains("43", result?.ToString()); @@ -528,7 +484,8 @@ public async Task FromKeyedServices_OptionalDefaultsToNull() IServiceProvider sp = sc.BuildServiceProvider(); AIFunction f = AIFunctionFactory.Create(([FromKeyedServices("key")] MyService? service = null, int myInteger = 0) => - service is null ? "null " + 1 : (service.Value + myInteger).ToString()); + service is null ? "null " + 1 : (service.Value + myInteger).ToString(), + CreateKeyedServicesSupportOptions()); Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); @@ -891,6 +848,27 @@ public StructWithDefaultCtor() } } + private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() => + new AIFunctionFactoryOptions + { + ConfigureParameterBinding = p => + { + if (p.GetCustomAttribute() is { } attr) + { + return new() + { + BindParameter = (p, a) => + (a.Services as IKeyedServiceProvider)?.GetKeyedService(p.ParameterType, attr.Key) is { } s ? s : + p.HasDefaultValue ? p.DefaultValue : + throw new ArgumentException($"Unable to resolve argument for '{p.Name}'.", "arguments.Services"), + ExcludeFromSchema = true + }; + } + + return default; + }, + }; + [JsonSerializable(typeof(IAsyncEnumerable))] [JsonSerializable(typeof(int[]))] [JsonSerializable(typeof(string))]