diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentRegistry.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentRegistry.cs new file mode 100644 index 00000000000..f430608b034 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentRegistry.cs @@ -0,0 +1,173 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1204 // Static elements should appear before instance elements + +namespace Microsoft.Extensions.AI.Contents; + +/// +/// Provides a global registry for custom AI content types and their +/// discriminator IDs for use in System.Text.Json polymorphic serialization. +/// +public static class AIContentRegistry +{ + private static readonly ConcurrentDictionary _registry = new(); + private static readonly Dictionary _discriminatorIdToType = typeof(AIContent) + .GetCustomAttributes() + .ToDictionary(attr => (string)attr.TypeDiscriminator!, attr => attr.DerivedType); + + /// + /// Registers a custom AI content type with a discriminator ID. + /// + /// The custom content type to be generated. + /// The type discriminator associated with the type. + /// The contract resolver used for the specified derived type. + public static void RegisterCustomAIContentType(string typeDiscriminatorId, IJsonTypeInfoResolver? resolver = null) + where TContent : AIContent + { + _ = Throw.IfNull(typeDiscriminatorId); + RegisterCore(typeof(TContent), typeDiscriminatorId, resolver); + } + + /// + /// Registers a custom AI content type with a discriminator ID. + /// + /// The custom content type to be generated. + /// The type discriminator associated with the type. + /// The contract resolver used for the specified derived type. + public static void RegisterCustomAIContentType(Type contentType, string typeDiscriminatorId, IJsonTypeInfoResolver? resolver = null) + { + _ = Throw.IfNull(contentType); + _ = Throw.IfNull(typeDiscriminatorId); + + if (!typeof(AIContent).IsAssignableFrom(contentType)) + { + Throw.ArgumentException(nameof(contentType), "The content type must derive from AIContent."); + } + + RegisterCore(contentType, typeDiscriminatorId, resolver); + } + + /// + /// Creates a wrapper that applies the configuration of the registry over the specified resolver. + /// + /// The underlying resolver over which to apply configuration from the registry. + /// A new that applies the configuration from the registry. + public static IJsonTypeInfoResolver ApplyAIContentRegistry(this IJsonTypeInfoResolver resolver) + { + _ = Throw.IfNull(resolver); + return new AIContentRegistryResolver(resolver); + } + + private static void RegisterCore(Type contentType, string typeDiscriminatorId, IJsonTypeInfoResolver? resolver) + { + if (contentType.Assembly == typeof(AIContent).Assembly) + { + Throw.ArgumentException(nameof(contentType), "Cannot register built-in AI content types."); + } + + ValidateConfiguration(contentType, typeDiscriminatorId, resolver, out bool alreadyRegistered); + if (alreadyRegistered) + { + return; + } + + lock (_registry) + { + ValidateConfiguration(contentType, typeDiscriminatorId, resolver, out alreadyRegistered); + if (alreadyRegistered) + { + return; + } + + bool success = _registry.TryAdd(contentType, (typeDiscriminatorId, resolver)); + _discriminatorIdToType.Add(typeDiscriminatorId, contentType); + Debug.Assert(success, "must not conflict with other entries."); + } + + static void ValidateConfiguration(Type contentType, string typeDiscriminatorId, IJsonTypeInfoResolver? resolver, out bool alreadyRegistered) + { + alreadyRegistered = false; + if (_registry.TryGetValue(contentType, out var existing)) + { + if (existing == (typeDiscriminatorId, resolver)) + { + // We have an equivalent registration, return early. + alreadyRegistered = true; + return; + } + + throw new InvalidOperationException($"The content type '{contentType.FullName}' has already been registered with conflicting configuration."); + } + + if (_discriminatorIdToType.TryGetValue(typeDiscriminatorId, out Type? existingType)) + { + throw new InvalidOperationException($"The discriminator ID '{typeDiscriminatorId}' conflicts with that of '{existingType}'."); + } + } + } + + private sealed class AIContentRegistryResolver(IJsonTypeInfoResolver underlying) : IJsonTypeInfoResolver + { + public JsonTypeInfo? GetTypeInfo(Type type, JsonSerializerOptions options) + { + JsonTypeInfo? typeInfo = GetTypeInfoCore(type, options); + + if (typeInfo is not null && typeInfo.Type == typeof(AIContent)) + { + ModifyAIContentTypeInfo(typeInfo); + } + + return typeInfo; + } + + private JsonTypeInfo? GetTypeInfoCore(Type type, JsonSerializerOptions options) + { + JsonTypeInfo? typeInfo = underlying.GetTypeInfo(type, options); + if (typeInfo is not null) + { + return typeInfo; + } + + foreach (var kvp in _registry) + { + if (kvp.Value.Resolver is { } resolver) + { + typeInfo = resolver.GetTypeInfo(type, options); + if (typeInfo is not null) + { + return typeInfo; + } + } + } + + return null; + } + + private static void ModifyAIContentTypeInfo(JsonTypeInfo typeInfo) + { + Debug.Assert(typeInfo.Type == typeof(AIContent), "Should only be used for AIContent types."); + if (typeInfo.PolymorphismOptions is null) + { + Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.None, "A custom converter should have been applied for the type."); + return; + } + + foreach (var entry in _registry) + { + typeInfo.PolymorphismOptions.DerivedTypes.Add(new(entry.Key, entry.Value.DiscriminatorId)); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs index de2c2a695b6..aba0bcc65a8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs @@ -8,6 +8,7 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI.Contents; namespace Microsoft.Extensions.AI; @@ -25,24 +26,29 @@ private static JsonSerializerOptions CreateDefaultOptions() // and we want to be flexible in terms of what can be put into the various collections in the object model. // Otherwise, use the source-generated options to enable trimming and Native AOT. + JsonSerializerOptions options; + if (JsonSerializer.IsReflectionEnabledByDefault) { // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below. - JsonSerializerOptions options = new(JsonSerializerDefaults.Web) + options = new(JsonSerializerDefaults.Web) { - TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + TypeInfoResolver = new DefaultJsonTypeInfoResolver().ApplyAIContentRegistry(), Converters = { new JsonStringEnumConverter() }, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, WriteIndented = true, }; - - options.MakeReadOnly(); - return options; } else { - return JsonContext.Default.Options; + options = new(JsonContext.Default.Options) + { + TypeInfoResolver = JsonContext.Default.ApplyAIContentRegistry() + }; } + + options.MakeReadOnly(); + return options; } // Keep in sync with CreateDefaultOptions above. diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentRegistryTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentRegistryTests.cs new file mode 100644 index 00000000000..1bddd848486 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentRegistryTests.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; +using Xunit; + +namespace Microsoft.Extensions.AI.Contents; + +public static partial class AIContentRegistryTests +{ + [Fact] + public static void DerivedAIContent_SerializeUsingRegistry() + { + JsonSerializerOptions options = AIJsonUtilities.DefaultOptions; + + AIContentRegistry.RegisterCustomAIContentType("derivativeContent", DerivedAIContentContext.Default); + AIContent c = new DerivedAIContent(); + + JsonElement expectedJson = JsonDocument.Parse("""{"$type":"derivativeContent"}""").RootElement; + JsonElement json = JsonSerializer.SerializeToElement(c, options); + Assert.True(JsonElement.DeepEquals(expectedJson, json)); + + AIContent? deserialized = JsonSerializer.Deserialize(json, options); + Assert.IsType(deserialized); + } + + private sealed class DerivedAIContent : AIContent; + + [JsonSerializable(typeof(DerivedAIContent))] + private partial class DerivedAIContentContext : JsonSerializerContext; + + [Fact] + public static void RegisterCustomAIContentType_NonAIContent_ThrowsArgumentException() + { + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(typeof(int), "discriminator")); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(typeof(object), "discriminator")); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(typeof(ChatMessage), "discriminator")); + } + + [Fact] + public static void RegisterCustomAIContentType_BuildInAIContent_ThrowsArgumentException() + { + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType("discriminator")); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType("discriminator")); + } + + [Fact] + public static void RegisterCustomAIContentType_ConflictingIdentifier_ThrowsInvalidOperationException() + { + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType("text")); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType("audio")); + + AIContentRegistry.RegisterCustomAIContentType("discriminator"); + AIContentRegistry.RegisterCustomAIContentType("discriminator"); // Matching configurations are idempotent. + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType("discriminator2")); + } + + private sealed class DerivedAIContent2 : AIContent; + + [Fact] + public static void NullArguments_ThrowsArgumentNullException() + { + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(null!)); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(typeof(DerivedAIContent3), null!)); + Assert.Throws(() => AIContentRegistry.RegisterCustomAIContentType(null!, "discriminator")); + Assert.Throws(() => AIContentRegistry.ApplyAIContentRegistry(null!)); + } + + private sealed class DerivedAIContent3 : AIContent; +}