diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs index 448c81f61f4..3910040d0a0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs @@ -38,6 +38,13 @@ public abstract class AIFunction : AITool /// public virtual JsonElement JsonSchema => AIJsonUtilities.DefaultJsonSchema; + /// Gets a JSON Schema describing the function's return value. + /// + /// A typically reflects a function that doesn't specify a return schema + /// or a function that returns , , or . + /// + public virtual JsonElement? ReturnJsonSchema => null; + /// /// Gets the underlying that this might be wrapping. /// diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs index ffd47eb08fc..320df4098a3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs @@ -540,6 +540,7 @@ private ReflectionAIFunction( public override string Description => FunctionDescriptor.Description; public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; + public override JsonElement? ReturnJsonSchema => FunctionDescriptor.ReturnJsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; protected override async ValueTask InvokeCoreAsync( @@ -683,13 +684,17 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]); } - // Get a marshaling delegate for the return value. - ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions); - + ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType); Method = key.Method; Name = key.Name ?? GetFunctionName(key.Method); Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; JsonSerializerOptions = serializerOptions; + ReturnJsonSchema = returnType is null ? null : AIJsonUtilities.CreateJsonSchema( + returnType, + description: key.Method.ReturnParameter.GetCustomAttribute(inherit: true)?.Description, + serializerOptions: serializerOptions, + inferenceOptions: schemaOptions); + JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( key.Method, title: string.Empty, // Forces skipping of the title keyword @@ -703,6 +708,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions public MethodInfo Method { get; } public JsonSerializerOptions JsonSerializerOptions { get; } public JsonElement JsonSchema { get; } + public JsonElement? ReturnJsonSchema { get; } public Func[] ParameterMarshallers { get; } public Func> ReturnParameterMarshaller { get; } public ReflectionAIFunction? CachedDefaultInstance { get; set; } @@ -849,15 +855,16 @@ static void ThrowNullServices(string parameterName) => /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// private static Func> GetReturnParameterMarshaller( - DescriptorKey key, JsonSerializerOptions serializerOptions) + DescriptorKey key, JsonSerializerOptions serializerOptions, out Type? returnType) { - Type returnType = key.Method.ReturnType; + returnType = key.Method.ReturnType; JsonTypeInfo returnTypeInfo; Func>? marshalResult = key.MarshalResult; // Void if (returnType == typeof(void)) { + returnType = null; if (marshalResult is not null) { return (result, cancellationToken) => marshalResult(null, null, cancellationToken); @@ -869,6 +876,7 @@ static void ThrowNullServices(string parameterName) => // Task if (returnType == typeof(Task)) { + returnType = null; if (marshalResult is not null) { return async (result, cancellationToken) => @@ -888,6 +896,7 @@ static void ThrowNullServices(string parameterName) => // ValueTask if (returnType == typeof(ValueTask)) { + returnType = null; if (marshalResult is not null) { return async (result, cancellationToken) => @@ -910,6 +919,8 @@ static void ThrowNullServices(string parameterName) => if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) { MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); + returnType = taskResultGetter.ReturnType; + if (marshalResult is not null) { return async (taskObj, cancellationToken) => @@ -920,7 +931,7 @@ static void ThrowNullServices(string parameterName) => }; } - returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); + returnTypeInfo = serializerOptions.GetTypeInfo(returnType); return async (taskObj, cancellationToken) => { await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true); @@ -934,6 +945,7 @@ static void ThrowNullServices(string parameterName) => { MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); + returnType = asTaskResultGetter.ReturnType; if (marshalResult is not null) { @@ -946,7 +958,7 @@ static void ThrowNullServices(string parameterName) => }; } - returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); + returnTypeInfo = serializerOptions.GetTypeInfo(returnType); return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; @@ -960,7 +972,8 @@ static void ThrowNullServices(string parameterName) => // For everything else, just serialize the result as-is. if (marshalResult is not null) { - return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken); + Type returnTypeCopy = returnType; + return (result, cancellationToken) => marshalResult(result, returnTypeCopy, cancellationToken); } returnTypeInfo = serializerOptions.GetTypeInfo(returnType); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json index 25fe7307655..499ff4b4a71 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json @@ -169,6 +169,10 @@ "Member": "virtual System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.AIFunction.JsonSerializerOptions { get; }", "Stage": "Stable" }, + { + "Member": "virtual System.Text.Json.JsonElement? Microsoft.Extensions.AI.AIFunction.ReturnJsonSchema { get; }", + "Stage": "Stable" + }, { "Member": "virtual System.Reflection.MethodInfo? Microsoft.Extensions.AI.AIFunction.UnderlyingMethod { get; }", "Stage": "Stable" diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs index ad4b897cef2..e182d4149bb 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs @@ -227,7 +227,7 @@ private static JsonNode CreateJsonSchemaCore( (schemaObj ??= [])[DescriptionPropertyName] = description; } - return schemaObj ?? (JsonNode)true; + return schemaObj ?? new JsonObject(); } if (type == typeof(void)) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index 6d448efb710..84298788e8c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -100,22 +100,27 @@ public async Task Returns_AsyncReturnTypesSupported_Async() AIFunction func; func = AIFunctionFactory.Create(Task (string a) => Task.FromResult(a + " " + a)); + Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString()); AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" })); func = AIFunctionFactory.Create(ValueTask (string a, string b) => new ValueTask(b + " " + a)); + Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString()); AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" })); long result = 0; func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); }); + Assert.Null(func.ReturnJsonSchema); AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L })); Assert.Equal(3, result); result = 0; func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); }); + Assert.Null(func.ReturnJsonSchema); AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L })); Assert.Equal(3, result); func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count), serializerOptions: JsonContext.Default.Options); + Assert.Equal("""{"type":"array","items":{"type":"integer"}}""", func.ReturnJsonSchema.ToString()); AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }), JsonContext.Default.Options); static async IAsyncEnumerable SimpleIAsyncEnumerable(int count) @@ -220,6 +225,8 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters() Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString()); Assert.Contains("secondParameter", func.JsonSchema.ToString()); + Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString()); + var result = (JsonElement?)await func.InvokeAsync(new() { ["firstParameter"] = "test", @@ -265,6 +272,8 @@ public async Task AIFunctionArguments_SatisfiesParameters() Assert.DoesNotContain("services", func.JsonSchema.ToString()); Assert.DoesNotContain("arguments", func.JsonSchema.ToString()); + Assert.Equal("""{"type":"integer"}""", func.ReturnJsonSchema.ToString()); + await Assert.ThrowsAsync("arguments.Services", () => func.InvokeAsync(arguments).AsTask()); arguments.Services = sp; @@ -430,6 +439,8 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider() Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); + Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString()); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); @@ -451,6 +462,8 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices() Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); + Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString()); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); @@ -743,6 +756,7 @@ public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenDerivedTypeReturned() Assert.Equal(cts.Token, cancellationToken); return "marshalResultInvoked"; }, + SerializerOptions = JsonContext.Default.Options, }); object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token); @@ -760,6 +774,17 @@ public async Task AIFunctionFactory_DefaultDefaultParameter() Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString()); } + [Fact] + public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute() + { + AIFunction f = AIFunctionFactory.Create(Add, serializerOptions: JsonContext.Default.Options); + + Assert.Equal("""{"description":"The summed result","type":"integer"}""", f.ReturnJsonSchema.ToString()); + + [return: Description("The summed result")] + static int Add(int a, int b) => a + b; + } + private sealed class MyService(int value) { public int Value => value; @@ -853,5 +878,6 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() => [JsonSerializable(typeof(string))] [JsonSerializable(typeof(Guid))] [JsonSerializable(typeof(StructWithDefaultCtor))] + [JsonSerializable(typeof(B))] private partial class JsonContext : JsonSerializerContext; }