diff --git a/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs
new file mode 100644
index 0000000000..7df0637520
--- /dev/null
+++ b/src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs
@@ -0,0 +1,97 @@
+using System;
+using System.Security.Cryptography;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Npgsql;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// Provides AES-based encryption and decryption for connection strings using a 32-byte key.
+/// Connection strings are encrypted in memory using AES with a randomly generated IV.
+///
+internal class AesConnectionStringEncryptor : IConnectionStringEncryptor
+{
+ private readonly string _encryptionKey;
+
+ ///
+ /// Initializes a new instance of the AES encryption provider.
+ ///
+ /// The encryption key for AES encryption
+ /// Thrown when the key is null or empty
+ public AesConnectionStringEncryptor(string encryptionKey)
+ {
+ if (string.IsNullOrWhiteSpace(encryptionKey))
+ throw new ArgumentException("AES encryption key cannot be empty or whitespace", nameof(encryptionKey));
+
+ _encryptionKey = encryptionKey;
+ }
+
+ public string Encrypt(string connectionString)
+ {
+ using var aes = Aes.Create();
+ using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, 16, 1000, HashAlgorithmName.SHA256);
+ aes.Key = deriveBytes.GetBytes(32);
+ aes.IV = deriveBytes.GetBytes(16);
+
+ using var encryptor = aes.CreateEncryptor();
+ var plainTextBytes = Encoding.UTF8.GetBytes(connectionString);
+ var cipherTextBytes = encryptor.TransformFinalBlock(plainTextBytes, 0, plainTextBytes.Length);
+
+ var result = new byte[aes.IV.Length + cipherTextBytes.Length];
+ Buffer.BlockCopy(deriveBytes.Salt, 0, result, 0, 16);
+ Buffer.BlockCopy(cipherTextBytes, 0, result, 16, cipherTextBytes.Length);
+
+ return Convert.ToBase64String(result);
+ }
+
+ public string Decrypt(string encryptedConnectionString)
+ {
+ try
+ {
+ var cipherTextBytes = Convert.FromBase64String(encryptedConnectionString);
+ if (cipherTextBytes.Length < 16) // IV size
+ return encryptedConnectionString;
+
+ using var aes = Aes.Create();
+ var salt = new byte[16];
+ var cipher = new byte[cipherTextBytes.Length - 16];
+ Buffer.BlockCopy(cipherTextBytes, 0, salt, 0, 16);
+ Buffer.BlockCopy(cipherTextBytes, 16, cipher, 0, cipherTextBytes.Length - 16);
+ using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, salt, 1000, HashAlgorithmName.SHA256);
+ aes.Key = deriveBytes.GetBytes(32);
+ aes.IV = deriveBytes.GetBytes(16);
+
+ using var decryptor = aes.CreateDecryptor();
+ var plainTextBytes = decryptor.TransformFinalBlock(cipher, 0, cipher.Length);
+ return Encoding.UTF8.GetString(plainTextBytes);
+ }
+ catch
+ {
+ // If decryption fails, return the original string
+ return encryptedConnectionString;
+ }
+ }
+
+ public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString)
+ {
+ var encryptedString = Encrypt(connectionString);
+ return ($"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " +
+ "on conflict (tenant_id) do update set connection_string = ?",
+ [
+ tenantId,
+ encryptedString,
+ encryptedString
+ ]);
+ }
+
+ public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId)
+ {
+ return ($"select tenant_id, connection_string from {schemaName}.{tableName}" +
+ (tenantId == "*" ? "" : " where tenant_id = ?"),
+ tenantId == "*" ? Array.Empty() : [tenantId]);
+ }
+
+ // No prerequisites needed for AES encryption since it's done in memory
+}
diff --git a/src/Marten/Storage/Encryption/EncryptionOptions.cs b/src/Marten/Storage/Encryption/EncryptionOptions.cs
new file mode 100644
index 0000000000..575ff172a7
--- /dev/null
+++ b/src/Marten/Storage/Encryption/EncryptionOptions.cs
@@ -0,0 +1,100 @@
+using System;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// Specifies the encryption method to use for tenant database connection strings.
+///
+public enum ConnectionStringEncryption
+{
+ ///
+ /// No encryption of connection strings. Connection strings will be stored as plain text.
+ ///
+ None,
+
+ ///
+ /// Use AES encryption with a provided 32-byte encryption key.
+ /// Connection strings are encrypted in memory using AES with a randomly generated IV.
+ /// The encrypted data is stored as base64-encoded strings.
+ ///
+ AES,
+
+ ///
+ /// Use PostgreSQL's pgcrypto extension for encryption.
+ /// Connection strings are encrypted and decrypted directly in the database using
+ /// pgp_sym_encrypt and pgp_sym_decrypt functions. This requires the pgcrypto
+ /// extension to be installed in the same schema as the tenant table.
+ ///
+ PgCrypto
+}
+
+///
+/// Options for configuring connection string encryption.
+///
+public class EncryptionOptions
+{
+ private string? _encryptionKey;
+
+ ///
+ /// The type of encryption to use for connection strings.
+ ///
+ public ConnectionStringEncryption Type { get; private set; } = ConnectionStringEncryption.None;
+
+ ///
+ /// The encryption key used to encrypt/decrypt connection strings.
+ /// Must be exactly 32 characters long.
+ ///
+ public string? Key
+ {
+ get => _encryptionKey;
+ private set
+ {
+ if (value != null && string.IsNullOrWhiteSpace(value))
+ throw new ArgumentException("Encryption key cannot be empty or whitespace", nameof(value));
+ _encryptionKey = value;
+ }
+ }
+
+ ///
+ /// Use AES encryption with the specified key for connection strings.
+ ///
+ /// The encryption key for AES encryption
+ /// The current options instance for method chaining
+ public EncryptionOptions UseAes(string key)
+ {
+ var keyBytes = Convert.FromBase64String(key);
+ if (keyBytes.Length < 16 || keyBytes.Length > 32)
+ throw new ArgumentException("AES encryption key must be between 16 and 32 bytes (128-256 bits) when base64 decoded", nameof(key));
+
+ Key = key;
+ Type = ConnectionStringEncryption.AES;
+ return this;
+ }
+
+ ///
+ /// Use PostgreSQL's pgcrypto extension with the specified key for connection strings.
+ ///
+ /// The encryption key for pgcrypto encryption
+ /// The current options instance for method chaining
+ public EncryptionOptions UsePgCrypto(string key)
+ {
+ var keyBytes = Convert.FromBase64String(key);
+ if (keyBytes.Length < 16)
+ throw new ArgumentException("PgCrypto encryption key must be at least 16 bytes (128 bits) when base64 decoded", nameof(key));
+
+ Key = key;
+ Type = ConnectionStringEncryption.PgCrypto;
+ return this;
+ }
+
+ ///
+ /// Disable encryption for connection strings.
+ ///
+ /// The current options instance for method chaining
+ public EncryptionOptions UseNoEncryption()
+ {
+ Key = null;
+ Type = ConnectionStringEncryption.None;
+ return this;
+ }
+}
diff --git a/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs
new file mode 100644
index 0000000000..1536ac122f
--- /dev/null
+++ b/src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs
@@ -0,0 +1,61 @@
+using System.Threading;
+using System.Threading.Tasks;
+using Npgsql;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// Provides encryption and decryption services for tenant database connection strings.
+/// Implementations should handle both the encryption/decryption operations and
+/// the generation of SQL commands for database operations involving encrypted data.
+///
+internal interface IConnectionStringEncryptor
+{
+ ///
+ /// Encrypts a connection string using the provider's encryption method.
+ ///
+ /// The connection string to encrypt
+ /// The encrypted connection string
+ string Encrypt(string connectionString);
+
+ ///
+ /// Decrypts an encrypted connection string using the provider's decryption method.
+ /// If decryption fails, implementations should return the original string.
+ ///
+ /// The encrypted connection string to decrypt
+ /// The decrypted connection string, or the original string if decryption fails
+ string Decrypt(string encryptedConnectionString);
+
+ ///
+ /// Generates a parameterized SQL command for inserting or updating an encrypted connection string.
+ ///
+ /// The database schema name
+ /// The table name
+ /// The tenant identifier
+ /// The connection string to encrypt and store
+ /// A tuple containing the SQL command text and its parameters
+ (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString);
+
+ ///
+ /// Generates a parameterized SQL command for selecting and decrypting connection strings.
+ ///
+ /// The database schema name
+ /// The table name
+ /// The tenant identifier, use "*" to select all tenants
+ /// A tuple containing the SQL command text and its parameters
+ (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId);
+
+ ///
+ /// Ensures any prerequisites required by the encryption provider are met.
+ /// For example, checking if required database extensions are installed.
+ ///
+ /// The database data source to check against
+ /// The database schema name
+ /// Optional cancellation token
+ /// A task representing the asynchronous operation
+ Task EnsurePrerequisitesAsync(NpgsqlDataSource dataSource, string schemaName, CancellationToken token = default)
+ {
+ // Default implementation assumes no prerequisites are needed
+ return Task.CompletedTask;
+ }
+}
diff --git a/src/Marten/Storage/Encryption/KeyGenerator.cs b/src/Marten/Storage/Encryption/KeyGenerator.cs
new file mode 100644
index 0000000000..83a6eb2f3b
--- /dev/null
+++ b/src/Marten/Storage/Encryption/KeyGenerator.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// Utility class for generating secure encryption keys compatible with Marten's encryption providers.
+///
+public static class KeyGenerator
+{
+ ///
+ /// Generates a secure random encryption key suitable for use with Marten's encryption providers.
+ /// The key is generated using a cryptographically secure random number generator.
+ ///
+ /// The number of random bytes to generate. Default is 24 bytes which produces a 32-character base64 string.
+ /// Use this flag for only testing purposes.
+ /// A base64-encoded string suitable for use as an encryption key
+ public static string GenerateKey(int byteLength = 24, bool allowAnyByteLength = false)
+ {
+ if (byteLength < 16 && !allowAnyByteLength)
+ throw new ArgumentException("For security, key should be at least 16 bytes (128 bits)", nameof(byteLength));
+
+ var keyBytes = RandomNumberGenerator.GetBytes(byteLength);
+ return Convert.ToBase64String(keyBytes);
+ }
+}
diff --git a/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs
new file mode 100644
index 0000000000..444deb83a9
--- /dev/null
+++ b/src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs
@@ -0,0 +1,52 @@
+using System;
+using Npgsql;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// A no-operation implementation of IConnectionStringEncryptor that passes through connection strings without encryption.
+/// This provides a consistent interface when no encryption is needed.
+///
+internal class NoopConnectionStringEncryptor : IConnectionStringEncryptor
+{
+ ///
+ /// Returns the connection string as-is without encryption
+ ///
+ public string Encrypt(string connectionString) => connectionString;
+
+ ///
+ /// Returns the connection string as-is without decryption
+ ///
+ public string Decrypt(string encryptedConnectionString) => encryptedConnectionString;
+
+ ///
+ /// Generates a parameterized SQL command for inserting or updating an unencrypted connection string
+ ///
+ public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString)
+ {
+ var sql = $"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " +
+ "on conflict (tenant_id) do update set connection_string = ?";
+
+ return (sql, [
+ tenantId,
+ connectionString,
+ connectionString
+ ]);
+ }
+
+ ///
+ /// Generates a parameterized SQL command for selecting unencrypted connection strings
+ ///
+ public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId)
+ {
+ if (tenantId == "*")
+ {
+ return ($"select tenant_id, connection_string from {schemaName}.{tableName}", Array.Empty());
+ }
+
+ var sql = $"select connection_string from {schemaName}.{tableName} where tenant_id = ?";
+
+ return (sql, [tenantId]);
+ }
+
+}
diff --git a/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs b/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs
new file mode 100644
index 0000000000..57ead66c45
--- /dev/null
+++ b/src/Marten/Storage/Encryption/PgCryptoConnectionStringEncryptor.cs
@@ -0,0 +1,98 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Marten.Exceptions;
+using Npgsql;
+
+namespace Marten.Storage.Encryption;
+
+///
+/// Provides database-level encryption and decryption for connection strings using PostgreSQL's pgcrypto extension.
+/// Connection strings are encrypted and decrypted directly in the database using pgp_sym_encrypt and pgp_sym_decrypt functions.
+///
+internal class PgCryptoConnectionStringEncryptor : IConnectionStringEncryptor
+{
+ private readonly string _encryptionKey;
+
+ ///
+ /// Initializes a new instance of the PgCrypto encryption provider.
+ ///
+ /// The encryption key for pgcrypto functions
+ /// Thrown when the encryption key is null
+ public PgCryptoConnectionStringEncryptor(string encryptionKey)
+ {
+ _encryptionKey = encryptionKey ?? throw new ArgumentNullException(nameof(encryptionKey));
+ }
+
+ ///
+ /// Returns the connection string as-is since encryption is handled at the database level.
+ ///
+ public string Encrypt(string connectionString) => connectionString;
+
+ ///
+ /// Returns the connection string as-is since decryption is handled at the database level.
+ ///
+ public string Decrypt(string encryptedConnectionString) => encryptedConnectionString;
+
+ ///
+ /// Generates SQL to insert or update an encrypted connection string using pgp_sym_encrypt.
+ /// The encryption is performed by the database using the pgcrypto extension.
+ ///
+ public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString)
+ {
+ return ($"insert into {schemaName}.{tableName} (tenant_id, connection_string) " +
+ $"values (?, pgp_sym_encrypt(?::text, ?::text)) " +
+ $"on conflict (tenant_id) do update set connection_string = pgp_sym_encrypt(?::text, ?::text)",
+ [
+ tenantId,
+ connectionString,
+ _encryptionKey,
+ connectionString,
+ _encryptionKey,
+ ]);
+ }
+
+ ///
+ /// Generates SQL to select and decrypt connection strings using pgp_sym_decrypt.
+ /// The decryption is performed by the database using the pgcrypto extension.
+ ///
+ public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId)
+ {
+ var whereClause = tenantId == "*" ? "" : " where tenant_id = ?";
+ var sql = $"select tenant_id, pgp_sym_decrypt(connection_string::bytea, ?::text) as connection_string " +
+ $"from {schemaName}.{tableName}{whereClause}";
+
+ return (sql, [_encryptionKey, tenantId]);
+ }
+
+ ///
+ /// Ensures the pgcrypto extension is available and properly configured in the specified schema.
+ ///
+ public async Task EnsurePrerequisitesAsync(NpgsqlDataSource dataSource, string schemaName, CancellationToken token = default)
+ {
+ await using var conn = await dataSource.OpenConnectionAsync(token).ConfigureAwait(false);
+ try
+ {
+ await using var cmd = conn.CreateCommand();
+ cmd.CommandText = "SELECT EXISTS(" +
+ "SELECT 1 FROM pg_extension e " +
+ "WHERE e.extname = 'pgcrypto')";
+ cmd.Parameters.AddWithValue("schema_name", schemaName);
+ var extensionExists = (bool)await cmd.ExecuteScalarAsync(token).ConfigureAwait(false);
+
+ if (!extensionExists)
+ {
+ throw new MartenException(
+ $"PgCrypto encryption requires the pgcrypto extension to be installed.\n" +
+ $"Run 'CREATE EXTENSION IF NOT EXISTS pgcrypto;' as a superuser or contact your database administrator.")
+ {
+ HelpLink = "https://www.postgresql.org/docs/current/pgcrypto.html"
+ };
+ }
+ }
+ finally
+ {
+ await conn.CloseAsync().ConfigureAwait(false);
+ }
+ }
+}
diff --git a/src/Marten/Storage/MasterTableTenancy.cs b/src/Marten/Storage/MasterTableTenancy.cs
index 57ed910e9a..9785ef355f 100644
--- a/src/Marten/Storage/MasterTableTenancy.cs
+++ b/src/Marten/Storage/MasterTableTenancy.cs
@@ -1,12 +1,13 @@
using System;
using System.Collections.Generic;
-using System.Data.Common;
using System.Linq;
using System.Reflection;
+using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using JasperFx.Core;
using Marten.Schema;
+using Marten.Storage.Encryption;
using Npgsql;
using Weasel.Core;
using Weasel.Core.Migrations;
@@ -29,6 +30,15 @@ public class MasterTableTenancyOptions
///
public NpgsqlDataSource? DataSource { get; set; }
+ ///
+ /// The encryption key used to encrypt/decrypt connection strings. Must be exactly 32 bytes.
+ ///
+ public EncryptionOptions ConnectionStringEncryptionOpts
+ {
+ get;
+ set;
+ } = new();
+
///
/// If specified, override the database schema name for the tenants table
/// Default is "public"
@@ -73,7 +83,6 @@ internal string CorrectConnectionString(string connectionString)
if (ApplicationName.IsNotEmpty())
{
var builder = new NpgsqlConnectionStringBuilder(connectionString) { ApplicationName = ApplicationName };
-
return builder.ConnectionString;
}
@@ -89,6 +98,7 @@ public class MasterTableTenancy: ITenancy, ITenancyWithMasterDatabase
private readonly string _schemaName;
private readonly Lazy _tenantDatabase;
private ImHashMap _databases = ImHashMap.Empty;
+ private readonly IConnectionStringEncryptor _encryptionProvider;
private bool _hasAppliedChanges;
private bool _hasAppliedDefaults;
@@ -101,7 +111,6 @@ public MasterTableTenancy(StoreOptions options, string connectionString, string
public MasterTableTenancy(StoreOptions options, MasterTableTenancyOptions tenancyOptions)
{
_options = options;
-
_configuration = tenancyOptions;
if (tenancyOptions.DataSource != null)
@@ -124,9 +133,20 @@ public MasterTableTenancy(StoreOptions options, MasterTableTenancyOptions tenanc
Cleaner = new CompositeDocumentCleaner(this, _options);
_tenantDatabase = new Lazy(() =>
- new TenantLookupDatabase(_options, _dataSource.Value, tenancyOptions.SchemaName));
+ new TenantLookupDatabase(_options, _dataSource.Value, tenancyOptions.SchemaName, _configuration.ConnectionStringEncryptionOpts));
+
+ _encryptionProvider = _configuration.ConnectionStringEncryptionOpts.Type switch
+ {
+ ConnectionStringEncryption.AES when _configuration.ConnectionStringEncryptionOpts.Key != null => new AesConnectionStringEncryptor(_configuration.ConnectionStringEncryptionOpts.Key),
+ ConnectionStringEncryption.PgCrypto when _configuration.ConnectionStringEncryptionOpts.Key != null => new PgCryptoConnectionStringEncryptor(_configuration.ConnectionStringEncryptionOpts.Key),
+ _ => new NoopConnectionStringEncryptor()
+ };
+
+ // Provider prerequisites will be checked in BuildDatabases
}
+
+
public void Dispose()
{
foreach (var entry in _databases.Enumerate()) entry.Value.Dispose();
@@ -139,6 +159,8 @@ public void Dispose()
public async ValueTask> BuildDatabases()
{
+ await _encryptionProvider.EnsurePrerequisitesAsync(_dataSource.Value, _schemaName).ConfigureAwait(false);
+
await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false);
await using var conn = _dataSource.Value.CreateConnection();
@@ -151,10 +173,15 @@ public async ValueTask> BuildDatabases()
await seedDatabasesAsync(conn).ConfigureAwait(false);
}
- await using var reader = await ((DbCommand)conn
- .CreateCommand($"select tenant_id, connection_string from {_schemaName}.{TenantTable.TableName}"))
- .ExecuteReaderAsync().ConfigureAwait(false);
+ var command = conn.CreateCommand();
+ var (sql, parameters) = _encryptionProvider.GetSelectSql(_schemaName, TenantTable.TableName, "*");
+ command.CommandText = ConvertToPgPositionalParams(sql);
+ foreach (var param in parameters)
+ {
+ command.Parameters.Add(new NpgsqlParameter { Value = param });
+ }
+ await using var reader = await command.ExecuteReaderAsync().ConfigureAwait(false);
while (await reader.ReadAsync().ConfigureAwait(false))
{
var tenantId = await reader.GetFieldValueAsync(0).ConfigureAwait(false);
@@ -167,6 +194,7 @@ public async ValueTask> BuildDatabases()
}
var connectionString = await reader.GetFieldValueAsync(1).ConfigureAwait(false);
+ connectionString = _encryptionProvider?.Decrypt(connectionString) ?? connectionString;
connectionString = _configuration.CorrectConnectionString(connectionString);
var database = new MartenDatabase(_options, _options.NpgsqlDataSourceFactory.Create(connectionString),
@@ -258,46 +286,51 @@ public bool IsTenantStoredInCurrentDatabase(IMartenDatabase database, string ten
public async Task DeleteDatabaseRecordAsync(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
+#pragma warning disable MA0032
await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false);
+#pragma warning restore MA0032
- await _dataSource.Value
- .CreateCommand($"delete from {_schemaName}.{TenantTable.TableName} where tenant_id = :id")
- .With("id", tenantId)
- .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
+ var cmd = _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName} where tenant_id = @tenant_id");
+ cmd.Parameters.AddWithValue("tenant_id", tenantId);
+ await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
}
public async Task ClearAllDatabaseRecordsAsync()
{
+#pragma warning disable MA0032
await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false);
+#pragma warning restore MA0032
- await _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName}")
- .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
+ var cmd = _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.TableName}");
+ await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
}
public async Task AddDatabaseRecordAsync(string tenantId, string connectionString)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
- await _dataSource.Value
- .CreateCommand(
- $"insert into {_schemaName}.{TenantTable.TableName} (tenant_id, connection_string) values (:id, :connection) on conflict (tenant_id) do update set connection_string = :connection")
- .With("id", tenantId)
- .With("connection", connectionString)
- .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
+
+ var (sql, parameters) = _encryptionProvider.GetInsertSql(_schemaName, TenantTable.TableName, tenantId, connectionString);
+ var cmd = _dataSource.Value.CreateCommand(ConvertToPgPositionalParams(sql));
+ foreach (var param in parameters)
+ {
+ cmd.Parameters.Add(new NpgsqlParameter { Value = param });
+ }
+#pragma warning disable MA0032
+ await cmd.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
+#pragma warning restore MA0032
}
- private async Task maybeApplyChanges(TenantLookupDatabase tenantDatabase)
+ private async Task maybeApplyChanges(TenantLookupDatabase tenantDatabase, CancellationToken token = default)
{
if (!_hasAppliedChanges && (_configuration.AutoCreate ?? _options.AutoCreateSchemaObjects) != AutoCreate.None)
{
-#pragma warning disable MA0032
await tenantDatabase
- .ApplyAllConfiguredChangesToDatabaseAsync(_options.AutoCreateSchemaObjects).ConfigureAwait(false);
-#pragma warning restore MA0032
+ .ApplyAllConfiguredChangesToDatabaseAsync(_options.AutoCreateSchemaObjects, ct: default).ConfigureAwait(false);
_hasAppliedChanges = true;
}
}
- private async Task seedDatabasesAsync(NpgsqlConnection conn)
+ private async Task seedDatabasesAsync(NpgsqlConnection conn, CancellationToken token = default)
{
if (!_configuration.SeedDatabases.Any())
{
@@ -308,17 +341,17 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn)
foreach (var pair in _configuration.SeedDatabases)
{
builder.StartNewCommand();
- var parameters = builder.AppendWithParameters(
- $"insert into {_schemaName}.{TenantTable.TableName} (tenant_id, connection_string) values (?, ?) on conflict (tenant_id) do update set connection_string = ?");
-
- parameters[0].Value = pair.Key;
- parameters[1].Value = pair.Value;
- parameters[2].Value = pair.Value;
+ var (sql, parameters) = _encryptionProvider.GetInsertSql(_schemaName, TenantTable.TableName, pair.Key, pair.Value);
+ var builderParams = builder.AppendWithParameters(sql);
+ for (var i = 0; i < builderParams.Length; i++)
+ {
+ builderParams[i].Value = parameters[i];
+ }
}
var batch = builder.Compile();
batch.Connection = conn;
- await batch.ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false);
+ await batch.ExecuteNonQueryAsync(token).ConfigureAwait(false);
_hasAppliedDefaults = true;
}
@@ -326,17 +359,27 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn)
private async Task tryFindTenantDatabase(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
- var connectionString = (string)await _dataSource.Value
- .CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id")
- .With("id", tenantId)
- .ExecuteScalarAsync(CancellationToken.None).ConfigureAwait(false);
- if (connectionString.IsEmpty())
+ var (sql, parameters) = _encryptionProvider.GetSelectSql(_schemaName, TenantTable.TableName, tenantId);
+ var cmd = _dataSource.Value.CreateCommand(ConvertToPgPositionalParams(sql));
+ foreach (var param in parameters)
+ {
+ cmd.Parameters.Add(new NpgsqlParameter { Value = param });
+ }
+
+ await using var reader = await cmd.ExecuteReaderAsync(CancellationToken.None).ConfigureAwait(false);
+ var connectionString = string.Empty;
+ while (await reader.ReadAsync(CancellationToken.None).ConfigureAwait(false))
{
- return null;
+ connectionString = await reader.GetFieldValueAsync(1, CancellationToken.None).ConfigureAwait(false);
+ if (!string.IsNullOrWhiteSpace(connectionString))
+ {
+ connectionString = _encryptionProvider?.Decrypt(connectionString) ?? connectionString;
+ connectionString = _configuration.CorrectConnectionString(connectionString);
+ }
}
- connectionString = _configuration.CorrectConnectionString(connectionString);
+ await reader.CloseAsync().ConfigureAwait(false);
return connectionString.IsNotEmpty()
? new MartenDatabase(_options,
@@ -344,14 +387,23 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn)
: null;
}
+ private string ConvertToPgPositionalParams(string sql)
+ {
+ var index = 0;
+#pragma warning disable MA0009
+ return Regex.Replace(sql, @"\?", _ => $"${++index}");
+#pragma warning restore MA0009
+ }
+
+
internal class TenantLookupDatabase: PostgresqlDatabase
{
private readonly TenantDatabaseStorage _feature;
- public TenantLookupDatabase(StoreOptions options, NpgsqlDataSource dataSource, string schemaName): base(options,
+ public TenantLookupDatabase(StoreOptions options, NpgsqlDataSource dataSource, string schemaName, EncryptionOptions encryptionOpts): base(options,
options.AutoCreateSchemaObjects, options.Advanced.Migrator, "TenantDatabases", dataSource)
{
- _feature = new TenantDatabaseStorage(schemaName, options);
+ _feature = new TenantDatabaseStorage(schemaName, options, encryptionOpts);
}
public override IFeatureSchema[] BuildFeatureSchemas()
@@ -364,16 +416,23 @@ internal class TenantDatabaseStorage: FeatureSchemaBase
{
private readonly StoreOptions _options;
private readonly string _schemaName;
+ private readonly EncryptionOptions _encryptionOpts;
- public TenantDatabaseStorage(string schemaName, StoreOptions options): base("TenantDatabases",
+ public TenantDatabaseStorage(string schemaName, StoreOptions options,EncryptionOptions encryptionOpts): base("TenantDatabases",
options.Advanced.Migrator)
{
_schemaName = schemaName;
+ _encryptionOpts = encryptionOpts;
}
protected override IEnumerable schemaObjects()
{
yield return new TenantTable(_schemaName);
+
+ if (_encryptionOpts.Type == ConnectionStringEncryption.PgCrypto)
+ {
+ yield return new Extension("pgcrypto");
+ }
}
}
diff --git a/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs b/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs
new file mode 100644
index 0000000000..d4cb5d6857
--- /dev/null
+++ b/src/MultiTenancyTests/using_master_table_multi_tenancy_with_encryption.cs
@@ -0,0 +1,209 @@
+using System;
+using System.Threading.Tasks;
+using Marten;
+using Marten.Services;
+using Marten.Storage.Encryption;
+using Marten.Testing.Documents;
+using Marten.Testing.Harness;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting;
+using Npgsql;
+using Shouldly;
+using Weasel.Core;
+using Weasel.Postgresql;
+using Weasel.Postgresql.Migrations;
+
+namespace MultiTenancyTests;
+
+
+[CollectionDefinition("multi-tenancy", DisableParallelization = true)]
+public class using_master_table_multi_tenancy_with_aes_encryption: IAsyncLifetime
+{
+ private IHost _host;
+ private IDocumentStore theStore;
+ private string tenant1ConnectionString;
+
+ private async Task CreateDatabaseIfNotExists(NpgsqlConnection conn, string databaseName)
+ {
+ var builder = new NpgsqlConnectionStringBuilder(ConnectionSource.ConnectionString);
+
+ var exists = await conn.DatabaseExists(databaseName);
+
+ if (!exists)
+ {
+ await new DatabaseSpecification().BuildDatabase(conn, databaseName);
+ }
+
+ builder.Database = databaseName;
+
+ return builder.ConnectionString;
+ }
+
+ public async Task InitializeAsync()
+ {
+ await using var conn = new NpgsqlConnection(ConnectionSource.ConnectionString);
+ await conn.OpenAsync();
+ await conn.DropSchemaAsync("tenants");
+
+ tenant1ConnectionString = await CreateDatabaseIfNotExists(conn, "tenant1");
+
+ _host = await Host.CreateDefaultBuilder()
+ .ConfigureServices(services =>
+ {
+ services.AddNpgsqlDataSource(ConnectionSource.ConnectionString);
+ services.AddMarten(opts =>
+ {
+ opts.RegisterDocumentType();
+ opts.RegisterDocumentType();
+ })
+ // All detected changes will be applied to all
+ // the configured tenant databases on startup
+ .ApplyAllDatabaseChangesOnStartup();
+
+ services.ConfigureMarten((sp, so) =>
+ {
+ so.MultiTenantedDatabasesWithMasterDatabaseTable(x =>
+ {
+ x.DataSource = sp.GetRequiredService();
+ x.ConnectionStringEncryptionOpts.UseAes(KeyGenerator.GenerateKey());
+ x.SchemaName = "tenants";
+ x.ApplicationName = "Sample";
+ x.RegisterDatabase("tenant1", tenant1ConnectionString);
+ });
+ });
+ }).StartAsync();
+
+
+ theStore = _host.Services.GetRequiredService();
+
+ await _host.ClearAllTenantDatabaseRecordsAsync();
+ }
+
+ public async Task DisposeAsync()
+ {
+ await _host.StopAsync();
+ theStore.Dispose();
+ }
+
+ [Theory]
+ [InlineData(8)] // Too short
+ [InlineData(15)] // Just under minimum
+ [InlineData(33)] // Just over maximum
+ [InlineData(64)] // Way too long
+ public void aes_encryption_rejects_invalid_key_lengths(int keyBytes)
+ {
+ // Arrange
+ var key = KeyGenerator.GenerateKey(keyBytes, true);
+
+ // Act & Assert
+ Should.Throw(() =>
+ {
+ var opts = new EncryptionOptions().UseAes(key);
+ }).Message.ShouldContain("AES encryption key must be between 16 and 32 bytes (128-256 bits)");
+ }
+
+ [Fact]
+ public async Task can_open_a_session_to_a_different_database()
+ {
+ await _host.AddTenantDatabaseAsync("tenant1", tenant1ConnectionString);
+
+ await using var session =
+ theStore.LightweightSession(new SessionOptions { TenantId = "tenant1" });
+ }
+}
+
+[CollectionDefinition("multi-tenancy", DisableParallelization = true)]
+public class using_master_table_multi_tenancy_with_pgcrypto_encryption: IAsyncLifetime
+{
+ private IHost _host;
+ private IDocumentStore theStore;
+ private string tenant1ConnectionString;
+
+ private async Task CreateDatabaseIfNotExists(NpgsqlConnection conn, string databaseName)
+ {
+ var builder = new NpgsqlConnectionStringBuilder(ConnectionSource.ConnectionString);
+
+ var exists = await conn.DatabaseExists(databaseName);
+ if (!exists)
+ {
+ await new DatabaseSpecification().BuildDatabase(conn, databaseName);
+ }
+
+ builder.Database = databaseName;
+
+ return builder.ConnectionString;
+ }
+
+ public async Task InitializeAsync()
+ {
+ await using var conn = new NpgsqlConnection(ConnectionSource.ConnectionString);
+ await conn.OpenAsync();
+ await conn.DropSchemaAsync("tenants");
+
+ tenant1ConnectionString = await CreateDatabaseIfNotExists(conn, "tenant1");
+
+ await conn.CloseAsync();
+
+ _host = await Host.CreateDefaultBuilder()
+ .ConfigureServices(services =>
+ {
+ services.AddNpgsqlDataSource(ConnectionSource.ConnectionString);
+ services.AddMarten(opts =>
+ {
+ opts.RegisterDocumentType();
+ opts.RegisterDocumentType();
+
+ opts.MultiTenantedDatabasesWithMasterDatabaseTable(x =>
+ {
+ x.AutoCreate = AutoCreate.CreateOrUpdate;
+ x.ConnectionString = ConnectionSource.ConnectionString;
+ x.ConnectionStringEncryptionOpts.UsePgCrypto(KeyGenerator.GenerateKey());
+ x.SchemaName = "tenants";
+ x.ApplicationName = "Sample";
+ x.RegisterDatabase("tenant1", tenant1ConnectionString);
+ });
+ })
+ // All detected changes will be applied to all
+ // the configured tenant databases on startup
+ .ApplyAllDatabaseChangesOnStartup();
+ }).StartAsync();
+
+ theStore = _host.Services.GetRequiredService();
+
+ await _host.ClearAllTenantDatabaseRecordsAsync();
+ }
+
+ public async Task DisposeAsync()
+ {
+ await _host.StopAsync();
+ theStore.Dispose();
+ }
+
+ [Theory]
+ [InlineData(8)] // Too short
+ [InlineData(15)] // Just under minimum
+ public void pgcrypto_encryption_rejects_invalid_key_lengths(int keyBytes)
+ {
+ // Arrange
+ var key = KeyGenerator.GenerateKey(keyBytes, true);
+
+ // Act & Assert
+ Should.Throw(() =>
+ {
+ var opts = new EncryptionOptions().UsePgCrypto(key);
+ }).Message.ShouldContain("at least 16 bytes");
+ }
+
+ [Fact]
+ public async Task can_open_a_session_to_a_different_database()
+ {
+ await _host.AddTenantDatabaseAsync("tenant1", tenant1ConnectionString);
+
+ await using var session =
+ theStore.LightweightSession(new SessionOptions { TenantId = "tenant1" });
+
+ session.Connection.Database.ShouldBe("tenant1");
+ }
+}
+
+