diff --git a/src/Persistence/EfCoreTests/dbContext_transactions_with_abstractions_tests.cs b/src/Persistence/EfCoreTests/dbContext_transactions_with_abstractions_tests.cs new file mode 100644 index 000000000..e09645918 --- /dev/null +++ b/src/Persistence/EfCoreTests/dbContext_transactions_with_abstractions_tests.cs @@ -0,0 +1,133 @@ +using IntegrationTests; +using JasperFx.CodeGeneration.Frames; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Shouldly; +using Wolverine; +using Wolverine.Attributes; +using Wolverine.EntityFrameworkCore; +using Wolverine.EntityFrameworkCore.Codegen; +using Wolverine.Tracking; +using Wolverine.Postgresql; +using JasperFx.Resources; + + +namespace EfCoreTests; + +public abstract class DbContextAbstractionTestFixture +{ + public record AbstractionCommand; + + public interface IItemRepository; + + public class AbstractionDbContext(DbContextOptions options) : DbContext(options), IItemRepository + { + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.MapWolverineEnvelopeStorage("wolverine"); + } + } + + [WolverineHandler] + public class AbstractionCommandHandler + { + public async Task Handle(AbstractionCommand command, IItemRepository items) + { + + } + } +} + +public class dbContext_transactions_with_abstractions_tests +{ + [Fact] + public async Task can_apply_transactional_middleware_to_abstraction() + { + using var host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.CodeGeneration.SourceCodeWritingEnabled = true; + opts.Durability.Mode = DurabilityMode.Solo; + + opts.Services.AddDbContextWithWolverineIntegration(x => + x.UseNpgsql(Servers.PostgresConnectionString)); + + opts.Services.AddScoped(); + + opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine"); + + opts.UseEntityFrameworkCoreTransactions(mode: Wolverine.Persistence.TransactionMiddlewareMode.Eager) + .WithDbContextAbstraction(); + + opts.Policies.AutoApplyTransactions(); + + opts.Discovery.DisableConventionalDiscovery().IncludeType(); + }).StartAsync(); + + var runtime = host.GetRuntime(); + var chain = runtime.Handlers.ChainFor(); + + chain.ShouldNotBeNull(); + + chain.Middleware.OfType().ShouldNotBeEmpty(); + chain.Middleware.OfType().ShouldNotBeEmpty(); + } + + + [Fact] + public async Task codegen_works_with_abstraction() + { + using var host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.Services.AddDbContextWithWolverineIntegration(x => + x.UseNpgsql(Servers.PostgresConnectionString)); + + opts.Services.AddScoped(); + + opts.Services.AddResourceSetupOnStartup(StartupAction.ResetState); + opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine"); + + opts.UseEntityFrameworkCoreTransactions() + .WithDbContextAbstraction(); + + opts.Policies.AutoApplyTransactions(); + + opts.Discovery.DisableConventionalDiscovery().IncludeType(); + }).StartAsync(); + + // If it compiles and runs without error, the cast worked + Should.NotThrow(async () => await host.InvokeMessageAndWaitAsync(new DbContextAbstractionTestFixture.AbstractionCommand())); + } + + [Fact] + public async Task should_add_save_changes_async_call_to_postprocessors() + { + using var host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.Services.AddDbContextWithWolverineIntegration(x => + x.UseNpgsql(Servers.PostgresConnectionString)); + + opts.Services.AddScoped(); + + opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine"); + + opts.UseEntityFrameworkCoreTransactions() + .WithDbContextAbstraction(); + + opts.Policies.AutoApplyTransactions(); + + opts.Discovery.DisableConventionalDiscovery().IncludeType(); + }).StartAsync(); + + var runtime = host.GetRuntime(); + var chain = runtime.Handlers.ChainFor(); + + chain.ShouldNotBeNull(); + chain.Postprocessors.OfType() + .Any(x => x.Method.Name == nameof(DbContext.SaveChangesAsync)) + .ShouldBeTrue(); + } +} diff --git a/src/Persistence/Wolverine.EntityFrameworkCore/Codegen/EFCorePersistenceFrameProvider.cs b/src/Persistence/Wolverine.EntityFrameworkCore/Codegen/EFCorePersistenceFrameProvider.cs index 2b810c0ec..5f988d3be 100644 --- a/src/Persistence/Wolverine.EntityFrameworkCore/Codegen/EFCorePersistenceFrameProvider.cs +++ b/src/Persistence/Wolverine.EntityFrameworkCore/Codegen/EFCorePersistenceFrameProvider.cs @@ -43,9 +43,15 @@ internal class EFCorePersistenceFrameProvider : IPersistenceFrameProvider public const string UsingEfCoreTransaction = "uses_efcore_transaction"; public const string TransactionModeKey = "TransactionMiddlewareMode"; private ImHashMap _dbContextTypes = ImHashMap.Empty; + private ImHashMap _abstractions = ImHashMap.Empty; public TransactionMiddlewareMode DefaultMode { get; set; } = TransactionMiddlewareMode.Eager; + public void RegisterAbstraction(Type abstractionType, Type dbContextType) + { + _abstractions = _abstractions.AddOrUpdate(abstractionType, dbContextType); + } + public bool CanPersist(Type entityType, IServiceContainer container, out Type persistenceService) { var dbContextType = TryDetermineDbContextType(entityType, container); @@ -209,6 +215,12 @@ public void ApplyTransactionSupport(IChain chain, IServiceContainer container) } } + var abstractionType = chain.ServiceDependencies(container, Type.EmptyTypes).FirstOrDefault(x => _abstractions.Contains(x)); + if (abstractionType != null) + { + chain.Middleware.Insert(0, new CastDbContextFrame(abstractionType, dbContextType)); + } + var saveChangesAsync = dbContextType.GetMethod(nameof(DbContext.SaveChangesAsync), [typeof(CancellationToken)]); @@ -309,6 +321,12 @@ public void ApplyTransactionSupport(IChain chain, IServiceContainer container, T } } + var abstractionType = chain.ServiceDependencies(container, Type.EmptyTypes).FirstOrDefault(x => _abstractions.Contains(x)); + if (abstractionType != null) + { + chain.Middleware.Insert(0, new CastDbContextFrame(abstractionType, dbType)); + } + var saveChangesAsync = dbType.GetMethod(nameof(DbContext.SaveChangesAsync), [typeof(CancellationToken)]); @@ -341,7 +359,7 @@ public bool CanApply(IChain chain, IServiceContainer container) } var serviceDependencies = chain.ServiceDependencies(container, Type.EmptyTypes).ToArray(); - return serviceDependencies.Any(x => x.CanBeCastTo()); + return serviceDependencies.Any(x => x.CanBeCastTo() || _abstractions.Contains(x)); } internal Type? TryDetermineDbContextType(Type entityType, IServiceContainer container) @@ -420,8 +438,22 @@ public Type DetermineDbContextType(IChain chain, IServiceContainer container) { return DetermineDbContextType(saga.SagaType, container); } -// START HERE. Look for any IStorageAction, and use the T - var contextTypes = chain.ServiceDependencies(container, Type.EmptyTypes).Where(x => x.CanBeCastTo()).ToArray(); + + IEnumerable FindDbContextTypes() + { + var dependencies = chain.ServiceDependencies(container, Type.EmptyTypes); + + var contextTypes = dependencies.Where(x => x.CanBeCastTo()).ToArray(); + var abstractionTypes = dependencies.Where(x => _abstractions.Contains(x)).ToArray(); + + return contextTypes + .Concat(abstractionTypes.Select(x => _abstractions.TryFind(x, out var concrete) ? concrete : null)) + .OfType() // Removes nullability + .Distinct() + .ToArray(); + } + + var contextTypes = FindDbContextTypes().ToArray(); if (contextTypes.Length == 0) { @@ -448,6 +480,34 @@ public Type DetermineDbContextType(IChain chain, IServiceContainer container) return contextTypes.Single(); } + public class CastDbContextFrame : SyncFrame + { + private readonly Type _abstractionType; + private readonly Type _dbContextType; + private Variable _abstraction = null!; + + public CastDbContextFrame(Type abstractionType, Type dbContextType) + { + _abstractionType = abstractionType; + _dbContextType = dbContextType; + DbContext = new Variable(_dbContextType, this); + } + + public Variable DbContext { get; } + + public override void GenerateCode(GeneratedMethod method, ISourceWriter writer) + { + writer.WriteLine($"if ({_abstraction.Usage} is not {_dbContextType.FullNameInCode()} {DbContext.Usage}) throw new System.Exception($\"DbContext abstraction - {_abstraction.Usage} must be implemented by {_dbContextType.FullNameInCode()}.\");"); + Next?.GenerateCode(method, writer); + } + + public override IEnumerable FindVariables(IMethodVariables chain) + { + _abstraction = chain.FindVariable(_abstractionType); + yield return _abstraction; + } + } + public class IncrementSagaVersionIfNecessary : SyncFrame { private readonly Type _dbContextType; @@ -515,4 +575,4 @@ public override void GenerateCode(GeneratedMethod method, ISourceWriter writer) } } -} \ No newline at end of file +} diff --git a/src/Persistence/Wolverine.EntityFrameworkCore/EFCoreTransactionConfiguration.cs b/src/Persistence/Wolverine.EntityFrameworkCore/EFCoreTransactionConfiguration.cs new file mode 100644 index 000000000..b193f2127 --- /dev/null +++ b/src/Persistence/Wolverine.EntityFrameworkCore/EFCoreTransactionConfiguration.cs @@ -0,0 +1,29 @@ +using Wolverine.EntityFrameworkCore.Codegen; + +namespace Wolverine.EntityFrameworkCore; + +public class EFCoreTransactionConfiguration +{ + private readonly WolverineOptions _options; + private readonly EFCorePersistenceFrameProvider _provider; + + internal EFCoreTransactionConfiguration(WolverineOptions options, EFCorePersistenceFrameProvider provider) + { + _options = options; + _provider = provider; + } + + /// + /// Register a DbContext abstraction that should be used for auto-transactions + /// when the abstraction is used as a dependency in a handler. + /// + /// The abstraction type (e.g., IUnitOfWork) + /// The concrete DbContext type + public EFCoreTransactionConfiguration WithDbContextAbstraction() where TDbContext : Microsoft.EntityFrameworkCore.DbContext, TAbstraction + { + _provider.RegisterAbstraction(typeof(TAbstraction), typeof(TDbContext)); + _options.CodeGeneration.AlwaysUseServiceLocationFor(); + + return this; + } +} diff --git a/src/Persistence/Wolverine.EntityFrameworkCore/WolverineEntityCoreExtensions.cs b/src/Persistence/Wolverine.EntityFrameworkCore/WolverineEntityCoreExtensions.cs index 082f62078..a8c926f3b 100644 --- a/src/Persistence/Wolverine.EntityFrameworkCore/WolverineEntityCoreExtensions.cs +++ b/src/Persistence/Wolverine.EntityFrameworkCore/WolverineEntityCoreExtensions.cs @@ -10,6 +10,7 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; using Weasel.EntityFrameworkCore; using Wolverine.EntityFrameworkCore.Codegen; using Wolverine.EntityFrameworkCore.Internals; @@ -255,9 +256,9 @@ private static void registerEFCoreSagaStoreDiagnostics(IServiceCollection servic /// middleware using mode by default. /// /// - public static void UseEntityFrameworkCoreTransactions(this WolverineOptions options) + public static EFCoreTransactionConfiguration UseEntityFrameworkCoreTransactions(this WolverineOptions options) { - options.UseEntityFrameworkCoreTransactions(TransactionMiddlewareMode.Eager); + return options.UseEntityFrameworkCoreTransactions(TransactionMiddlewareMode.Eager); } /// @@ -269,7 +270,7 @@ public static void UseEntityFrameworkCoreTransactions(this WolverineOptions opti /// /// /// The transaction middleware mode to use - public static void UseEntityFrameworkCoreTransactions(this WolverineOptions options, TransactionMiddlewareMode mode) + public static EFCoreTransactionConfiguration UseEntityFrameworkCoreTransactions(this WolverineOptions options, TransactionMiddlewareMode mode) { try { @@ -319,10 +320,12 @@ public static void UseEntityFrameworkCoreTransactions(this WolverineOptions opti var providers = options.CodeGeneration.PersistenceProviders(); var efProvider = providers.OfType().FirstOrDefault(); - if (efProvider != null) + if (efProvider == null) { - efProvider.DefaultMode = mode; + throw new Exception($"Unable to find any ${typeof(EFCorePersistenceFrameProvider)}"); } + efProvider.DefaultMode = mode; + return new EFCoreTransactionConfiguration(options, efProvider); } /// @@ -460,4 +463,4 @@ public static WolverineOptions PublishDomainEventsFromEntityFrameworkCore(scraper); return options; } -} \ No newline at end of file +}