diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs new file mode 100644 index 00000000000..2261fd97949 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections; +using System.ComponentModel; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents an embedding composed of a bit vector. +public sealed class BinaryEmbedding : Embedding +{ + /// The embedding vector this embedding represents. + private BitArray _vector; + + /// Initializes a new instance of the class with the embedding vector. + /// The embedding vector this embedding represents. + /// is . + public BinaryEmbedding(BitArray vector) + { + _vector = Throw.IfNull(vector); + } + + /// Gets or sets the embedding vector this embedding represents. + [JsonConverter(typeof(VectorConverter))] + public BitArray Vector + { + get => _vector; + set => _vector = Throw.IfNull(value); + } + + /// + [JsonIgnore] + public override int Dimensions => _vector.Length; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class VectorConverter : JsonConverter + { + /// + public override BitArray Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + _ = Throw.IfNull(typeToConvert); + _ = Throw.IfNull(options); + + if (reader.TokenType != JsonTokenType.String) + { + throw new JsonException("Expected string property."); + } + + ReadOnlySpan utf8; + byte[]? tmpArray = null; + if (!reader.HasValueSequence && !reader.ValueIsEscaped) + { + utf8 = reader.ValueSpan; + } + else + { + // This path should be rare. + int length = reader.HasValueSequence ? checked((int)reader.ValueSequence.Length) : reader.ValueSpan.Length; + tmpArray = ArrayPool.Shared.Rent(length); + utf8 = tmpArray.AsSpan(0, reader.CopyString(tmpArray)); + } + + BitArray result = new(utf8.Length); + + for (int i = 0; i < utf8.Length; i++) + { + result[i] = utf8[i] switch + { + (byte)'0' => false, + (byte)'1' => true, + _ => throw new JsonException("Expected binary character sequence.") + }; + } + + if (tmpArray is not null) + { + ArrayPool.Shared.Return(tmpArray); + } + + return result; + } + + /// + public override void Write(Utf8JsonWriter writer, BitArray value, JsonSerializerOptions options) + { + _ = Throw.IfNull(writer); + _ = Throw.IfNull(value); + _ = Throw.IfNull(options); + + int length = value.Length; + + byte[] tmpArray = ArrayPool.Shared.Rent(length); + + Span utf8 = tmpArray.AsSpan(0, length); + for (int i = 0; i < utf8.Length; i++) + { + utf8[i] = value[i] ? (byte)'1' : (byte)'0'; + } + + writer.WriteStringValue(utf8); + + ArrayPool.Shared.Return(tmpArray); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs index 19b8feaa182..d6596e1e53e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Diagnostics; using System.Text.Json.Serialization; namespace Microsoft.Extensions.AI; @@ -9,13 +10,15 @@ namespace Microsoft.Extensions.AI; /// Represents an embedding generated by a . /// This base class provides metadata about the embedding. Derived types provide the concrete data contained in the embedding. [JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +[JsonDerivedType(typeof(BinaryEmbedding), typeDiscriminator: "binary")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "uint8")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "int8")] #if NET -[JsonDerivedType(typeof(Embedding), typeDiscriminator: "halves")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float16")] #endif -[JsonDerivedType(typeof(Embedding), typeDiscriminator: "floats")] -[JsonDerivedType(typeof(Embedding), typeDiscriminator: "doubles")] -[JsonDerivedType(typeof(Embedding), typeDiscriminator: "bytes")] -[JsonDerivedType(typeof(Embedding), typeDiscriminator: "sbytes")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float32")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float64")] +[DebuggerDisplay("Dimensions = {Dimensions}")] public class Embedding { /// Initializes a new instance of the class. @@ -26,6 +29,13 @@ protected Embedding() /// Gets or sets a timestamp at which the embedding was created. public DateTimeOffset? CreatedAt { get; set; } + /// Gets the dimensionality of the embedding vector. + /// + /// This value corresponds to the number of elements in the embedding vector. + /// + [JsonIgnore] + public virtual int Dimensions { get; } + /// Gets or sets the model ID using in the creation of the embedding. public string? ModelId { get; set; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs index c80e20dfda4..22bc02f2f3f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Text.Json.Serialization; namespace Microsoft.Extensions.AI; @@ -19,4 +20,8 @@ public Embedding(ReadOnlyMemory vector) /// Gets or sets the embedding vector this embedding represents. public ReadOnlyMemory Vector { get; set; } + + /// + [JsonIgnore] + public override int Dimensions => Vector.Length; } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs new file mode 100644 index 00000000000..c75d715466e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Linq; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class BinaryEmbeddingTests +{ + [Fact] + public void Ctor_Roundtrips() + { + BitArray vector = new BitArray(new bool[] { false, true, false, true }); + + BinaryEmbedding e = new(vector); + Assert.Same(vector, e.Vector); + Assert.Null(e.ModelId); + Assert.Null(e.CreatedAt); + Assert.Null(e.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrips() + { + BitArray vector = new BitArray(new bool[] { false, true, false, true }); + + BinaryEmbedding e = new(vector); + + Assert.Same(vector, e.Vector); + BitArray newVector = new BitArray(new bool[] { true, false, true, false }); + e.Vector = newVector; + Assert.Same(newVector, e.Vector); + + Assert.Null(e.ModelId); + e.ModelId = "text-embedding-3-small"; + Assert.Equal("text-embedding-3-small", e.ModelId); + + Assert.Null(e.CreatedAt); + DateTimeOffset createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z"); + e.CreatedAt = createdAt; + Assert.Equal(createdAt, e.CreatedAt); + + Assert.Null(e.AdditionalProperties); + AdditionalPropertiesDictionary props = new(); + e.AdditionalProperties = props; + Assert.Same(props, e.AdditionalProperties); + } + + [Fact] + public void Serialization_Roundtrips() + { + foreach (int length in Enumerable.Range(0, 64).Concat(new[] { 10_000 })) + { + bool[] bools = new bool[length]; + Random r = new(42); + for (int i = 0; i < length; i++) + { + bools[i] = r.Next(2) != 0; + } + + BitArray vector = new BitArray(bools); + BinaryEmbedding e = new(vector); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal($$"""{"$type":"binary","vector":"{{string.Concat(vector.Cast().Select(b => b ? '1' : '0'))}}"}""", json); + + BinaryEmbedding result = Assert.IsType(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector, result.Vector); + } + } + + [Fact] + public void Derialization_SupportsEncodedBits() + { + BinaryEmbedding result = Assert.IsType(JsonSerializer.Deserialize( + """{"$type":"binary","vector":"\u0030\u0031\u0030\u0031\u0030\u0031"}""", + TestJsonSerializerContext.Default.Embedding)); + + Assert.Equal(new BitArray(new[] { false, true, false, true, false, true }), result.Vector); + } + + [Theory] + [InlineData("""{"$type":"binary","vector":"\u0030\u0032"}""")] + [InlineData("""{"$type":"binary","vector":"02"}""")] + [InlineData("""{"$type":"binary","vector":" "}""")] + [InlineData("""{"$type":"binary","vector":10101}""")] + public void Derialization_InvalidBinaryEmbedding_Throws(string json) + { + Assert.Throws(() => JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs index 45fcce8ba63..c3809782006 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs @@ -14,7 +14,7 @@ public class EmbeddingTests public void Embedding_Ctor_Roundtrips() { float[] floats = [1f, 2f, 3f]; - UsageDetails usage = new(); + AdditionalPropertiesDictionary props = []; var createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z"); const string Model = "text-embedding-3-small"; @@ -35,6 +35,32 @@ public void Embedding_Ctor_Roundtrips() Assert.Same(floats, array.Array); } + [Fact] + public void Embedding_Byte_SerializationRoundtrips() + { + byte[] bytes = [1, 2, 3]; + Embedding e = new(bytes); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"uint8","vector":"AQID"}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } + + [Fact] + public void Embedding_SByte_SerializationRoundtrips() + { + sbyte[] bytes = [1, 2, 3]; + Embedding e = new(bytes); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"int8","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } + #if NET [Fact] public void Embedding_Half_SerializationRoundtrips() @@ -43,7 +69,7 @@ public void Embedding_Half_SerializationRoundtrips() Embedding e = new(halfs); string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); - Assert.Equal("""{"$type":"halves","vector":[1,2,3]}""", json); + Assert.Equal("""{"$type":"float16","vector":[1,2,3]}""", json); Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); @@ -57,7 +83,7 @@ public void Embedding_Single_SerializationRoundtrips() Embedding e = new(floats); string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); - Assert.Equal("""{"$type":"floats","vector":[1,2,3]}""", json); + Assert.Equal("""{"$type":"float32","vector":[1,2,3]}""", json); Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); @@ -70,7 +96,7 @@ public void Embedding_Double_SerializationRoundtrips() Embedding e = new(floats); string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); - Assert.Equal("""{"$type":"doubles","vector":[1,2,3]}""", json); + Assert.Equal("""{"$type":"float64","vector":[1,2,3]}""", json); Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs deleted file mode 100644 index f538d1476b0..00000000000 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs +++ /dev/null @@ -1,16 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; - -namespace Microsoft.Extensions.AI; - -internal sealed class BinaryEmbedding : Embedding -{ - public BinaryEmbedding(ReadOnlyMemory bits) - { - Bits = bits; - } - - public ReadOnlyMemory Bits { get; } -} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs index 1188e899e4d..1504d0d2488 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +#if NET +using System.Collections; +#endif using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -148,7 +151,14 @@ public async Task Quantization_Binary_EmbeddingsCompareSuccessfully() { for (int j = 0; j < embeddings.Count; j++) { - distances[i, j] = TensorPrimitives.HammingBitDistance(embeddings[i].Bits.Span, embeddings[j].Bits.Span); + distances[i, j] = TensorPrimitives.HammingBitDistance(ToArray(embeddings[i].Vector), ToArray(embeddings[j].Vector)); + + static byte[] ToArray(BitArray array) + { + byte[] result = new byte[(array.Length + 7) / 8]; + array.CopyTo(result, 0); + return result; + } } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs index 3bf33988146..ea87408da38 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Collections.Generic; using System.Linq; #if NET @@ -46,12 +47,12 @@ private static BinaryEmbedding QuantizeToBinary(Embedding embedding) { ReadOnlySpan vector = embedding.Vector.Span; - var result = new byte[(int)Math.Ceiling(vector.Length / 8.0)]; + var result = new BitArray(vector.Length); for (int i = 0; i < vector.Length; i++) { if (vector[i] > 0) { - result[i / 8] |= (byte)(1 << (i % 8)); + result[i / 8] = true; } }