diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs
new file mode 100644
index 00000000000..2261fd97949
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/BinaryEmbedding.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Buffers;
+using System.Collections;
+using System.ComponentModel;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using Microsoft.Shared.Diagnostics;
+
+namespace Microsoft.Extensions.AI;
+
+/// Represents an embedding composed of a bit vector.
+public sealed class BinaryEmbedding : Embedding
+{
+ /// The embedding vector this embedding represents.
+ private BitArray _vector;
+
+ /// Initializes a new instance of the class with the embedding vector.
+ /// The embedding vector this embedding represents.
+ /// is .
+ public BinaryEmbedding(BitArray vector)
+ {
+ _vector = Throw.IfNull(vector);
+ }
+
+ /// Gets or sets the embedding vector this embedding represents.
+ [JsonConverter(typeof(VectorConverter))]
+ public BitArray Vector
+ {
+ get => _vector;
+ set => _vector = Throw.IfNull(value);
+ }
+
+ ///
+ [JsonIgnore]
+ public override int Dimensions => _vector.Length;
+
+ /// Provides a for serializing instances.
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public sealed class VectorConverter : JsonConverter
+ {
+ ///
+ public override BitArray Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ _ = Throw.IfNull(typeToConvert);
+ _ = Throw.IfNull(options);
+
+ if (reader.TokenType != JsonTokenType.String)
+ {
+ throw new JsonException("Expected string property.");
+ }
+
+ ReadOnlySpan utf8;
+ byte[]? tmpArray = null;
+ if (!reader.HasValueSequence && !reader.ValueIsEscaped)
+ {
+ utf8 = reader.ValueSpan;
+ }
+ else
+ {
+ // This path should be rare.
+ int length = reader.HasValueSequence ? checked((int)reader.ValueSequence.Length) : reader.ValueSpan.Length;
+ tmpArray = ArrayPool.Shared.Rent(length);
+ utf8 = tmpArray.AsSpan(0, reader.CopyString(tmpArray));
+ }
+
+ BitArray result = new(utf8.Length);
+
+ for (int i = 0; i < utf8.Length; i++)
+ {
+ result[i] = utf8[i] switch
+ {
+ (byte)'0' => false,
+ (byte)'1' => true,
+ _ => throw new JsonException("Expected binary character sequence.")
+ };
+ }
+
+ if (tmpArray is not null)
+ {
+ ArrayPool.Shared.Return(tmpArray);
+ }
+
+ return result;
+ }
+
+ ///
+ public override void Write(Utf8JsonWriter writer, BitArray value, JsonSerializerOptions options)
+ {
+ _ = Throw.IfNull(writer);
+ _ = Throw.IfNull(value);
+ _ = Throw.IfNull(options);
+
+ int length = value.Length;
+
+ byte[] tmpArray = ArrayPool.Shared.Rent(length);
+
+ Span utf8 = tmpArray.AsSpan(0, length);
+ for (int i = 0; i < utf8.Length; i++)
+ {
+ utf8[i] = value[i] ? (byte)'1' : (byte)'0';
+ }
+
+ writer.WriteStringValue(utf8);
+
+ ArrayPool.Shared.Return(tmpArray);
+ }
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs
index 19b8feaa182..d6596e1e53e 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Diagnostics;
using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI;
@@ -9,13 +10,15 @@ namespace Microsoft.Extensions.AI;
/// Represents an embedding generated by a .
/// This base class provides metadata about the embedding. Derived types provide the concrete data contained in the embedding.
[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")]
+[JsonDerivedType(typeof(BinaryEmbedding), typeDiscriminator: "binary")]
+[JsonDerivedType(typeof(Embedding), typeDiscriminator: "uint8")]
+[JsonDerivedType(typeof(Embedding), typeDiscriminator: "int8")]
#if NET
-[JsonDerivedType(typeof(Embedding), typeDiscriminator: "halves")]
+[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float16")]
#endif
-[JsonDerivedType(typeof(Embedding), typeDiscriminator: "floats")]
-[JsonDerivedType(typeof(Embedding), typeDiscriminator: "doubles")]
-[JsonDerivedType(typeof(Embedding), typeDiscriminator: "bytes")]
-[JsonDerivedType(typeof(Embedding), typeDiscriminator: "sbytes")]
+[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float32")]
+[JsonDerivedType(typeof(Embedding), typeDiscriminator: "float64")]
+[DebuggerDisplay("Dimensions = {Dimensions}")]
public class Embedding
{
/// Initializes a new instance of the class.
@@ -26,6 +29,13 @@ protected Embedding()
/// Gets or sets a timestamp at which the embedding was created.
public DateTimeOffset? CreatedAt { get; set; }
+ /// Gets the dimensionality of the embedding vector.
+ ///
+ /// This value corresponds to the number of elements in the embedding vector.
+ ///
+ [JsonIgnore]
+ public virtual int Dimensions { get; }
+
/// Gets or sets the model ID using in the creation of the embedding.
public string? ModelId { get; set; }
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs
index c80e20dfda4..22bc02f2f3f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI;
@@ -19,4 +20,8 @@ public Embedding(ReadOnlyMemory vector)
/// Gets or sets the embedding vector this embedding represents.
public ReadOnlyMemory Vector { get; set; }
+
+ ///
+ [JsonIgnore]
+ public override int Dimensions => Vector.Length;
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs
new file mode 100644
index 00000000000..c75d715466e
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/BinaryEmbeddingTests.cs
@@ -0,0 +1,95 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Linq;
+using System.Text.Json;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class BinaryEmbeddingTests
+{
+ [Fact]
+ public void Ctor_Roundtrips()
+ {
+ BitArray vector = new BitArray(new bool[] { false, true, false, true });
+
+ BinaryEmbedding e = new(vector);
+ Assert.Same(vector, e.Vector);
+ Assert.Null(e.ModelId);
+ Assert.Null(e.CreatedAt);
+ Assert.Null(e.AdditionalProperties);
+ }
+
+ [Fact]
+ public void Properties_Roundtrips()
+ {
+ BitArray vector = new BitArray(new bool[] { false, true, false, true });
+
+ BinaryEmbedding e = new(vector);
+
+ Assert.Same(vector, e.Vector);
+ BitArray newVector = new BitArray(new bool[] { true, false, true, false });
+ e.Vector = newVector;
+ Assert.Same(newVector, e.Vector);
+
+ Assert.Null(e.ModelId);
+ e.ModelId = "text-embedding-3-small";
+ Assert.Equal("text-embedding-3-small", e.ModelId);
+
+ Assert.Null(e.CreatedAt);
+ DateTimeOffset createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z");
+ e.CreatedAt = createdAt;
+ Assert.Equal(createdAt, e.CreatedAt);
+
+ Assert.Null(e.AdditionalProperties);
+ AdditionalPropertiesDictionary props = new();
+ e.AdditionalProperties = props;
+ Assert.Same(props, e.AdditionalProperties);
+ }
+
+ [Fact]
+ public void Serialization_Roundtrips()
+ {
+ foreach (int length in Enumerable.Range(0, 64).Concat(new[] { 10_000 }))
+ {
+ bool[] bools = new bool[length];
+ Random r = new(42);
+ for (int i = 0; i < length; i++)
+ {
+ bools[i] = r.Next(2) != 0;
+ }
+
+ BitArray vector = new BitArray(bools);
+ BinaryEmbedding e = new(vector);
+
+ string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
+ Assert.Equal($$"""{"$type":"binary","vector":"{{string.Concat(vector.Cast().Select(b => b ? '1' : '0'))}}"}""", json);
+
+ BinaryEmbedding result = Assert.IsType(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
+ Assert.Equal(e.Vector, result.Vector);
+ }
+ }
+
+ [Fact]
+ public void Derialization_SupportsEncodedBits()
+ {
+ BinaryEmbedding result = Assert.IsType(JsonSerializer.Deserialize(
+ """{"$type":"binary","vector":"\u0030\u0031\u0030\u0031\u0030\u0031"}""",
+ TestJsonSerializerContext.Default.Embedding));
+
+ Assert.Equal(new BitArray(new[] { false, true, false, true, false, true }), result.Vector);
+ }
+
+ [Theory]
+ [InlineData("""{"$type":"binary","vector":"\u0030\u0032"}""")]
+ [InlineData("""{"$type":"binary","vector":"02"}""")]
+ [InlineData("""{"$type":"binary","vector":" "}""")]
+ [InlineData("""{"$type":"binary","vector":10101}""")]
+ public void Derialization_InvalidBinaryEmbedding_Throws(string json)
+ {
+ Assert.Throws(() => JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs
index 45fcce8ba63..c3809782006 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs
@@ -14,7 +14,7 @@ public class EmbeddingTests
public void Embedding_Ctor_Roundtrips()
{
float[] floats = [1f, 2f, 3f];
- UsageDetails usage = new();
+
AdditionalPropertiesDictionary props = [];
var createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z");
const string Model = "text-embedding-3-small";
@@ -35,6 +35,32 @@ public void Embedding_Ctor_Roundtrips()
Assert.Same(floats, array.Array);
}
+ [Fact]
+ public void Embedding_Byte_SerializationRoundtrips()
+ {
+ byte[] bytes = [1, 2, 3];
+ Embedding e = new(bytes);
+
+ string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
+ Assert.Equal("""{"$type":"uint8","vector":"AQID"}""", json);
+
+ Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
+ Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
+ }
+
+ [Fact]
+ public void Embedding_SByte_SerializationRoundtrips()
+ {
+ sbyte[] bytes = [1, 2, 3];
+ Embedding e = new(bytes);
+
+ string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
+ Assert.Equal("""{"$type":"int8","vector":[1,2,3]}""", json);
+
+ Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
+ Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
+ }
+
#if NET
[Fact]
public void Embedding_Half_SerializationRoundtrips()
@@ -43,7 +69,7 @@ public void Embedding_Half_SerializationRoundtrips()
Embedding e = new(halfs);
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
- Assert.Equal("""{"$type":"halves","vector":[1,2,3]}""", json);
+ Assert.Equal("""{"$type":"float16","vector":[1,2,3]}""", json);
Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
@@ -57,7 +83,7 @@ public void Embedding_Single_SerializationRoundtrips()
Embedding e = new(floats);
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
- Assert.Equal("""{"$type":"floats","vector":[1,2,3]}""", json);
+ Assert.Equal("""{"$type":"float32","vector":[1,2,3]}""", json);
Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
@@ -70,7 +96,7 @@ public void Embedding_Double_SerializationRoundtrips()
Embedding e = new(floats);
string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding);
- Assert.Equal("""{"$type":"doubles","vector":[1,2,3]}""", json);
+ Assert.Equal("""{"$type":"float64","vector":[1,2,3]}""", json);
Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding));
Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray());
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs
deleted file mode 100644
index f538d1476b0..00000000000
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs
+++ /dev/null
@@ -1,16 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System;
-
-namespace Microsoft.Extensions.AI;
-
-internal sealed class BinaryEmbedding : Embedding
-{
- public BinaryEmbedding(ReadOnlyMemory bits)
- {
- Bits = bits;
- }
-
- public ReadOnlyMemory Bits { get; }
-}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
index 1188e899e4d..1504d0d2488 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
@@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+#if NET
+using System.Collections;
+#endif
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
@@ -148,7 +151,14 @@ public async Task Quantization_Binary_EmbeddingsCompareSuccessfully()
{
for (int j = 0; j < embeddings.Count; j++)
{
- distances[i, j] = TensorPrimitives.HammingBitDistance(embeddings[i].Bits.Span, embeddings[j].Bits.Span);
+ distances[i, j] = TensorPrimitives.HammingBitDistance(ToArray(embeddings[i].Vector), ToArray(embeddings[j].Vector));
+
+ static byte[] ToArray(BitArray array)
+ {
+ byte[] result = new byte[(array.Length + 7) / 8];
+ array.CopyTo(result, 0);
+ return result;
+ }
}
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs
index 3bf33988146..ea87408da38 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
#if NET
@@ -46,12 +47,12 @@ private static BinaryEmbedding QuantizeToBinary(Embedding embedding)
{
ReadOnlySpan vector = embedding.Vector.Span;
- var result = new byte[(int)Math.Ceiling(vector.Length / 8.0)];
+ var result = new BitArray(vector.Length);
for (int i = 0; i < vector.Length; i++)
{
if (vector[i] > 0)
{
- result[i / 8] |= (byte)(1 << (i % 8));
+ result[i / 8] = true;
}
}