Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Description = key.Description ?? key.Method.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description ?? string.Empty;
JsonSerializerOptions = serializerOptions;
ReturnJsonSchema = returnType is null || key.ExcludeResultSchema ? null : AIJsonUtilities.CreateJsonSchema(
returnType,
NormalizeReturnType(returnType, serializerOptions),
description: key.Method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
serializerOptions: serializerOptions,
inferenceOptions: schemaOptions);
Expand Down Expand Up @@ -978,6 +978,7 @@ static void ThrowNullServices(string parameterName) =>
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
returnType = taskResultGetter.ReturnType;

// If a MarshalResult delegate is provided, use it.
if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
Expand All @@ -988,6 +989,18 @@ static void ThrowNullServices(string parameterName) =>
};
}

// Special-case AIContent results to not be serialized, so that IChatClients can type test and handle them
// specially, such as by returning content to the model/service in a manner appropriate to the content type.
if (IsAIContentRelatedType(returnType))
{
return async (taskObj, cancellationToken) =>
{
await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(true);
return ReflectionInvoke(taskResultGetter, taskObj, null);
};
}

// For everything else, just serialize the result as-is.
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
Expand All @@ -1004,6 +1017,7 @@ static void ThrowNullServices(string parameterName) =>
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
returnType = asTaskResultGetter.ReturnType;

// If a MarshalResult delegate is provided, use it.
if (marshalResult is not null)
{
return async (taskObj, cancellationToken) =>
Expand All @@ -1015,6 +1029,19 @@ static void ThrowNullServices(string parameterName) =>
};
}

// Special-case AIContent results to not be serialized, so that IChatClients can type test and handle them
// specially, such as by returning content to the model/service in a manner appropriate to the content type.
if (IsAIContentRelatedType(returnType))
{
return async (taskObj, cancellationToken) =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
await task.ConfigureAwait(true);
return ReflectionInvoke(asTaskResultGetter, task, null);
};
}

// For everything else, just serialize the result as-is.
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return async (taskObj, cancellationToken) =>
{
Expand All @@ -1026,13 +1053,21 @@ static void ThrowNullServices(string parameterName) =>
}
}

// For everything else, just serialize the result as-is.
// If a MarshalResult delegate is provided, use it.
if (marshalResult is not null)
{
Type returnTypeCopy = returnType;
return (result, cancellationToken) => marshalResult(result, returnTypeCopy, cancellationToken);
}

// Special-case AIContent results to not be serialized, so that IChatClients can type test and handle them
// specially, such as by returning content to the model/service in a manner appropriate to the content type.
if (IsAIContentRelatedType(returnType))
{
return static (result, _) => new ValueTask<object?>(result);
}

// For everything else, just serialize the result as-is.
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);

Expand Down Expand Up @@ -1069,6 +1104,41 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#endif
}

private static bool IsAIContentRelatedType(Type type) =>
typeof(AIContent).IsAssignableFrom(type) ||
typeof(IEnumerable<AIContent>).IsAssignableFrom(type);

private static Type NormalizeReturnType(Type type, JsonSerializerOptions? options)
{
options ??= AIJsonUtilities.DefaultOptions;

if (options == AIJsonUtilities.DefaultOptions && !options.TryGetTypeInfo(type, out _))
{
// GetTypeInfo is not polymorphic, so attempts to look up derived types will fail even if the
// base type is registered. In some cases, though, we can fall back to using interfaces
// we know we have contracts for in AIJsonUtilities.DefaultOptions where the semantics of using
// that interface will be reasonable. This should really only affect situations where
// reflection-based serialization is disabled.

if (typeof(IEnumerable<AIContent>).IsAssignableFrom(type))
{
return typeof(IEnumerable<AIContent>);
}

if (typeof(IEnumerable<ChatMessage>).IsAssignableFrom(type))
{
return typeof(IEnumerable<ChatMessage>);
}

if (typeof(IEnumerable<string>).IsAssignableFrom(type))
{
return typeof(IEnumerable<string>);
}
}

return type;
}

private record struct DescriptorKey(
MethodInfo Method,
string? Name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,23 @@ private static JsonSerializerOptions CreateDefaultOptions()
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(SpeechToTextOptions))]
[JsonSerializable(typeof(SpeechToTextClientMetadata))]
[JsonSerializable(typeof(SpeechToTextResponse))]
[JsonSerializable(typeof(SpeechToTextResponseUpdate))]
[JsonSerializable(typeof(IReadOnlyList<SpeechToTextResponseUpdate>))]
[JsonSerializable(typeof(ImageGenerationOptions))]
[JsonSerializable(typeof(ImageGenerationResponse))]
[JsonSerializable(typeof(IList<ChatMessage>))]
[JsonSerializable(typeof(IEnumerable<ChatMessage>))]
[JsonSerializable(typeof(ChatMessage[]))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
[JsonSerializable(typeof(ChatClientMetadata))]
[JsonSerializable(typeof(EmbeddingGeneratorMetadata))]
[JsonSerializable(typeof(ChatResponse))]
[JsonSerializable(typeof(ChatResponseUpdate))]
[JsonSerializable(typeof(IReadOnlyList<ChatResponseUpdate>))]
[JsonSerializable(typeof(Dictionary<string, object>))]
[JsonSerializable(typeof(IDictionary<string, object?>))]

// JSON
[JsonSerializable(typeof(JsonDocument))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(JsonNode))]
[JsonSerializable(typeof(JsonObject))]
[JsonSerializable(typeof(JsonValue))]
[JsonSerializable(typeof(JsonArray))]
[JsonSerializable(typeof(IEnumerable<string>))]
[JsonSerializable(typeof(char))]

// Primitives
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(char))]
[JsonSerializable(typeof(short))]
[JsonSerializable(typeof(long))]
[JsonSerializable(typeof(uint))]
[JsonSerializable(typeof(ushort))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(uint))]
[JsonSerializable(typeof(long))]
[JsonSerializable(typeof(ulong))]
[JsonSerializable(typeof(float))]
[JsonSerializable(typeof(double))]
Expand All @@ -116,26 +100,58 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(TimeSpan))]
[JsonSerializable(typeof(DateTime))]
[JsonSerializable(typeof(DateTimeOffset))]
[JsonSerializable(typeof(Embedding))]
[JsonSerializable(typeof(Embedding<byte>))]
[JsonSerializable(typeof(Embedding<int>))]
#if NET
[JsonSerializable(typeof(Embedding<Half>))]
#endif
[JsonSerializable(typeof(Embedding<float>))]
[JsonSerializable(typeof(Embedding<double>))]
[JsonSerializable(typeof(AIContent))]

// AIFunction
[JsonSerializable(typeof(AIFunctionArguments))]

// Temporary workaround:
// These should be added in once they're no longer [Experimental] and included via [JsonDerivedType] on AIContent.
// IChatClient
[JsonSerializable(typeof(IEnumerable<ChatMessage>))]
[JsonSerializable(typeof(IList<ChatMessage>))]
[JsonSerializable(typeof(ChatMessage[]))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(ChatClientMetadata))]
[JsonSerializable(typeof(ChatResponse))]
[JsonSerializable(typeof(ChatResponseUpdate))]
[JsonSerializable(typeof(IReadOnlyList<ChatResponseUpdate>))]
[JsonSerializable(typeof(Dictionary<string, object>))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(IEnumerable<string>))]
[JsonSerializable(typeof(AIContent))]
[JsonSerializable(typeof(IEnumerable<AIContent>))]

// Temporary workaround: These should be implicitly added in once they're no longer [Experimental]
// and are included via [JsonDerivedType] on AIContent.
[JsonSerializable(typeof(FunctionApprovalRequestContent))]
[JsonSerializable(typeof(FunctionApprovalResponseContent))]
[JsonSerializable(typeof(McpServerToolCallContent))]
[JsonSerializable(typeof(McpServerToolResultContent))]
[JsonSerializable(typeof(McpServerToolApprovalRequestContent))]
[JsonSerializable(typeof(McpServerToolApprovalResponseContent))]
[JsonSerializable(typeof(ResponseContinuationToken))]

// IEmbeddingGenerator
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
[JsonSerializable(typeof(EmbeddingGeneratorMetadata))]
[JsonSerializable(typeof(Embedding))]
[JsonSerializable(typeof(Embedding<byte>))]
[JsonSerializable(typeof(Embedding<int>))]
#if NET
[JsonSerializable(typeof(Embedding<Half>))]
#endif
[JsonSerializable(typeof(Embedding<float>))]
[JsonSerializable(typeof(Embedding<double>))]

// ISpeechToTextClient
[JsonSerializable(typeof(SpeechToTextOptions))]
[JsonSerializable(typeof(SpeechToTextClientMetadata))]
[JsonSerializable(typeof(SpeechToTextResponse))]
[JsonSerializable(typeof(SpeechToTextResponseUpdate))]
[JsonSerializable(typeof(IReadOnlyList<SpeechToTextResponseUpdate>))]

// IImageGenerator
[JsonSerializable(typeof(ImageGenerationOptions))]
[JsonSerializable(typeof(ImageGenerationResponse))]

[EditorBrowsable(EditorBrowsableState.Never)] // Never use JsonContext directly, use DefaultOptions instead.
private sealed partial class JsonContext : JsonSerializerContext;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,15 +760,27 @@ internal static IEnumerable<ResponseItem> ToOpenAIResponseItems(IEnumerable<Chat

case FunctionResultContent resultContent:
string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null)
if (result is null && resultContent.Result is { } resultObj)
{
try
switch (resultObj)
{
result = JsonSerializer.Serialize(resultContent.Result, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
// If the type can't be serialized, skip it.
// https://github.com/openai/openai-dotnet/issues/759
// Once OpenAI supports other forms of tool call outputs, special-case various AIContent types here, e.g.
// case DataContent
// case HostedFileContent
// case IEnumerable<AIContent>
// etc.

default:
try
{
result = JsonSerializer.Serialize(resultContent.Result, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
// If the type can't be serialized, skip it.
}
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Text;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -85,12 +86,24 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
return new() { Content = new StringContent(_expectedOutput) };
}

public static string? RemoveWhiteSpace(string? text) =>
text is null ? null :
Regex.Replace(text, @"\s*", string.Empty);
[return: NotNullIfNotNull(nameof(text))]
public static string? RemoveWhiteSpace(string? text)
{
if (text is null)
{
return null;
}

text = text.Replace("\\r", "").Replace("\\n", "").Replace("\\t", "");

return Regex.Replace(text, @"\s*", string.Empty);
}

private static void AssertEqualNormalized(string expected, string actual)
{
expected = RemoveWhiteSpace(expected);
actual = RemoveWhiteSpace(actual);

// First try to compare as JSON.
JsonNode? expectedNode = null;
JsonNode? actualNode = null;
Expand All @@ -114,10 +127,7 @@ private static void AssertEqualNormalized(string expected, string actual)
}

// Legitimately may not have been JSON. Fall back to whitespace normalization.
if (RemoveWhiteSpace(expected) != RemoveWhiteSpace(actual))
{
FailNotEqual(expected, actual);
}
FailNotEqual(expected, actual);
}

private static void FailNotEqual(string expected, string actual) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,26 +1383,26 @@ public async Task AssistantMessageWithBothToolsAndContent_NonStreaming()
"tool_calls": [
{
"id": "12345",
"type": "function",
"function": {
"name": "SayHello",
"arguments": "null"
},
"type": "function"
}
},
{
"id": "12346",
"type": "function",
"function": {
"name": "SayHi",
"arguments": "null"
},
"type": "function"
}
}
]
},
{
"role": "tool",
"tool_call_id": "12345",
"content": "Said hello"
"content": "{ \"$type\": \"text\", \"text\": \"Said hello\" }"
},
{
"role":"tool",
Expand Down Expand Up @@ -1471,7 +1471,7 @@ public async Task AssistantMessageWithBothToolsAndContent_NonStreaming()
]),
new (ChatRole.Tool,
[
new FunctionResultContent("12345", "Said hello"),
new FunctionResultContent("12345", new TextContent("Said hello")),
new FunctionResultContent("12346", "Said hi"),
]),
new(ChatRole.Assistant, "You are great."),
Expand Down
Loading
Loading