diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.cs new file mode 100644 index 00000000000..2596b86cd48 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S1121 // Assignments should not be made from within sub-expressions + +namespace Microsoft.Extensions.AI; + +public static partial class AIJsonUtilities +{ + /// + /// Adds a custom content type to the polymorphic configuration for . + /// + /// The custom content type to configure. + /// The options instance to configure. + /// The type discriminator id for the content type. + /// or is . + /// is a built-in content type. + /// is a read-only instance. + public static void AddAIContentType(this JsonSerializerOptions options, string typeDiscriminatorId) + where TContent : AIContent + { + _ = Throw.IfNull(options); + _ = Throw.IfNull(typeDiscriminatorId); + + AddAIContentTypeCore(options, typeof(TContent), typeDiscriminatorId); + } + + /// + /// Adds a custom content type to the polymorphic configuration for . + /// + /// The options instance to configure. + /// The custom content type to configure. + /// The type discriminator id for the content type. + /// , , or is . + /// is a built-in content type or does not derived from . + /// is a read-only instance. + public static void AddAIContentType(this JsonSerializerOptions options, Type contentType, string typeDiscriminatorId) + { + _ = Throw.IfNull(options); + _ = Throw.IfNull(contentType); + _ = Throw.IfNull(typeDiscriminatorId); + + if (!typeof(AIContent).IsAssignableFrom(contentType)) + { + Throw.ArgumentException(nameof(contentType), "The content type must derive from AIContent."); + } + + AddAIContentTypeCore(options, contentType, typeDiscriminatorId); + } + + private static void AddAIContentTypeCore(JsonSerializerOptions options, Type contentType, string typeDiscriminatorId) + { + if (contentType.Assembly == typeof(AIContent).Assembly) + { + Throw.ArgumentException(nameof(contentType), "Cannot register built-in AI content types."); + } + + IJsonTypeInfoResolver resolver = options.TypeInfoResolver ?? DefaultOptions.TypeInfoResolver!; + options.TypeInfoResolver = resolver.WithAddedModifier(typeInfo => + { + if (typeInfo.Type == typeof(AIContent)) + { + (typeInfo.PolymorphismOptions ??= new()).DerivedTypes.Add(new(contentType, typeDiscriminatorId)); + } + }); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index b156de3f18e..e79d2c9034e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -296,4 +296,68 @@ public static void CreateJsonSchema_ValidateWithTestData(ITestData testData) JsonNode? serializedValue = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options); SchemaTestHelpers.AssertDocumentMatchesSchema(schemaAsNode, serializedValue); } + + [Fact] + public static void AddAIContentType_DerivedAIContent() + { + JsonSerializerOptions options = new(); + options.AddAIContentType("derivativeContent"); + + AIContent c = new DerivedAIContent { DerivedValue = 42 }; + string json = JsonSerializer.Serialize(c, options); + Assert.Equal("""{"$type":"derivativeContent","DerivedValue":42,"AdditionalProperties":null}""", json); + + AIContent? deserialized = JsonSerializer.Deserialize(json, options); + Assert.IsType(deserialized); + } + + [Fact] + public static void AddAIContentType_ReadOnlyJsonSerializerOptions_ThrowsInvalidOperationException() + { + Assert.Throws(() => AIJsonUtilities.DefaultOptions.AddAIContentType("derivativeContent")); + } + + [Fact] + public static void AddAIContentType_NonAIContent_ThrowsArgumentException() + { + JsonSerializerOptions options = new(); + Assert.Throws(() => options.AddAIContentType(typeof(int), "discriminator")); + Assert.Throws(() => options.AddAIContentType(typeof(object), "discriminator")); + Assert.Throws(() => options.AddAIContentType(typeof(ChatMessage), "discriminator")); + } + + [Fact] + public static void AddAIContentType_BuiltInAIContent_ThrowsArgumentException() + { + JsonSerializerOptions options = new(); + Assert.Throws(() => options.AddAIContentType("discriminator")); + Assert.Throws(() => options.AddAIContentType("discriminator")); + } + + [Fact] + public static void AddAIContentType_ConflictingIdentifier_ThrowsInvalidOperationException() + { + JsonSerializerOptions options = new(); + options.AddAIContentType("text"); + options.AddAIContentType("audio"); + + AIContent c = new DerivedAIContent(); + Assert.Throws(() => JsonSerializer.Serialize(c, options)); + } + + [Fact] + public static void AddAIContentType_NullArguments_ThrowsArgumentNullException() + { + JsonSerializerOptions options = new(); + Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType("discriminator")); + Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator")); + Assert.Throws(() => options.AddAIContentType(null!)); + Assert.Throws(() => options.AddAIContentType(typeof(DerivedAIContent), null!)); + Assert.Throws(() => options.AddAIContentType(null!, "discriminator")); + } + + private class DerivedAIContent : AIContent + { + public int DerivedValue { get; set; } + } }