diff --git a/src/EFCore/DbContextOptions.cs b/src/EFCore/DbContextOptions.cs index 90b347ef80d..e18fd17ee31 100644 --- a/src/EFCore/DbContextOptions.cs +++ b/src/EFCore/DbContextOptions.cs @@ -93,6 +93,15 @@ public virtual TExtension GetExtension() public abstract DbContextOptions WithExtension(TExtension extension) where TExtension : class, IDbContextOptionsExtension; + /// + /// Removes the given extension from the underlying options and creates a new + /// with the extension removed. + /// + /// The type of extension to be removed. + /// The new options instance with the extension removed. + public abstract DbContextOptions WithoutExtension() + where TExtension : class, IDbContextOptionsExtension; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore/DbContextOptionsBuilder.cs b/src/EFCore/DbContextOptionsBuilder.cs index 548d8868b8b..ff48245461c 100644 --- a/src/EFCore/DbContextOptionsBuilder.cs +++ b/src/EFCore/DbContextOptionsBuilder.cs @@ -786,6 +786,17 @@ public virtual DbContextOptionsBuilder UseAsyncSeeding(Func(TExtension extension) => _options = _options.WithExtension(extension); + /// + /// Removes the extension of the given type from the options. If no extension of the given type exists, this is a no-op. + /// + /// + /// This method is intended for use by extension methods to configure the context. It is not intended to be used in + /// application code. + /// + /// The type of extension to be removed. + void IDbContextOptionsBuilderInfrastructure.RemoveExtension() + => _options = _options.WithoutExtension(); + private DbContextOptionsBuilder WithOption(Func withFunc) { ((IDbContextOptionsBuilderInfrastructure)this).AddOrUpdateExtension( diff --git a/src/EFCore/DbContextOptions`.cs b/src/EFCore/DbContextOptions`.cs index 067906b256f..4544c03f13b 100644 --- a/src/EFCore/DbContextOptions`.cs +++ b/src/EFCore/DbContextOptions`.cs @@ -57,6 +57,30 @@ public override DbContextOptions WithExtension(TExtension extension) return new DbContextOptions(ExtensionsMap.SetItem(type, (extension, ordinal))); } + /// + public override DbContextOptions WithoutExtension() + { + var type = typeof(TExtension); + if (!ExtensionsMap.TryGetValue(type, out var removedValue)) + { + return this; + } + + var removedOrdinal = removedValue.Ordinal; + var newMap = ExtensionsMap.Remove(type); + + // Renormalize ordinals for extensions that followed the removed one + foreach (var (key, value) in newMap) + { + if (value.Ordinal > removedOrdinal) + { + newMap = newMap.SetItem(key, (value.Extension, value.Ordinal - 1)); + } + } + + return new DbContextOptions(newMap); + } + /// /// The type of context that these options are for (). /// diff --git a/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs b/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs index 163a5ece1ff..41d4148ca45 100644 --- a/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkServiceCollectionExtensions.cs @@ -1163,6 +1163,66 @@ public static IServiceCollection ConfigureDbContext return serviceCollection; } + /// + /// Removes services for the given context type from the . + /// + /// + /// + /// This method can be used to remove the context registration in integration testing scenarios + /// where a different database provider is used for tests. + /// + /// + /// See Using DbContext with dependency injection for more information and examples. + /// + /// + /// The type of context to be removed. + /// The to remove services from. + /// + /// If , only the registrations will be removed; + /// the context itself will remain registered. If (the default), all services related to the context + /// will be removed. + /// + /// The same service collection so that multiple calls can be chained. + public static IServiceCollection RemoveDbContext + <[DynamicallyAccessedMembers(DbContext.DynamicallyAccessedMemberTypes)] TContext>( + this IServiceCollection serviceCollection, + bool removeConfigurationOnly = false) + where TContext : DbContext + { + Check.NotNull(serviceCollection); + + if (removeConfigurationOnly) + { + var configurations = serviceCollection + .Where(d => d.ServiceType == typeof(IDbContextOptionsConfiguration)) + .ToList(); + + foreach (var descriptor in configurations) + { + serviceCollection.Remove(descriptor); + } + } + else + { + var descriptorsToRemove = serviceCollection + .Where(d => d.ServiceType == typeof(TContext) + || d.ServiceType == typeof(DbContextOptions) + || d.ServiceType == typeof(IDbContextOptionsConfiguration) + || d.ServiceType == typeof(IDbContextFactorySource) + || d.ServiceType == typeof(IDbContextFactory) + || d.ServiceType == typeof(IDbContextPool) + || d.ServiceType == typeof(IScopedDbContextLease)) + .ToList(); + + foreach (var descriptor in descriptorsToRemove) + { + serviceCollection.Remove(descriptor); + } + } + + return serviceCollection; + } + private static void AddCoreServices( IServiceCollection serviceCollection, Action? optionsAction, diff --git a/src/EFCore/Infrastructure/IDbContextOptionsBuilderInfrastructure.cs b/src/EFCore/Infrastructure/IDbContextOptionsBuilderInfrastructure.cs index bcad121c394..e1352df63a4 100644 --- a/src/EFCore/Infrastructure/IDbContextOptionsBuilderInfrastructure.cs +++ b/src/EFCore/Infrastructure/IDbContextOptionsBuilderInfrastructure.cs @@ -36,4 +36,21 @@ public interface IDbContextOptionsBuilderInfrastructure /// The extension to be added. void AddOrUpdateExtension(TExtension extension) where TExtension : class, IDbContextOptionsExtension; + + /// + /// + /// Removes the extension of the given type from the options. If no extension of the given type exists, this is a no-op. + /// + /// + /// This method is intended for use by extension methods to configure the context. It is not intended to be used in + /// application code. + /// + /// + /// + /// See Implementation of database providers and extensions + /// for more information and examples. + /// + /// The type of extension to be removed. + void RemoveExtension() + where TExtension : class, IDbContextOptionsExtension; } diff --git a/test/EFCore.Tests/DbContextOptionsTest.cs b/test/EFCore.Tests/DbContextOptionsTest.cs index a560679e5af..bc5d0e59beb 100644 --- a/test/EFCore.Tests/DbContextOptionsTest.cs +++ b/test/EFCore.Tests/DbContextOptionsTest.cs @@ -96,6 +96,76 @@ public void Can_update_an_existing_extension() Assert.Same(extension2, optionsBuilder.Options.FindExtension()); } + [ConditionalFact] + public void Can_remove_an_existing_extension() + { + var optionsBuilder = new DbContextOptionsBuilder(); + + var extension1 = new FakeDbContextOptionsExtension1(); + var extension2 = new FakeDbContextOptionsExtension2(); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension1); + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension2); + + Assert.Equal(2, optionsBuilder.Options.Extensions.Count()); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).RemoveExtension(); + + Assert.Single(optionsBuilder.Options.Extensions); + Assert.Null(optionsBuilder.Options.FindExtension()); + Assert.Same(extension2, optionsBuilder.Options.FindExtension()); + } + + [ConditionalFact] + public void Removing_non_existent_extension_is_no_op() + { + var optionsBuilder = new DbContextOptionsBuilder(); + + var extension = new FakeDbContextOptionsExtension1(); + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); + + Assert.Single(optionsBuilder.Options.Extensions); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).RemoveExtension(); + + Assert.Single(optionsBuilder.Options.Extensions); + Assert.Same(extension, optionsBuilder.Options.FindExtension()); + } + + [ConditionalFact] + public void Removing_extension_from_middle_renormalizes_ordinals_and_preserves_insertion_order() + { + var optionsBuilder = new DbContextOptionsBuilder(); + + var extension1 = new FakeDbContextOptionsExtension1(); + var extension2 = new FakeDbContextOptionsExtension2(); + var extension3 = new FakeDbContextOptionsExtension3(); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension1); + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension2); + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension3); + + Assert.Equal(3, optionsBuilder.Options.Extensions.Count()); + + // Remove the middle extension + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).RemoveExtension(); + + Assert.Equal(2, optionsBuilder.Options.Extensions.Count()); + var extensionsList = optionsBuilder.Options.Extensions.ToList(); + Assert.Same(extension1, extensionsList[0]); + Assert.Same(extension3, extensionsList[1]); + + // Add a new extension after removing the middle one - ordinals should stay contiguous + var extension2New = new FakeDbContextOptionsExtension2(); + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension2New); + + Assert.Equal(3, optionsBuilder.Options.Extensions.Count()); + extensionsList = optionsBuilder.Options.Extensions.ToList(); + Assert.Same(extension1, extensionsList[0]); + Assert.Same(extension3, extensionsList[1]); + Assert.Same(extension2New, extensionsList[2]); + } + [ConditionalFact] public void IsConfigured_returns_true_if_any_provider_extensions_have_been_added() { @@ -199,6 +269,42 @@ public override void PopulateDebugInfo(IDictionary debugInfo) } } + private class FakeDbContextOptionsExtension3 : IDbContextOptionsExtension + { + private DbContextOptionsExtensionInfo _info; + + public DbContextOptionsExtensionInfo Info + => _info ??= new ExtensionInfo(this); + + public bool AppliedServices { get; private set; } + + public virtual void ApplyServices(IServiceCollection services) + => AppliedServices = true; + + public virtual void Validate(IDbContextOptions options) + { + } + + private sealed class ExtensionInfo(IDbContextOptionsExtension extension) : DbContextOptionsExtensionInfo(extension) + { + public override bool IsDatabaseProvider + => false; + + public override int GetServiceProviderHashCode() + => 0; + + public override bool ShouldUseSameServiceProvider(DbContextOptionsExtensionInfo other) + => true; + + public override string LogFragment + => ""; + + public override void PopulateDebugInfo(IDictionary debugInfo) + { + } + } + } + [ConditionalFact] public void UseModel_on_generic_builder_returns_generic_builder() { diff --git a/test/EFCore.Tests/DbContextTest.Services.cs b/test/EFCore.Tests/DbContextTest.Services.cs index 3097e5db118..bdd6537c390 100644 --- a/test/EFCore.Tests/DbContextTest.Services.cs +++ b/test/EFCore.Tests/DbContextTest.Services.cs @@ -3911,6 +3911,138 @@ protected DerivedContext1(DbContextOptions options) } private class DerivedContext2(DbContextOptions options) : DerivedContext1(options); + + [ConditionalFact] + public void RemoveDbContext_removes_all_context_services() + { + var serviceCollection = new ServiceCollection() + .AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + + serviceCollection.RemoveDbContext(); + + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + } + + [ConditionalFact] + public void RemoveDbContext_does_not_remove_services_for_different_context() + { + var serviceCollection = new ServiceCollection() + .AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))) + .AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + serviceCollection.RemoveDbContext(); + + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContextWithOC3A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + } + + [ConditionalFact] + public void RemoveDbContext_with_removeConfigurationOnly_only_removes_configurations() + { + var serviceCollection = new ServiceCollection() + .AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + + serviceCollection.RemoveDbContext(removeConfigurationOnly: true); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + } + + [ConditionalFact] + public void RemoveDbContext_allows_re_registration_with_different_provider() + { + var serviceCollection = new ServiceCollection() + .AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase("OriginalDb") + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + serviceCollection.RemoveDbContext(); + serviceCollection.AddDbContext(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase("ReplacementDb") + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + var appServiceProvider = serviceCollection.BuildServiceProvider(validateScopes: true); + + using var serviceScope = appServiceProvider + .GetRequiredService() + .CreateScope(); + var context = serviceScope.ServiceProvider.GetService(); + Assert.NotNull(context); + Assert.Equal( + "ReplacementDb", + context.GetService().FindExtension().StoreName); + } + + [ConditionalFact] + public void RemoveDbContext_removes_pooled_context_factory_services() + { + var serviceCollection = new ServiceCollection() + .AddPooledDbContextFactory(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextFactory)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextPool)); + + serviceCollection.RemoveDbContext(); + + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextFactory)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextPool)); + } + + [ConditionalFact] + public void RemoveDbContext_removes_pooled_context_services() + { + var serviceCollection = new ServiceCollection() + .AddDbContextPool(b => b.EnableServiceProviderCaching(false) + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .ConfigureWarnings(w => w.Default(WarningBehavior.Throw))); + + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IDbContextPool)); + Assert.Contains(serviceCollection, d => d.ServiceType == typeof(IScopedDbContextLease)); + + serviceCollection.RemoveDbContext(); + + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(ConstructorTestContext1A)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(DbContextOptions)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextOptionsConfiguration)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IDbContextPool)); + Assert.DoesNotContain(serviceCollection, d => d.ServiceType == typeof(IScopedDbContextLease)); + } } }