Skip to content

Commit

Permalink
.Net: First step to make ChatMessageContent more serialization-friend…
Browse files Browse the repository at this point in the history
…ly (#5131)

### Motivation and Context
Today, it is not possible to deserialize ChatMessageContent that has at
least one content item in the Items collection. The reason for this is
that the content items in the Items collection are referenced
polymorphically through the KernelContent abstract class. As a result,
the deserialization process fails because it cannot create an instance
of the abstract class. To solve the problem, the serialization process
should save type information when serializing the Items collection so
that the deserialization process can use this information to find the
type that was serialized and create its instance. Therefore, this PR
leverages [type
discriminators](https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0)
to save the type information and use it during deserialization.

### Description
This PR registers the type discriminator for the KernelContent class and
whitelists its subclasses to participate in polymorphic deserialization.
On top of that, a few content types are modified to be
serializable/deserializable.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Dmytro Struk <[email protected]>
  • Loading branch information
SergeyMenshykh and dmytrostruk committed Feb 26, 2024
1 parent 8b16af6 commit f40ea59
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Xunit;
using Xunit.Abstractions;

namespace Examples;

public class Example85_ChatHistorySerialization : BaseTest
{
/// <summary>
/// Demonstrates how to serialize and deserialize <see cref="ChatHistory"/> class
/// with <see cref="ChatMessageContent"/> having SK various content types as items.
/// </summary>
[Fact]
public void SerializeChatHistoryWithSKContentTypes()
{
var data = new[] { 1, 2, 3 };

var message = new ChatMessageContent(AuthorRole.User, "Describe the factors contributing to climate change.");
message.Items = new ChatMessageContentItemCollection
{
new TextContent("Discuss the potential long-term consequences for the Earth's ecosystem as well."),
new ImageContent(new Uri("https://fake-random-test-host:123")),
new BinaryContent(new BinaryData(data)),
#pragma warning disable SKEXP0005
new AudioContent(new BinaryData(data))
#pragma warning restore SKEXP0005
};

var chatHistory = new ChatHistory(new[] { message });

var chatHistoryJson = JsonSerializer.Serialize(chatHistory);

var deserializedHistory = JsonSerializer.Deserialize<ChatHistory>(chatHistoryJson);

var deserializedMessage = deserializedHistory!.Single();

WriteLine($"Content: {deserializedMessage.Content}");
WriteLine($"Role: {deserializedMessage.Role.Label}");

WriteLine($"Text content: {(deserializedMessage.Items![0]! as TextContent)!.Text}");

WriteLine($"Image content: {(deserializedMessage.Items![1]! as ImageContent)!.Uri}");

WriteLine($"Binary content: {(deserializedMessage.Items![2]! as BinaryContent)!.Content}");

WriteLine($"Audio content: {(deserializedMessage.Items![3]! as AudioContent)!.Data}");
}

/// <summary>
/// Shows how to serialize and deserialize <see cref="ChatHistory"/> class with <see cref="ChatMessageContent"/> having custom content type as item.
/// </summary>
[Fact]
public void SerializeChatWithHistoryWithCustomContentType()
{
var message = new ChatMessageContent(AuthorRole.User, "Describe the factors contributing to climate change.");
message.Items = new ChatMessageContentItemCollection
{
new TextContent("Discuss the potential long-term consequences for the Earth's ecosystem as well."),
new CustomContent("Some custom content"),
};

var chatHistory = new ChatHistory(new[] { message });

// The custom resolver should be used to serialize and deserialize the chat history with custom .
var options = new JsonSerializerOptions
{
TypeInfoResolver = new CustomResolver()
};

var chatHistoryJson = JsonSerializer.Serialize(chatHistory, options);

var deserializedHistory = JsonSerializer.Deserialize<ChatHistory>(chatHistoryJson, options);

var deserializedMessage = deserializedHistory!.Single();

WriteLine($"Content: {deserializedMessage.Content}");
WriteLine($"Role: {deserializedMessage.Role.Label}");

WriteLine($"Text content: {(deserializedMessage.Items![0]! as TextContent)!.Text}");

WriteLine($"Custom content: {(deserializedMessage.Items![1]! as CustomContent)!.Content}");
}

public Example85_ChatHistorySerialization(ITestOutputHelper output) : base(output)
{
}

private sealed class CustomContent : KernelContent
{
public CustomContent(string content) : base(content)
{
Content = content;
}

public string Content { get; }
}

/// <summary>
/// The TypeResolver is used to serialize and deserialize custom content types polymorphically.
/// For more details, refer to the <see href="https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0"/> article.
/// </summary>
private sealed class CustomResolver : DefaultJsonTypeInfoResolver
{
public override JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions options)
{
var jsonTypeInfo = base.GetTypeInfo(type, options);

if (jsonTypeInfo.Type != typeof(KernelContent))
{
return jsonTypeInfo;
}

// It's possible to completely override the polymorphic configuration specified in the KernelContent class
// by using the '=' assignment operator instead of the ??= compound assignment one in the line below.
jsonTypeInfo.PolymorphismOptions ??= new JsonPolymorphismOptions();

// Add custom content type to the list of derived types declared on KernelContent class.
jsonTypeInfo.PolymorphismOptions.DerivedTypes.Add(new JsonDerivedType(typeof(CustomContent), "customContent"));

// Override type discriminator declared on KernelContent class as "$type", if needed.
jsonTypeInfo.PolymorphismOptions.TypeDiscriminatorPropertyName = "name";

return jsonTypeInfo;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.ImageContent.#ctor(System.Uri,System.String,System.Object,System.Text.Encoding,System.Collections.Generic.IReadOnlyDictionary{System.String,System.Object})</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.TextToAudio.ITextToAudioService.GetAudioContentAsync(System.String,Microsoft.SemanticKernel.PromptExecutionSettings,Microsoft.SemanticKernel.Kernel,System.Threading.CancellationToken)</Target>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel;

Expand All @@ -24,6 +25,7 @@ public class AudioContent : KernelContent
/// <param name="modelId">The model ID used to generate the content.</param>
/// <param name="innerContent">Inner content,</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public AudioContent(
BinaryData data,
string? modelId = null,
Expand Down
18 changes: 12 additions & 6 deletions dotnet/src/SemanticKernel.Abstractions/Contents/BinaryContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Microsoft.SemanticKernel;
Expand All @@ -13,7 +14,11 @@ namespace Microsoft.SemanticKernel;
public class BinaryContent : KernelContent
{
private readonly Func<Task<Stream>>? _streamProvider;
private readonly BinaryData? _content;

/// <summary>
/// The binary content.
/// </summary>
public BinaryData? Content { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="BinaryContent"/> class.
Expand All @@ -22,6 +27,7 @@ public class BinaryContent : KernelContent
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="innerContent">Inner content</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public BinaryContent(
BinaryData content,
string? modelId = null,
Expand All @@ -31,7 +37,7 @@ public class BinaryContent : KernelContent
{
Verify.NotNull(content, nameof(content));

this._content = content;
this.Content = content;
}

/// <summary>
Expand Down Expand Up @@ -71,9 +77,9 @@ public async Task<Stream> GetStreamAsync()
return await this._streamProvider.Invoke().ConfigureAwait(false);
}

if (this._content != null)
if (this.Content != null)
{
return this._content.ToStream();
return this.Content.ToStream();
}

throw new KernelException("Null content");
Expand All @@ -90,9 +96,9 @@ public async Task<BinaryData> GetContentAsync()
return await BinaryData.FromStreamAsync(stream).ConfigureAwait(false);
}

if (this._content != null)
if (this.Content != null)
{
return this._content;
return this.Content;
}

throw new KernelException("Null content");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel;

Expand All @@ -19,21 +19,20 @@ public sealed class ImageContent : KernelContent
/// <summary>
/// The image binary data.
/// </summary>
public BinaryData? Data { get; }
public BinaryData? Data { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="ImageContent"/> class.
/// </summary>
/// <param name="uri">The URI of image.</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="innerContent">Inner content</param>
/// <param name="encoding">Encoding of the text</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public ImageContent(
Uri uri,
string? modelId = null,
object? innerContent = null,
Encoding? encoding = null,
IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata)
{
Expand All @@ -46,13 +45,11 @@ public sealed class ImageContent : KernelContent
/// <param name="data">The Data used as DataUri for the image.</param>
/// <param name="modelId">The model ID used to generate the content</param>
/// <param name="innerContent">Inner content</param>
/// <param name="encoding">Encoding of the text</param>
/// <param name="metadata">Additional metadata</param>
public ImageContent(
BinaryData data,
string? modelId = null,
object? innerContent = null,
Encoding? encoding = null,
IReadOnlyDictionary<string, object?>? metadata = null)
: base(innerContent, modelId, metadata)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ namespace Microsoft.SemanticKernel;
/// <summary>
/// Base class for all AI non-streaming results
/// </summary>
[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")]
[JsonDerivedType(typeof(TextContent), typeDiscriminator: nameof(TextContent))]
[JsonDerivedType(typeof(ImageContent), typeDiscriminator: nameof(ImageContent))]
[JsonDerivedType(typeof(BinaryContent), typeDiscriminator: nameof(BinaryContent))]
#pragma warning disable SKEXP0005
[JsonDerivedType(typeof(AudioContent), typeDiscriminator: nameof(AudioContent))]
#pragma warning restore SKEXP0005
public abstract class KernelContent
{
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public sealed class TextContent : KernelContent
/// <param name="innerContent">Inner content</param>
/// <param name="encoding">Encoding of the text</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
public TextContent(string? text, string? modelId = null, object? innerContent = null, Encoding? encoding = null, IReadOnlyDictionary<string, object?>? metadata = null) : base(innerContent, modelId, metadata)
{
this.Text = text;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace SemanticKernel.UnitTests.AI.ChatCompletion;
public class ChatHistoryTests
{
[Fact]
public void ItCanBeSerialised()
public void ItCanBeSerialized()
{
// Arrange
var options = new JsonSerializerOptions();
Expand All @@ -29,7 +29,7 @@ public void ItCanBeSerialised()
}

[Fact]
public void ItCanBeDeserialised()
public void ItCanBeDeserialized()
{
// Arrange
var options = new JsonSerializerOptions();
Expand All @@ -39,15 +39,15 @@ public void ItCanBeDeserialised()
var chatHistoryJson = JsonSerializer.Serialize(chatHistory, options);

// Act
var chatHistoryDeserialised = JsonSerializer.Deserialize<ChatHistory>(chatHistoryJson, options);
var chatHistoryDeserialized = JsonSerializer.Deserialize<ChatHistory>(chatHistoryJson, options);

// Assert
Assert.NotNull(chatHistoryDeserialised);
Assert.Equal(chatHistory.Count, chatHistoryDeserialised.Count);
Assert.NotNull(chatHistoryDeserialized);
Assert.Equal(chatHistory.Count, chatHistoryDeserialized.Count);
for (var i = 0; i < chatHistory.Count; i++)
{
Assert.Equal(chatHistory[i].Role.Label, chatHistoryDeserialised[i].Role.Label);
Assert.Equal(chatHistory[i].Content, chatHistoryDeserialised[i].Content);
Assert.Equal(chatHistory[i].Role.Label, chatHistoryDeserialized[i].Role.Label);
Assert.Equal(chatHistory[i].Content, chatHistoryDeserialized[i].Content);
}
}
}
Loading

0 comments on commit f40ea59

Please sign in to comment.