Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ public abstract class AIFunction : AITool
/// </remarks>
public virtual JsonElement JsonSchema => AIJsonUtilities.DefaultJsonSchema;

/// <summary>Gets a JSON Schema describing the function's return value.</summary>
/// <remarks>
/// A <see langword="null"/> typically reflects a function that doesn't specify a return schema
/// or a function that returns <see cref="void"/>, <see cref="Task"/>, or <see cref="ValueTask"/>.
/// </remarks>
public virtual JsonElement? ReturnJsonSchema => null;

/// <summary>
/// Gets the underlying <see cref="MethodInfo"/> that this <see cref="AIFunction"/> might be wrapping.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ private ReflectionAIFunction(
public override string Description => FunctionDescriptor.Description;
public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
public override JsonElement? ReturnJsonSchema => FunctionDescriptor.ReturnJsonSchema;
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;

protected override async ValueTask<object?> InvokeCoreAsync(
Expand Down Expand Up @@ -683,13 +684,17 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
}

// Get a marshaling delegate for the return value.
ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions);

ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType);
Method = key.Method;
Name = key.Name ?? GetFunctionName(key.Method);
Description = key.Description ?? key.Method.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description ?? string.Empty;
JsonSerializerOptions = serializerOptions;
ReturnJsonSchema = returnType is null ? null : AIJsonUtilities.CreateJsonSchema(
returnType,
description: key.Method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
serializerOptions: serializerOptions,
inferenceOptions: schemaOptions);

JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema(
key.Method,
title: string.Empty, // Forces skipping of the title keyword
Expand All @@ -703,6 +708,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public MethodInfo Method { get; }
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
public JsonElement? ReturnJsonSchema { get; }
public Func<AIFunctionArguments, CancellationToken, object?>[] ParameterMarshallers { get; }
public Func<object?, CancellationToken, ValueTask<object?>> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
Expand Down Expand Up @@ -849,15 +855,16 @@ static void ThrowNullServices(string parameterName) =>
/// Gets a delegate for handling the result value of a method, converting it into the <see cref="Task{FunctionResult}"/> to return from the invocation.
/// </summary>
private static Func<object?, CancellationToken, ValueTask<object?>> GetReturnParameterMarshaller(
DescriptorKey key, JsonSerializerOptions serializerOptions)
DescriptorKey key, JsonSerializerOptions serializerOptions, out Type? returnType)
{
Type returnType = key.Method.ReturnType;
returnType = key.Method.ReturnType;
JsonTypeInfo returnTypeInfo;
Func<object?, Type?, CancellationToken, ValueTask<object?>>? marshalResult = key.MarshalResult;

// Void
if (returnType == typeof(void))
{
returnType = null;
if (marshalResult is not null)
{
return (result, cancellationToken) => marshalResult(null, null, cancellationToken);
Expand All @@ -869,6 +876,7 @@ static void ThrowNullServices(string parameterName) =>
// Task
if (returnType == typeof(Task))
{
returnType = null;
if (marshalResult is not null)
{
return async (result, cancellationToken) =>
Expand All @@ -888,6 +896,7 @@ static void ThrowNullServices(string parameterName) =>
// ValueTask
if (returnType == typeof(ValueTask))
{
returnType = null;
if (marshalResult is not null)
{
return async (result, cancellationToken) =>
Expand All @@ -910,6 +919,8 @@ static void ThrowNullServices(string parameterName) =>
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
returnType = taskResultGetter.ReturnType;

if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
Expand All @@ -920,7 +931,7 @@ static void ThrowNullServices(string parameterName) =>
};
}

returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
Expand All @@ -934,6 +945,7 @@ static void ThrowNullServices(string parameterName) =>
{
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
returnType = asTaskResultGetter.ReturnType;

if (marshalResult is not null)
{
Expand All @@ -946,7 +958,7 @@ static void ThrowNullServices(string parameterName) =>
};
}

returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
Expand All @@ -960,7 +972,8 @@ static void ThrowNullServices(string parameterName) =>
// For everything else, just serialize the result as-is.
if (marshalResult is not null)
{
return (result, cancellationToken) => marshalResult(result, returnType, cancellationToken);
Type returnTypeCopy = returnType;
return (result, cancellationToken) => marshalResult(result, returnTypeCopy, cancellationToken);
}

returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@
"Member": "virtual System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.AIFunction.JsonSerializerOptions { get; }",
"Stage": "Stable"
},
{
"Member": "virtual System.Text.Json.JsonElement? Microsoft.Extensions.AI.AIFunction.ReturnJsonSchema { get; }",
"Stage": "Stable"
},
{
"Member": "virtual System.Reflection.MethodInfo? Microsoft.Extensions.AI.AIFunction.UnderlyingMethod { get; }",
"Stage": "Stable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ private static JsonNode CreateJsonSchemaCore(
(schemaObj ??= [])[DescriptionPropertyName] = description;
}

return schemaObj ?? (JsonNode)true;
return schemaObj ?? new JsonObject();
}

if (type == typeof(void))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,27 @@ public async Task Returns_AsyncReturnTypesSupported_Async()
AIFunction func;

func = AIFunctionFactory.Create(Task<string> (string a) => Task.FromResult(a + " " + a));
Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));

func = AIFunctionFactory.Create(ValueTask<string> (string a, string b) => new ValueTask<string>(b + " " + a));
Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));

long result = 0;
func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); });
Assert.Null(func.ReturnJsonSchema);
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);

result = 0;
func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); });
Assert.Null(func.ReturnJsonSchema);
AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);

func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count), serializerOptions: JsonContext.Default.Options);
Assert.Equal("""{"type":"array","items":{"type":"integer"}}""", func.ReturnJsonSchema.ToString());
AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }), JsonContext.Default.Options);

static async IAsyncEnumerable<int> SimpleIAsyncEnumerable(int count)
Expand Down Expand Up @@ -220,6 +225,8 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
Assert.Contains("secondParameter", func.JsonSchema.ToString());

Assert.Equal("""{"type":"string"}""", func.ReturnJsonSchema.ToString());

var result = (JsonElement?)await func.InvokeAsync(new()
{
["firstParameter"] = "test",
Expand Down Expand Up @@ -265,6 +272,8 @@ public async Task AIFunctionArguments_SatisfiesParameters()
Assert.DoesNotContain("services", func.JsonSchema.ToString());
Assert.DoesNotContain("arguments", func.JsonSchema.ToString());

Assert.Equal("""{"type":"integer"}""", func.ReturnJsonSchema.ToString());

await Assert.ThrowsAsync<ArgumentNullException>("arguments.Services", () => func.InvokeAsync(arguments).AsTask());

arguments.Services = sp;
Expand Down Expand Up @@ -430,6 +439,8 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider()
Assert.Contains("myInteger", f.JsonSchema.ToString());
Assert.DoesNotContain("service", f.JsonSchema.ToString());

Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());

Exception e = await Assert.ThrowsAsync<ArgumentException>("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());

var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
Expand All @@ -451,6 +462,8 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices()
Assert.Contains("myInteger", f.JsonSchema.ToString());
Assert.DoesNotContain("service", f.JsonSchema.ToString());

Assert.Equal("""{"type":"integer"}""", f.ReturnJsonSchema.ToString());

Exception e = await Assert.ThrowsAsync<ArgumentException>("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask());

var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp });
Expand Down Expand Up @@ -743,6 +756,7 @@ public async Task MarshalResult_TypeIsDeclaredTypeEvenWhenDerivedTypeReturned()
Assert.Equal(cts.Token, cancellationToken);
return "marshalResultInvoked";
},
SerializerOptions = JsonContext.Default.Options,
});

object? result = await f.InvokeAsync(new() { ["i"] = 42 }, cts.Token);
Expand All @@ -760,6 +774,17 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
}

[Fact]
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
{
AIFunction f = AIFunctionFactory.Create(Add, serializerOptions: JsonContext.Default.Options);

Assert.Equal("""{"description":"The summed result","type":"integer"}""", f.ReturnJsonSchema.ToString());

[return: Description("The summed result")]
static int Add(int a, int b) => a + b;
}

private sealed class MyService(int value)
{
public int Value => value;
Expand Down Expand Up @@ -853,5 +878,6 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(StructWithDefaultCtor))]
[JsonSerializable(typeof(B))]
private partial class JsonContext : JsonSerializerContext;
}
Loading