diff --git a/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs b/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs index f1ea702cffae..af9a2b7ecffa 100644 --- a/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs +++ b/dotnet/samples/KernelSyntaxExamples/Example39_Postgres.cs @@ -5,22 +5,29 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Memory.Postgres; using Microsoft.SemanticKernel.Memory; +using Npgsql; +using Pgvector.Npgsql; using RepoUtils; // ReSharper disable once InconsistentNaming public static class Example39_Postgres { - private const string MemoryCollectionName = "postgres-test"; + private const string MemoryCollectionName = "postgres_test"; public static async Task RunAsync() { - string connectionString = Env.Var("POSTGRES_CONNECTIONSTRING"); - using PostgresMemoryStore memoryStore = await PostgresMemoryStore.ConnectAsync(connectionString, vectorSize: 1536); + NpgsqlDataSourceBuilder dataSourceBuilder = new(Env.Var("POSTGRES_CONNECTIONSTRING")); + dataSourceBuilder.UseVector(); + using NpgsqlDataSource dataSource = dataSourceBuilder.Build(); + + PostgresMemoryStore memoryStore = new(dataSource, vectorSize: 1536, schema: "public", numberOfLists: 100); + IKernel kernel = Kernel.Builder .WithLogger(ConsoleLogger.Log) .WithOpenAITextCompletionService("text-davinci-003", Env.Var("OPENAI_API_KEY")) .WithOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY")) .WithMemoryStorage(memoryStore) + //.WithPostgresMemoryStore(dataSource, vectorSize: 1536, schema: "public") // This method offers an alternative approach to registering Postgres memory store. .Build(); Console.WriteLine("== Printing Collections in DB =="); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs new file mode 100644 index 000000000000..59d69fbb0da5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; + +/// +/// Interface for client managing postgres database operations. +/// +public interface IPostgresDbClient +{ + /// + /// Check if a collection exists. + /// + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task DoesCollectionExistsAsync(string collectionName, CancellationToken cancellationToken = default); + + /// + /// Create a collection. + /// + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + + /// + /// Get all collections. + /// + /// The to monitor for cancellation requests. The default is . + /// + IAsyncEnumerable GetCollectionsAsync(CancellationToken cancellationToken = default); + + /// + /// Delete a collection. + /// + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + + /// + /// Upsert entry into a collection. + /// + /// The name assigned to a collection of entries. + /// The key of the entry to upsert. + /// The metadata of the entry. + /// The embedding of the entry. + /// The timestamp of the entry + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertAsync(string collectionName, string key, string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default); + + /// + /// Gets the nearest matches to the . + /// + /// The name assigned to a collection of entries. + /// The to compare the collection's embeddings with. + /// The maximum number of similarity results to return. + /// The minimum relevance threshold for returned results. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync(string collectionName, Vector embeddingFilter, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + /// + /// Read a entry by its key. + /// + /// The name assigned to a collection of entries. + /// The key of the entry to read. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + Task ReadAsync(string collectionName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + /// + /// Delete a entry by its key. + /// + /// The name assigned to a collection of entries. + /// The key of the entry to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteAsync(string collectionName, string key, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs similarity index 52% rename from dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs rename to dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index 59d0d30ca5c5..e957f01f7668 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -12,119 +11,114 @@ namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; /// -/// A postgres memory entry. +/// An implementation of a client for Postgres. This class is used to managing postgres database operations. /// -internal struct DatabaseEntry +[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] +public class PostgresDbClient : IPostgresDbClient { /// - /// Unique identifier of the memory entry. + /// Initializes a new instance of the class. /// - public string Key { get; set; } - - /// - /// Metadata as a string. - /// - public string MetadataString { get; set; } - - /// - /// The embedding data as a . - /// - public Vector? Embedding { get; set; } - - /// - /// Optional timestamp. - /// - public long? Timestamp { get; set; } -} - -/// -/// The class for managing postgres database operations. -/// -internal sealed class Database -{ - private const string TableName = "sk_memory_table"; + /// Postgres data source. + /// Schema of collection tables. + /// Embedding vector size. + /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info + public PostgresDbClient(NpgsqlDataSource dataSource, string schema, int vectorSize, int numberOfLists) + { + this._dataSource = dataSource; + this._schema = schema; + this._vectorSize = vectorSize; + this._numberOfLists = numberOfLists; + } /// - /// Create pgvector extensions. + /// Check if a collection exists. /// - /// An opened instance. + /// The name assigned to a collection of entries. /// The to monitor for cancellation requests. The default is . /// - public async Task CreatePgVectorExtensionAsync(NpgsqlConnection conn, CancellationToken cancellationToken = default) + public async Task DoesCollectionExistsAsync( + string collectionName, + CancellationToken cancellationToken = default) { - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = "CREATE EXTENSION IF NOT EXISTS vector"; - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - await conn.ReloadTypesAsync().ConfigureAwait(false); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $@" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_type = 'BASE TABLE' + AND table_name = '{collectionName}'"; + cmd.Parameters.AddWithValue("@schema", this._schema); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return dataReader.GetString(dataReader.GetOrdinal("table_name")) == collectionName; + } + + return false; } /// - /// Create memory table. + /// Create a collection. /// - /// An opened instance. - /// Vector size of embedding column + /// The name assigned to a collection of entries. /// The to monitor for cancellation requests. The default is . /// - public async Task CreateTableAsync(NpgsqlConnection conn, int vectorSize, CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) { - await this.CreatePgVectorExtensionAsync(conn, cancellationToken).ConfigureAwait(false); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - using NpgsqlCommand cmd = conn.CreateCommand(); -#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities - cmd.CommandText = $@" - CREATE TABLE IF NOT EXISTS {TableName} ( - collection TEXT, - key TEXT, - metadata TEXT, - embedding vector({vectorSize}), - timestamp BIGINT, - PRIMARY KEY(collection, key))"; -#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await this.CreateTableAsync(connection, collectionName, cancellationToken).ConfigureAwait(false); + + await this.CreateIndexAsync(connection, collectionName, cancellationToken).ConfigureAwait(false); } /// - /// Create index for memory table. + /// Get all collections. /// - /// An opened instance. /// The to monitor for cancellation requests. The default is . /// - public async Task CreateIndexAsync(NpgsqlConnection conn, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = $@" - CREATE INDEX IF NOT EXISTS {TableName}_ivfflat_embedding_vector_cosine_ops_idx - ON {TableName} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 1000)"; - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = @schema + AND table_type = 'BASE TABLE'"; + cmd.Parameters.AddWithValue("@schema", this._schema); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + } } /// - /// Create a collection. + /// Delete a collection. /// - /// An opened instance. /// The name assigned to a collection of entries. /// The to monitor for cancellation requests. The default is . /// - public async Task CreateCollectionAsync(NpgsqlConnection conn, string collectionName, CancellationToken cancellationToken = default) + public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) { - if (await this.DoesCollectionExistsAsync(conn, collectionName, cancellationToken).ConfigureAwait(false)) - { - // Collection already exists - return; - } + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"DROP TABLE IF EXISTS {this.GetTableName(collectionName)}"; - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = $@" - INSERT INTO {TableName} (collection, key) - VALUES(@collection, '')"; - cmd.Parameters.AddWithValue("@collection", collectionName); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// /// Upsert entry into a collection. /// - /// An opened instance. /// The name assigned to a collection of entries. /// The key of the entry to upsert. /// The metadata of the entry. @@ -132,63 +126,28 @@ public async Task CreateCollectionAsync(NpgsqlConnection conn, string collection /// The timestamp of the entry /// The to monitor for cancellation requests. The default is . /// - public async Task UpsertAsync(NpgsqlConnection conn, - string collectionName, string key, string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default) + public async Task UpsertAsync(string collectionName, string key, + string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default) { - using NpgsqlCommand cmd = conn.CreateCommand(); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); cmd.CommandText = $@" - INSERT INTO {TableName} (collection, key, metadata, embedding, timestamp) - VALUES(@collection, @key, @metadata, @embedding, @timestamp) - ON CONFLICT (collection, key) + INSERT INTO {this.GetTableName(collectionName)} (key, metadata, embedding, timestamp) + VALUES(@key, @metadata, @embedding, @timestamp) + ON CONFLICT (key) DO UPDATE SET metadata=@metadata, embedding=@embedding, timestamp=@timestamp"; - cmd.Parameters.AddWithValue("@collection", collectionName); cmd.Parameters.AddWithValue("@key", key); - cmd.Parameters.AddWithValue("@metadata", metadata ?? string.Empty); + cmd.Parameters.AddWithValue("@metadata", NpgsqlTypes.NpgsqlDbType.Jsonb, metadata ?? (object)DBNull.Value); cmd.Parameters.AddWithValue("@embedding", embedding ?? (object)DBNull.Value); cmd.Parameters.AddWithValue("@timestamp", timestamp ?? (object)DBNull.Value); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } - - /// - /// Check if a collection exists. - /// - /// An opened instance. - /// The name assigned to a collection of entries. - /// The to monitor for cancellation requests. The default is . - /// - public async Task DoesCollectionExistsAsync(NpgsqlConnection conn, - string collectionName, - CancellationToken cancellationToken = default) - { - var collections = await this.GetCollectionsAsync(conn, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); - return collections.Contains(collectionName); - } - - /// - /// Get all collections. - /// - /// An opened instance. - /// The to monitor for cancellation requests. The default is . - /// - public async IAsyncEnumerable GetCollectionsAsync(NpgsqlConnection conn, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = $@" - SELECT DISTINCT(collection) - FROM {TableName}"; - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - yield return dataReader.GetString(dataReader.GetOrdinal("collection")); - } + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// /// Gets the nearest matches to the . /// - /// An opened instance. /// The name assigned to a collection of entries. /// The to compare the collection's embeddings with. /// The maximum number of similarity results to return. @@ -196,20 +155,21 @@ SELECT DISTINCT(collection) /// If true, the embeddings will be returned in the entries. /// The to monitor for cancellation requests. The default is . /// - public async IAsyncEnumerable<(DatabaseEntry, double)> GetNearestMatchesAsync(NpgsqlConnection conn, + public async IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync( string collectionName, Vector embeddingFilter, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var queryColumns = "collection, key, metadata, timestamp"; + var queryColumns = "key, metadata, timestamp"; if (withEmbeddings) { queryColumns = "*"; } - using NpgsqlCommand cmd = conn.CreateCommand(); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); cmd.CommandText = @$" - SELECT * FROM (SELECT {queryColumns}, 1 - (embedding <=> @embedding) AS cosine_similarity FROM {TableName} - WHERE collection = @collection + SELECT * FROM (SELECT {queryColumns}, 1 - (embedding <=> @embedding) AS cosine_similarity FROM {this.GetTableName(collectionName)} ) AS sk_memory_cosine_similarity_table WHERE cosine_similarity >= @min_relevance_score ORDER BY cosine_similarity DESC @@ -229,120 +189,122 @@ SELECT DISTINCT(collection) } /// - /// Read all entries from a collection + /// Read a entry by its key. /// - /// An opened instance. /// The name assigned to a collection of entries. + /// The key of the entry to read. /// If true, the embeddings will be returned in the entries. /// The to monitor for cancellation requests. The default is . /// - public async IAsyncEnumerable ReadAllAsync(NpgsqlConnection conn, - string collectionName, bool withEmbeddings = false, - [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async Task ReadAsync(string collectionName, string key, + bool withEmbeddings = false, CancellationToken cancellationToken = default) { - var queryColumns = "collection, key, metadata, timestamp"; + var queryColumns = "key, metadata, timestamp"; if (withEmbeddings) { queryColumns = "*"; } - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = $@" - SELECT {queryColumns} FROM {TableName} - WHERE collection=@collection"; - cmd.Parameters.AddWithValue("@collection", collectionName); + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"SELECT {queryColumns} FROM {this.GetTableName(collectionName)} WHERE key=@key"; + cmd.Parameters.AddWithValue("@key", key); using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); } + + return null; } /// - /// Read a entry by its key. + /// Delete a entry by its key. /// - /// An opened instance. /// The name assigned to a collection of entries. - /// The key of the entry to read. - /// If true, the embeddings will be returned in the entries. + /// The key of the entry to delete. /// The to monitor for cancellation requests. The default is . /// - public async Task ReadAsync(NpgsqlConnection conn, - string collectionName, string key, bool withEmbeddings = false, - CancellationToken cancellationToken = default) + public async Task DeleteAsync(string collectionName, string key, CancellationToken cancellationToken = default) { - var queryColumns = "collection, key, metadata, timestamp"; - if (withEmbeddings) - { - queryColumns = "*"; - } + using NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - using NpgsqlCommand cmd = conn.CreateCommand(); - cmd.CommandText = $@" - SELECT {queryColumns} FROM {TableName} - WHERE collection=@collection AND key=@key"; - cmd.Parameters.AddWithValue("@collection", collectionName); + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = $"DELETE FROM {this.GetTableName(collectionName)} WHERE key=@key"; cmd.Parameters.AddWithValue("@key", key); - using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); - if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); - } + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } - return null; + #region private ================================================================================ + + private readonly NpgsqlDataSource _dataSource; + private readonly int _vectorSize; + private readonly string _schema; + private readonly int _numberOfLists; + + /// + /// Read a entry. + /// + /// The to read. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + private async Task ReadEntryAsync(NpgsqlDataReader dataReader, bool withEmbeddings = false, CancellationToken cancellationToken = default) + { + string key = dataReader.GetString(dataReader.GetOrdinal("key")); + string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata")); + Vector? embedding = withEmbeddings ? await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("embedding"), cancellationToken).ConfigureAwait(false) : null; + long? timestamp = await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("timestamp"), cancellationToken).ConfigureAwait(false); + return new PostgresMemoryEntry() { Key = key, MetadataString = metadata, Embedding = embedding, Timestamp = timestamp }; } /// - /// Delete a collection. + /// Create a collection as table. /// - /// An opened instance. + /// An opened instance. /// The name assigned to a collection of entries. /// The to monitor for cancellation requests. The default is . /// - public Task DeleteCollectionAsync(NpgsqlConnection conn, string collectionName, CancellationToken cancellationToken = default) + private async Task CreateTableAsync(NpgsqlConnection connection, string collectionName, CancellationToken cancellationToken = default) { - using NpgsqlCommand cmd = conn.CreateCommand(); + using NpgsqlCommand cmd = connection.CreateCommand(); cmd.CommandText = $@" - DELETE FROM {TableName} - WHERE collection=@collection"; - cmd.Parameters.AddWithValue("@collection", collectionName); - return cmd.ExecuteNonQueryAsync(cancellationToken); + CREATE TABLE IF NOT EXISTS {this.GetTableName(collectionName)} ( + key TEXT NOT NULL, + metadata JSONB, + embedding vector({this._vectorSize}), + timestamp BIGINT, + PRIMARY KEY (key))"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// - /// Delete a entry by its key. + /// Create index for collection table. /// - /// An opened instance. + /// An opened instance. /// The name assigned to a collection of entries. - /// The key of the entry to delete. /// The to monitor for cancellation requests. The default is . /// - public Task DeleteAsync(NpgsqlConnection conn, string collectionName, string key, CancellationToken cancellationToken = default) + private async Task CreateIndexAsync(NpgsqlConnection connection, string collectionName, CancellationToken cancellationToken = default) { - using NpgsqlCommand cmd = conn.CreateCommand(); + using NpgsqlCommand cmd = connection.CreateCommand(); cmd.CommandText = $@" - DELETE FROM {TableName} - WHERE collection=@collection AND key=@key "; - cmd.Parameters.AddWithValue("@collection", collectionName); - cmd.Parameters.AddWithValue("@key", key); - return cmd.ExecuteNonQueryAsync(cancellationToken); + CREATE INDEX IF NOT EXISTS {collectionName}_ix + ON {this.GetTableName(collectionName)} USING ivfflat (embedding vector_cosine_ops) WITH (lists = {this._numberOfLists})"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// - /// Read a entry. + /// Get table name from collection name. /// - /// The to read. - /// If true, the embeddings will be returned in the entries. - /// The to monitor for cancellation requests. The default is . + /// /// - private async Task ReadEntryAsync(NpgsqlDataReader dataReader, bool withEmbeddings = false, CancellationToken cancellationToken = default) + private string GetTableName(string collectionName) { - string key = dataReader.GetString(dataReader.GetOrdinal("key")); - string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata")); - Vector? embedding = withEmbeddings ? await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("embedding"), cancellationToken).ConfigureAwait(false) : null; - long? timestamp = await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("timestamp"), cancellationToken).ConfigureAwait(false); - return new DatabaseEntry() { Key = key, MetadataString = metadata, Embedding = embedding, Timestamp = timestamp }; + return $"{this._schema}.\"{collectionName}\""; } + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs new file mode 100644 index 000000000000..440dfcd9dcf8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Connectors.Memory.Postgres; +using Npgsql; + +#pragma warning disable IDE0130 +namespace Microsoft.SemanticKernel; +#pragma warning restore IDE0130 + +/// +/// Provides extension methods for the class to configure Postgres connectors. +/// +public static class PostgresKernelBuilderExtensions +{ + /// + /// Registers Postgres Memory Store. + /// + /// The instance + /// Postgres data source. + /// Embedding vector size. + /// Schema of collection tables. + /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info + /// Self instance + public static KernelBuilder WithPostgresMemoryStore(this KernelBuilder builder, + NpgsqlDataSource dataSource, + int vectorSize, + string schema = PostgresMemoryStore.DefaultSchema, + int numberOfLists = PostgresMemoryStore.DefaultNumberOfLists) + { + builder.WithMemoryStorage((parameters) => + { + return new PostgresMemoryStore(dataSource, vectorSize, schema, numberOfLists); + }); + + return builder; + } + + /// + /// Registers Postgres Memory Store. + /// + /// The instance + /// Postgres database client. + /// Self instance + public static KernelBuilder WithPostgresMemoryStore(this KernelBuilder builder, IPostgresDbClient postgresDbClient) + { + builder.WithMemoryStorage((parameters) => + { + return new PostgresMemoryStore(postgresDbClient); + }); + + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs new file mode 100644 index 000000000000..29a4b8f31f90 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryEntry.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; + +/// +/// A postgres memory entry. +/// +public record struct PostgresMemoryEntry +{ + /// + /// Unique identifier of the memory entry. + /// + public string Key { get; set; } + + /// + /// Metadata as a string. + /// + public string MetadataString { get; set; } + + /// + /// The embedding data as a . + /// + public Vector? Embedding { get; set; } + + /// + /// Optional timestamp. + /// + public long? Timestamp { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs index 1846c946cf95..269624651ec5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs @@ -7,58 +7,61 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Memory; using Npgsql; using Pgvector; -using Pgvector.Npgsql; namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; /// /// An implementation of backed by a Postgres database with pgvector extension. /// -public class PostgresMemoryStore : IMemoryStore, IDisposable +/// The embedded data is saved to the Postgres database specified in the constructor. +/// Similarity search capability is provided through the pgvector extension. Use Postgres's "Table" to implement "Collection". +/// +public class PostgresMemoryStore : IMemoryStore { + internal const string DefaultSchema = "public"; + internal const int DefaultNumberOfLists = 1000; + /// - /// Connect a Postgres database + /// Initializes a new instance of the class. /// - /// Database connection string. If table does not exist, it will be created. - /// Embedding vector size - /// The to monitor for cancellation requests. The default is . - public static async Task ConnectAsync(string connectionString, int vectorSize, - CancellationToken cancellationToken = default) + /// Postgres data source. + /// Embedding vector size. + /// Database schema of collection tables. The default value is "public". + /// Specifies the number of lists for indexing. Higher values can improve recall but may impact performance. The default value is 1000. More info + public PostgresMemoryStore(NpgsqlDataSource dataSource, int vectorSize, string schema = DefaultSchema, int numberOfLists = DefaultNumberOfLists) + : this(new PostgresDbClient(dataSource, schema, vectorSize, numberOfLists)) + { + } + + public PostgresMemoryStore(IPostgresDbClient postgresDbClient) { - var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); - // Use pgvector - dataSourceBuilder.UseVector(); - - var memoryStore = new PostgresMemoryStore(dataSourceBuilder.Build()); - using NpgsqlConnection dbConnection = await memoryStore._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await memoryStore._dbConnector.CreatePgVectorExtensionAsync(dbConnection, cancellationToken).ConfigureAwait(false); - await memoryStore._dbConnector.CreateTableAsync(dbConnection, vectorSize, cancellationToken).ConfigureAwait(false); - await memoryStore._dbConnector.CreateIndexAsync(dbConnection, cancellationToken).ConfigureAwait(false); - return memoryStore; + this._postgresDbClient = postgresDbClient; } /// public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await this._dbConnector.CreateCollectionAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + await this._postgresDbClient.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// public async Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - return await this._dbConnector.DoesCollectionExistsAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + return await this._postgresDbClient.DoesCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await foreach (var collection in this._dbConnector.GetCollectionsAsync(dbConnection, cancellationToken).ConfigureAwait(false)) + await foreach (var collection in this._postgresDbClient.GetCollectionsAsync(cancellationToken).ConfigureAwait(false)) { yield return collection; } @@ -67,68 +70,71 @@ await foreach (var collection in this._dbConnector.GetCollectionsAsync(dbConnect /// public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await this._dbConnector.DeleteCollectionAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + await this._postgresDbClient.DeleteCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false); } /// public async Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - return await this.InternalUpsertAsync(dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false); } /// public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + foreach (var record in records) { - yield return await this.InternalUpsertAsync(dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + yield return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false); } } /// public async Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - return await this.InternalGetAsync(dbConnection, collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + return await this.InternalGetAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); } /// public async IAsyncEnumerable GetBatchAsync(string collectionName, IEnumerable keys, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + foreach (var key in keys) { - var result = await this.InternalGetAsync(dbConnection, collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false); + var result = await this.InternalGetAsync(collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false); if (result != null) { yield return result; } - else - { - yield break; - } } } /// public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await this._dbConnector.DeleteAsync(dbConnection, collectionName, key, cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + + await this._postgresDbClient.DeleteAsync(collectionName, key, cancellationToken).ConfigureAwait(false); } /// public async Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default) { - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + Verify.NotNullOrWhiteSpace(collectionName); + foreach (var key in keys) { - await this._dbConnector.DeleteAsync(dbConnection, collectionName, key, cancellationToken).ConfigureAwait(false); + await this._postgresDbClient.DeleteAsync(collectionName, key, cancellationToken).ConfigureAwait(false); } } @@ -141,15 +147,14 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable ke bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNullOrWhiteSpace(collectionName); + if (limit <= 0) { yield break; } - using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - IAsyncEnumerable<(DatabaseEntry, double)> results = this._dbConnector.GetNearestMatchesAsync( - dbConnection, + IAsyncEnumerable<(PostgresMemoryEntry, double)> results = this._postgresDbClient.GetNearestMatchesAsync( collectionName: collectionName, embeddingFilter: new Vector(embedding.Vector.ToArray()), limit: limit, @@ -161,7 +166,7 @@ await foreach (var (entry, cosineSimilarity) in results.ConfigureAwait(false)) { MemoryRecord record = MemoryRecord.FromJsonMetadata( json: entry.MetadataString, - withEmbeddings && entry.Embedding != null ? new Embedding(entry.Embedding!.ToArray()) : Embedding.Empty, + this.GetEmbeddingForEntry(entry), entry.Key, ParseTimestamp(entry.Timestamp)); yield return (record, cosineSimilarity); @@ -181,46 +186,9 @@ await foreach (var (entry, cosineSimilarity) in results.ConfigureAwait(false)) cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); } - /// - public void Dispose() - { - this.Dispose(true); - GC.SuppressFinalize(this); - } - - #region protected ================================================================================ - - protected virtual void Dispose(bool disposing) - { - if (!this._disposedValue) - { - if (disposing) - { - this._dataSource.Dispose(); - } - - this._disposedValue = true; - } - } - - #endregion - #region private ================================================================================ - private readonly Database _dbConnector; - private readonly NpgsqlDataSource _dataSource; - private bool _disposedValue; - - /// - /// Constructor - /// - /// Postgres data source. - private PostgresMemoryStore(NpgsqlDataSource dataSource) - { - this._dataSource = dataSource; - this._dbConnector = new Database(); - this._disposedValue = false; - } + private readonly IPostgresDbClient _postgresDbClient; private static long? ToTimestampLong(DateTimeOffset? timestamp) { @@ -237,12 +205,11 @@ private PostgresMemoryStore(NpgsqlDataSource dataSource) return null; } - private async Task InternalUpsertAsync(NpgsqlConnection connection, string collectionName, MemoryRecord record, CancellationToken cancellationToken) + private async Task InternalUpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken) { record.Key = record.Metadata.Id; - await this._dbConnector.UpsertAsync( - conn: connection, + await this._postgresDbClient.UpsertAsync( collectionName: collectionName, key: record.Key, metadata: record.GetSerializedMetadata(), @@ -253,27 +220,23 @@ private async Task InternalUpsertAsync(NpgsqlConnection connection, stri return record.Key; } - private async Task InternalGetAsync(NpgsqlConnection connection, string collectionName, string key, bool withEmbedding, CancellationToken cancellationToken) + private async Task InternalGetAsync(string collectionName, string key, bool withEmbedding, CancellationToken cancellationToken) { - DatabaseEntry? entry = await this._dbConnector.ReadAsync(connection, collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + PostgresMemoryEntry? entry = await this._postgresDbClient.ReadAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); if (!entry.HasValue) { return null; } - if (withEmbedding) - { - return MemoryRecord.FromJsonMetadata( - json: entry.Value.MetadataString, - embedding: entry.Value.Embedding != null ? new Embedding(entry.Value.Embedding.ToArray()) : Embedding.Empty, - entry.Value.Key, - ParseTimestamp(entry.Value.Timestamp)); - } - return MemoryRecord.FromJsonMetadata( json: entry.Value.MetadataString, - Embedding.Empty, + embedding: this.GetEmbeddingForEntry(entry.Value), entry.Value.Key, ParseTimestamp(entry.Value.Timestamp)); } + private Embedding GetEmbeddingForEntry(PostgresMemoryEntry entry) + { + return entry.Embedding != null ? new Embedding(entry.Embedding!.ToArray()) : Embedding.Empty; + } + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index 6685e3e0fd77..4ba3424224ae 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -8,6 +8,10 @@ This connector uses Postgres to implement Semantic Memory. It requires the [pgve How to install the pgvector extension, please refer to its [documentation](https://github.com/pgvector/pgvector#installation). +This extension is also available for **Azure Database for PostgreSQL - Flexible Server** and **Azure Cosmos DB for PostgreSQL**. +- [Azure Database for Postgres](https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/how-to-use-pgvector) +- [Azure Cosmos DB for PostgreSQL](https://learn.microsoft.com/en-us/azure/cosmos-db/postgresql/howto-use-pgvector) + ## Quick start 1. To install pgvector using Docker: @@ -16,15 +20,84 @@ How to install the pgvector extension, please refer to its [documentation](https docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword ankane/pgvector ``` -2. To use Postgres as a semantic memory store: +2. Create a database and enable pgvector extension on this database + +```bash +docker exec -it postgres-pgvector psql -U postgres + +postgres=# CREATE DATABASE sk_demo; +postgres=# \c sk_demo +sk_demo=# CREATE EXTENSION vector; +``` + +> Note, "Azure Cosmos DB for PostgreSQL" uses `SELECT CREATE_EXTENSION('vector');` to enable the extension. + +3. To use Postgres as a semantic memory store: ```csharp -using PostgresMemoryStore memoryStore = await PostgresMemoryStore.ConnectAsync("Host=localhost;Port=5432;Database=sk_memory;User Id=postgres;Password=mysecretpassword", vectorSize: 1536); +NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_memory;User Id=postgres;Password=mysecretpassword"); +dataSourceBuilder.UseVector(); +NpgsqlDataSource dataSource = dataSourceBuilder.Build(); + +PostgresMemoryStore memoryStore = new PostgresMemoryStore(dataSource, vectorSize: 1536/*, schema: "public", numberOfLists: 1000 */); IKernel kernel = Kernel.Builder .WithLogger(ConsoleLogger.Log) - .Configure(c => c.AddOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY"))) + .WithOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY")) .WithMemoryStorage(memoryStore) + //.WithPostgresMemoryStore(dataSource, vectorSize: 1536, schema: "public", numberOfLists: 1000) // This method offers an alternative approach to registering Postgres memory store. .Build(); ``` +## Migration from older versions +Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. + +We provide the following migration script to help you migrate to the new structure. However, please note that due to the use of collections as table names, you need to make sure that all Collections conform to the [Postgres naming convention](https://www.postgresql.org/docs/15/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS) before migrating. + +- Table names may only consist of ASCII letters, digits, and underscores. +- Table names must start with a letter or an underscore. +- Table names may not exceed 63 characters in length. +- Table names are case-insensitive, but it is recommended to use lowercase letters. + +```sql +-- Create new tables, each with the name of the collection field value +DO $$ +DECLARE + r record; +BEGIN + FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP + + -- Drop Table (This will delete the table that already exists. Please consider carefully if you think you need to cancel this comment!) + -- EXECUTE format('DROP TABLE IF EXISTS %I;', r.collection); + + -- Create Table (Modify vector size on demand) + EXECUTE format('CREATE TABLE public.%I ( + key TEXT NOT NULL, + metadata JSONB, + embedding vector(1536), + timestamp BIGINT, + PRIMARY KEY (key) + );', r.collection); + + -- Create Index (You can modify the size of lists according to your data needs. Its default value is 1000.) + EXECUTE format('CREATE INDEX %I + ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = 1000);', + r.collection || '_ix', r.collection); + END LOOP; +END $$; + +-- Copy data from the old table to the new table +DO $$ +DECLARE + r record; +BEGIN + FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP + EXECUTE format('INSERT INTO public.%I (key, metadata, embedding, timestamp) + SELECT key, metadata::JSONB, embedding, timestamp + FROM sk_memory_table WHERE collection = %L AND key <> '''';', r.collection, r.collection); + END LOOP; +END $$; + +-- Drop old table (After ensuring successful execution, you can remove the following comments to remove sk_memory_table.) +-- DROP TABLE IF EXISTS sk_memory_table; +``` diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj index 9fb6b4c67fa5..2174e8393a68 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.UnitTests/Connectors.UnitTests.csproj @@ -32,6 +32,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs new file mode 100644 index 000000000000..581d623a0c54 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Connectors.Memory.Postgres; +using Microsoft.SemanticKernel.Memory; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.Memory.Postgres; + +/// +/// Unit tests for class. +/// +public class PostgresMemoryStoreTests +{ + private const string CollectionName = "fake-collection-name"; + + private readonly Mock _postgresDbClientMock; + + public PostgresMemoryStoreTests() + { + this._postgresDbClientMock = new Mock(); + this._postgresDbClientMock + .Setup(client => client.DoesCollectionExistsAsync(CollectionName, CancellationToken.None)) + .ReturnsAsync(true); + } + + [Fact] + public async Task ItCanCreateCollectionAsync() + { + // Arrange + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + await store.CreateCollectionAsync(CollectionName); + + // Assert + this._postgresDbClientMock.Verify(client => client.CreateCollectionAsync(CollectionName, CancellationToken.None), Times.Once()); + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + await store.DeleteCollectionAsync(CollectionName); + + // Assert + this._postgresDbClientMock.Verify(client => client.DeleteCollectionAsync(CollectionName, CancellationToken.None), Times.Once()); + } + + [Fact] + public async Task ItReturnsTrueWhenCollectionExistsAsync() + { + // Arrange + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var doesCollectionExist = await store.DoesCollectionExistAsync(CollectionName); + + // Assert + Assert.True(doesCollectionExist); + } + + [Fact] + public async Task ItReturnsFalseWhenCollectionDoesNotExistAsync() + { + // Arrange + const string collectionName = "non-existent-collection"; + + this._postgresDbClientMock + .Setup(client => client.DoesCollectionExistsAsync(collectionName, CancellationToken.None)) + .ReturnsAsync(false); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var doesCollectionExist = await store.DoesCollectionExistAsync(collectionName); + + // Assert + Assert.False(doesCollectionExist); + } + + [Fact] + public async Task ItCanUpsertAsync() + { + // Arrange + var expectedMemoryRecord = this.GetRandomMemoryRecord(); + var postgresMemoryEntry = this.GetPostgresMemoryEntryFromMemoryRecord(expectedMemoryRecord)!; + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualMemoryRecordKey = await store.UpsertAsync(CollectionName, expectedMemoryRecord); + + // Assert + this._postgresDbClientMock.Verify(client => client.UpsertAsync(CollectionName, postgresMemoryEntry.Key, postgresMemoryEntry.MetadataString, It.Is(x => x.ToArray().SequenceEqual(postgresMemoryEntry.Embedding!.ToArray())), postgresMemoryEntry.Timestamp, CancellationToken.None), Times.Once()); + Assert.Equal(expectedMemoryRecord.Key, actualMemoryRecordKey); + } + + [Fact] + public async Task ItCanUpsertBatchAsyncAsync() + { + // Arrange + var memoryRecord1 = this.GetRandomMemoryRecord(); + var memoryRecord2 = this.GetRandomMemoryRecord(); + var memoryRecord3 = this.GetRandomMemoryRecord(); + + var batchUpsertMemoryRecords = new[] { memoryRecord1, memoryRecord2, memoryRecord3 }; + var expectedMemoryRecordKeys = batchUpsertMemoryRecords.Select(l => l.Key).ToList(); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualMemoryRecordKeys = await store.UpsertBatchAsync(CollectionName, batchUpsertMemoryRecords).ToListAsync(); + + // Assert + foreach (var memoryRecord in batchUpsertMemoryRecords) + { + var postgresMemoryEntry = this.GetPostgresMemoryEntryFromMemoryRecord(memoryRecord)!; + this._postgresDbClientMock.Verify(client => client.UpsertAsync(CollectionName, postgresMemoryEntry.Key, postgresMemoryEntry.MetadataString, It.Is(x => x.ToArray().SequenceEqual(postgresMemoryEntry.Embedding!.ToArray())), postgresMemoryEntry.Timestamp, CancellationToken.None), Times.Once()); + } + + for (int i = 0; i < expectedMemoryRecordKeys.Count; i++) + { + Assert.Equal(expectedMemoryRecordKeys[i], actualMemoryRecordKeys[i]); + } + } + + [Fact] + public async Task ItCanGetMemoryRecordFromCollectionAsync() + { + // Arrange + var expectedMemoryRecord = this.GetRandomMemoryRecord(); + var postgresMemoryEntry = this.GetPostgresMemoryEntryFromMemoryRecord(expectedMemoryRecord); + + this._postgresDbClientMock + .Setup(client => client.ReadAsync(CollectionName, expectedMemoryRecord.Key, true, CancellationToken.None)) + .ReturnsAsync(postgresMemoryEntry); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualMemoryRecord = await store.GetAsync(CollectionName, expectedMemoryRecord.Key, withEmbedding: true); + + // Assert + Assert.NotNull(actualMemoryRecord); + this.AssertMemoryRecordEqual(expectedMemoryRecord, actualMemoryRecord); + } + + [Fact] + public async Task ItReturnsNullWhenMemoryRecordDoesNotExistAsync() + { + // Arrange + const string memoryRecordKey = "fake-record-key"; + + this._postgresDbClientMock + .Setup(client => client.ReadAsync(CollectionName, memoryRecordKey, true, CancellationToken.None)) + .ReturnsAsync((PostgresMemoryEntry?)null); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualMemoryRecord = await store.GetAsync(CollectionName, memoryRecordKey, withEmbedding: true); + + // Assert + Assert.Null(actualMemoryRecord); + } + + [Fact] + public async Task ItCanGetMemoryRecordBatchFromCollectionAsync() + { + // Arrange + var memoryRecord1 = this.GetRandomMemoryRecord(); + var memoryRecord2 = this.GetRandomMemoryRecord(); + var memoryRecord3 = this.GetRandomMemoryRecord(); + + var expectedMemoryRecords = new[] { memoryRecord1, memoryRecord2, memoryRecord3 }; + var memoryRecordKeys = expectedMemoryRecords.Select(l => l.Key).ToList(); + + foreach (var memoryRecord in expectedMemoryRecords) + { + this._postgresDbClientMock + .Setup(client => client.ReadAsync(CollectionName, memoryRecord.Key, true, CancellationToken.None)) + .ReturnsAsync(this.GetPostgresMemoryEntryFromMemoryRecord(memoryRecord)); + } + + var doesNotExistMemoryKey = "fake-record-key"; + this._postgresDbClientMock + .Setup(client => client.ReadAsync(CollectionName, doesNotExistMemoryKey, true, CancellationToken.None)) + .ReturnsAsync((PostgresMemoryEntry?)null); + + memoryRecordKeys.Add(doesNotExistMemoryKey); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualMemoryRecords = await store.GetBatchAsync(CollectionName, memoryRecordKeys, withEmbeddings: true).ToListAsync(); + + // Assert + Assert.Equal(expectedMemoryRecords.Length, actualMemoryRecords.Count); + + for (var i = 0; i < expectedMemoryRecords.Length; i++) + { + this.AssertMemoryRecordEqual(expectedMemoryRecords[i], actualMemoryRecords[i]); + } + } + + [Fact] + public async Task ItCanReturnCollectionsAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresDbClientMock + .Setup(client => client.GetCollectionsAsync(CancellationToken.None)) + .Returns(expectedCollections.ToAsyncEnumerable()); + + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + var actualCollections = await store.GetCollectionsAsync().ToListAsync(); + + // Assert + Assert.Equal(expectedCollections.Count, actualCollections.Count); + + for (var i = 0; i < expectedCollections.Count; i++) + { + Assert.Equal(expectedCollections[i], actualCollections[i]); + } + } + + [Fact] + public async Task ItCanRemoveAsync() + { + // Arrange + const string memoryRecordKey = "fake-record-key"; + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + await store.RemoveAsync(CollectionName, memoryRecordKey); + + // Assert + this._postgresDbClientMock.Verify(client => client.DeleteAsync(CollectionName, memoryRecordKey, CancellationToken.None), Times.Once()); + } + + [Fact] + public async Task ItCanRemoveBatchAsync() + { + // Arrange + string[] memoryRecordKeys = new string[] { "fake-record-key1", "fake-record-key2", "fake-record-key3" }; + var store = new PostgresMemoryStore(this._postgresDbClientMock.Object); + + // Act + await store.RemoveBatchAsync(CollectionName, memoryRecordKeys); + + // Assert + foreach (var memoryRecordKey in memoryRecordKeys) + { + this._postgresDbClientMock.Verify(client => client.DeleteAsync(CollectionName, memoryRecordKey, CancellationToken.None), Times.Once()); + } + } + + #region private ================================================================================ + + private void AssertMemoryRecordEqual(MemoryRecord expectedRecord, MemoryRecord actualRecord) + { + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Embedding.Vector, actualRecord.Embedding.Vector); + Assert.Equal(expectedRecord.Metadata.Id, actualRecord.Metadata.Id); + Assert.Equal(expectedRecord.Metadata.Text, actualRecord.Metadata.Text); + Assert.Equal(expectedRecord.Metadata.Description, actualRecord.Metadata.Description); + Assert.Equal(expectedRecord.Metadata.AdditionalMetadata, actualRecord.Metadata.AdditionalMetadata); + Assert.Equal(expectedRecord.Metadata.IsReference, actualRecord.Metadata.IsReference); + Assert.Equal(expectedRecord.Metadata.ExternalSourceName, actualRecord.Metadata.ExternalSourceName); + } + + private MemoryRecord GetRandomMemoryRecord(Embedding? embedding = null) + { + var id = Guid.NewGuid().ToString(); + var memoryEmbedding = embedding ?? new Embedding(new[] { 1f, 3f, 5f }); + + return MemoryRecord.LocalRecord( + id: id, + text: "text-" + Guid.NewGuid().ToString(), + description: "description-" + Guid.NewGuid().ToString(), + embedding: memoryEmbedding, + additionalMetadata: "metadata-" + Guid.NewGuid().ToString(), + key: id); + } + + private PostgresMemoryEntry GetPostgresMemoryEntryFromMemoryRecord(MemoryRecord memoryRecord) + { + return new PostgresMemoryEntry() + { + Key = memoryRecord.Key, + Embedding = new Pgvector.Vector(memoryRecord.Embedding.Vector.ToArray()), + MetadataString = memoryRecord.GetSerializedMetadata(), + Timestamp = memoryRecord.Timestamp?.ToUnixTimeMilliseconds() + }; + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs index f5a0c680846b..dabea03027c8 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -3,13 +3,14 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.Globalization; using System.Linq; using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.AI.Embeddings; using Microsoft.SemanticKernel.Connectors.Memory.Postgres; using Microsoft.SemanticKernel.Memory; using Npgsql; +using Pgvector.Npgsql; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -17,216 +18,109 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// Integration tests of . /// -public class PostgresMemoryStoreTests : IDisposable +public class PostgresMemoryStoreTests : IAsyncLifetime { - // Set null enable tests - private const string SkipOrNot = "Required posgres with pgvector extension"; + // If null, all tests will be enabled + private const string? SkipReason = "Required postgres with pgvector extension"; - private const string ConnectionString = "Host=localhost;Database={0};User Id=postgres"; - private readonly string _databaseName; - - private bool _disposedValue = false; - - public PostgresMemoryStoreTests() + public async Task InitializeAsync() { -#pragma warning disable CA5394 - this._databaseName = $"sk_pgvector_dotnet_it_{Random.Shared.Next(0, 1000)}"; -#pragma warning restore CA5394 - } + // Load configuration + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - this.Dispose(disposing: true); - GC.SuppressFinalize(this); - } + var connectionString = configuration["Postgres:ConnectionString"]; - protected virtual void Dispose(bool disposing) - { - if (!this._disposedValue) + if (string.IsNullOrWhiteSpace(connectionString)) { - if (disposing) - { - using NpgsqlConnection conn = new(string.Format(CultureInfo.CurrentCulture, ConnectionString, "postgres")); - conn.Open(); -#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities - using NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); -#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities - command.ExecuteNonQuery(); - } - - this._disposedValue = true; + throw new ArgumentNullException("Postgres memory connection string is not configured"); } - } - private int _collectionNum = 0; + this._connectionString = connectionString; + this._databaseName = $"sk_it_{Guid.NewGuid():N}"; - private async Task TryCreateDatabaseAsync() - { - using NpgsqlConnection conn = new(string.Format(CultureInfo.CurrentCulture, ConnectionString, "postgres")); - await conn.OpenAsync(); - using NpgsqlCommand checkCmd = new("SELECT COUNT(*) FROM pg_database WHERE datname = @databaseName", conn); - checkCmd.Parameters.AddWithValue("@databaseName", this._databaseName); + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString); + connectionStringBuilder.Database = this._databaseName; - var count = (long?)await checkCmd.ExecuteScalarAsync(); - if (count == 0) - { -#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities - using var createCmd = new NpgsqlCommand($"CREATE DATABASE \"{this._databaseName}\"", conn); -#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities - await createCmd.ExecuteNonQueryAsync(); - } - } + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString()); + dataSourceBuilder.UseVector(); - private async Task CreateMemoryStoreAsync() - { - await this.TryCreateDatabaseAsync(); - return await PostgresMemoryStore.ConnectAsync(string.Format(CultureInfo.CurrentCulture, ConnectionString, this._databaseName), vectorSize: 3); + this._dataSource = dataSourceBuilder.Build(); + + await this.CreateDatabaseAsync(); } - private IEnumerable CreateBatchRecords(int numRecords) + public async Task DisposeAsync() { - Assert.True(numRecords % 2 == 0, "Number of records must be even"); - Assert.True(numRecords > 0, "Number of records must be greater than 0"); - - IEnumerable records = new List(numRecords); - for (int i = 0; i < numRecords / 2; i++) - { - var testRecord = MemoryRecord.LocalRecord( - id: "test" + i, - text: "text" + i, - description: "description" + i, - embedding: new Embedding(new float[] { 1, 1, 1 })); - records = records.Append(testRecord); - } - - for (int i = numRecords / 2; i < numRecords; i++) - { - var testRecord = MemoryRecord.ReferenceRecord( - externalId: "test" + i, - sourceName: "sourceName" + i, - description: "description" + i, - embedding: new Embedding(new float[] { 1, 2, 3 })); - records = records.Append(testRecord); - } + await this._dataSource.DisposeAsync(); - return records; + await this.DropDatabaseAsync(); } - [Fact(Skip = SkipOrNot)] - public async Task InitializeDbConnectionSucceedsAsync() + [Fact(Skip = SkipReason)] + public void InitializeDbConnectionSucceeds() { - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); // Assert - Assert.NotNull(db); + Assert.NotNull(memoryStore); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanCreateAndGetCollectionAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var collections = db.GetCollectionsAsync(); + await memoryStore.CreateCollectionAsync(collection); + var collections = memoryStore.GetCollectionsAsync(); // Assert Assert.NotEmpty(collections.ToEnumerable()); Assert.True(await collections.ContainsAsync(collection)); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanCheckIfCollectionExistsAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); string collection = "my_collection"; - this._collectionNum++; - - // Act - await db.CreateCollectionAsync(collection); - - // Assert - Assert.True(await db.DoesCollectionExistAsync("my_collection")); - Assert.False(await db.DoesCollectionExistAsync("my_collection2")); - } - - [Fact(Skip = SkipOrNot)] - public async Task CreatingDuplicateCollectionDoesNothingAsync() - { - // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; // Act - await db.CreateCollectionAsync(collection); - var collections = db.GetCollectionsAsync(); - await db.CreateCollectionAsync(collection); + await memoryStore.CreateCollectionAsync(collection); // Assert - var collections2 = db.GetCollectionsAsync(); - Assert.Equal(await collections.CountAsync(), await collections.CountAsync()); + Assert.True(await memoryStore.DoesCollectionExistAsync("my_collection")); + Assert.False(await memoryStore.DoesCollectionExistAsync("my_collection2")); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task CollectionsCanBeDeletedAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; - await db.CreateCollectionAsync(collection); - var collections = await db.GetCollectionsAsync().ToListAsync(); - Assert.True(collections.Count > 0); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); + string collection = "test_collection"; + await memoryStore.CreateCollectionAsync(collection); + Assert.True(await memoryStore.DoesCollectionExistAsync(collection)); // Act - foreach (var c in collections) - { - await db.DeleteCollectionAsync(c); - } + await memoryStore.DeleteCollectionAsync(collection); // Assert - var collections2 = db.GetCollectionsAsync(); - Assert.True(await collections2.CountAsync() == 0); + Assert.False(await memoryStore.DoesCollectionExistAsync(collection)); } - [Fact(Skip = SkipOrNot)] - public async Task ItCanInsertIntoNonExistentCollectionAsync() - { - // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - MemoryRecord testRecord = MemoryRecord.LocalRecord( - id: "test", - text: "text", - description: "description", - embedding: new Embedding(new float[] { 1, 2, 3 }), - key: null, - timestamp: null); - - // Arrange - var key = await db.UpsertAsync("random collection", testRecord); - var actual = await db.GetAsync("random collection", key, true); - - // Assert - Assert.NotNull(actual); - Assert.Equal(testRecord.Metadata.Id, key); - Assert.Equal(testRecord.Metadata.Id, actual.Key); - Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); - Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); - Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); - Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); - Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); - } - - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test", text: "text", @@ -234,14 +128,13 @@ public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() embedding: new Embedding(new float[] { 1, 2, 3 }), key: null, timestamp: null); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var key = await db.UpsertAsync(collection, testRecord); - var actualDefault = await db.GetAsync(collection, key); - var actualWithEmbedding = await db.GetAsync(collection, key, true); + await memoryStore.CreateCollectionAsync(collection); + var key = await memoryStore.UpsertAsync(collection, testRecord); + var actualDefault = await memoryStore.GetAsync(collection, key); + var actualWithEmbedding = await memoryStore.GetAsync(collection, key, true); // Assert Assert.NotNull(actualDefault); @@ -250,11 +143,11 @@ public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() Assert.NotEmpty(actualWithEmbedding.Embedding.Vector); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanUpsertAndRetrieveARecordWithNoTimestampAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test", text: "text", @@ -262,13 +155,12 @@ public async Task ItCanUpsertAndRetrieveARecordWithNoTimestampAsync() embedding: new Embedding(new float[] { 1, 2, 3 }), key: null, timestamp: null); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var key = await db.UpsertAsync(collection, testRecord); - var actual = await db.GetAsync(collection, key, true); + await memoryStore.CreateCollectionAsync(collection); + var key = await memoryStore.UpsertAsync(collection, testRecord); + var actual = await memoryStore.GetAsync(collection, key, true); // Assert Assert.NotNull(actual); @@ -281,25 +173,24 @@ public async Task ItCanUpsertAndRetrieveARecordWithNoTimestampAsync() Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanUpsertAndRetrieveARecordWithTimestampAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test", text: "text", description: "description", embedding: new Embedding(new float[] { 1, 2, 3 }), key: null, - timestamp: DateTimeOffset.UtcNow); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + timestamp: DateTimeOffset.FromUnixTimeMilliseconds(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())); + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var key = await db.UpsertAsync(collection, testRecord); - var actual = await db.GetAsync(collection, key, true); + await memoryStore.CreateCollectionAsync(collection); + var key = await memoryStore.UpsertAsync(collection, testRecord); + var actual = await memoryStore.GetAsync(collection, key, true); // Assert Assert.NotNull(actual); @@ -310,13 +201,14 @@ public async Task ItCanUpsertAndRetrieveARecordWithTimestampAsync() Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + Assert.Equal(testRecord.Timestamp, actual.Timestamp); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task UpsertReplacesExistingRecordWithSameIdAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); string commonId = "test"; MemoryRecord testRecord = MemoryRecord.LocalRecord( id: commonId, @@ -328,14 +220,13 @@ public async Task UpsertReplacesExistingRecordWithSameIdAsync() text: "text2", description: "description2", embedding: new Embedding(new float[] { 1, 2, 4 })); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var key = await db.UpsertAsync(collection, testRecord); - var key2 = await db.UpsertAsync(collection, testRecord2); - var actual = await db.GetAsync(collection, key, true); + await memoryStore.CreateCollectionAsync(collection); + var key = await memoryStore.UpsertAsync(collection, testRecord); + var key2 = await memoryStore.UpsertAsync(collection, testRecord2); + var actual = await memoryStore.GetAsync(collection, key, true); // Assert Assert.NotNull(actual); @@ -347,64 +238,63 @@ public async Task UpsertReplacesExistingRecordWithSameIdAsync() Assert.Equal(testRecord2.Metadata.Description, actual.Metadata.Description); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ExistingRecordCanBeRemovedAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test", text: "text", description: "description", embedding: new Embedding(new float[] { 1, 2, 3 })); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - var key = await db.UpsertAsync(collection, testRecord); - await db.RemoveAsync(collection, key); - var actual = await db.GetAsync(collection, key); + await memoryStore.CreateCollectionAsync(collection); + var key = await memoryStore.UpsertAsync(collection, testRecord); + var upsertedRecord = await memoryStore.GetAsync(collection, key); + await memoryStore.RemoveAsync(collection, key); + var actual = await memoryStore.GetAsync(collection, key); // Assert + Assert.NotNull(upsertedRecord); Assert.Null(actual); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task RemovingNonExistingRecordDoesNothingAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); + string collection = "test_collection"; // Act - await db.CreateCollectionAsync(collection); - await db.RemoveAsync(collection, "key"); - var actual = await db.GetAsync(collection, "key"); + await memoryStore.CreateCollectionAsync(collection); + await memoryStore.RemoveAsync(collection, "key"); + var actual = await memoryStore.GetAsync(collection, "key"); // Assert Assert.Null(actual); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanListAllDatabaseCollectionsAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); string[] testCollections = { "random_collection1", "random_collection2", "random_collection3" }; - this._collectionNum += 3; - await db.CreateCollectionAsync(testCollections[0]); - await db.CreateCollectionAsync(testCollections[1]); - await db.CreateCollectionAsync(testCollections[2]); + await memoryStore.CreateCollectionAsync(testCollections[0]); + await memoryStore.CreateCollectionAsync(testCollections[1]); + await memoryStore.CreateCollectionAsync(testCollections[2]); // Act - var collections = await db.GetCollectionsAsync().ToListAsync(); + var collections = await memoryStore.GetCollectionsAsync().ToListAsync(); // Assert foreach (var collection in testCollections) { - Assert.True(await db.DoesCollectionExistAsync(collection)); + Assert.True(await memoryStore.DoesCollectionExistAsync(collection)); } Assert.NotNull(collections); @@ -418,23 +308,22 @@ public async Task ItCanListAllDatabaseCollectionsAsync() $"Collections does not contain the newly-created collection {testCollections[2]}"); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); int topN = 4; - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; - await db.CreateCollectionAsync(collection); + string collection = "test_collection"; + await memoryStore.CreateCollectionAsync(collection); int i = 0; MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test" + i, text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 1, 1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -442,7 +331,7 @@ public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -1, -1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -450,7 +339,7 @@ public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 2, 3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -458,7 +347,7 @@ public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -2, -3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -466,11 +355,11 @@ public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, -1, -2 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); // Act double threshold = -1; - var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: threshold).ToEnumerable().ToArray(); + var topNResults = memoryStore.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: threshold).ToEnumerable().ToArray(); // Assert Assert.Equal(topN, topNResults.Length); @@ -481,22 +370,21 @@ public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() } } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; - await db.CreateCollectionAsync(collection); + string collection = "test_collection"; + await memoryStore.CreateCollectionAsync(collection); int i = 0; MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test" + i, text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 1, 1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -504,7 +392,7 @@ public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync( text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -1, -1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -512,7 +400,7 @@ public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync( text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 2, 3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -520,7 +408,7 @@ public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync( text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -2, -3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -528,12 +416,12 @@ public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync( text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, -1, -2 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); // Act double threshold = 0.75; - var topNResultDefault = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); - var topNResultWithEmbedding = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold, withEmbedding: true); + var topNResultDefault = await memoryStore.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + var topNResultWithEmbedding = await memoryStore.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold, withEmbedding: true); // Assert Assert.NotNull(topNResultDefault); @@ -542,22 +430,21 @@ public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync( Assert.NotEmpty(topNResultWithEmbedding.Value.Item1.Embedding.Vector); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task GetNearestMatchAsyncReturnsExpectedAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; - await db.CreateCollectionAsync(collection); + string collection = "test_collection"; + await memoryStore.CreateCollectionAsync(collection); int i = 0; MemoryRecord testRecord = MemoryRecord.LocalRecord( id: "test" + i, text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 1, 1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -565,7 +452,7 @@ public async Task GetNearestMatchAsyncReturnsExpectedAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -1, -1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -573,7 +460,7 @@ public async Task GetNearestMatchAsyncReturnsExpectedAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 2, 3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -581,7 +468,7 @@ public async Task GetNearestMatchAsyncReturnsExpectedAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { -1, -2, -3 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); i++; testRecord = MemoryRecord.LocalRecord( @@ -589,11 +476,11 @@ public async Task GetNearestMatchAsyncReturnsExpectedAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, -1, -2 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); // Act double threshold = 0.75; - var topNResult = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + var topNResult = await memoryStore.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); // Assert Assert.NotNull(topNResult); @@ -601,16 +488,15 @@ public async Task GetNearestMatchAsyncReturnsExpectedAsync() Assert.True(topNResult.Value.Item2 >= threshold); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task GetNearestMatchesDifferentiatesIdenticalVectorsByKeyAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); int topN = 4; - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; - await db.CreateCollectionAsync(collection); + string collection = "test_collection"; + await memoryStore.CreateCollectionAsync(collection); for (int i = 0; i < 10; i++) { @@ -619,11 +505,11 @@ public async Task GetNearestMatchesDifferentiatesIdenticalVectorsByKeyAsync() text: "text" + i, description: "description" + i, embedding: new Embedding(new float[] { 1, 1, 1 })); - _ = await db.UpsertAsync(collection, testRecord); + _ = await memoryStore.UpsertAsync(collection, testRecord); } // Act - var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: 0.75).ToEnumerable().ToArray(); + var topNResults = memoryStore.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: 0.75).ToEnumerable().ToArray(); IEnumerable topNKeys = topNResults.Select(x => x.Item1.Key).ToImmutableSortedSet(); // Assert @@ -637,20 +523,19 @@ public async Task GetNearestMatchesDifferentiatesIdenticalVectorsByKeyAsync() } } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanBatchUpsertRecordsAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); int numRecords = 10; - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; IEnumerable records = this.CreateBatchRecords(numRecords); // Act - await db.CreateCollectionAsync(collection); - var keys = db.UpsertBatchAsync(collection, records); - var resultRecords = db.GetBatchAsync(collection, keys.ToEnumerable()); + await memoryStore.CreateCollectionAsync(collection); + var keys = memoryStore.UpsertBatchAsync(collection, records); + var resultRecords = memoryStore.GetBatchAsync(collection, keys.ToEnumerable()); // Assert Assert.NotNull(keys); @@ -658,20 +543,19 @@ public async Task ItCanBatchUpsertRecordsAsync() Assert.Equal(numRecords, resultRecords.ToEnumerable().Count()); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanBatchGetRecordsAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); int numRecords = 10; - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; IEnumerable records = this.CreateBatchRecords(numRecords); - var keys = db.UpsertBatchAsync(collection, records); + var keys = memoryStore.UpsertBatchAsync(collection, records); // Act - await db.CreateCollectionAsync(collection); - var results = db.GetBatchAsync(collection, keys.ToEnumerable()); + await memoryStore.CreateCollectionAsync(collection); + var results = memoryStore.GetBatchAsync(collection, keys.ToEnumerable()); // Assert Assert.NotNull(keys); @@ -679,43 +563,141 @@ public async Task ItCanBatchGetRecordsAsync() Assert.Equal(numRecords, results.ToEnumerable().Count()); } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task ItCanBatchRemoveRecordsAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); int numRecords = 10; - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + string collection = "test_collection"; IEnumerable records = this.CreateBatchRecords(numRecords); - await db.CreateCollectionAsync(collection); + await memoryStore.CreateCollectionAsync(collection); List keys = new(); // Act - await foreach (var key in db.UpsertBatchAsync(collection, records)) + await foreach (var key in memoryStore.UpsertBatchAsync(collection, records)) { keys.Add(key); } - await db.RemoveBatchAsync(collection, keys); + await memoryStore.RemoveBatchAsync(collection, keys); // Assert - await foreach (var result in db.GetBatchAsync(collection, keys)) + await foreach (var result in memoryStore.GetBatchAsync(collection, keys)) { Assert.Null(result); } } - [Fact(Skip = SkipOrNot)] + [Fact(Skip = SkipReason)] public async Task DeletingNonExistentCollectionDoesNothingAsync() { // Arrange - using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); - string collection = "test_collection" + this._collectionNum; - this._collectionNum++; + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); + string collection = "test_collection"; + + // Act + await memoryStore.DeleteCollectionAsync(collection); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanBatchGetRecordsAndSkipIfKeysDoNotExistAsync() + { + // Arrange + PostgresMemoryStore memoryStore = this.CreateMemoryStore(); + int numRecords = 10; + string collection = "test_collection"; + IEnumerable records = this.CreateBatchRecords(numRecords); // Act - await db.DeleteCollectionAsync(collection); + await memoryStore.CreateCollectionAsync(collection); + var keys = await memoryStore.UpsertBatchAsync(collection, records).ToListAsync(); + keys.Insert(0, "not-exist-key-0"); + keys.Insert(5, "not-exist-key-5"); + keys.Add("not-exist-key-n"); + var resultRecords = memoryStore.GetBatchAsync(collection, keys); + + // Assert + Assert.NotNull(keys); + Assert.Equal(numRecords, keys.Count - 3); + Assert.Equal(numRecords, resultRecords.ToEnumerable().Count()); } + + #region private ================================================================================ + + private string _connectionString = null!; + private string _databaseName = null!; + private NpgsqlDataSource _dataSource = null!; + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task CreateDatabaseAsync() + { + using NpgsqlDataSource dataSource = NpgsqlDataSource.Create(this._connectionString); + await using (NpgsqlConnection conn = await dataSource.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new($"CREATE DATABASE \"{this._databaseName}\"", conn)) + { + await command.ExecuteNonQueryAsync(); + } + } + + await using (NpgsqlConnection conn = await this._dataSource.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn)) + { + await command.ExecuteNonQueryAsync(); + } + await conn.ReloadTypesAsync(); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task DropDatabaseAsync() + { + using NpgsqlDataSource dataSource = NpgsqlDataSource.Create(this._connectionString); + await using (NpgsqlConnection conn = await dataSource.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn)) + { + await command.ExecuteNonQueryAsync(); + } + }; + } + + private PostgresMemoryStore CreateMemoryStore() + { + return new PostgresMemoryStore(this._dataSource!, vectorSize: 3, schema: "public"); + } + + private IEnumerable CreateBatchRecords(int numRecords) + { + Assert.True(numRecords % 2 == 0, "Number of records must be even"); + Assert.True(numRecords > 0, "Number of records must be greater than 0"); + + IEnumerable records = new List(numRecords); + for (int i = 0; i < numRecords / 2; i++) + { + var testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + records = records.Append(testRecord); + } + + for (int i = numRecords / 2; i < numRecords; i++) + { + var testRecord = MemoryRecord.ReferenceRecord( + externalId: "test" + i, + sourceName: "sourceName" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + records = records.Append(testRecord); + } + + return records; + } + + #endregion } diff --git a/dotnet/src/IntegrationTests/README.md b/dotnet/src/IntegrationTests/README.md index ca68c03b329a..00186f6309f6 100644 --- a/dotnet/src/IntegrationTests/README.md +++ b/dotnet/src/IntegrationTests/README.md @@ -8,6 +8,7 @@ 3. **HuggingFace API key**: see https://huggingface.co/docs/huggingface_hub/guides/inference for details. 4. **Azure Bing Web Search API**: go to [Bing Web Search API](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api) and select `Try Now` to get started. +5. **Postgres**: start a postgres with the [pgvector](https://github.com/pgvector/pgvector) extension installed. You can easily do it using the docker image [ankane/pgvector](https://hub.docker.com/r/ankane/pgvector). ## Setup @@ -44,6 +45,7 @@ dotnet user-secrets set "AzureOpenAIEmbeddings:ApiKey" "..." dotnet user-secrets set "HuggingFace:ApiKey" "..." dotnet user-secrets set "Bing:ApiKey" "..." +dotnet user-secrets set "Postgres:ConnectionString" "..." ``` ### Option 2: Use Configuration File @@ -86,6 +88,9 @@ For example: }, "Bing": { "ApiKey": "...." + }, + "Postgres": { + "ConnectionString": "Host=localhost;Database=postgres;User Id=postgres;Password=mysecretpassword" } } ``` @@ -106,6 +111,7 @@ When setting environment variables, use a double underscore (i.e. "\_\_") to del export AzureOpenAI__Endpoint="https://contoso.openai.azure.com/" export HuggingFace__ApiKey="...." export Bing__ApiKey="...." + export Postgres__ConnectionString="...." ``` - PowerShell: @@ -119,4 +125,5 @@ When setting environment variables, use a double underscore (i.e. "\_\_") to del $env:AzureOpenAI__Endpoint = "https://contoso.openai.azure.com/" $env:HuggingFace__ApiKey = "...." $env:Bing__ApiKey = "...." + $env:Postgres__ConnectionString = "...." ``` diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index 5efd594e827e..2b5e41c5cbd7 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -27,5 +27,8 @@ }, "Bing": { "ApiKey": "" + }, + "Postgres": { + "ConnectionString": "" } } \ No newline at end of file