Skip to content

Commit

Permalink
.Net: Support records with no key, fix issue with Relevance (#6393)
Browse files Browse the repository at this point in the history
### Motivation and Context

When using SemanticTextMemory.SaveReferenceAsync, no value is passed for
the `Key` to use. In the Cosmos DB Memory Connector, we treat the "Key"
as the Id and PartitionKey, and so it has to have a value. Address this
by generating a `GUID` to use if none are provided.

### Description

* Assign a GUID key when the passed in `MemoryRecord.Key` is null or
empty so that CosmosDB always has a valid PartitionKey.
* While adding test coverage for this I noticed two things:
1. `withEmbeddings` is ignored in
`AzureCosmosDBNoSQLMemoryStore.GetAsync`. It's not clear whether it's
worth changing this, as point reads like the current one tend to be less
expensive than queries in Cosmos DB.
2. The SimilarityScore returned from `GetNearestAsync` wasn't actually a
similarity, it was a cosine distance. Normalize that.

Fixes #6379.

### Contribution Checklist

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---------

Co-authored-by: Dmytro Struk <[email protected]>
  • Loading branch information
Pilchie and dmytrostruk committed Jun 12, 2024
1 parent ddf1d46 commit 4adf3fe
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
Expand All @@ -22,11 +23,62 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL;
/// </summary>
public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable
{
private const string EmbeddingPath = "/embedding";

private readonly CosmosClient _cosmosClient;
private readonly VectorEmbeddingPolicy _vectorEmbeddingPolicy;
private readonly IndexingPolicy _indexingPolicy;
private readonly string _databaseName;

/// <summary>
/// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string
/// and other properties required for vector search.
/// </summary>
/// <param name="connectionString">Connection string required to connect to Azure Cosmos DB.</param>
/// <param name="databaseName">The database name to connect to.</param>
/// <param name="dimensions">The number of dimensions the embedding vectors to be stored.</param>
/// <param name="vectorDataType">The data type of the embedding vectors to be stored.</param>
/// <param name="vectorIndexType">The type of index to use for the embedding vectors to be stored.</param>
/// <param name="applicationName">The application name to use in requests.</param>
public AzureCosmosDBNoSQLMemoryStore(
string connectionString,
string databaseName,
ulong dimensions,
VectorDataType vectorDataType,
VectorIndexType vectorIndexType,
string? applicationName = null)
: this(
new CosmosClient(
connectionString,
new CosmosClientOptions
{
ApplicationName = applicationName ?? HttpHeaderConstant.Values.UserAgent,
Serializer = new CosmosSystemTextJsonSerializer(JsonSerializerOptions.Default),
}),
databaseName,
new VectorEmbeddingPolicy(
[
new Embedding
{
DataType = vectorDataType,
Dimensions = dimensions,
DistanceFunction = DistanceFunction.Cosine,
Path = EmbeddingPath,
}
]),
new IndexingPolicy
{
VectorIndexes = new Collection<VectorIndexPath> {
new()
{
Path = EmbeddingPath,
Type = vectorIndexType,
},
},
})
{
}

/// <summary>
/// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string
/// and other properties required for vector search.
Expand Down Expand Up @@ -71,14 +123,29 @@ public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable
VectorEmbeddingPolicy vectorEmbeddingPolicy,
IndexingPolicy indexingPolicy)
{
if (!vectorEmbeddingPolicy.Embeddings.Any(e => e.Path == "/embedding"))
var embedding = vectorEmbeddingPolicy.Embeddings.FirstOrDefault(e => e.Path == EmbeddingPath);
if (embedding is null)
{
throw new InvalidOperationException($"""
In order for {nameof(GetNearestMatchAsync)} to function, {nameof(vectorEmbeddingPolicy)} should
contain an embedding path at /embedding. It's also recommended to include a that path in the
contain an embedding path at {EmbeddingPath}. It's also recommended to include that path in the
{nameof(indexingPolicy)} to improve performance and reduce cost for searches.
""");
}
else if (embedding.DistanceFunction != DistanceFunction.Cosine)
{
throw new InvalidOperationException($"""
In order for {nameof(GetNearestMatchAsync)} to reliably return relevance information, the {nameof(DistanceFunction)} should
be specified as {nameof(DistanceFunction)}.{nameof(DistanceFunction.Cosine)}.
""");
}
else if (embedding.DataType != VectorDataType.Float16 && embedding.DataType != VectorDataType.Float32)
{
throw new NotSupportedException($"""
Only {nameof(VectorDataType)}.{nameof(VectorDataType.Float16)} and {nameof(VectorDataType)}.{nameof(VectorDataType.Float32)}
are supported.
""");
}
this._cosmosClient = cosmosClient;
this._databaseName = databaseName;
this._vectorEmbeddingPolicy = vectorEmbeddingPolicy;
Expand Down Expand Up @@ -164,6 +231,12 @@ await this._cosmosClient
MemoryRecord record,
CancellationToken cancellationToken = default)
{
// In some cases we're expected to generate the key to use. Do so if one isn't provided.
if (string.IsNullOrEmpty(record.Key))
{
record.Key = Guid.NewGuid().ToString();
}

var result = await this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
Expand Down Expand Up @@ -193,6 +266,7 @@ await this._cosmosClient
bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
// TODO: Consider using a query when `withEmbedding` is false to avoid passing it over the wire.
var result = await this._cosmosClient
.GetDatabase(this._databaseName)
.GetContainer(collectionName)
Expand Down Expand Up @@ -330,9 +404,10 @@ FROM x
{
foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false))
{
if (memoryRecord.SimilarityScore >= minRelevanceScore)
var relevanceScore = (memoryRecord.SimilarityScore + 1) / 2;
if (relevanceScore >= minRelevanceScore)
{
yield return (memoryRecord, memoryRecord.SimilarityScore);
yield return (memoryRecord, relevanceScore);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;
using MongoDB.Driver;
using Xunit;
Expand Down Expand Up @@ -117,6 +122,54 @@ public async Task ItCanGetNearestMatchesAsync(int limit, bool withEmbeddings)
await memoryStore.DeleteCollectionAsync(collectionName);
}

[Theory(Skip = SkipReason)]
[InlineData(true)]
[InlineData(false)]
public async Task ItCanSaveReferenceGetTextAndSearchTextAsync(bool withEmbedding)
{
var collectionName = this._fixture.CollectionName;
var memoryStore = this._fixture.MemoryStore;
var textMemory = new SemanticTextMemory(memoryStore, new MockTextEmbeddingGenerationService());
var textToStore = "SampleText";
var id = "MyExternalId";
var source = "MyExternalSource";
var refId = await textMemory.SaveReferenceAsync(collectionName, textToStore, id, source);
Assert.NotNull(refId);

var expectedQueryResult = new MemoryQueryResult(
new MemoryRecordMetadata(isReference: true, id, text: "", description: "", source, additionalMetadata: ""),
1.0,
withEmbedding ? DataHelper.VectorSearchTestEmbedding : null);

var queryResult = await textMemory.GetAsync(collectionName, refId, withEmbedding);
AssertQueryResultEqual(expectedQueryResult, queryResult, withEmbedding);

var searchResults = await textMemory.SearchAsync(collectionName, textToStore, withEmbeddings: withEmbedding).ToListAsync();
Assert.Equal(1, searchResults?.Count);
AssertQueryResultEqual(expectedQueryResult, searchResults?[0], compareEmbeddings: true);

await textMemory.RemoveAsync(collectionName, refId);
}

private static void AssertQueryResultEqual(MemoryQueryResult expected, MemoryQueryResult? actual, bool compareEmbeddings)
{
Assert.NotNull(actual);
Assert.Equal(expected.Relevance, actual.Relevance);
Assert.Equal(expected.Metadata.Id, actual.Metadata.Id);
Assert.Equal(expected.Metadata.Text, actual.Metadata.Text);
Assert.Equal(expected.Metadata.Description, actual.Metadata.Description);
Assert.Equal(expected.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName);
Assert.Equal(expected.Metadata.AdditionalMetadata, actual.Metadata.AdditionalMetadata);
Assert.Equal(expected.Metadata.IsReference, actual.Metadata.IsReference);

if (compareEmbeddings)
{
Assert.NotNull(expected.Embedding);
Assert.NotNull(actual.Embedding);
Assert.Equal(expected.Embedding.Value.Span, actual.Embedding.Value.Span);
}
}

private static void AssertMemoryRecordEqual(
MemoryRecord expectedRecord,
MemoryRecord actualRecord,
Expand Down Expand Up @@ -147,4 +200,15 @@ public async Task ItCanGetNearestMatchesAsync(int limit, bool withEmbeddings)
Assert.True(actualRecord.Embedding.Span.IsEmpty);
}
}

private sealed class MockTextEmbeddingGenerationService : ITextEmbeddingGenerationService
{
public IReadOnlyDictionary<string, object?> Attributes { get; } = ReadOnlyDictionary<string, object?>.Empty;

public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
IList<ReadOnlyMemory<float>> result = new List<ReadOnlyMemory<float>> { DataHelper.VectorSearchTestEmbedding };
return Task.FromResult(result);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.ObjectModel;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos;
using Microsoft.Extensions.Configuration;
Expand Down Expand Up @@ -35,28 +34,9 @@ public AzureCosmosDBNoSQLMemoryStoreTestsFixture()
this.MemoryStore = new AzureCosmosDBNoSQLMemoryStore(
connectionString,
this.DatabaseName,
new VectorEmbeddingPolicy(
new Collection<Embedding>
{
new()
{
DataType = VectorDataType.Float32,
Dimensions = 3,
DistanceFunction = DistanceFunction.Cosine,
Path = "/embedding"
}
}),
new()
{
VectorIndexes = new Collection<VectorIndexPath> {
new()
{
Path = "/embedding",
Type = VectorIndexType.Flat,
},
},
}
);
dimensions: 3,
VectorDataType.Float32,
VectorIndexType.Flat);
}

public Task InitializeAsync()
Expand Down

0 comments on commit 4adf3fe

Please sign in to comment.