diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
index 448c81f61f4..3910040d0a0 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
@@ -38,6 +38,13 @@ public abstract class AIFunction : AITool
///
public virtual JsonElement JsonSchema => AIJsonUtilities.DefaultJsonSchema;
+ /// Gets a JSON Schema describing the function's return value.
+ ///
+ /// A typically reflects a function that doesn't specify a return schema
+ /// or a function that returns , , or .
+ ///
+ public virtual JsonElement? ReturnJsonSchema => null;
+
///
/// Gets the underlying that this might be wrapping.
///
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
index ffd47eb08fc..320df4098a3 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs
@@ -540,6 +540,7 @@ private ReflectionAIFunction(
public override string Description => FunctionDescriptor.Description;
public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
+ public override JsonElement? ReturnJsonSchema => FunctionDescriptor.ReturnJsonSchema;
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
protected override async ValueTask InvokeCoreAsync(
@@ -683,13 +684,17 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
}
- // Get a marshaling delegate for the return value.
- ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions);
-
+ ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType);
Method = key.Method;
Name = key.Name ?? GetFunctionName(key.Method);
Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty;
JsonSerializerOptions = serializerOptions;
+ ReturnJsonSchema = returnType is null ? null : AIJsonUtilities.CreateJsonSchema(
+ returnType,
+ description: key.Method.ReturnParameter.GetCustomAttribute(inherit: true)?.Description,
+ serializerOptions: serializerOptions,
+ inferenceOptions: schemaOptions);
+
JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema(
key.Method,
title: string.Empty, // Forces skipping of the title keyword
@@ -703,6 +708,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public MethodInfo Method { get; }
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
+ public JsonElement? ReturnJsonSchema { get; }
public Func[] ParameterMarshallers { get; }
public Func> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
@@ -849,15 +855,16 @@ static void ThrowNullServices(string parameterName) =>
/// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation.
///
private static Func> GetReturnParameterMarshaller(
- DescriptorKey key, JsonSerializerOptions serializerOptions)
+ DescriptorKey key, JsonSerializerOptions serializerOptions, out Type? returnType)
{
- Type returnType = key.Method.ReturnType;
+ returnType = key.Method.ReturnType;
JsonTypeInfo returnTypeInfo;
Func>? marshalResult = key.MarshalResult;
// Void
if (returnType == typeof(void))
{
+ returnType = null;
if (marshalResult is not null)
{
return (result, cancellationToken) => marshalResult(null, null, cancellationToken);
@@ -869,6 +876,7 @@ static void ThrowNullServices(string parameterName) =>
// Task
if (returnType == typeof(Task))
{
+ returnType = null;
if (marshalResult is not null)
{
return async (result, cancellationToken) =>
@@ -888,6 +896,7 @@ static void ThrowNullServices(string parameterName) =>
// ValueTask
if (returnType == typeof(ValueTask))
{
+ returnType = null;
if (marshalResult is not null)
{
return async (result, cancellationToken) =>
@@ -910,6 +919,8 @@ static void ThrowNullServices(string parameterName) =>
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
+ returnType = taskResultGetter.ReturnType;
+
if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
@@ -920,7 +931,7 @@ static void ThrowNullServices(string parameterName) =>
};
}
- returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
+ returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
@@ -934,6 +945,7 @@ static void ThrowNullServices(string parameterName) =>
{
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
+ returnType = asTaskResultGetter.ReturnType;
if (marshalResult is not null)
{
@@ -946,7 +958,7 @@ static void ThrowNullServices(string parameterName) =>
};
}
- returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
+ returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
@@ -960,7 +972,8 @@ static void ThrowNullServices(string parameterName) =>
// For everything else, just serialize the result as-is.
if (marshalResult is not null)
{
- return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken);
+ Type returnTypeCopy = returnType;
+ return (result, cancellationToken) => marshalResult(result, returnTypeCopy, cancellationToken);
}
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json
index 25fe7307655..499ff4b4a71 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json
@@ -169,6 +169,10 @@
"Member": "virtual System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.AIFunction.JsonSerializerOptions { get; }",
"Stage": "Stable"
},
+ {
+ "Member": "virtual System.Text.Json.JsonElement? Microsoft.Extensions.AI.AIFunction.ReturnJsonSchema { get; }",
+ "Stage": "Stable"
+ },
{
"Member": "virtual System.Reflection.MethodInfo? Microsoft.Extensions.AI.AIFunction.UnderlyingMethod { get; }",
"Stage": "Stable"
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs
index ad4b897cef2..e182d4149bb 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs
@@ -227,7 +227,7 @@ private static JsonNode CreateJsonSchemaCore(
(schemaObj ??= [])[DescriptionPropertyName] = description;
}
- return schemaObj ?? (JsonNode)true;
+ return schemaObj ?? new JsonObject();
}
if (type == typeof(void))
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
index 6d448efb710..84298788e8c 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
@@ -100,22 +100,27 @@ public async Task Returns_AsyncReturnTypesSupported_Async()
AIFunction func;
func = AIFunctionFactory.Create(Task (string a) => Task.FromResult(a + " " + a));
+ Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
func = AIFunctionFactory.Create(ValueTask (string a, string b) => new ValueTask(b + " " + a));
+ Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
long result = 0;
func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); });
+ Assert.Null(func.ReturnJsonSchema);
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);
result = 0;
func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); });
+ Assert.Null(func.ReturnJsonSchema);
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);
func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count), serializerOptions: JsonContext.Default.Options);
+ Assert.Equal("""{"type":"array","items":{"type":"integer"}}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }), JsonContext.Default.Options);
static async IAsyncEnumerable SimpleIAsyncEnumerable(int count)
@@ -220,6 +225,8 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
Assert.Contains("secondParameter", func.JsonSchema.ToString());
+ Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
+
var result = (JsonElement?)await func.InvokeAsync(new()
{
["firstParameter"] = "test",
@@ -265,6 +272,8 @@ public async Task AIFunctionArguments_SatisfiesParameters()
Assert.DoesNotContain("services", func.JsonSchema.ToString());
Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
+ Assert.Equal("""{"type":"integer"}""", func.ReturnJsonSchema.ToString());
+
await Assert.ThrowsAsync("arguments.Services", () => func.InvokeAsync(arguments).AsTask());
arguments.Services = sp;
@@ -430,6 +439,8 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider()
Assert.Contains("myInteger", f.JsonSchema.ToString());
Assert.DoesNotContain("service", f.JsonSchema.ToString());
+ Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());
+
Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
@@ -451,6 +462,8 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices()
Assert.Contains("myInteger", f.JsonSchema.ToString());
Assert.DoesNotContain("service", f.JsonSchema.ToString());
+ Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());
+
Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());
var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
@@ -743,6 +756,7 @@ public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenDerivedTypeReturned()
Assert.Equal(cts.Token, cancellationToken);
return "marshalResultInvoked";
},
+ SerializerOptions = JsonContext.Default.Options,
});
object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
@@ -760,6 +774,17 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
}
+ [Fact]
+ public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
+ {
+ AIFunction f = AIFunctionFactory.Create(Add, serializerOptions: JsonContext.Default.Options);
+
+ Assert.Equal("""{"description":"The summed result","type":"integer"}""", f.ReturnJsonSchema.ToString());
+
+ [return: Description("The summed result")]
+ static int Add(int a, int b) => a + b;
+ }
+
private sealed class MyService(int value)
{
public int Value => value;
@@ -853,5 +878,6 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(StructWithDefaultCtor))]
+ [JsonSerializable(typeof(B))]
private partial class JsonContext : JsonSerializerContext;
}