From cd76fd3c34d18ace5935d6db049c15ef805f586b Mon Sep 17 00:00:00 2001 From: Bradley Grainger Date: Thu, 16 Nov 2023 12:45:56 -0800 Subject: [PATCH] Support keyed services in dependency injection. Fixes #1391 Add tests for MySqlConnector.DependencyInjection project. --- .ci/build-steps.yml | 12 ++ .ci/mysqlconnector-tests-steps.yml | 10 + .ci/test.ps1 | 6 + Directory.Packages.props | 3 +- MySqlConnector.sln | 6 + ...SqlConnectorServiceCollectionExtensions.cs | 76 +++++++- .../docs/README.md | 30 +++ src/MySqlConnector/MySqlConnector.csproj | 1 + .../DependencyInjectionTests.cs | 174 ++++++++++++++++++ ...Connector.DependencyInjection.Tests.csproj | 33 ++++ 10 files changed, 348 insertions(+), 3 deletions(-) create mode 100644 tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs create mode 100644 tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj diff --git a/.ci/build-steps.yml b/.ci/build-steps.yml index c95f7acf3..e81044196 100644 --- a/.ci/build-steps.yml +++ b/.ci/build-steps.yml @@ -47,6 +47,18 @@ steps: artifactName: 'Conformance.Tests-8.0-$(Agent.OS)' targetPath: 'artifacts/publish/Conformance.Tests/release_net8.0' +- task: DotNetCoreCLI@2 + displayName: 'Publish MySqlConnector.DependencyInjection.Tests' + inputs: + command: 'publish' + arguments: '-c Release -f net8.0 --no-build tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj' + publishWebProjects: false + zipAfterPublish: false +- task: PublishPipelineArtifact@0 + inputs: + artifactName: 'MySqlConnector.DependencyInjection.Tests-8.0-$(Agent.OS)' + targetPath: 'artifacts/publish/MySqlConnector.DependencyInjection.Tests/release_net8.0' + - task: DotNetCoreCLI@2 displayName: 'Publish IntegrationTests (7.0)' inputs: diff --git a/.ci/mysqlconnector-tests-steps.yml b/.ci/mysqlconnector-tests-steps.yml index ca4b1ce8f..012b11a0b 100644 --- a/.ci/mysqlconnector-tests-steps.yml +++ b/.ci/mysqlconnector-tests-steps.yml @@ -14,6 +14,16 @@ steps: command: 'custom' custom: 'vstest' arguments: 'MySqlConnector.Tests.dll /logger:trx' +- task: DownloadPipelineArtifact@0 + inputs: + artifactName: 'MySqlConnector.DependencyInjection.Tests-8.0-$(Agent.OS)' + targetPath: $(System.DefaultWorkingDirectory) +- task: DotNetCoreCLI@2 + displayName: 'Run MySqlConnector.DependencyInjection.Tests' + inputs: + command: 'custom' + custom: 'vstest' + arguments: 'MySqlConnector.DependencyInjection.Tests.dll /logger:trx' - task: PublishTestResults@2 inputs: testResultsFormat: VSTest diff --git a/.ci/test.ps1 b/.ci/test.ps1 index 7ebe26842..a619c6402 100644 --- a/.ci/test.ps1 +++ b/.ci/test.ps1 @@ -23,6 +23,12 @@ if ($LASTEXITCODE -ne 0){ exit $LASTEXITCODE; } popd +pushd tests\MySqlConnector.DependencyInjection.Tests +dotnet test -c Release +if ($LASTEXITCODE -ne 0){ + exit $LASTEXITCODE; +} +popd pushd .\tests\IntegrationTests diff --git a/Directory.Packages.props b/Directory.Packages.props index b319ba3da..d8225df5a 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -12,7 +12,8 @@ - + + diff --git a/MySqlConnector.sln b/MySqlConnector.sln index 0d56e5fa4..c3646c415 100644 --- a/MySqlConnector.sln +++ b/MySqlConnector.sln @@ -26,6 +26,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SchemaCollectionGenerator", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MySqlConnector.DependencyInjection", "src\MySqlConnector.DependencyInjection\MySqlConnector.DependencyInjection.csproj", "{D48B3619-7FE1-420C-A96C-B231B7EA73EA}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MySqlConnector.DependencyInjection.Tests", "tests\MySqlConnector.DependencyInjection.Tests\MySqlConnector.DependencyInjection.Tests.csproj", "{E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -76,6 +78,10 @@ Global {D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Debug|Any CPU.Build.0 = Debug|Any CPU {D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Release|Any CPU.ActiveCfg = Release|Any CPU {D48B3619-7FE1-420C-A96C-B231B7EA73EA}.Release|Any CPU.Build.0 = Release|Any CPU + {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E41AD8B7-2F67-444F-A8DC-51C3C8B1FD16}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs b/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs index 782ac6183..f00ff40db 100644 --- a/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs +++ b/src/MySqlConnector.DependencyInjection/MySqlConnectorServiceCollectionExtensions.cs @@ -42,6 +42,45 @@ public static IServiceCollection AddMySqlDataSource( ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) => DoAddMySqlDataSource(serviceCollection, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + /// + /// Registers a and a in the . + /// + /// The to add services to. + /// The of the service. + /// A MySQL connection string. + /// The lifetime with which to register the in the container. Defaults to . + /// The lifetime with which to register the service in the container. Defaults to . + /// The same service collection so that multiple calls can be chained. + /// If the is a , it will automatically be used to initialize the data source name. + public static IServiceCollection AddKeyedMySqlDataSource( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) => + DoAddMySqlDataSource(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + + /// + /// Registers a and a in the . + /// + /// The to add services to. + /// The of the service. + /// A MySQL connection string. + /// An action to configure the for further customizations of the . + /// The lifetime with which to register the in the container. Defaults to . + /// The lifetime with which to register the service in the container. Defaults to . + /// The same service collection so that multiple calls can be chained. + /// If the is a , it will automatically be used to initialize the data source name; this can + /// be overridden by the configuration action. + public static IServiceCollection AddKeyedMySqlDataSource( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton) => + DoAddMySqlDataSource(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + private static IServiceCollection DoAddMySqlDataSource( this IServiceCollection serviceCollection, string connectionString, @@ -52,10 +91,10 @@ private static IServiceCollection DoAddMySqlDataSource( serviceCollection.TryAdd( new ServiceDescriptor( typeof(MySqlDataSource), - x => + serviceProvider => { var dataSourceBuilder = new MySqlDataSourceBuilder(connectionString) - .UseLoggerFactory(x.GetService()); + .UseLoggerFactory(serviceProvider.GetService()); dataSourceBuilderAction?.Invoke(dataSourceBuilder); return dataSourceBuilder.Build(); }, @@ -71,4 +110,37 @@ private static IServiceCollection DoAddMySqlDataSource( return serviceCollection; } + + private static IServiceCollection DoAddMySqlDataSource( + this IServiceCollection serviceCollection, + object? serviceKey, + string connectionString, + Action? dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + { + serviceCollection.TryAdd( + new ServiceDescriptor( + typeof(MySqlDataSource), + serviceKey, + (serviceProvider, serviceKey) => + { + var dataSourceBuilder = new MySqlDataSourceBuilder(connectionString) + .UseLoggerFactory(serviceProvider.GetService()) + .UseName(serviceKey as string); + dataSourceBuilderAction?.Invoke(dataSourceBuilder); + return dataSourceBuilder.Build(); + }, + dataSourceLifetime)); + + serviceCollection.TryAdd(new ServiceDescriptor(typeof(MySqlConnection), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk).CreateConnection(), connectionLifetime)); + +#if NET7_0_OR_GREATER + serviceCollection.TryAdd(new ServiceDescriptor(typeof(DbDataSource), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk), dataSourceLifetime)); +#endif + + serviceCollection.TryAdd(new ServiceDescriptor(typeof(DbConnection), serviceKey, (sp, sk) => sp.GetRequiredKeyedService(sk), connectionLifetime)); + + return serviceCollection; + } } diff --git a/src/MySqlConnector.DependencyInjection/docs/README.md b/src/MySqlConnector.DependencyInjection/docs/README.md index d28d69389..a19a3f96b 100644 --- a/src/MySqlConnector.DependencyInjection/docs/README.md +++ b/src/MySqlConnector.DependencyInjection/docs/README.md @@ -48,3 +48,33 @@ builder.Services.AddMySqlDataSource("Server=server;User ID=test;Password=test;Da x => x.UseRemoteCertificateValidationCallback((sender, certificate, chain, sslPolicyErrors) => { /* custom logic */ }) ); ``` + +## Keyed Services + +Use the `AddKeyedMySqlDataSource` method to register a `MySqlDataSource` as a [keyed service](https://learn.microsoft.com/en-us/dotnet/core/whats-new/dotnet-8#keyed-di-services). +This is useful if you have multiple connection strings or need to connect to multiple databases. +If the service key is a string, it will automatically be used as the `MySqlDataSource` name; +to customize this, call the `AddKeyedMySqlDataSource(object?, string, Action)` overload and call `MySqlDataSourceBuilder.UseName`. + +```csharp +builder.Services.AddKeyedMySqlDataSource("users", builder.Configuration.GetConnectionString("Users")); +builder.Services.AddKeyedMySqlDataSource("products", builder.Configuration.GetConnectionString("Products")); + +app.MapGet("/users/{userId}", async (int userId, [FromKeyedServices("users")] MySqlConnection connection) => +{ + await connection.OpenAsync(); + await using var command = connection.CreateCommand(); + command.CommandText = "SELECT name FROM users WHERE user_id = @userId LIMIT 1"; + command.Parameters.AddWithValue("@userId", userId); + return $"Hello, {await command.ExecuteScalarAsync()}"; +}); + +app.MapGet("/products/{productId}", async (int productId, [FromKeyedServices("products")] MySqlConnection connection) => +{ + await connection.OpenAsync(); + await using var command = connection.CreateCommand(); + command.CommandText = "SELECT name FROM products WHERE product_id = @productId LIMIT 1"; + command.Parameters.AddWithValue("@productId", productId); + return await command.ExecuteScalarAsync(); +}); +``` diff --git a/src/MySqlConnector/MySqlConnector.csproj b/src/MySqlConnector/MySqlConnector.csproj index 5ce7293e0..536f437c3 100644 --- a/src/MySqlConnector/MySqlConnector.csproj +++ b/src/MySqlConnector/MySqlConnector.csproj @@ -34,6 +34,7 @@ + diff --git a/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs b/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs new file mode 100644 index 000000000..6d9bf4775 --- /dev/null +++ b/tests/MySqlConnector.DependencyInjection.Tests/DependencyInjectionTests.cs @@ -0,0 +1,174 @@ +namespace MySqlConnector.DependencyInjection.Tests; + +public class DependencyInjectionTests +{ + [Fact] + public async Task MySqlDataSourceIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddMySqlDataSource(c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var dataSource = serviceProvider.GetRequiredService(); + await using var connection = dataSource.CreateConnection(); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task MySqlConnectionIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddMySqlDataSource(c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var connection = serviceProvider.GetRequiredService(); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task DbConnectionIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddMySqlDataSource(c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var connection = serviceProvider.GetRequiredService(); + Assert.IsAssignableFrom(connection); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task DbDataSourceIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddMySqlDataSource(c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var dataSource = serviceProvider.GetRequiredService(); + Assert.IsAssignableFrom(dataSource); + await using var connection = dataSource.CreateConnection(); + Assert.IsAssignableFrom(connection); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task MySqlDataSourceCanSetName() + { + var serviceCollection = new ServiceCollection(); + + serviceCollection.AddMySqlDataSource(c_connectionString, builder => + { + builder.UseName("MyName"); + }); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + var dataSource = serviceProvider.GetRequiredService(); + Assert.Equal("MyName", dataSource.Name); + } + + [Fact] + public async Task KeyedMySqlDataSourceIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var dataSource = serviceProvider.GetRequiredKeyedService(new object()); + Assert.Null(dataSource.Name); + await using var connection = dataSource.CreateConnection(); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task StringKeyedMySqlDataSourceHasNameSet() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var dataSource = serviceProvider.GetRequiredKeyedService("key"); + Assert.Equal("key", dataSource.Name); + await using var connection = dataSource.CreateConnection(); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task KeyedMySqlDataSourceRetrievedWithStringKeyHasName() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + var dataSource = serviceProvider.GetRequiredKeyedService("key"); + Assert.Equal("key", dataSource.Name); + await using var connection = dataSource.CreateConnection(); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task KeyedMySqlConnectionIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var connection = serviceProvider.GetRequiredKeyedService("key"); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task TwoKeyedMySqlDataConnectionsAreRegistered() + { + const string c_connectionString2 = c_connectionString + ";Database=test"; + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource(KeyedService.AnyKey, c_connectionString); + serviceCollection.AddKeyedMySqlDataSource("key2", c_connectionString2); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var connection1 = serviceProvider.GetRequiredKeyedService("key"); + Assert.Equal(c_connectionString, connection1.ConnectionString); + + await using var connection2 = serviceProvider.GetRequiredKeyedService("key2"); + Assert.Equal(c_connectionString2, connection2.ConnectionString); + } + + [Fact] + public async Task KeyedDbConnectionIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var connection = serviceProvider.GetRequiredKeyedService("key"); + Assert.IsAssignableFrom(connection); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + [Fact] + public async Task KeyedDbDataSourceIsRegistered() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddKeyedMySqlDataSource("key", c_connectionString); + + await using var serviceProvider = serviceCollection.BuildServiceProvider(); + + await using var dataSource = serviceProvider.GetRequiredKeyedService("key"); + Assert.IsAssignableFrom(dataSource); + await using var connection = dataSource.CreateConnection(); + Assert.IsAssignableFrom(connection); + Assert.Equal(c_connectionString, connection.ConnectionString); + } + + const string c_connectionString = "Server=localhost;User ID=root;Password=pass"; +} diff --git a/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj b/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj new file mode 100644 index 000000000..fac726a3a --- /dev/null +++ b/tests/MySqlConnector.DependencyInjection.Tests/MySqlConnector.DependencyInjection.Tests.csproj @@ -0,0 +1,33 @@ + + + + net8.0 + true + true + ..\..\MySqlConnector.snk + true + enable + enable + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + +