diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
similarity index 82%
rename from src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
index 4878239f35b..3f090a2ac3b 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
@@ -2,11 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.IO;
#if !NET
using System.Linq;
#endif
@@ -15,23 +17,26 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization.Metadata;
+using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
-using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
#pragma warning disable CA1031 // Do not catch general exception types
+#pragma warning disable S2333 // Redundant modifiers should not be used
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
-#pragma warning disable SA1118 // Parameter should not span multiple lines
-#pragma warning disable SA1500 // Braces for multi-line statements should not share line
namespace Microsoft.Extensions.AI;
-/// Provides factory methods for creating commonly used implementations of .
+/// Provides factory methods for creating commonly-used implementations of .
/// Invoke .NET functions using an AI model.
public static partial class AIFunctionFactory
{
+ // NOTE:
+ // Unlike most library code, AIFunctionFactory uses ConfigureAwait(true) rather than ConfigureAwait(false). This is to
+ // enable AIFunctionFactory to be used with methods that might be context-aware, such as those employing a UI framework.
+
/// Holds the default options instance used when creating function.
private static readonly AIFunctionFactoryOptions _defaultOptions = new();
@@ -71,25 +76,6 @@ public static partial class AIFunctionFactory
/// .
///
///
- ///
- ///
- /// By default, parameters attributed with are resolved from the
- /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided,
- /// is allowed to be ; otherwise,
- /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
- /// The handling of such parameters may be overridden via .
- ///
- ///
- ///
- ///
- /// When the is constructed, it may be passed an via
- /// . Any parameter that can be satisfied by that
- /// according to will not be included in the generated JSON schema and will be resolved
- /// from the provided to via ,
- /// rather than from the argument collection. The handling of such parameters may be overridden via
- /// .
- ///
- ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -170,23 +156,6 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio
/// optional or not.
///
///
- ///
- ///
- /// By default, parameters attributed with are resolved from the
- /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided,
- /// is allowed to be ; otherwise,
- /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
- ///
- ///
- ///
- ///
- /// When the is constructed, it may be passed an via
- /// . Any parameter that can be satisfied by that
- /// according to will not be included in the generated JSON schema and will be resolved
- /// from the provided to via ,
- /// rather than from the argument collection.
- ///
- ///
///
/// All other parameter types are bound from the dictionary passed into
/// and are included in the generated JSON schema.
@@ -270,25 +239,6 @@ public static AIFunction Create(Delegate method, string? name = null, string? de
/// .
///
///
- ///
- ///
- /// By default, parameters attributed with are resolved from the
- /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided,
- /// is allowed to be ; otherwise,
- /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
- /// The handling of such parameters may be overridden via .
- ///
- ///
- ///
- ///
- /// When the is constructed, it may be passed an via
- /// . Any parameter that can be satisfied by that
- /// according to will not be included in the generated JSON schema and will be resolved
- /// from the provided to via ,
- /// rather than from the argument collection. The handling of such parameters may be overridden via
- /// .
- ///
- ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -379,23 +329,6 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
/// optional or not.
///
///
- ///
- ///
- /// By default, parameters attributed with are resolved from the
- /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided,
- /// is allowed to be ; otherwise,
- /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
- ///
- ///
- ///
- /// When the is constructed, it may be passed an via
- /// . Any parameter that can be satisfied by that
- /// according to will not be included in the generated JSON schema and will be resolved
- /// from the provided to via ,
- /// rather than from the argument collection.
- ///
- ///
- ///
///
/// All other parameter types are bound from the dictionary passed into
/// and are included in the generated JSON schema.
@@ -447,10 +380,9 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
/// The instance method to be represented via the created .
///
/// The to construct an instance of on which to invoke when
- /// the resulting is invoked. If is provided,
- /// will be used to construct the instance using those services; otherwise,
- /// is used, utilizing the type's public parameterless constructor.
- /// If an instance can't be constructed, an exception is thrown during the function's invocation.
+ /// the resulting is invoked. is used,
+ /// utilizing the type's public parameterless constructor. If an instance can't be constructed, an exception is
+ /// thrown during the function's invocation.
///
/// Metadata to use to override defaults inferred from .
/// The created for invoking .
@@ -494,25 +426,6 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
/// .
///
///
- ///
- ///
- /// By default, parameters attributed with are resolved from the
- /// property and are not included in the JSON schema. If the parameter is optional, such that a default value is provided,
- /// is allowed to be ; otherwise,
- /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter.
- /// The handling of such parameters may be overridden via .
- ///
- ///
- ///
- ///
- /// When the is constructed, it may be passed an via
- /// . Any parameter that can be satisfied by that
- /// according to will not be included in the generated JSON schema and will be resolved
- /// from the provided to via ,
- /// rather than from the argument collection. The handling of such parameters may be overridden via
- /// .
- ///
- ///
///
/// All other parameter types are, by default, bound from the dictionary passed into
/// and are included in the generated JSON schema. This may be overridden by the provided
@@ -627,6 +540,7 @@ private ReflectionAIFunction(
{
FunctionDescriptor = functionDescriptor;
TargetType = targetType;
+ CreateInstance = options.CreateInstance;
AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance;
}
@@ -634,6 +548,8 @@ private ReflectionAIFunction(
public object? Target { get; }
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
public Type? TargetType { get; }
+ public Func? CreateInstance { get; }
+
public override IReadOnlyDictionary AdditionalProperties { get; }
public override string Name => FunctionDescriptor.Name;
public override string Description => FunctionDescriptor.Description;
@@ -654,9 +570,14 @@ private ReflectionAIFunction(
Debug.Assert(target is null, "Expected target to be null when we have a non-null target type");
Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method");
- target = arguments.Services is { } services ?
- ActivatorUtilities.CreateInstance(services, targetType!) :
+ target = CreateInstance is not null ?
+ CreateInstance(targetType, arguments) :
Activator.CreateInstance(targetType);
+ if (target is null)
+ {
+ Throw.InvalidOperationException("Unable to create an instance of the target type.");
+ }
+
disposeTarget = true;
}
@@ -669,7 +590,7 @@ private ReflectionAIFunction(
}
return await FunctionDescriptor.ReturnParameterMarshaller(
- ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken);
+ ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(true);
}
finally
{
@@ -677,7 +598,7 @@ private ReflectionAIFunction(
{
if (target is IAsyncDisposable ad)
{
- await ad.DisposeAsync();
+ await ad.DisposeAsync().ConfigureAwait(true);
}
else if (target is IDisposable d)
{
@@ -709,7 +630,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
serializerOptions.MakeReadOnly();
ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions);
- DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions);
+ DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions);
if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor))
{
return descriptor;
@@ -736,8 +657,6 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
}
}
- IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService();
-
// Use that binding information to impact the schema generation.
AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with
{
@@ -757,21 +676,6 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
return false;
}
- // If the parameter is attributed as [FromKeyedServices], exclude it, as we'll instead
- // get its value from the IServiceProvider.
- if (parameterInfo.GetCustomAttribute(inherit: true) is not null)
- {
- return false;
- }
-
- // We assume that if the services used to create the function support a particular type,
- // so too do the services that will be passed into InvokeAsync. This is the same basic assumption
- // made in ASP.NET.
- if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true)
- {
- return false;
- }
-
// If there was an existing IncludeParameter delegate, now defer to it as we've
// excluded everything we need to exclude.
if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter)
@@ -793,7 +697,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
options = default;
}
- ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService);
+ ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
}
// Get a marshaling delegate for the return value.
@@ -863,8 +767,7 @@ static bool IsAsyncMethod(MethodInfo method)
private static Func GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions,
- ParameterInfo parameter,
- IServiceProviderIsService? serviceProviderIsService)
+ ParameterInfo parameter)
{
if (string.IsNullOrWhiteSpace(parameter.Name))
{
@@ -911,56 +814,6 @@ static bool IsAsyncMethod(MethodInfo method)
};
}
- // For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments.
- if (parameter.GetCustomAttribute(inherit: true) is { } keyedAttr)
- {
- return (arguments, _) =>
- {
- if ((arguments.Services as IKeyedServiceProvider)?.GetKeyedService(parameterType, keyedAttr.Key) is { } service)
- {
- return service;
- }
-
- if (!parameter.HasDefaultValue)
- {
- if (arguments.Services is null)
- {
- ThrowNullServices(parameter.Name);
- }
-
- Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'.");
- }
-
- return parameter.DefaultValue;
- };
- }
-
- // For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync
- // via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether
- // the parameter should be included in the schema.
- if (serviceProviderIsService?.IsService(parameterType) is true)
- {
- return (arguments, _) =>
- {
- if (arguments.Services?.GetService(parameterType) is { } service)
- {
- return service;
- }
-
- if (!parameter.HasDefaultValue)
- {
- if (arguments.Services is null)
- {
- ThrowNullServices(parameter.Name);
- }
-
- Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'.");
- }
-
- return parameter.DefaultValue;
- };
- }
-
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType);
@@ -1037,14 +890,14 @@ static void ThrowNullServices(string parameterName) =>
{
return async (result, cancellationToken) =>
{
- await ((Task)ThrowIfNullResult(result));
- return await marshalResult(null, null, cancellationToken);
+ await ((Task)ThrowIfNullResult(result)).ConfigureAwait(true);
+ return await marshalResult(null, null, cancellationToken).ConfigureAwait(true);
};
}
return async static (result, _) =>
{
- await ((Task)ThrowIfNullResult(result));
+ await ((Task)ThrowIfNullResult(result)).ConfigureAwait(true);
return null;
};
}
@@ -1056,14 +909,14 @@ static void ThrowNullServices(string parameterName) =>
{
return async (result, cancellationToken) =>
{
- await ((ValueTask)ThrowIfNullResult(result));
- return await marshalResult(null, null, cancellationToken);
+ await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(true);
+ return await marshalResult(null, null, cancellationToken).ConfigureAwait(true);
};
}
return async static (result, _) =>
{
- await ((ValueTask)ThrowIfNullResult(result));
+ await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(true);
return null;
};
}
@@ -1078,18 +931,18 @@ static void ThrowNullServices(string parameterName) =>
{
return async (taskObj, cancellationToken) =>
{
- await ((Task)ThrowIfNullResult(taskObj));
+ await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
- return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken);
+ return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(true);
};
}
returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
return async (taskObj, cancellationToken) =>
{
- await ((Task)ThrowIfNullResult(taskObj));
+ await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
- return await SerializeResultAsync(result, returnTypeInfo, cancellationToken);
+ return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(true);
};
}
@@ -1104,9 +957,9 @@ static void ThrowNullServices(string parameterName) =>
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
- await task;
+ await task.ConfigureAwait(true);
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
- return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken);
+ return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(true);
};
}
@@ -1114,9 +967,9 @@ static void ThrowNullServices(string parameterName) =>
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
- await task;
+ await task.ConfigureAwait(true);
object? result = ReflectionInvoke(asTaskResultGetter, task, null);
- return await SerializeResultAsync(result, returnTypeInfo, cancellationToken);
+ return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(true);
};
}
}
@@ -1140,7 +993,7 @@ static void ThrowNullServices(string parameterName) =>
// Serialize asynchronously to support potential IAsyncEnumerable responses.
using PooledMemoryStream stream = new();
- await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken);
+ await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(true);
Utf8JsonReader reader = new(stream.GetBuffer());
return JsonElement.ParseValue(ref reader);
}
@@ -1169,7 +1022,126 @@ private record struct DescriptorKey(
string? Description,
Func? GetBindParameterOptions,
Func