diff --git a/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs new file mode 100644 index 0000000000..7df0637520 --- /dev/null +++ b/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs @@ -0,0 +1,97 @@ +using System; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Npgsql; + +namespace Marten.Storage.Encryption; + +/// +/// Provides AES-based encryption and decryption for connection strings using a 32-byte key. +/// Connection strings are encrypted in memory using AES with a randomly generated IV. +/// +internal class AesConnectionStringEncryptor : IConnectionStringEncryptor +{ + private readonly string _encryptionKey; + + /// + /// Initializes a new instance of the AES encryption provider. + /// + /// The encryption key for AES encryption + /// Thrown when the key is null or empty + public AesConnectionStringEncryptor(string encryptionKey) + { + if (string.IsNullOrWhiteSpace(encryptionKey)) + throw new ArgumentException("AES encryption key cannot be empty or whitespace", nameof(encryptionKey)); + + _encryptionKey = encryptionKey; + } + + public string Encrypt(string connectionString) + { + using var aes = Aes.Create(); + using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, 16, 1000, HashAlgorithmName.SHA256); + aes.Key = deriveBytes.GetBytes(32); + aes.IV = deriveBytes.GetBytes(16); + + using var encryptor = aes.CreateEncryptor(); + var plainTextBytes = Encoding.UTF8.GetBytes(connectionString); + var cipherTextBytes = encryptor.TransformFinalBlock(plainTextBytes, 0, plainTextBytes.Length); + + var result = new byte[aes.IV.Length + cipherTextBytes.Length]; + Buffer.BlockCopy(deriveBytes.Salt, 0, result, 0, 16); + Buffer.BlockCopy(cipherTextBytes, 0, result, 16, cipherTextBytes.Length); + + return Convert.ToBase64String(result); + } + + public string Decrypt(string encryptedConnectionString) + { + try + { + var cipherTextBytes = Convert.FromBase64String(encryptedConnectionString); + if (cipherTextBytes.Length < 16) // IV size + return encryptedConnectionString; + + using var aes = Aes.Create(); + var salt = new byte[16]; + var cipher = new byte[cipherTextBytes.Length - 16]; + Buffer.BlockCopy(cipherTextBytes, 0, salt, 0, 16); + Buffer.BlockCopy(cipherTextBytes, 16, cipher, 0, cipherTextBytes.Length - 16); + using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, salt, 1000, HashAlgorithmName.SHA256); + aes.Key = deriveBytes.GetBytes(32); + aes.IV = deriveBytes.GetBytes(16); + + using var decryptor = aes.CreateDecryptor(); + var plainTextBytes = decryptor.TransformFinalBlock(cipher, 0, cipher.Length); + return Encoding.UTF8.GetString(plainTextBytes); + } + catch + { + // If decryption fails, return the original string + return encryptedConnectionString; + } + } + + public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString) + { + var encryptedString = Encrypt(connectionString); + return ($"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " + + "on conflict (tenant_id) do update set connection_string = ?", + [ + tenantId, + encryptedString, + encryptedString + ]); + } + + public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId) + { + return ($"select tenant_id, connection_string from {schemaName}.{tableName}" + + (tenantId == "*" ? "" : " where tenant_id = ?"), + tenantId == "*" ? Array.Empty() : [tenantId]); + } + + // No prerequisites needed for AES encryption since it's done in memory +} diff --git a/src/Marten/Storage/Encryption/EncryptionOptions.cs b/src/Marten/Storage/Encryption/EncryptionOptions.cs new file mode 100644 index 0000000000..575ff172a7 --- /dev/null +++ b/src/Marten/Storage/Encryption/EncryptionOptions.cs @@ -0,0 +1,100 @@ +using System; + +namespace Marten.Storage.Encryption; + +/// +/// Specifies the encryption method to use for tenant database connection strings. +/// +public enum ConnectionStringEncryption +{ + /// + /// No encryption of connection strings. Connection strings will be stored as plain text. + /// + None, + + /// + /// Use AES encryption with a provided 32-byte encryption key. + /// Connection strings are encrypted in memory using AES with a randomly generated IV. + /// The encrypted data is stored as base64-encoded strings. + /// + AES, + + /// + /// Use PostgreSQL's pgcrypto extension for encryption. + /// Connection strings are encrypted and decrypted directly in the database using + /// pgp_sym_encrypt and pgp_sym_decrypt functions. This requires the pgcrypto + /// extension to be installed in the same schema as the tenant table. + /// + PgCrypto +} + +/// +/// Options for configuring connection string encryption. +/// +public class EncryptionOptions +{ + private string? _encryptionKey; + + /// + /// The type of encryption to use for connection strings. + /// + public ConnectionStringEncryption Type { get; private set; } = ConnectionStringEncryption.None; + + /// + /// The encryption key used to encrypt/decrypt connection strings. + /// Must be exactly 32 characters long. + /// + public string? Key + { + get => _encryptionKey; + private set + { + if (value != null && string.IsNullOrWhiteSpace(value)) + throw new ArgumentException("Encryption key cannot be empty or whitespace", nameof(value)); + _encryptionKey = value; + } + } + + /// + /// Use AES encryption with the specified key for connection strings. + /// + /// The encryption key for AES encryption + /// The current options instance for method chaining + public EncryptionOptions UseAes(string key) + { + var keyBytes = Convert.FromBase64String(key); + if (keyBytes.Length < 16 || keyBytes.Length > 32) + throw new ArgumentException("AES encryption key must be between 16 and 32 bytes (128-256 bits) when base64 decoded", nameof(key)); + + Key = key; + Type = ConnectionStringEncryption.AES; + return this; + } + + /// + /// Use PostgreSQL's pgcrypto extension with the specified key for connection strings. + /// + /// The encryption key for pgcrypto encryption + /// The current options instance for method chaining + public EncryptionOptions UsePgCrypto(string key) + { + var keyBytes = Convert.FromBase64String(key); + if (keyBytes.Length < 16) + throw new ArgumentException("PgCrypto encryption key must be at least 16 bytes (128 bits) when base64 decoded", nameof(key)); + + Key = key; + Type = ConnectionStringEncryption.PgCrypto; + return this; + } + + /// + /// Disable encryption for connection strings. + /// + /// The current options instance for method chaining + public EncryptionOptions UseNoEncryption() + { + Key = null; + Type = ConnectionStringEncryption.None; + return this; + } +} diff --git a/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs new file mode 100644 index 0000000000..1536ac122f --- /dev/null +++ b/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs @@ -0,0 +1,61 @@ +using System.Threading; +using System.Threading.Tasks; +using Npgsql; + +namespace Marten.Storage.Encryption; + +/// +/// Provides encryption and decryption services for tenant database connection strings. +/// Implementations should handle both the encryption/decryption operations and +/// the generation of SQL commands for database operations involving encrypted data. +/// +internal interface IConnectionStringEncryptor +{ + /// + /// Encrypts a connection string using the provider's encryption method. + /// + /// The connection string to encrypt + /// The encrypted connection string + string Encrypt(string connectionString); + + /// + /// Decrypts an encrypted connection string using the provider's decryption method. + /// If decryption fails, implementations should return the original string. + /// + /// The encrypted connection string to decrypt + /// The decrypted connection string, or the original string if decryption fails + string Decrypt(string encryptedConnectionString); + + /// + /// Generates a parameterized SQL command for inserting or updating an encrypted connection string. + /// + /// The database schema name + /// The table name + /// The tenant identifier + /// The connection string to encrypt and store + /// A tuple containing the SQL command text and its parameters + (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString); + + /// + /// Generates a parameterized SQL command for selecting and decrypting connection strings. + /// + /// The database schema name + /// The table name + /// The tenant identifier, use "*" to select all tenants + /// A tuple containing the SQL command text and its parameters + (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId); + + /// + /// Ensures any prerequisites required by the encryption provider are met. + /// For example, checking if required database extensions are installed. + /// + /// The database data source to check against + /// The database schema name + /// Optional cancellation token + /// A task representing the asynchronous operation + Task EnsurePrerequisitesAsync(NpgsqlDataSource dataSource, string schemaName, CancellationToken token = default) + { + // Default implementation assumes no prerequisites are needed + return Task.CompletedTask; + } +} diff --git a/src/Marten/Storage/Encryption/KeyGenerator.cs b/src/Marten/Storage/Encryption/KeyGenerator.cs new file mode 100644 index 0000000000..83a6eb2f3b --- /dev/null +++ b/src/Marten/Storage/Encryption/KeyGenerator.cs @@ -0,0 +1,26 @@ +using System; +using System.Security.Cryptography; + +namespace Marten.Storage.Encryption; + +/// +/// Utility class for generating secure encryption keys compatible with Marten's encryption providers. +/// +public static class KeyGenerator +{ + /// + /// Generates a secure random encryption key suitable for use with Marten's encryption providers. + /// The key is generated using a cryptographically secure random number generator. + /// + /// The number of random bytes to generate. Default is 24 bytes which produces a 32-character base64 string. + /// Use this flag for only testing purposes. + /// A base64-encoded string suitable for use as an encryption key + public static string GenerateKey(int byteLength = 24, bool allowAnyByteLength = false) + { + if (byteLength < 16 && !allowAnyByteLength) + throw new ArgumentException("For security, key should be at least 16 bytes (128 bits)", nameof(byteLength)); + + var keyBytes = RandomNumberGenerator.GetBytes(byteLength); + return Convert.ToBase64String(keyBytes); + } +} diff --git a/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs new file mode 100644 index 0000000000..444deb83a9 --- /dev/null +++ b/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs @@ -0,0 +1,52 @@ +using System; +using Npgsql; + +namespace Marten.Storage.Encryption; + +/// +/// A no-operation implementation of IConnectionStringEncryptor that passes through connection strings without encryption. +/// This provides a consistent interface when no encryption is needed. +/// +internal class NoopConnectionStringEncryptor : IConnectionStringEncryptor +{ + /// + /// Returns the connection string as-is without encryption + /// + public string Encrypt(string connectionString) => connectionString; + + /// + /// Returns the connection string as-is without decryption + /// + public string Decrypt(string encryptedConnectionString) => encryptedConnectionString; + + /// + /// Generates a parameterized SQL command for inserting or updating an unencrypted connection string + /// + public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString) + { + var sql = $"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " + + "on conflict (tenant_id) do update set connection_string = ?"; + + return (sql, [ + tenantId, + connectionString, + connectionString + ]); + } + + /// + /// Generates a parameterized SQL command for selecting unencrypted connection strings + /// + public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId) + { + if (tenantId == "*") + { + return ($"select tenant_id, connection_string from {schemaName}.{tableName}", Array.Empty()); + } + + var sql = $"select connection_string from {schemaName}.{tableName} where tenant_id = ?"; + + return (sql, [tenantId]); + } + +} diff --git a/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs new file mode 100644 index 0000000000..57ead66c45 --- /dev/null +++ b/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs @@ -0,0 +1,98 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Marten.Exceptions; +using Npgsql; + +namespace Marten.Storage.Encryption; + +/// +/// Provides database-level encryption and decryption for connection strings using PostgreSQL's pgcrypto extension. +/// Connection strings are encrypted and decrypted directly in the database using pgp_sym_encrypt and pgp_sym_decrypt functions. +/// +internal class PgCryptoConnectionStringEncryptor : IConnectionStringEncryptor +{ + private readonly string _encryptionKey; + + /// + /// Initializes a new instance of the PgCrypto encryption provider. + /// + /// The encryption key for pgcrypto functions + /// Thrown when the encryption key is null + public PgCryptoConnectionStringEncryptor(string encryptionKey) + { + _encryptionKey = encryptionKey ?? throw new ArgumentNullException(nameof(encryptionKey)); + } + + /// + /// Returns the connection string as-is since encryption is handled at the database level. + /// + public string Encrypt(string connectionString) => connectionString; + + /// + /// Returns the connection string as-is since decryption is handled at the database level. + /// + public string Decrypt(string encryptedConnectionString) => encryptedConnectionString; + + /// + /// Generates SQL to insert or update an encrypted connection string using pgp_sym_encrypt. + /// The encryption is performed by the database using the pgcrypto extension. + /// + public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString) + { + return ($"insert into {schemaName}.{tableName} (tenant_id, connection_string) " + + $"values (?, pgp_sym_encrypt(?::text, ?::text)) " + + $"on conflict (tenant_id) do update set connection_string = pgp_sym_encrypt(?::text, ?::text)", + [ + tenantId, + connectionString, + _encryptionKey, + connectionString, + _encryptionKey, + ]); + } + + /// + /// Generates SQL to select and decrypt connection strings using pgp_sym_decrypt. + /// The decryption is performed by the database using the pgcrypto extension. + /// + public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId) + { + var whereClause = tenantId == "*" ? "" : " where tenant_id = ?"; + var sql = $"select tenant_id, pgp_sym_decrypt(connection_string::bytea, ?::text) as connection_string " + + $"from {schemaName}.{tableName}{whereClause}"; + + return (sql, [_encryptionKey, tenantId]); + } + + /// + /// Ensures the pgcrypto extension is available and properly configured in the specified schema. + /// + public async Task EnsurePrerequisitesAsync(NpgsqlDataSource dataSource, string schemaName, CancellationToken token = default) + { + await using var conn = await dataSource.OpenConnectionAsync(token).ConfigureAwait(false); + try + { + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT EXISTS(" + + "SELECT 1 FROM pg_extension e " + + "WHERE e.extname = 'pgcrypto')"; + cmd.Parameters.AddWithValue("schema_name", schemaName); + var extensionExists = (bool)await cmd.ExecuteScalarAsync(token).ConfigureAwait(false); + + if (!extensionExists) + { + throw new MartenException( + $"PgCrypto encryption requires the pgcrypto extension to be installed.\n" + + $"Run 'CREATE EXTENSION IF NOT EXISTS pgcrypto;' as a superuser or contact your database administrator.") + { + HelpLink = "https://www.postgresql.org/docs/current/pgcrypto.html" + }; + } + } + finally + { + await conn.CloseAsync().ConfigureAwait(false); + } + } +} diff --git a/src/Marten/Storage/MasterTableTenancy.cs b/src/Marten/Storage/MasterTableTenancy.cs index 57ed910e9a..9785ef355f 100644 --- a/src/Marten/Storage/MasterTableTenancy.cs +++ b/src/Marten/Storage/MasterTableTenancy.cs @@ -1,12 +1,13 @@ using System; using System.Collections.Generic; -using System.Data.Common; using System.Linq; using System.Reflection; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using JasperFx.Core; using Marten.Schema; +using Marten.Storage.Encryption; using Npgsql; using Weasel.Core; using Weasel.Core.Migrations; @@ -29,6 +30,15 @@ public class MasterTableTenancyOptions /// public NpgsqlDataSource? DataSource { get; set; } + /// + /// The encryption key used to encrypt/decrypt connection strings. Must be exactly 32 bytes. + /// + public EncryptionOptions ConnectionStringEncryptionOpts + { + get; + set; + } = new(); + /// /// If specified, override the database schema name for the tenants table /// Default is "public" @@ -73,7 +83,6 @@ internal string CorrectConnectionString(string connectionString) if (ApplicationName.IsNotEmpty()) { var builder = new NpgsqlConnectionStringBuilder(connectionString) { ApplicationName = ApplicationName }; - return builder.ConnectionString; } @@ -89,6 +98,7 @@ public class MasterTableTenancy: ITenancy, ITenancyWithMasterDatabase private readonly string _schemaName; private readonly Lazy _tenantDatabase; private ImHashMap _databases = ImHashMap.Empty; + private readonly IConnectionStringEncryptor _encryptionProvider; private bool _hasAppliedChanges; private bool _hasAppliedDefaults; @@ -101,7 +111,6 @@ public MasterTableTenancy(StoreOptions options, string connectionString, string public MasterTableTenancy(StoreOptions options, MasterTableTenancyOptions tenancyOptions) { _options = options; - _configuration = tenancyOptions; if (tenancyOptions.DataSource != null) @@ -124,9 +133,20 @@ public MasterTableTenancy(StoreOptions options, MasterTableTenancyOptions tenanc Cleaner = new CompositeDocumentCleaner(this, _options); _tenantDatabase = new Lazy(() => - new TenantLookupDatabase(_options, _dataSource.Value, tenancyOptions.SchemaName)); + new TenantLookupDatabase(_options, _dataSource.Value, tenancyOptions.SchemaName, _configuration.ConnectionStringEncryptionOpts)); + + _encryptionProvider = _configuration.ConnectionStringEncryptionOpts.Type switch + { + ConnectionStringEncryption.AES when _configuration.ConnectionStringEncryptionOpts.Key != null => new AesConnectionStringEncryptor(_configuration.ConnectionStringEncryptionOpts.Key), + ConnectionStringEncryption.PgCrypto when _configuration.ConnectionStringEncryptionOpts.Key != null => new PgCryptoConnectionStringEncryptor(_configuration.ConnectionStringEncryptionOpts.Key), + _ => new NoopConnectionStringEncryptor() + }; + + // Provider prerequisites will be checked in BuildDatabases } + + public void Dispose() { foreach (var entry in _databases.Enumerate()) entry.Value.Dispose(); @@ -139,6 +159,8 @@ public void Dispose() public async ValueTask> BuildDatabases() { + await _encryptionProvider.EnsurePrerequisitesAsync(_dataSource.Value, _schemaName).ConfigureAwait(false); + await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); await using var conn = _dataSource.Value.CreateConnection(); @@ -151,10 +173,15 @@ public async ValueTask> BuildDatabases() await seedDatabasesAsync(conn).ConfigureAwait(false); } - await using var reader = await ((DbCommand)conn - .CreateCommand($"select tenant_id, connection_string from {_schemaName}.{TenantTable.TableName}")) - .ExecuteReaderAsync().ConfigureAwait(false); + var command = conn.CreateCommand(); + var (sql, parameters) = _encryptionProvider.GetSelectSql(_schemaName, TenantTable.TableName, "*"); + command.CommandText = ConvertToPgPositionalParams(sql); + foreach (var param in parameters) + { + command.Parameters.Add(new NpgsqlParameter { Value = param }); + } + await using var reader = await command.ExecuteReaderAsync().ConfigureAwait(false); while (await reader.ReadAsync().ConfigureAwait(false)) { var tenantId = await reader.GetFieldValueAsync(0).ConfigureAwait(false); @@ -167,6 +194,7 @@ public async ValueTask> BuildDatabases() } var connectionString = await reader.GetFieldValueAsync(1).ConfigureAwait(false); + connectionString = _encryptionProvider?.Decrypt(connectionString) ?? connectionString; connectionString = _configuration.CorrectConnectionString(connectionString); var database = new MartenDatabase(_options, _options.NpgsqlDataSourceFactory.Create(connectionString), @@ -258,46 +286,51 @@ public bool IsTenantStoredInCurrentDatabase(IMartenDatabase database, string ten public async Task DeleteDatabaseRecordAsync(string tenantId) { tenantId = _options.MaybeCorrectTenantId(tenantId); +#pragma warning disable MA0032 await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); +#pragma warning restore MA0032 - await _dataSource.Value - .CreateCommand($"delete from {_schemaName}.{TenantTable.TableName} where tenant_id = :id") - .With("id", tenantId) - .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + var cmd = _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName} where tenant_id = @tenant_id"); + cmd.Parameters.AddWithValue("tenant_id", tenantId); + await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); } public async Task ClearAllDatabaseRecordsAsync() { +#pragma warning disable MA0032 await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); +#pragma warning restore MA0032 - await _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName}") - .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + var cmd = _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName}"); + await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); } public async Task AddDatabaseRecordAsync(string tenantId, string connectionString) { tenantId = _options.MaybeCorrectTenantId(tenantId); - await _dataSource.Value - .CreateCommand( - $"insert into {_schemaName}.{TenantTable.TableName} (tenant_id, connection_string) values (:id, :connection) on conflict (tenant_id) do update set connection_string = :connection") - .With("id", tenantId) - .With("connection", connectionString) - .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + + var (sql, parameters) = _encryptionProvider.GetInsertSql(_schemaName, TenantTable.TableName, tenantId, connectionString); + var cmd = _dataSource.Value.CreateCommand(ConvertToPgPositionalParams(sql)); + foreach (var param in parameters) + { + cmd.Parameters.Add(new NpgsqlParameter { Value = param }); + } +#pragma warning disable MA0032 + await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); +#pragma warning restore MA0032 } - private async Task maybeApplyChanges(TenantLookupDatabase tenantDatabase) + private async Task maybeApplyChanges(TenantLookupDatabase tenantDatabase, CancellationToken token = default) { if (!_hasAppliedChanges && (_configuration.AutoCreate ?? _options.AutoCreateSchemaObjects) != AutoCreate.None) { -#pragma warning disable MA0032 await tenantDatabase - .ApplyAllConfiguredChangesToDatabaseAsync(_options.AutoCreateSchemaObjects).ConfigureAwait(false); -#pragma warning restore MA0032 + .ApplyAllConfiguredChangesToDatabaseAsync(_options.AutoCreateSchemaObjects, ct: default).ConfigureAwait(false); _hasAppliedChanges = true; } } - private async Task seedDatabasesAsync(NpgsqlConnection conn) + private async Task seedDatabasesAsync(NpgsqlConnection conn, CancellationToken token = default) { if (!_configuration.SeedDatabases.Any()) { @@ -308,17 +341,17 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn) foreach (var pair in _configuration.SeedDatabases) { builder.StartNewCommand(); - var parameters = builder.AppendWithParameters( - $"insert into {_schemaName}.{TenantTable.TableName} (tenant_id, connection_string) values (?, ?) on conflict (tenant_id) do update set connection_string = ?"); - - parameters[0].Value = pair.Key; - parameters[1].Value = pair.Value; - parameters[2].Value = pair.Value; + var (sql, parameters) = _encryptionProvider.GetInsertSql(_schemaName, TenantTable.TableName, pair.Key, pair.Value); + var builderParams = builder.AppendWithParameters(sql); + for (var i = 0; i < builderParams.Length; i++) + { + builderParams[i].Value = parameters[i]; + } } var batch = builder.Compile(); batch.Connection = conn; - await batch.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + await batch.ExecuteNonQueryAsync(token).ConfigureAwait(false); _hasAppliedDefaults = true; } @@ -326,17 +359,27 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn) private async Task tryFindTenantDatabase(string tenantId) { tenantId = _options.MaybeCorrectTenantId(tenantId); - var connectionString = (string)await _dataSource.Value - .CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id") - .With("id", tenantId) - .ExecuteScalarAsync(CancellationToken.None).ConfigureAwait(false); - if (connectionString.IsEmpty()) + var (sql, parameters) = _encryptionProvider.GetSelectSql(_schemaName, TenantTable.TableName, tenantId); + var cmd = _dataSource.Value.CreateCommand(ConvertToPgPositionalParams(sql)); + foreach (var param in parameters) + { + cmd.Parameters.Add(new NpgsqlParameter { Value = param }); + } + + await using var reader = await cmd.ExecuteReaderAsync(CancellationToken.None).ConfigureAwait(false); + var connectionString = string.Empty; + while (await reader.ReadAsync(CancellationToken.None).ConfigureAwait(false)) { - return null; + connectionString = await reader.GetFieldValueAsync(1, CancellationToken.None).ConfigureAwait(false); + if (!string.IsNullOrWhiteSpace(connectionString)) + { + connectionString = _encryptionProvider?.Decrypt(connectionString) ?? connectionString; + connectionString = _configuration.CorrectConnectionString(connectionString); + } } - connectionString = _configuration.CorrectConnectionString(connectionString); + await reader.CloseAsync().ConfigureAwait(false); return connectionString.IsNotEmpty() ? new MartenDatabase(_options, @@ -344,14 +387,23 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn) : null; } + private string ConvertToPgPositionalParams(string sql) + { + var index = 0; +#pragma warning disable MA0009 + return Regex.Replace(sql, @"\?", _ => $"${++index}"); +#pragma warning restore MA0009 + } + + internal class TenantLookupDatabase: PostgresqlDatabase { private readonly TenantDatabaseStorage _feature; - public TenantLookupDatabase(StoreOptions options, NpgsqlDataSource dataSource, string schemaName): base(options, + public TenantLookupDatabase(StoreOptions options, NpgsqlDataSource dataSource, string schemaName, EncryptionOptions encryptionOpts): base(options, options.AutoCreateSchemaObjects, options.Advanced.Migrator, "TenantDatabases", dataSource) { - _feature = new TenantDatabaseStorage(schemaName, options); + _feature = new TenantDatabaseStorage(schemaName, options, encryptionOpts); } public override IFeatureSchema[] BuildFeatureSchemas() @@ -364,16 +416,23 @@ internal class TenantDatabaseStorage: FeatureSchemaBase { private readonly StoreOptions _options; private readonly string _schemaName; + private readonly EncryptionOptions _encryptionOpts; - public TenantDatabaseStorage(string schemaName, StoreOptions options): base("TenantDatabases", + public TenantDatabaseStorage(string schemaName, StoreOptions options,EncryptionOptions encryptionOpts): base("TenantDatabases", options.Advanced.Migrator) { _schemaName = schemaName; + _encryptionOpts = encryptionOpts; } protected override IEnumerable schemaObjects() { yield return new TenantTable(_schemaName); + + if (_encryptionOpts.Type == ConnectionStringEncryption.PgCrypto) + { + yield return new Extension("pgcrypto"); + } } } diff --git a/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs b/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs new file mode 100644 index 0000000000..d4cb5d6857 --- /dev/null +++ b/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs @@ -0,0 +1,209 @@ +using System; +using System.Threading.Tasks; +using Marten; +using Marten.Services; +using Marten.Storage.Encryption; +using Marten.Testing.Documents; +using Marten.Testing.Harness; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Npgsql; +using Shouldly; +using Weasel.Core; +using Weasel.Postgresql; +using Weasel.Postgresql.Migrations; + +namespace MultiTenancyTests; + + +[CollectionDefinition("multi-tenancy", DisableParallelization = true)] +public class using_master_table_multi_tenancy_with_aes_encryption: IAsyncLifetime +{ + private IHost _host; + private IDocumentStore theStore; + private string tenant1ConnectionString; + + private async Task CreateDatabaseIfNotExists(NpgsqlConnection conn, string databaseName) + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionSource.ConnectionString); + + var exists = await conn.DatabaseExists(databaseName); + + if (!exists) + { + await new DatabaseSpecification().BuildDatabase(conn, databaseName); + } + + builder.Database = databaseName; + + return builder.ConnectionString; + } + + public async Task InitializeAsync() + { + await using var conn = new NpgsqlConnection(ConnectionSource.ConnectionString); + await conn.OpenAsync(); + await conn.DropSchemaAsync("tenants"); + + tenant1ConnectionString = await CreateDatabaseIfNotExists(conn, "tenant1"); + + _host = await Host.CreateDefaultBuilder() + .ConfigureServices(services => + { + services.AddNpgsqlDataSource(ConnectionSource.ConnectionString); + services.AddMarten(opts => + { + opts.RegisterDocumentType(); + opts.RegisterDocumentType(); + }) + // All detected changes will be applied to all + // the configured tenant databases on startup + .ApplyAllDatabaseChangesOnStartup(); + + services.ConfigureMarten((sp, so) => + { + so.MultiTenantedDatabasesWithMasterDatabaseTable(x => + { + x.DataSource = sp.GetRequiredService(); + x.ConnectionStringEncryptionOpts.UseAes(KeyGenerator.GenerateKey()); + x.SchemaName = "tenants"; + x.ApplicationName = "Sample"; + x.RegisterDatabase("tenant1", tenant1ConnectionString); + }); + }); + }).StartAsync(); + + + theStore = _host.Services.GetRequiredService(); + + await _host.ClearAllTenantDatabaseRecordsAsync(); + } + + public async Task DisposeAsync() + { + await _host.StopAsync(); + theStore.Dispose(); + } + + [Theory] + [InlineData(8)] // Too short + [InlineData(15)] // Just under minimum + [InlineData(33)] // Just over maximum + [InlineData(64)] // Way too long + public void aes_encryption_rejects_invalid_key_lengths(int keyBytes) + { + // Arrange + var key = KeyGenerator.GenerateKey(keyBytes, true); + + // Act & Assert + Should.Throw(() => + { + var opts = new EncryptionOptions().UseAes(key); + }).Message.ShouldContain("AES encryption key must be between 16 and 32 bytes (128-256 bits)"); + } + + [Fact] + public async Task can_open_a_session_to_a_different_database() + { + await _host.AddTenantDatabaseAsync("tenant1", tenant1ConnectionString); + + await using var session = + theStore.LightweightSession(new SessionOptions { TenantId = "tenant1" }); + } +} + +[CollectionDefinition("multi-tenancy", DisableParallelization = true)] +public class using_master_table_multi_tenancy_with_pgcrypto_encryption: IAsyncLifetime +{ + private IHost _host; + private IDocumentStore theStore; + private string tenant1ConnectionString; + + private async Task CreateDatabaseIfNotExists(NpgsqlConnection conn, string databaseName) + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionSource.ConnectionString); + + var exists = await conn.DatabaseExists(databaseName); + if (!exists) + { + await new DatabaseSpecification().BuildDatabase(conn, databaseName); + } + + builder.Database = databaseName; + + return builder.ConnectionString; + } + + public async Task InitializeAsync() + { + await using var conn = new NpgsqlConnection(ConnectionSource.ConnectionString); + await conn.OpenAsync(); + await conn.DropSchemaAsync("tenants"); + + tenant1ConnectionString = await CreateDatabaseIfNotExists(conn, "tenant1"); + + await conn.CloseAsync(); + + _host = await Host.CreateDefaultBuilder() + .ConfigureServices(services => + { + services.AddNpgsqlDataSource(ConnectionSource.ConnectionString); + services.AddMarten(opts => + { + opts.RegisterDocumentType(); + opts.RegisterDocumentType(); + + opts.MultiTenantedDatabasesWithMasterDatabaseTable(x => + { + x.AutoCreate = AutoCreate.CreateOrUpdate; + x.ConnectionString = ConnectionSource.ConnectionString; + x.ConnectionStringEncryptionOpts.UsePgCrypto(KeyGenerator.GenerateKey()); + x.SchemaName = "tenants"; + x.ApplicationName = "Sample"; + x.RegisterDatabase("tenant1", tenant1ConnectionString); + }); + }) + // All detected changes will be applied to all + // the configured tenant databases on startup + .ApplyAllDatabaseChangesOnStartup(); + }).StartAsync(); + + theStore = _host.Services.GetRequiredService(); + + await _host.ClearAllTenantDatabaseRecordsAsync(); + } + + public async Task DisposeAsync() + { + await _host.StopAsync(); + theStore.Dispose(); + } + + [Theory] + [InlineData(8)] // Too short + [InlineData(15)] // Just under minimum + public void pgcrypto_encryption_rejects_invalid_key_lengths(int keyBytes) + { + // Arrange + var key = KeyGenerator.GenerateKey(keyBytes, true); + + // Act & Assert + Should.Throw(() => + { + var opts = new EncryptionOptions().UsePgCrypto(key); + }).Message.ShouldContain("at least 16 bytes"); + } + + [Fact] + public async Task can_open_a_session_to_a_different_database() + { + await _host.AddTenantDatabaseAsync("tenant1", tenant1ConnectionString); + + await using var session = + theStore.LightweightSession(new SessionOptions { TenantId = "tenant1" }); + + session.Connection.Database.ShouldBe("tenant1"); + } +} + +