Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion Microsoft.Azure.Cosmos/src/Resource/Settings/Embedding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ public class Embedding : IEquatable<Embedding>
[JsonConverter(typeof(StringEnumConverter))]
public DistanceFunction DistanceFunction { get; set; }

/// <summary>
/// Gets or sets the optional <see cref="Cosmos.EmbeddingSource"/> describing the source
/// document paths and embedding service that the Cosmos DB service should use to
/// generate the vector value for this embedding.
/// </summary>
[JsonProperty(PropertyName = "embeddingSource", NullValueHandling = NullValueHandling.Ignore)]
#if PREVIEW
public
#else
internal
#endif
EmbeddingSource EmbeddingSource { get; set; }

/// <summary>
/// This contains additional values for scenarios where the SDK is not aware of new fields.
/// This ensures that if resource is read and updated none of the fields will be lost in the process.
Expand All @@ -68,10 +81,16 @@ public void ValidateEmbeddingPath()
/// <inheritdoc/>
public bool Equals(Embedding that)
{
if (that is null)
{
return false;
}

return this.Path.Equals(that.Path)
&& this.DataType.Equals(that.DataType)
&& this.Dimensions == that.Dimensions
&& this.Dimensions.Equals(that.Dimensions);
&& this.DistanceFunction.Equals(that.DistanceFunction)
&& object.Equals(this.EmbeddingSource, that.EmbeddingSource);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
using System.Runtime.Serialization;

/// <summary>
/// Defines the authentication type used by the Azure Cosmos DB service to call the
/// embedding service referenced from a <see cref="EmbeddingSource"/>.
/// </summary>
#if PREVIEW
public
#else
internal
#endif
enum EmbeddingAuthType
{
/// <summary>
/// Default sentinel — indicates no authentication type has been configured.
/// </summary>
[EnumMember(Value = "Unknown")]
Unknown = 0,

/// <summary>
/// Authenticate to the embedding service using Microsoft Entra ID (managed identity / token credential).
/// </summary>
[EnumMember(Value = "Entra")]
Entra,
Comment thread
ananth7592 marked this conversation as resolved.

/// <summary>
/// Authenticate to the embedding service using an API key.
/// </summary>
[EnumMember(Value = "ApiKey")]
ApiKey,
}
}
118 changes: 118 additions & 0 deletions Microsoft.Azure.Cosmos/src/Resource/Settings/EmbeddingSource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;

/// <summary>
/// Describes the source document paths and the embedding service that the Azure Cosmos DB
/// service should use to generate the vector value for an <see cref="Embedding"/>.
/// </summary>
/// <remarks>
/// When present on an <see cref="Embedding"/>, this block tells the SDK (and the Cosmos
/// DB embedding provider) where and how to call the embedding service for the vector
/// path in question.
/// </remarks>
#if PREVIEW
public
#else
internal
#endif
sealed class EmbeddingSource : IEquatable<EmbeddingSource>
{
/// <summary>
/// Gets or sets the list of document paths whose values are concatenated and sent to
/// the embedding service to generate the vector.
/// </summary>
[JsonProperty(PropertyName = "sourcePaths")]
public Collection<string> SourcePaths { get; set; }

/// <summary>
/// Gets or sets the deployment name of the embedding model on the embedding service.
/// </summary>
[JsonProperty(PropertyName = "deploymentName")]
public string DeploymentName { get; set; }

/// <summary>
/// Gets or sets the name of the embedding model.
/// </summary>
[JsonProperty(PropertyName = "modelName")]
public string ModelName { get; set; }

/// <summary>
/// Gets or sets the endpoint of the embedding service.
/// </summary>
[JsonProperty(PropertyName = "endpoint")]
public string Endpoint { get; set; }

/// <summary>
/// Gets or sets the <see cref="Cosmos.EmbeddingAuthType"/> used to authenticate to the
/// embedding service.
/// </summary>
[JsonProperty(PropertyName = "authType")]
[JsonConverter(typeof(StringEnumConverter))]
public EmbeddingAuthType AuthType { get; set; }

/// <summary>
/// This contains additional values for scenarios where the SDK is not aware of new fields.
/// This ensures that if resource is read and updated none of the fields will be lost in the process.
/// </summary>
[JsonExtensionData]
internal IDictionary<string, JToken> AdditionalProperties { get; private set; }

/// <inheritdoc/>
public bool Equals(EmbeddingSource that)
{
if (that is null)
{
return false;
}

if (object.ReferenceEquals(this, that))
{
return true;
}

return ((this.SourcePaths == null && that.SourcePaths == null) ||
(this.SourcePaths != null && that.SourcePaths != null && Enumerable.SequenceEqual(this.SourcePaths, that.SourcePaths)))
&& this.AuthType == that.AuthType
&& this.DeploymentName == that.DeploymentName
&& this.Endpoint == that.Endpoint
&& this.ModelName == that.ModelName;
}

/// <inheritdoc/>
public override bool Equals(object obj)
{
return this.Equals(obj as EmbeddingSource);
}

/// <inheritdoc/>
public override int GetHashCode()
{
int hashCode = 1265339359;

if (this.SourcePaths != null)
{
foreach (string sourcePath in this.SourcePaths)
{
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(sourcePath);
}
}

hashCode = (hashCode * -1521134295) + this.AuthType.GetHashCode();
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(this.DeploymentName);
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(this.Endpoint);
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(this.ModelName);
return hashCode;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
[TestClass]
public class ContainerSettingsTests : BaseCosmosClientHelper
{
private static long ToEpoch(DateTime dateTime) => (long)(dateTime - new DateTime(1970, 1, 1)).TotalSeconds;

private static long ToEpoch(DateTime dateTime)
{
return (long)(dateTime - new DateTime(1970, 1, 1)).TotalSeconds;
}

[TestInitialize]
public async Task TestInitialize()
{
Expand Down Expand Up @@ -629,6 +632,105 @@ await databaseForVectorEmbedding.DefineContainer(containerName, partitionKeyPath
}
}

[TestMethod]
[Ignore("Requires a real Cosmos DB account with vector-embedding preview enabled. Fill in the endpoint and key before running.")]
public async Task TestVectorEmbeddingPolicyWithEmbeddingSource()
{
const string accountEndpoint = "";
const string accountKey = "";

const string databaseId = "embeddingSourceIntegrationDb";
const string containerId = "embeddingSourceIntegrationContainer";
const string partitionKeyPath = "/pk";
const string embeddingPath = "/embedding";

CosmosClient client = new CosmosClient(accountEndpoint, accountKey);
Database database = await client.CreateDatabaseIfNotExistsAsync(databaseId);

try
{
EmbeddingSource embeddingSource = new EmbeddingSource()
{
SourcePaths = new Collection<string>
{
"/journal_title",
"/title",
"/toc_abstract",
"/abstract",
"/full_text",
},
DeploymentName = "text-embedding-3-small",
ModelName = "text-embedding-3-small",
Endpoint = "https://embedding-south-central.cognitiveservices.azure.com/",
AuthType = EmbeddingAuthType.ApiKey,
};

Collection<Embedding> embeddings = new Collection<Embedding>()
{
new Embedding()
{
Path = embeddingPath,
DataType = VectorDataType.Float32,
DistanceFunction = DistanceFunction.Cosine,
Dimensions = 1536,
EmbeddingSource = embeddingSource,
},
};

try
{
await database.GetContainer(containerId).DeleteContainerAsync();
}
catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.NotFound)
{
}

ContainerResponse containerResponse =
await database.DefineContainer(containerId, partitionKeyPath)
.WithVectorEmbeddingPolicy(embeddings)
.Attach()
.CreateAsync();

Assert.AreEqual(HttpStatusCode.Created, containerResponse.StatusCode);
Assert.AreEqual(containerId, containerResponse.Resource.Id);
Assert.AreEqual(partitionKeyPath, containerResponse.Resource.PartitionKey.Paths.First());

this.AssertEmbeddingSourceRoundTrip(containerResponse.Resource.VectorEmbeddingPolicy, embeddingPath, embeddingSource);

ContainerResponse readResponse = await containerResponse.Container.ReadContainerAsync();
Assert.AreEqual(HttpStatusCode.OK, readResponse.StatusCode);
Assert.AreEqual(containerId, readResponse.Resource.Id);
Assert.AreEqual(partitionKeyPath, readResponse.Resource.PartitionKey.Paths.First());

this.AssertEmbeddingSourceRoundTrip(readResponse.Resource.VectorEmbeddingPolicy, embeddingPath, embeddingSource);
}
finally
{
await database.DeleteAsync();
client.Dispose();
}
}

private void AssertEmbeddingSourceRoundTrip(VectorEmbeddingPolicy policy, string expectedEmbeddingPath, EmbeddingSource expected)
{
Assert.IsNotNull(policy);
Assert.AreEqual(1, policy.Embeddings.Count());

Embedding readEmbedding = policy.Embeddings.Single();
Assert.AreEqual(expectedEmbeddingPath, readEmbedding.Path);
Assert.AreEqual(VectorDataType.Float32, readEmbedding.DataType);
Assert.AreEqual(DistanceFunction.Cosine, readEmbedding.DistanceFunction);
Assert.AreEqual(1536, readEmbedding.Dimensions);

EmbeddingSource readSource = readEmbedding.EmbeddingSource;
Assert.IsNotNull(readSource, "EmbeddingSource should be returned by the server.");
CollectionAssert.AreEqual(expected.SourcePaths.ToArray(), readSource.SourcePaths.ToArray());
Assert.AreEqual(expected.DeploymentName, readSource.DeploymentName);
Assert.AreEqual(expected.ModelName, readSource.ModelName);
Assert.AreEqual(expected.Endpoint, readSource.Endpoint);
Assert.AreEqual(expected.AuthType, readSource.AuthType);
}

[TestMethod]
public async Task WithIndexingPolicy()
{
Expand Down
Loading