Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using IntegrationTests;
using JasperFx.CodeGeneration.Frames;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Shouldly;
using Wolverine;
using Wolverine.Attributes;
using Wolverine.EntityFrameworkCore;
using Wolverine.EntityFrameworkCore.Codegen;
using Wolverine.Tracking;
using Wolverine.Postgresql;
using JasperFx.Resources;


namespace EfCoreTests;

public abstract class DbContextAbstractionTestFixture
{
public record AbstractionCommand;

public interface IItemRepository;

public class AbstractionDbContext(DbContextOptions<AbstractionDbContext> options) : DbContext(options), IItemRepository
{
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.MapWolverineEnvelopeStorage("wolverine");
}
}

[WolverineHandler]
public class AbstractionCommandHandler
{
public async Task Handle(AbstractionCommand command, IItemRepository items)
{

}
}
}

public class dbContext_transactions_with_abstractions_tests
{
[Fact]
public async Task can_apply_transactional_middleware_to_abstraction()
{
using var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.CodeGeneration.SourceCodeWritingEnabled = true;
opts.Durability.Mode = DurabilityMode.Solo;

opts.Services.AddDbContextWithWolverineIntegration<DbContextAbstractionTestFixture.AbstractionDbContext>(x =>
x.UseNpgsql(Servers.PostgresConnectionString));

opts.Services.AddScoped<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine");

opts.UseEntityFrameworkCoreTransactions(mode: Wolverine.Persistence.TransactionMiddlewareMode.Eager)
.WithDbContextAbstraction<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.Policies.AutoApplyTransactions();

opts.Discovery.DisableConventionalDiscovery().IncludeType<DbContextAbstractionTestFixture.AbstractionCommandHandler>();
}).StartAsync();

var runtime = host.GetRuntime();
var chain = runtime.Handlers.ChainFor<DbContextAbstractionTestFixture.AbstractionCommand>();

chain.ShouldNotBeNull();

chain.Middleware.OfType<EnrollDbContextInTransaction>().ShouldNotBeEmpty();
chain.Middleware.OfType<EFCorePersistenceFrameProvider.CastDbContextFrame>().ShouldNotBeEmpty();
}


[Fact]
public async Task codegen_works_with_abstraction()
{
using var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.Services.AddDbContextWithWolverineIntegration<DbContextAbstractionTestFixture.AbstractionDbContext>(x =>
x.UseNpgsql(Servers.PostgresConnectionString));

opts.Services.AddScoped<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.Services.AddResourceSetupOnStartup(StartupAction.ResetState);
opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine");

opts.UseEntityFrameworkCoreTransactions()
.WithDbContextAbstraction<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.Policies.AutoApplyTransactions();

opts.Discovery.DisableConventionalDiscovery().IncludeType<DbContextAbstractionTestFixture.AbstractionCommandHandler>();
}).StartAsync();

// If it compiles and runs without error, the cast worked
Should.NotThrow(async () => await host.InvokeMessageAndWaitAsync(new DbContextAbstractionTestFixture.AbstractionCommand()));
}

[Fact]
public async Task should_add_save_changes_async_call_to_postprocessors()
{
using var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.Services.AddDbContextWithWolverineIntegration<DbContextAbstractionTestFixture.AbstractionDbContext>(x =>
x.UseNpgsql(Servers.PostgresConnectionString));

opts.Services.AddScoped<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, "wolverine");

opts.UseEntityFrameworkCoreTransactions()
.WithDbContextAbstraction<DbContextAbstractionTestFixture.IItemRepository, DbContextAbstractionTestFixture.AbstractionDbContext>();

opts.Policies.AutoApplyTransactions();

opts.Discovery.DisableConventionalDiscovery().IncludeType<DbContextAbstractionTestFixture.AbstractionCommandHandler>();
}).StartAsync();

var runtime = host.GetRuntime();
var chain = runtime.Handlers.ChainFor<DbContextAbstractionTestFixture.AbstractionCommand>();

chain.ShouldNotBeNull();
chain.Postprocessors.OfType<MethodCall>()
.Any(x => x.Method.Name == nameof(DbContext.SaveChangesAsync))
.ShouldBeTrue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ internal class EFCorePersistenceFrameProvider : IPersistenceFrameProvider
public const string UsingEfCoreTransaction = "uses_efcore_transaction";
public const string TransactionModeKey = "TransactionMiddlewareMode";
private ImHashMap<Type, Type?> _dbContextTypes = ImHashMap<Type, Type?>.Empty;
private ImHashMap<Type, Type> _abstractions = ImHashMap<Type, Type>.Empty;

public TransactionMiddlewareMode DefaultMode { get; set; } = TransactionMiddlewareMode.Eager;

public void RegisterAbstraction(Type abstractionType, Type dbContextType)
{
_abstractions = _abstractions.AddOrUpdate(abstractionType, dbContextType);
}

public bool CanPersist(Type entityType, IServiceContainer container, out Type persistenceService)
{
var dbContextType = TryDetermineDbContextType(entityType, container);
Expand Down Expand Up @@ -209,6 +215,12 @@ public void ApplyTransactionSupport(IChain chain, IServiceContainer container)
}
}

var abstractionType = chain.ServiceDependencies(container, Type.EmptyTypes).FirstOrDefault(x => _abstractions.Contains(x));
if (abstractionType != null)
{
chain.Middleware.Insert(0, new CastDbContextFrame(abstractionType, dbContextType));
}

var saveChangesAsync =
dbContextType.GetMethod(nameof(DbContext.SaveChangesAsync), [typeof(CancellationToken)]);

Expand Down Expand Up @@ -309,6 +321,12 @@ public void ApplyTransactionSupport(IChain chain, IServiceContainer container, T
}
}

var abstractionType = chain.ServiceDependencies(container, Type.EmptyTypes).FirstOrDefault(x => _abstractions.Contains(x));
if (abstractionType != null)
{
chain.Middleware.Insert(0, new CastDbContextFrame(abstractionType, dbType));
}

var saveChangesAsync =
dbType.GetMethod(nameof(DbContext.SaveChangesAsync), [typeof(CancellationToken)]);

Expand Down Expand Up @@ -341,7 +359,7 @@ public bool CanApply(IChain chain, IServiceContainer container)
}

var serviceDependencies = chain.ServiceDependencies(container, Type.EmptyTypes).ToArray();
return serviceDependencies.Any(x => x.CanBeCastTo<DbContext>());
return serviceDependencies.Any(x => x.CanBeCastTo<DbContext>() || _abstractions.Contains(x));
}

internal Type? TryDetermineDbContextType(Type entityType, IServiceContainer container)
Expand Down Expand Up @@ -420,8 +438,22 @@ public Type DetermineDbContextType(IChain chain, IServiceContainer container)
{
return DetermineDbContextType(saga.SagaType, container);
}
// START HERE. Look for any IStorageAction<T>, and use the T
var contextTypes = chain.ServiceDependencies(container, Type.EmptyTypes).Where(x => x.CanBeCastTo<DbContext>()).ToArray();

IEnumerable<Type> FindDbContextTypes()
{
var dependencies = chain.ServiceDependencies(container, Type.EmptyTypes);

var contextTypes = dependencies.Where(x => x.CanBeCastTo<DbContext>()).ToArray();
var abstractionTypes = dependencies.Where(x => _abstractions.Contains(x)).ToArray();

return contextTypes
.Concat(abstractionTypes.Select(x => _abstractions.TryFind(x, out var concrete) ? concrete : null))
.OfType<Type>() // Removes nullability
.Distinct()
.ToArray();
}

var contextTypes = FindDbContextTypes().ToArray();

if (contextTypes.Length == 0)
{
Expand All @@ -448,6 +480,34 @@ public Type DetermineDbContextType(IChain chain, IServiceContainer container)
return contextTypes.Single();
}

public class CastDbContextFrame : SyncFrame
{
private readonly Type _abstractionType;
private readonly Type _dbContextType;
private Variable _abstraction = null!;

public CastDbContextFrame(Type abstractionType, Type dbContextType)
{
_abstractionType = abstractionType;
_dbContextType = dbContextType;
DbContext = new Variable(_dbContextType, this);
}

public Variable DbContext { get; }

public override void GenerateCode(GeneratedMethod method, ISourceWriter writer)
{
writer.WriteLine($"if ({_abstraction.Usage} is not {_dbContextType.FullNameInCode()} {DbContext.Usage}) throw new System.Exception($\"DbContext abstraction - {_abstraction.Usage} must be implemented by {_dbContextType.FullNameInCode()}.\");");
Next?.GenerateCode(method, writer);
}

public override IEnumerable<Variable> FindVariables(IMethodVariables chain)
{
_abstraction = chain.FindVariable(_abstractionType);
yield return _abstraction;
}
}

public class IncrementSagaVersionIfNecessary : SyncFrame
{
private readonly Type _dbContextType;
Expand Down Expand Up @@ -515,4 +575,4 @@ public override void GenerateCode(GeneratedMethod method, ISourceWriter writer)
}

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Wolverine.EntityFrameworkCore.Codegen;

namespace Wolverine.EntityFrameworkCore;

public class EFCoreTransactionConfiguration
{
private readonly WolverineOptions _options;
private readonly EFCorePersistenceFrameProvider _provider;

internal EFCoreTransactionConfiguration(WolverineOptions options, EFCorePersistenceFrameProvider provider)
{
_options = options;
_provider = provider;
}

/// <summary>
/// Register a DbContext abstraction that should be used for auto-transactions
/// when the abstraction is used as a dependency in a handler.
/// </summary>
/// <typeparam name="TAbstraction">The abstraction type (e.g., IUnitOfWork)</typeparam>
/// <typeparam name="TDbContext">The concrete DbContext type</typeparam>
public EFCoreTransactionConfiguration WithDbContextAbstraction<TAbstraction, TDbContext>() where TDbContext : Microsoft.EntityFrameworkCore.DbContext, TAbstraction
{
_provider.RegisterAbstraction(typeof(TAbstraction), typeof(TDbContext));
_options.CodeGeneration.AlwaysUseServiceLocationFor<TAbstraction>();

return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Weasel.EntityFrameworkCore;
using Wolverine.EntityFrameworkCore.Codegen;
using Wolverine.EntityFrameworkCore.Internals;
Expand Down Expand Up @@ -255,9 +256,9 @@ private static void registerEFCoreSagaStoreDiagnostics(IServiceCollection servic
/// middleware using <see cref="TransactionMiddlewareMode.Eager"/> mode by default.
/// </summary>
/// <param name="options"></param>
public static void UseEntityFrameworkCoreTransactions(this WolverineOptions options)
public static EFCoreTransactionConfiguration UseEntityFrameworkCoreTransactions(this WolverineOptions options)
{
options.UseEntityFrameworkCoreTransactions(TransactionMiddlewareMode.Eager);
return options.UseEntityFrameworkCoreTransactions(TransactionMiddlewareMode.Eager);
}

/// <summary>
Expand All @@ -269,7 +270,7 @@ public static void UseEntityFrameworkCoreTransactions(this WolverineOptions opti
/// </summary>
/// <param name="options"></param>
/// <param name="mode">The transaction middleware mode to use</param>
public static void UseEntityFrameworkCoreTransactions(this WolverineOptions options, TransactionMiddlewareMode mode)
public static EFCoreTransactionConfiguration UseEntityFrameworkCoreTransactions(this WolverineOptions options, TransactionMiddlewareMode mode)
{
try
{
Expand Down Expand Up @@ -319,10 +320,12 @@ public static void UseEntityFrameworkCoreTransactions(this WolverineOptions opti

var providers = options.CodeGeneration.PersistenceProviders();
var efProvider = providers.OfType<EFCorePersistenceFrameProvider>().FirstOrDefault();
if (efProvider != null)
if (efProvider == null)
{
efProvider.DefaultMode = mode;
throw new Exception($"Unable to find any ${typeof(EFCorePersistenceFrameProvider)}");
}
efProvider.DefaultMode = mode;
return new EFCoreTransactionConfiguration(options, efProvider);
}

/// <summary>
Expand Down Expand Up @@ -460,4 +463,4 @@ public static WolverineOptions PublishDomainEventsFromEntityFrameworkCore<TEntit
options.Services.AddSingleton<IDomainEventScraper>(scraper);
return options;
}
}
}