diff --git a/playground/PostgresEndToEnd/PostgresEndToEnd.AppHost/Program.cs b/playground/PostgresEndToEnd/PostgresEndToEnd.AppHost/Program.cs index 78ae32c84c7..2ca4ccc5967 100644 --- a/playground/PostgresEndToEnd/PostgresEndToEnd.AppHost/Program.cs +++ b/playground/PostgresEndToEnd/PostgresEndToEnd.AppHost/Program.cs @@ -13,7 +13,7 @@ // Containerized resources. var db5 = builder.AddPostgres("pg4").WithPgAdmin().PublishAsContainer().AddDatabase("db5"); var db6 = builder.AddPostgres("pg5").WithPgAdmin().PublishAsContainer().AddDatabase("db6"); -var pg6 = builder.AddPostgres("pg6").WithPgAdmin(c => c.WithHostPort(8999).WithImageTag("8.3")).PublishAsContainer(); +var pg6 = builder.AddPostgres("pg6").WithPgAdmin(c => c.WithHostPort(8999)).PublishAsContainer(); var db7 = pg6.AddDatabase("db7"); var db8 = pg6.AddDatabase("db8"); var db9 = pg6.AddDatabase("db9", "db8"); // different connection string (db9) on same database as db8 diff --git a/src/Aspire.Hosting.PostgreSQL/PostgresBuilderExtensions.cs b/src/Aspire.Hosting.PostgreSQL/PostgresBuilderExtensions.cs index 0a40c3b8665..469178fe98f 100644 --- a/src/Aspire.Hosting.PostgreSQL/PostgresBuilderExtensions.cs +++ b/src/Aspire.Hosting.PostgreSQL/PostgresBuilderExtensions.cs @@ -8,6 +8,8 @@ using Aspire.Hosting.Postgres; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Npgsql; namespace Aspire.Hosting; @@ -67,6 +69,32 @@ public static IResourceBuilder AddPostgres(this IDistrib } }); + builder.Eventing.Subscribe(postgresServer, async (@event, ct) => + { + if (connectionString is null) + { + throw new DistributedApplicationException($"ResourceReadyEvent was published for the '{postgresServer.Name}' resource but the connection string was null."); + } + + // Non-database scoped connection string + using var npgsqlConnection = new NpgsqlConnection(connectionString + ";Database=postgres;"); + + await npgsqlConnection.OpenAsync(ct).ConfigureAwait(false); + + if (npgsqlConnection.State != System.Data.ConnectionState.Open) + { + throw new InvalidOperationException($"Could not open connection to '{postgresServer.Name}'"); + } + + foreach (var name in postgresServer.Databases.Keys) + { + if (builder.Resources.FirstOrDefault(n => string.Equals(n.Name, name, StringComparisons.ResourceName)) is PostgresDatabaseResource postgreDatabase) + { + await CreateDatabaseAsync(npgsqlConnection, postgreDatabase, @event.Services, ct).ConfigureAwait(false); + } + } + }); + var healthCheckKey = $"{name}_check"; builder.Services.AddHealthChecks().AddNpgSql(sp => connectionString ?? throw new InvalidOperationException("Connection string is unavailable"), name: healthCheckKey, configure: (connection) => { @@ -121,9 +149,28 @@ public static IResourceBuilder AddDatabase(this IResou // Use the resource name as the database name if it's not provided databaseName ??= name; - builder.Resource.AddDatabase(name, databaseName); var postgresDatabase = new PostgresDatabaseResource(name, databaseName, builder.Resource); - return builder.ApplicationBuilder.AddResource(postgresDatabase); + + builder.Resource.AddDatabase(postgresDatabase.Name, postgresDatabase.DatabaseName); + + string? connectionString = null; + + builder.ApplicationBuilder.Eventing.Subscribe(postgresDatabase, async (@event, ct) => + { + connectionString = await postgresDatabase.ConnectionStringExpression.GetValueAsync(ct).ConfigureAwait(false); + + if (connectionString == null) + { + throw new DistributedApplicationException($"ConnectionStringAvailableEvent was published for the '{name}' resource but the connection string was null."); + } + }); + + var healthCheckKey = $"{name}_check"; + builder.ApplicationBuilder.Services.AddHealthChecks().AddNpgSql(sp => connectionString ?? throw new InvalidOperationException("Connection string is unavailable"), name: healthCheckKey); + + return builder.ApplicationBuilder + .AddResource(postgresDatabase) + .WithHealthCheck(healthCheckKey); } /// @@ -418,6 +465,27 @@ public static IResourceBuilder WithInitBindMount(this IR return builder.WithBindMount(source, "/docker-entrypoint-initdb.d", isReadOnly); } + /// + /// Defines the SQL script used to create the database. + /// + /// The builder for the . + /// The SQL script used to create the database. + /// A reference to the . + /// + /// The script can only contain SQL statements applying to the default database like CREATE DATABASE. Custom statements like table creation + /// and data insertion are not supported since they require a distinct connection to the newly created database. + /// Default script is CREATE DATABASE "<QUOTED_DATABASE_NAME>" + /// + public static IResourceBuilder WithCreationScript(this IResourceBuilder builder, string script) + { + ArgumentNullException.ThrowIfNull(builder); + ArgumentNullException.ThrowIfNull(script); + + builder.WithAnnotation(new CreationScriptAnnotation(script)); + + return builder; + } + private static string WritePgWebBookmarks(IEnumerable postgresInstances, out byte[] contentHash) { var dir = Directory.CreateTempSubdirectory().FullName; @@ -488,4 +556,26 @@ private static string WritePgAdminServerJson(IEnumerable return filePath; } + + private static async Task CreateDatabaseAsync(NpgsqlConnection npgsqlConnection, PostgresDatabaseResource npgsqlDatabase, IServiceProvider serviceProvider, CancellationToken cancellationToken) + { + var scriptAnnotation = npgsqlDatabase.Annotations.OfType().LastOrDefault(); + + try + { + var quotedDatabaseIdentifier = new NpgsqlCommandBuilder().QuoteIdentifier(npgsqlDatabase.DatabaseName); + using var command = npgsqlConnection.CreateCommand(); + command.CommandText = scriptAnnotation?.Script ?? $"CREATE DATABASE {quotedDatabaseIdentifier}"; + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + catch (PostgresException p) when (p.SqlState == "42P04") + { + // Ignore the error if the database already exists. + } + catch (Exception e) + { + var logger = serviceProvider.GetRequiredService().GetLogger(npgsqlDatabase.Parent); + logger.LogError(e, "Failed to create database '{DatabaseName}'", npgsqlDatabase.DatabaseName); + } + } } diff --git a/src/Aspire.Hosting.PostgreSQL/PostgresContainerImageTags.cs b/src/Aspire.Hosting.PostgreSQL/PostgresContainerImageTags.cs index 87cff6f0edd..972f8f48e5f 100644 --- a/src/Aspire.Hosting.PostgreSQL/PostgresContainerImageTags.cs +++ b/src/Aspire.Hosting.PostgreSQL/PostgresContainerImageTags.cs @@ -20,8 +20,8 @@ internal static class PostgresContainerImageTags /// dpage/pgadmin4 public const string PgAdminImage = "dpage/pgadmin4"; - /// 8.14 - public const string PgAdminTag = "8.14"; + /// 9.1.0 + public const string PgAdminTag = "9.1.0"; /// docker.io public const string PgWebRegistry = "docker.io"; diff --git a/src/Aspire.Hosting.PostgreSQL/PostgresDatabaseResource.cs b/src/Aspire.Hosting.PostgreSQL/PostgresDatabaseResource.cs index d06c6a24924..5ae94ad5575 100644 --- a/src/Aspire.Hosting.PostgreSQL/PostgresDatabaseResource.cs +++ b/src/Aspire.Hosting.PostgreSQL/PostgresDatabaseResource.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Data.Common; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; @@ -23,9 +24,18 @@ public class PostgresDatabaseResource(string name, string databaseName, Postgres /// /// Gets the connection string expression for the Postgres database. /// - public ReferenceExpression ConnectionStringExpression => - ReferenceExpression.Create($"{Parent};Database={DatabaseName}"); + public ReferenceExpression ConnectionStringExpression + { + get + { + var connectionStringBuilder = new DbConnectionStringBuilder + { + ["Database"] = DatabaseName + }; + return ReferenceExpression.Create($"{Parent};{connectionStringBuilder.ToString()}"); + } + } /// /// Gets the database name. /// diff --git a/tests/Aspire.Hosting.PostgreSQL.Tests/PostgresFunctionalTests.cs b/tests/Aspire.Hosting.PostgreSQL.Tests/PostgresFunctionalTests.cs index 8c94bea84db..57e2a949800 100644 --- a/tests/Aspire.Hosting.PostgreSQL.Tests/PostgresFunctionalTests.cs +++ b/tests/Aspire.Hosting.PostgreSQL.Tests/PostgresFunctionalTests.cs @@ -107,7 +107,7 @@ public async Task VerifyPostgresResource() var postgresDbName = "db1"; - var postgres = builder.AddPostgres("pg").WithEnvironment("POSTGRES_DB", postgresDbName); + var postgres = builder.AddPostgres("pg"); var db = postgres.AddDatabase(postgresDbName); using var app = builder.Build(); @@ -127,14 +127,16 @@ public async Task VerifyPostgresResource() await host.StartAsync(); + await app.ResourceNotifications.WaitForResourceHealthyAsync(postgres.Resource.Name, cts.Token); + await pipeline.ExecuteAsync(async token => { using var connection = host.Services.GetRequiredService(); await connection.OpenAsync(token); - var command = connection.CreateCommand(); + using var command = connection.CreateCommand(); command.CommandText = $"SELECT 1"; - var results = await command.ExecuteReaderAsync(token); + using var results = await command.ExecuteReaderAsync(token); Assert.True(results.HasRows); }, cts.Token); @@ -206,7 +208,7 @@ public async Task WithDataShouldPersistStateBetweenUsages(bool useVolume) var usernameParameter = builder1.AddParameter("user", username); var passwordParameter = builder1.AddParameter("pwd", password, secret: true); - var postgres1 = builder1.AddPostgres("pg", usernameParameter, passwordParameter).WithEnvironment("POSTGRES_DB", postgresDbName); + var postgres1 = builder1.AddPostgres("pg", usernameParameter, passwordParameter); var db1 = postgres1.AddDatabase(postgresDbName); @@ -215,7 +217,7 @@ public async Task WithDataShouldPersistStateBetweenUsages(bool useVolume) // Use a deterministic volume name to prevent them from exhausting the machines if deletion fails volumeName = VolumeNameGenerator.Generate(postgres1, nameof(WithDataShouldPersistStateBetweenUsages)); - // if the volume already exists (because of a crashing previous run), delete it + // If the volume already exists (because of a crashing previous run), delete it DockerUtils.AttemptDeleteDockerVolume(volumeName, throwOnFailure: true); postgres1.WithDataVolume(volumeName); } @@ -252,14 +254,14 @@ await pipeline.ExecuteAsync(async token => using var connection = host.Services.GetRequiredService(); await connection.OpenAsync(token); - var command = connection.CreateCommand(); + using var command = connection.CreateCommand(); command.CommandText = """ CREATE TABLE cars (brand VARCHAR(255)); INSERT INTO cars (brand) VALUES ('BatMobile'); SELECT * FROM cars; """; - var results = await command.ExecuteReaderAsync(token); + using var results = await command.ExecuteReaderAsync(token); Assert.True(results.HasRows); }, cts.Token); @@ -314,9 +316,9 @@ await pipeline.ExecuteAsync(async token => using var connection = host.Services.GetRequiredService(); await connection.OpenAsync(token); - var command = connection.CreateCommand(); + using var command = connection.CreateCommand(); command.CommandText = $"SELECT * FROM cars;"; - var results = await command.ExecuteReaderAsync(token); + using var results = await command.ExecuteReaderAsync(token); Assert.True(await results.ReadAsync(token)); Assert.Equal("BatMobile", results.GetString("brand")); @@ -361,7 +363,7 @@ public async Task VerifyWithInitBindMount() var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); var pipeline = new ResiliencePipelineBuilder() - .AddRetry(new() { MaxRetryAttempts = 10, Delay = TimeSpan.FromSeconds(2) }) + .AddRetry(new() { MaxRetryAttempts = 3, Delay = TimeSpan.FromSeconds(2) }) .Build(); var bindMountPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); @@ -371,15 +373,15 @@ public async Task VerifyWithInitBindMount() try { File.WriteAllText(Path.Combine(bindMountPath, "init.sql"), """ - CREATE TABLE cars (brand VARCHAR(255)); - INSERT INTO cars (brand) VALUES ('BatMobile'); - """); + CREATE TABLE "Cars" (brand VARCHAR(255)); + INSERT INTO "Cars" (brand) VALUES ('BatMobile'); + """); using var builder = TestDistributedApplicationBuilder.CreateWithTestContainerRegistry(testOutputHelper); var postgresDbName = "db1"; - var postgres = builder.AddPostgres("pg").WithEnvironment("POSTGRES_DB", postgresDbName); + var db = postgres.AddDatabase(postgresDbName); postgres.WithInitBindMount(bindMountPath); @@ -401,6 +403,8 @@ public async Task VerifyWithInitBindMount() await host.StartAsync(); + await app.ResourceNotifications.WaitForResourceHealthyAsync(db.Resource.Name, cts.Token); + // Wait until the database is available await pipeline.ExecuteAsync(async token => { @@ -414,9 +418,9 @@ await pipeline.ExecuteAsync(async token => using var connection = host.Services.GetRequiredService(); await connection.OpenAsync(token); - var command = connection.CreateCommand(); - command.CommandText = $"SELECT * FROM cars;"; - var results = await command.ExecuteReaderAsync(token); + using var command = connection.CreateCommand(); + command.CommandText = $"SELECT * FROM \"Cars\";"; + using var results = await command.ExecuteReaderAsync(token); Assert.True(await results.ReadAsync(token)); Assert.Equal("BatMobile", results.GetString("brand")); @@ -498,4 +502,220 @@ public async Task Postgres_WithPersistentLifetime_ReusesContainers() return resourceEvent.Snapshot.Properties.FirstOrDefault(x => x.Name == "container.id")?.Value?.ToString(); } } + + [Fact] + [RequiresDocker] + public async Task AddDatabaseCreatesDatabaseWithCustomScript() + { + const string databaseName = "newdb"; + + var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + + using var builder = TestDistributedApplicationBuilder.Create(o => { }, testOutputHelper); + + var postgres = builder.AddPostgres("pg1"); + + var newDb = postgres.AddDatabase(databaseName) + .WithCreationScript($$""" + CREATE DATABASE {{databaseName}} + ENCODING = 'UTF8'; + """); + + using var app = builder.Build(); + + await app.StartAsync(cts.Token); + + var hb = Host.CreateApplicationBuilder(); + + hb.Configuration[$"ConnectionStrings:{newDb.Resource.Name}"] = await newDb.Resource.ConnectionStringExpression.GetValueAsync(default); + + hb.AddNpgsqlDataSource(newDb.Resource.Name); + + using var host = hb.Build(); + + await host.StartAsync(); + + await app.ResourceNotifications.WaitForResourceHealthyAsync(newDb.Resource.Name, cts.Token); + + var conn = host.Services.GetRequiredService(); + + if (conn.State != ConnectionState.Open) + { + await conn.OpenAsync(cts.Token); + } + + Assert.Equal(ConnectionState.Open, conn.State); + } + + [Fact] + [RequiresDocker] + public async Task AddDatabaseCreatesDatabaseWithSpecialNames() + { + const string databaseName = "!']`'[\""; + const string resourceName = "db"; + + var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + + using var builder = TestDistributedApplicationBuilder.Create(o => { }, testOutputHelper); + + var postgres = builder.AddPostgres("pg1"); + + var newDb = postgres.AddDatabase(resourceName, databaseName); + + using var app = builder.Build(); + + await app.StartAsync(cts.Token); + + var hb = Host.CreateApplicationBuilder(); + + hb.Configuration[$"ConnectionStrings:{newDb.Resource.Name}"] = await newDb.Resource.ConnectionStringExpression.GetValueAsync(default); + + hb.AddNpgsqlDataSource(newDb.Resource.Name); + + using var host = hb.Build(); + + await host.StartAsync(); + + await app.ResourceNotifications.WaitForResourceHealthyAsync(newDb.Resource.Name, cts.Token); + + var conn = host.Services.GetRequiredService(); + + if (conn.State != ConnectionState.Open) + { + await conn.OpenAsync(cts.Token); + } + + Assert.Equal(ConnectionState.Open, conn.State); + } + + [Fact] + [RequiresDocker] + public async Task AddDatabaseCreatesDatabaseResiliently() + { + // Creating the database multiple times should not fail + + const string databaseName = "db1"; + const string resourceName = "db"; + + string? volumeName = null; + + var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + var pipeline = new ResiliencePipelineBuilder() + .AddRetry(new() { MaxRetryAttempts = 3, BackoffType = DelayBackoffType.Linear, Delay = TimeSpan.FromSeconds(2) }) + .Build(); + + var username = "postgres"; + var password = "p@ssw0rd1"; + + try + { + for (var i = 0; i < 2; i++) + { + using var builder = TestDistributedApplicationBuilder.Create(o => { }, testOutputHelper); + + var usernameParameter = builder.AddParameter("user", username); + var passwordParameter = builder.AddParameter("pwd", password, secret: true); + + var postgres = builder.AddPostgres("pg1", usernameParameter, passwordParameter); + + // Use a deterministic volume name to prevent them from exhausting the machines if deletion fails + volumeName = VolumeNameGenerator.Generate(postgres, nameof(AddDatabaseCreatesDatabaseResiliently)); + + if (i == 0) + { + // If the volume already exists (because of a crashing previous run), delete it + DockerUtils.AttemptDeleteDockerVolume(volumeName); + } + + postgres.WithDataVolume(volumeName); + + var newDb = postgres.AddDatabase(resourceName, databaseName); + + using var app = builder.Build(); + + await app.StartAsync(cts.Token); + + var hb = Host.CreateApplicationBuilder(); + + hb.Configuration[$"ConnectionStrings:{newDb.Resource.Name}"] = await newDb.Resource.ConnectionStringExpression.GetValueAsync(default); + + hb.AddNpgsqlDataSource(newDb.Resource.Name); + + using var host = hb.Build(); + + await host.StartAsync(); + + await app.ResourceNotifications.WaitForResourceHealthyAsync(postgres.Resource.Name, cts.Token); + + // Test connection + await pipeline.ExecuteAsync(async token => + { + var conn = host.Services.GetRequiredService(); + + if (conn.State != ConnectionState.Open) + { + await conn.OpenAsync(token); + } + + Assert.Equal(ConnectionState.Open, conn.State); + }, cts.Token); + + await app.StopAsync(cts.Token); + } + } + finally + { + if (volumeName is not null) + { + DockerUtils.AttemptDeleteDockerVolume(volumeName); + } + } + } + + [Fact] + [RequiresDocker] + public async Task AddDatabaseCreatesMultipleDatabases() + { + var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + + using var builder = TestDistributedApplicationBuilder.Create(o => { }, testOutputHelper); + + var postgres = builder.AddPostgres("pg1"); + + var db1 = postgres.AddDatabase("db1"); + var db2 = postgres.AddDatabase("db2"); + var db3 = postgres.AddDatabase("db3"); + + var dbs = new[] { db1, db2, db3 }; + + using var app = builder.Build(); + + await app.StartAsync(cts.Token); + + var hb = Host.CreateApplicationBuilder(); + + foreach (var db in dbs) + { + hb.Configuration[$"ConnectionStrings:{db.Resource.Name}"] = await db.Resource.ConnectionStringExpression.GetValueAsync(default); + hb.AddKeyedNpgsqlDataSource(db.Resource.Name); + } + + using var host = hb.Build(); + + await host.StartAsync(); + + foreach (var db in dbs) + { + await app.ResourceNotifications.WaitForResourceHealthyAsync(db.Resource.Name, cts.Token); + + var conn = host.Services.GetRequiredKeyedService(db.Resource.Name); + + if (conn.State != ConnectionState.Open) + { + await conn.OpenAsync(cts.Token); + } + + Assert.Equal(ConnectionState.Open, conn.State); + } + } }