Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src/Marten/Storage/Encryption/AesConnectionStringEncryptor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;

namespace Marten.Storage.Encryption;

/// <summary>
/// Provides AES-based encryption and decryption for connection strings using a 32-byte key.
/// Connection strings are encrypted in memory using AES with a randomly generated IV.
/// </summary>
internal class AesConnectionStringEncryptor : IConnectionStringEncryptor
{
private readonly string _encryptionKey;

/// <summary>
/// Initializes a new instance of the AES encryption provider.
/// </summary>
/// <param name="encryptionKey">The encryption key for AES encryption</param>
/// <exception cref="ArgumentException">Thrown when the key is null or empty</exception>
public AesConnectionStringEncryptor(string encryptionKey)
{
if (string.IsNullOrWhiteSpace(encryptionKey))
throw new ArgumentException("AES encryption key cannot be empty or whitespace", nameof(encryptionKey));

_encryptionKey = encryptionKey;
}

public string Encrypt(string connectionString)
{
using var aes = Aes.Create();
using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, 16, 1000, HashAlgorithmName.SHA256);
aes.Key = deriveBytes.GetBytes(32);
aes.IV = deriveBytes.GetBytes(16);

using var encryptor = aes.CreateEncryptor();
var plainTextBytes = Encoding.UTF8.GetBytes(connectionString);
var cipherTextBytes = encryptor.TransformFinalBlock(plainTextBytes, 0, plainTextBytes.Length);

var result = new byte[aes.IV.Length + cipherTextBytes.Length];
Buffer.BlockCopy(deriveBytes.Salt, 0, result, 0, 16);
Buffer.BlockCopy(cipherTextBytes, 0, result, 16, cipherTextBytes.Length);

return Convert.ToBase64String(result);
}

public string Decrypt(string encryptedConnectionString)
{
try
{
var cipherTextBytes = Convert.FromBase64String(encryptedConnectionString);
if (cipherTextBytes.Length < 16) // IV size
return encryptedConnectionString;

using var aes = Aes.Create();
var salt = new byte[16];
var cipher = new byte[cipherTextBytes.Length - 16];
Buffer.BlockCopy(cipherTextBytes, 0, salt, 0, 16);
Buffer.BlockCopy(cipherTextBytes, 16, cipher, 0, cipherTextBytes.Length - 16);
using var deriveBytes = new Rfc2898DeriveBytes(_encryptionKey, salt, 1000, HashAlgorithmName.SHA256);
aes.Key = deriveBytes.GetBytes(32);
aes.IV = deriveBytes.GetBytes(16);

using var decryptor = aes.CreateDecryptor();
var plainTextBytes = decryptor.TransformFinalBlock(cipher, 0, cipher.Length);
return Encoding.UTF8.GetString(plainTextBytes);
}
catch
{
// If decryption fails, return the original string
return encryptedConnectionString;
}
}

public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString)
{
var encryptedString = Encrypt(connectionString);
return ($"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " +
"on conflict (tenant_id) do update set connection_string = ?",
[
tenantId,
encryptedString,
encryptedString
]);
}

public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId)
{
return ($"select tenant_id, connection_string from {schemaName}.{tableName}" +
(tenantId == "*" ? "" : " where tenant_id = ?"),
tenantId == "*" ? Array.Empty<NpgsqlParameter>() : [tenantId]);
}

// No prerequisites needed for AES encryption since it's done in memory
}
100 changes: 100 additions & 0 deletions src/Marten/Storage/Encryption/EncryptionOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using System;

namespace Marten.Storage.Encryption;

/// <summary>
/// Specifies the encryption method to use for tenant database connection strings.
/// </summary>
public enum ConnectionStringEncryption
{
/// <summary>
/// No encryption of connection strings. Connection strings will be stored as plain text.
/// </summary>
None,

/// <summary>
/// Use AES encryption with a provided 32-byte encryption key.
/// Connection strings are encrypted in memory using AES with a randomly generated IV.
/// The encrypted data is stored as base64-encoded strings.
/// </summary>
AES,

/// <summary>
/// Use PostgreSQL's pgcrypto extension for encryption.
/// Connection strings are encrypted and decrypted directly in the database using
/// pgp_sym_encrypt and pgp_sym_decrypt functions. This requires the pgcrypto
/// extension to be installed in the same schema as the tenant table.
/// </summary>
PgCrypto
}

/// <summary>
/// Options for configuring connection string encryption.
/// </summary>
public class EncryptionOptions
{
private string? _encryptionKey;

/// <summary>
/// The type of encryption to use for connection strings.
/// </summary>
public ConnectionStringEncryption Type { get; private set; } = ConnectionStringEncryption.None;

/// <summary>
/// The encryption key used to encrypt/decrypt connection strings.
/// Must be exactly 32 characters long.
/// </summary>
public string? Key
{
get => _encryptionKey;
private set
{
if (value != null && string.IsNullOrWhiteSpace(value))
throw new ArgumentException("Encryption key cannot be empty or whitespace", nameof(value));
_encryptionKey = value;
}
}

/// <summary>
/// Use AES encryption with the specified key for connection strings.
/// </summary>
/// <param name="key">The encryption key for AES encryption</param>
/// <returns>The current options instance for method chaining</returns>
public EncryptionOptions UseAes(string key)
{
var keyBytes = Convert.FromBase64String(key);
if (keyBytes.Length < 16 || keyBytes.Length > 32)
throw new ArgumentException("AES encryption key must be between 16 and 32 bytes (128-256 bits) when base64 decoded", nameof(key));

Key = key;
Type = ConnectionStringEncryption.AES;
return this;
}

/// <summary>
/// Use PostgreSQL's pgcrypto extension with the specified key for connection strings.
/// </summary>
/// <param name="key">The encryption key for pgcrypto encryption</param>
/// <returns>The current options instance for method chaining</returns>
public EncryptionOptions UsePgCrypto(string key)
{
var keyBytes = Convert.FromBase64String(key);
if (keyBytes.Length < 16)
throw new ArgumentException("PgCrypto encryption key must be at least 16 bytes (128 bits) when base64 decoded", nameof(key));

Key = key;
Type = ConnectionStringEncryption.PgCrypto;
return this;
}

/// <summary>
/// Disable encryption for connection strings.
/// </summary>
/// <returns>The current options instance for method chaining</returns>
public EncryptionOptions UseNoEncryption()
{
Key = null;
Type = ConnectionStringEncryption.None;
return this;
}
}
61 changes: 61 additions & 0 deletions src/Marten/Storage/Encryption/IConnectionStringEncryptor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Threading;
using System.Threading.Tasks;
using Npgsql;

namespace Marten.Storage.Encryption;

/// <summary>
/// Provides encryption and decryption services for tenant database connection strings.
/// Implementations should handle both the encryption/decryption operations and
/// the generation of SQL commands for database operations involving encrypted data.
/// </summary>
internal interface IConnectionStringEncryptor
{
/// <summary>
/// Encrypts a connection string using the provider's encryption method.
/// </summary>
/// <param name="connectionString">The connection string to encrypt</param>
/// <returns>The encrypted connection string</returns>
string Encrypt(string connectionString);

/// <summary>
/// Decrypts an encrypted connection string using the provider's decryption method.
/// If decryption fails, implementations should return the original string.
/// </summary>
/// <param name="encryptedConnectionString">The encrypted connection string to decrypt</param>
/// <returns>The decrypted connection string, or the original string if decryption fails</returns>
string Decrypt(string encryptedConnectionString);

/// <summary>
/// Generates a parameterized SQL command for inserting or updating an encrypted connection string.
/// </summary>
/// <param name="schemaName">The database schema name</param>
/// <param name="tableName">The table name</param>
/// <param name="tenantId">The tenant identifier</param>
/// <param name="connectionString">The connection string to encrypt and store</param>
/// <returns>A tuple containing the SQL command text and its parameters</returns>
(string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString);

/// <summary>
/// Generates a parameterized SQL command for selecting and decrypting connection strings.
/// </summary>
/// <param name="schemaName">The database schema name</param>
/// <param name="tableName">The table name</param>
/// <param name="tenantId">The tenant identifier, use "*" to select all tenants</param>
/// <returns>A tuple containing the SQL command text and its parameters</returns>
(string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId);

/// <summary>
/// Ensures any prerequisites required by the encryption provider are met.
/// For example, checking if required database extensions are installed.
/// </summary>
/// <param name="dataSource">The database data source to check against</param>
/// <param name="schemaName">The database schema name</param>
/// <param name="token">Optional cancellation token</param>
/// <returns>A task representing the asynchronous operation</returns>
Task EnsurePrerequisitesAsync(NpgsqlDataSource dataSource, string schemaName, CancellationToken token = default)
{
// Default implementation assumes no prerequisites are needed
return Task.CompletedTask;
}
}
26 changes: 26 additions & 0 deletions src/Marten/Storage/Encryption/KeyGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Security.Cryptography;

namespace Marten.Storage.Encryption;

/// <summary>
/// Utility class for generating secure encryption keys compatible with Marten's encryption providers.
/// </summary>
public static class KeyGenerator
{
/// <summary>
/// Generates a secure random encryption key suitable for use with Marten's encryption providers.
/// The key is generated using a cryptographically secure random number generator.
/// </summary>
/// <param name="byteLength">The number of random bytes to generate. Default is 24 bytes which produces a 32-character base64 string.</param>
/// <param name="allowAnyByteLength">Use this flag for only testing purposes.</param>
/// <returns>A base64-encoded string suitable for use as an encryption key</returns>
public static string GenerateKey(int byteLength = 24, bool allowAnyByteLength = false)
{
if (byteLength < 16 && !allowAnyByteLength)
throw new ArgumentException("For security, key should be at least 16 bytes (128 bits)", nameof(byteLength));

var keyBytes = RandomNumberGenerator.GetBytes(byteLength);
return Convert.ToBase64String(keyBytes);
}
}
52 changes: 52 additions & 0 deletions src/Marten/Storage/Encryption/NoopConnectionStringEncryptor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System;
using Npgsql;

namespace Marten.Storage.Encryption;

/// <summary>
/// A no-operation implementation of IConnectionStringEncryptor that passes through connection strings without encryption.
/// This provides a consistent interface when no encryption is needed.
/// </summary>
internal class NoopConnectionStringEncryptor : IConnectionStringEncryptor
{
/// <summary>
/// Returns the connection string as-is without encryption
/// </summary>
public string Encrypt(string connectionString) => connectionString;

/// <summary>
/// Returns the connection string as-is without decryption
/// </summary>
public string Decrypt(string encryptedConnectionString) => encryptedConnectionString;

/// <summary>
/// Generates a parameterized SQL command for inserting or updating an unencrypted connection string
/// </summary>
public (string sql, object[] parameters) GetInsertSql(string schemaName, string tableName, string tenantId, string connectionString)
{
var sql = $"insert into {schemaName}.{tableName} (tenant_id, connection_string) values (?, ?) " +
"on conflict (tenant_id) do update set connection_string = ?";

return (sql, [
tenantId,
connectionString,
connectionString
]);
}

/// <summary>
/// Generates a parameterized SQL command for selecting unencrypted connection strings
/// </summary>
public (string sql, object[] parameters) GetSelectSql(string schemaName, string tableName, string tenantId)
{
if (tenantId == "*")
{
return ($"select tenant_id, connection_string from {schemaName}.{tableName}", Array.Empty<NpgsqlParameter>());
}

var sql = $"select connection_string from {schemaName}.{tableName} where tenant_id = ?";

return (sql, [tenantId]);
}

}
Loading