Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
app.MapDefaultEndpoints();
app.MapGet("/", async (MyDb1Context db1Context, MyDb2Context db2Context) =>
{
// You wouldn't normally do this on every call,
// but doing it here just to make this simple.

await db1Context.Database.EnsureCreatedAsync();
await db2Context.Database.EnsureCreatedAsync();

var entry1 = new Entry();
await db1Context.Entries.AddAsync(entry1);
await db1Context.SaveChangesAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
var created = await db.Database.EnsureCreatedAsync();
if (created)
{
Console.WriteLine("Database created!");
Console.WriteLine("Database schema created!");
}
}
4 changes: 4 additions & 0 deletions src/Aspire.Hosting.SqlServer/Aspire.Hosting.SqlServer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@
<ProjectReference Include="..\Aspire.Hosting\Aspire.Hosting.csproj" />
</ItemGroup>

<ItemGroup>
<InternalsVisibleTo Include="Aspire.Hosting.SqlServer.Tests" />
</ItemGroup>

</Project>
139 changes: 136 additions & 3 deletions src/Aspire.Hosting.SqlServer/SqlServerBuilderExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Globalization;
using System.Text.RegularExpressions;
using System.Text;
using Aspire.Hosting;
using Aspire.Hosting.ApplicationModel;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace Aspire.Hosting;

/// <summary>
/// Provides extension methods for adding SQL Server resources to the application model.
/// </summary>
public static class SqlServerBuilderExtensions
public static partial class SqlServerBuilderExtensions
{
// GO delimiter format: {spaces?}GO{spaces?}{repeat?}{comment?}
// https://learn.microsoft.com/sql/t-sql/language-elements/sql-server-utilities-statements-go
[GeneratedRegex(@"^\s*GO(?<repeat>\s+\d{1,6})?(\s*\-{2,}.*)?\s*$", RegexOptions.CultureInvariant | RegexOptions.IgnoreCase)]
internal static partial Regex GoStatements();

/// <summary>
/// Adds a SQL Server resource to the application model. A container is used for local development.
/// </summary>
Expand Down Expand Up @@ -45,6 +55,27 @@ public static IResourceBuilder<SqlServerServerResource> AddSqlServer(this IDistr
}
});

builder.Eventing.Subscribe<ResourceReadyEvent>(sqlServer, async (@event, ct) =>
{
if (connectionString is null)
{
throw new DistributedApplicationException($"ResourceReadyEvent was published for the '{sqlServer.Name}' resource but the connection string was null.");
}

using var sqlConnection = new SqlConnection(connectionString);
await sqlConnection.OpenAsync(ct).ConfigureAwait(false);

if (sqlConnection.State != System.Data.ConnectionState.Open)
{
throw new InvalidOperationException($"Could not open connection to '{sqlServer.Name}'");
}

foreach (var sqlDatabase in sqlServer.DatabaseResources)
{
await CreateDatabaseAsync(sqlConnection, sqlDatabase, @event.Services, ct).ConfigureAwait(false);
}
});

var healthCheckKey = $"{name}_check";
builder.Services.AddHealthChecks().AddSqlServer(sp => connectionString ?? throw new InvalidOperationException("Connection string is unavailable"), name: healthCheckKey);

Expand Down Expand Up @@ -75,9 +106,28 @@ public static IResourceBuilder<SqlServerDatabaseResource> AddDatabase(this IReso
// Use the resource name as the database name if it's not provided
databaseName ??= name;

builder.Resource.AddDatabase(name, databaseName);
var sqlServerDatabase = new SqlServerDatabaseResource(name, databaseName, builder.Resource);
return builder.ApplicationBuilder.AddResource(sqlServerDatabase);

builder.Resource.AddDatabase(sqlServerDatabase);

string? connectionString = null;

builder.ApplicationBuilder.Eventing.Subscribe<ConnectionStringAvailableEvent>(sqlServerDatabase, async (@event, ct) =>
{
connectionString = await sqlServerDatabase.ConnectionStringExpression.GetValueAsync(ct).ConfigureAwait(false);

if (connectionString == null)
{
throw new DistributedApplicationException($"ConnectionStringAvailableEvent was published for the '{name}' resource but the connection string was null.");
}
});

var healthCheckKey = $"{name}_check";
builder.ApplicationBuilder.Services.AddHealthChecks().AddSqlServer(sp => connectionString ?? throw new InvalidOperationException("Connection string is unavailable"), name: healthCheckKey);

return builder.ApplicationBuilder
.AddResource(sqlServerDatabase)
.WithHealthCheck(healthCheckKey);
}

/// <summary>
Expand Down Expand Up @@ -112,4 +162,87 @@ public static IResourceBuilder<SqlServerServerResource> WithDataBindMount(this I

return builder.WithBindMount(source, "/var/opt/mssql", isReadOnly);
}

/// <summary>
/// Defines the SQL script used to create the database.
/// </summary>
/// <param name="builder">The builder for the <see cref="SqlServerDatabaseResource"/>.</param>
/// <param name="script">The SQL script used to create the database.</param>
/// <returns>A reference to the <see cref="IResourceBuilder{T}"/>.</returns>
/// <remarks>
/// <value>Default script is <code>IF ( NOT EXISTS ( SELECT 1 FROM sys.databases WHERE name = @DatabaseName ) ) CREATE DATABASE [&lt;QUOTED_DATABASE_NAME%gt;];</code></value>
/// </remarks>
public static IResourceBuilder<SqlServerDatabaseResource> WithCreationScript(this IResourceBuilder<SqlServerDatabaseResource> builder, string script)
{
ArgumentNullException.ThrowIfNull(builder);
ArgumentNullException.ThrowIfNull(script);

builder.WithAnnotation(new CreationScriptAnnotation(script));

return builder;
}

private static async Task CreateDatabaseAsync(SqlConnection sqlConnection, SqlServerDatabaseResource sqlDatabase, IServiceProvider serviceProvider, CancellationToken ct)
{
try
{
var scriptAnnotation = sqlDatabase.Annotations.OfType<CreationScriptAnnotation>().LastOrDefault();

if (scriptAnnotation?.Script == null)
{
var quotedDatabaseIdentifier = new SqlCommandBuilder().QuoteIdentifier(sqlDatabase.DatabaseName);
using var command = sqlConnection.CreateCommand();
command.CommandText = $"IF ( NOT EXISTS ( SELECT 1 FROM sys.databases WHERE name = @DatabaseName ) ) CREATE DATABASE {quotedDatabaseIdentifier};";
command.Parameters.Add(new SqlParameter("@DatabaseName", sqlDatabase.DatabaseName));
await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false);
}
else
{
using var reader = new StringReader(scriptAnnotation.Script);
var batchBuilder = new StringBuilder();

while (reader.ReadLine() is { } line)
{
var matchGo = GoStatements().Match(line);

if (matchGo.Success)
{
// Execute the current batch
var count = matchGo.Groups["repeat"].Success ? int.Parse(matchGo.Groups["repeat"].Value, CultureInfo.InvariantCulture) : 1;
var batch = batchBuilder.ToString();

for (var i = 0; i < count; i++)
{
using var command = sqlConnection.CreateCommand();
command.CommandText = batch;
await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false);
}

batchBuilder.Clear();
}
else
{
// Prevent batches with only whitespace
if (!string.IsNullOrWhiteSpace(line))
{
batchBuilder.AppendLine(line);
}
}
}

// Process the remaining batch lines
if (batchBuilder.Length > 0)
{
using var command = sqlConnection.CreateCommand();
command.CommandText = batchBuilder.ToString();
await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false);
}
}
}
catch (Exception e)
{
var logger = serviceProvider.GetRequiredService<ILogger<DistributedApplicationBuilder>>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to update this to write the log to the resource instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just got taught how to, yes I will change that right away.

logger.LogError(e, "Failed to create database '{DatabaseName}'", sqlDatabase.DatabaseName);
}
}
}
15 changes: 13 additions & 2 deletions src/Aspire.Hosting.SqlServer/SqlServerDatabaseResource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.Data.SqlClient;

namespace Aspire.Hosting.ApplicationModel;

Expand All @@ -23,8 +24,18 @@ public class SqlServerDatabaseResource(string name, string databaseName, SqlServ
/// <summary>
/// Gets the connection string expression for the SQL Server database.
/// </summary>
public ReferenceExpression ConnectionStringExpression =>
ReferenceExpression.Create($"{Parent};Database={DatabaseName}");
public ReferenceExpression ConnectionStringExpression
{
get
{
var connectionStringBuilder = new SqlConnectionStringBuilder
{
["Database"] = DatabaseName
};

return ReferenceExpression.Create($"{Parent};{connectionStringBuilder.ToString()}");
}
}

/// <summary>
/// Gets the database name.
Expand Down
8 changes: 6 additions & 2 deletions src/Aspire.Hosting.SqlServer/SqlServerServerResource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@ public ReferenceExpression ConnectionStringExpression
}

private readonly Dictionary<string, string> _databases = new(StringComparers.ResourceName);
private readonly List<SqlServerDatabaseResource> _databaseResources = [];

/// <summary>
/// A dictionary where the key is the resource name and the value is the database name.
/// </summary>
public IReadOnlyDictionary<string, string> Databases => _databases;

internal void AddDatabase(string name, string databaseName)
internal void AddDatabase(SqlServerDatabaseResource database)
{
_databases.TryAdd(name, databaseName);
_databases.TryAdd(database.Name, database.DatabaseName);
_databaseResources.Add(database);
}

internal IReadOnlyList<SqlServerDatabaseResource> DatabaseResources => _databaseResources;
}
25 changes: 25 additions & 0 deletions src/Aspire.Hosting/ApplicationModel/CreationScriptAnnotation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Aspire.Hosting.ApplicationModel;

/// <summary>
/// Represents an annotation for defining a script to create a resource.
/// </summary>
public sealed class CreationScriptAnnotation : IResourceAnnotation
{
/// <summary>
/// Initializes a new instance of the <see cref="CreationScriptAnnotation"/> class.
/// </summary>
/// <param name="script">The script used to create the resource.</param>
public CreationScriptAnnotation(string script)
{
ArgumentNullException.ThrowIfNull(script);
Script = script;
}

/// <summary>
/// Gets the script used to create the resource.
/// </summary>
public string Script { get; }
}
14 changes: 7 additions & 7 deletions tests/Aspire.Hosting.SqlServer.Tests/AddSqlServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ public async Task SqlServerDatabaseCreatesConnectionString()
var connectionStringResource = (IResourceWithConnectionString)sqlResource;
var connectionString = await connectionStringResource.GetConnectionStringAsync();

Assert.Equal("Server=127.0.0.1,1433;User ID=sa;Password=p@ssw0rd1;TrustServerCertificate=true;Database=mydb", connectionString);
Assert.Equal("{sqlserver.connectionString};Database=mydb", connectionStringResource.ConnectionStringExpression.ValueExpression);
Assert.Equal("Server=127.0.0.1,1433;User ID=sa;Password=p@ssw0rd1;TrustServerCertificate=true;Initial Catalog=mydb", connectionString);
Assert.Equal("{sqlserver.connectionString};Initial Catalog=mydb", connectionStringResource.ConnectionStringExpression.ValueExpression);
}

[Fact]
Expand Down Expand Up @@ -154,7 +154,7 @@ public async Task VerifyManifest()
expectedManifest = """
{
"type": "value.v0",
"connectionString": "{sqlserver.connectionString};Database=db"
"connectionString": "{sqlserver.connectionString};Initial Catalog=db"
}
""";
Assert.Equal(expectedManifest, dbManifest.ToString());
Expand Down Expand Up @@ -228,8 +228,8 @@ public void CanAddDatabasesWithDifferentNamesOnSingleServer()
Assert.Equal("customers1", db1.Resource.DatabaseName);
Assert.Equal("customers2", db2.Resource.DatabaseName);

Assert.Equal("{sqlserver1.connectionString};Database=customers1", db1.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver1.connectionString};Database=customers2", db2.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver1.connectionString};Initial Catalog=customers1", db1.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver1.connectionString};Initial Catalog=customers2", db2.Resource.ConnectionStringExpression.ValueExpression);
}

[Fact]
Expand All @@ -246,7 +246,7 @@ public void CanAddDatabasesWithTheSameNameOnMultipleServers()
Assert.Equal("imports", db1.Resource.DatabaseName);
Assert.Equal("imports", db2.Resource.DatabaseName);

Assert.Equal("{sqlserver1.connectionString};Database=imports", db1.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver2.connectionString};Database=imports", db2.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver1.connectionString};Initial Catalog=imports", db1.Resource.ConnectionStringExpression.ValueExpression);
Assert.Equal("{sqlserver2.connectionString};Initial Catalog=imports", db2.Resource.ConnectionStringExpression.ValueExpression);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@
<ProjectReference Include="..\Aspire.Hosting.Tests\Aspire.Hosting.Tests.csproj" />
</ItemGroup>

<ItemGroup>
<Compile Include="$(RepoRoot)src\Aspire.Hosting.SqlServer\SqlServerContainerImageTags.cs" />
</ItemGroup>

</Project>
Loading