diff --git a/src/Mocha/src/Examples/MediatorShowcase/Generated/Mocha.Analyzers/Mocha.Analyzers.MediatorGenerator/MediatorShowcaseMediatorBuilderExtensions._5hYhW_IBO7W3L_MSoxCPw.g.cs b/src/Mocha/src/Examples/MediatorShowcase/Generated/Mocha.Analyzers/Mocha.Analyzers.MediatorGenerator/MediatorShowcaseMediatorBuilderExtensions._5hYhW_IBO7W3L_MSoxCPw.g.cs index 8831201bb13..862dbd1de0a 100644 --- a/src/Mocha/src/Examples/MediatorShowcase/Generated/Mocha.Analyzers/Mocha.Analyzers.MediatorGenerator/MediatorShowcaseMediatorBuilderExtensions._5hYhW_IBO7W3L_MSoxCPw.g.cs +++ b/src/Mocha/src/Examples/MediatorShowcase/Generated/Mocha.Analyzers/Mocha.Analyzers.MediatorGenerator/MediatorShowcaseMediatorBuilderExtensions._5hYhW_IBO7W3L_MSoxCPw.g.cs @@ -11,76 +11,92 @@ public static class MediatorShowcaseMediatorBuilderExtensions public static global::Mocha.Mediator.IMediatorHostBuilder AddMediatorShowcase( this global::Mocha.Mediator.IMediatorHostBuilder builder) { - var services = builder.Services; - var lifetime = builder.Options.ServiceLifetime; - // Register handlers - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.PlaceOrderCommandHandler), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.RiskyCommandHandler), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.CreateProductCommandHandler), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.CreateProductCommandHandler2), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.CreateProductCommandHandler3), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.ICommandHandler), typeof(global::MediatorShowcase.CreateProductCommandHandler4), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.IQueryHandler), typeof(global::MediatorShowcase.GetProductByIdQueryHandler), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mocha.Mediator.IQueryHandler>), typeof(global::MediatorShowcase.GetProductsQueryHandler), lifetime)); - - // Register notification handlers - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::MediatorShowcase.OrderShippedAnalyticsHandler), typeof(global::MediatorShowcase.OrderShippedAnalyticsHandler), lifetime)); - services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::MediatorShowcase.OrderShippedEmailHandler), typeof(global::MediatorShowcase.OrderShippedEmailHandler), lifetime)); - - // Register pipelines - global::Mocha.Mediator.MediatorHostBuilderExtensions.ConfigureMediator(builder, static b => - { - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration - { - MessageType = typeof(global::MediatorShowcase.PlaceOrderCommand), - ResponseType = typeof(global::MediatorShowcase.OrderResult), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildCommandTerminal() - }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration - { - MessageType = typeof(global::MediatorShowcase.RiskyCommand), - ResponseType = typeof(string), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildCommandTerminal() - }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + // Register handler configurations + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.CreateProductCommandHandler), MessageType = typeof(global::MediatorShowcase.CreateProductCommand), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildVoidCommandTerminal() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Command, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.CreateProductCommandHandler2), MessageType = typeof(global::MediatorShowcase.CreateProductCommand2), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildVoidCommandTerminal() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Command, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.CreateProductCommandHandler3), MessageType = typeof(global::MediatorShowcase.CreateProductCommand3), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildVoidCommandTerminal() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Command, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.CreateProductCommandHandler4), MessageType = typeof(global::MediatorShowcase.CreateProductCommand4), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildVoidCommandTerminal() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Command, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.PlaceOrderCommandHandler), + MessageType = typeof(global::MediatorShowcase.PlaceOrderCommand), + ResponseType = typeof(global::MediatorShowcase.OrderResult), + Kind = global::Mocha.Mediator.MediatorHandlerKind.CommandResponse, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandResponsePipeline() + }); + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration + { + HandlerType = typeof(global::MediatorShowcase.RiskyCommandHandler), + MessageType = typeof(global::MediatorShowcase.RiskyCommand), + ResponseType = typeof(string), + Kind = global::Mocha.Mediator.MediatorHandlerKind.CommandResponse, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildCommandResponsePipeline() + }); + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration + { + HandlerType = typeof(global::MediatorShowcase.GetProductByIdQueryHandler), MessageType = typeof(global::MediatorShowcase.GetProductByIdQuery), ResponseType = typeof(global::MediatorShowcase.ProductDto), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildQueryTerminal() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Query, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildQueryPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.GetProductsQueryHandler), MessageType = typeof(global::MediatorShowcase.GetProductsQuery), ResponseType = typeof(global::System.Collections.Generic.IReadOnlyList), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildQueryTerminal>() + Kind = global::Mocha.Mediator.MediatorHandlerKind.Query, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildQueryPipeline>() + }); + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration + { + HandlerType = typeof(global::MediatorShowcase.OrderShippedAnalyticsHandler), + MessageType = typeof(global::MediatorShowcase.OrderShippedNotification), + Kind = global::Mocha.Mediator.MediatorHandlerKind.Notification, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildNotificationPipeline() }); - b.RegisterPipeline(new global::Mocha.Mediator.MediatorPipelineConfiguration + global::Mocha.Mediator.MediatorHostBuilderHandlerExtensions.AddHandlerConfiguration(builder, + new global::Mocha.Mediator.MediatorHandlerConfiguration { + HandlerType = typeof(global::MediatorShowcase.OrderShippedEmailHandler), MessageType = typeof(global::MediatorShowcase.OrderShippedNotification), - Terminal = global::Mocha.Mediator.PipelineBuilder.BuildNotificationTerminal(new global::System.Type[] { typeof(global::MediatorShowcase.OrderShippedAnalyticsHandler), typeof(global::MediatorShowcase.OrderShippedEmailHandler) }) + Kind = global::Mocha.Mediator.MediatorHandlerKind.Notification, + Delegate = global::Mocha.Mediator.PipelineBuilder.BuildNotificationPipeline() }); - }); return builder; } diff --git a/src/Mocha/src/Examples/MediatorShowcase/MediatorShowcase.cs b/src/Mocha/src/Examples/MediatorShowcase/MediatorShowcase.cs index 3f04f6a9d3e..e1c539a9991 100644 --- a/src/Mocha/src/Examples/MediatorShowcase/MediatorShowcase.cs +++ b/src/Mocha/src/Examples/MediatorShowcase/MediatorShowcase.cs @@ -1,4 +1,3 @@ -using System.Runtime.CompilerServices; using MediatorShowcase; using Mocha.Mediator; @@ -8,6 +7,7 @@ builder.Services.AddMediator() .AddMediatorShowcase() .Use(LoggingMiddleware.Create()) + .Use(CommandAuditMiddleware.Create()) .Use(PlaceOrderValidationMiddleware.Create()) .Use(PlaceOrderAuditMiddleware.Create()) .Use(ExceptionHandlingMiddleware.Create()) diff --git a/src/Mocha/src/Examples/MediatorShowcase/PipelineBehaviors.cs b/src/Mocha/src/Examples/MediatorShowcase/PipelineBehaviors.cs index 439e7adaaa6..c71b29ee8c2 100644 --- a/src/Mocha/src/Examples/MediatorShowcase/PipelineBehaviors.cs +++ b/src/Mocha/src/Examples/MediatorShowcase/PipelineBehaviors.cs @@ -1,5 +1,4 @@ using System.Diagnostics; -using System.Runtime.CompilerServices; using Mocha.Mediator; namespace MediatorShowcase; @@ -12,133 +11,152 @@ namespace MediatorShowcase; /// Middleware that logs and times every message passing through the pipeline. /// Applies to all commands, queries, and notifications automatically. /// -public static class LoggingMiddleware +internal sealed class LoggingMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + var messageTypeName = context.MessageType.Name; + logger.LogInformation("[Pipeline] Handling {MessageType}...", messageTypeName); + + var sw = Stopwatch.StartNew(); + await next(context); + sw.Stop(); + + logger.LogInformation( + "[Pipeline] Handled {MessageType} in {ElapsedMs}ms", + messageTypeName, sw.ElapsedMilliseconds); + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.Logging"); - - return ctx => - { - var messageTypeName = ctx.MessageType.Name; - logger.LogInformation("[Pipeline] Handling {MessageType}...", messageTypeName); - - var sw = Stopwatch.StartNew(); - var task = next(ctx); - - if (task.IsCompletedSuccessfully) - { - sw.Stop(); - logger.LogInformation( - "[Pipeline] Handled {MessageType} in {ElapsedMs}ms", - messageTypeName, sw.ElapsedMilliseconds); - return default; - } - - return Awaited(task, sw, logger, messageTypeName); - - [MethodImpl(MethodImplOptions.NoInlining)] - static async ValueTask Awaited( - ValueTask t, Stopwatch sw, ILogger log, string msgType) - { - await t.ConfigureAwait(false); - sw.Stop(); - log.LogInformation( - "[Pipeline] Handled {MessageType} in {ElapsedMs}ms", - msgType, sw.ElapsedMilliseconds); - } - }; + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new LoggingMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "Logging"); } // ────────────────────────────────────────────────── -// Validation Middleware (message-specific pre-check) +// Command Audit Middleware (compile-time scoped to commands) // ────────────────────────────────────────────────── /// -/// Middleware that validates PlaceOrderCommand before the handler runs. -/// Demonstrates message-type-specific pre-processing. +/// Middleware that audits every write operation. Demonstrates compile-time filtering by +/// message kind: queries and notifications are skipped at startup with +/// and +/// , so this middleware +/// is only compiled into command pipelines. /// -public static class PlaceOrderValidationMiddleware +internal sealed class CommandAuditMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + logger.LogInformation("[Audit] Executing command {CommandType}", context.MessageType.Name); + await next(context); + logger.LogInformation("[Audit] Completed command {CommandType}", context.MessageType.Name); + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.Validation"); + // Compile-time filter: skip queries and notifications - audit only commands + if (factoryCtx.IsQuery() || factoryCtx.IsNotification()) + { + return next; + } + + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new CommandAuditMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); + }, + "CommandAudit"); +} + +// ────────────────────────────────────────────────── +// Validation Middleware (compile-time scoped to PlaceOrderCommand) +// ────────────────────────────────────────────────── - return ctx => +/// +/// Middleware that validates before the handler runs. +/// Demonstrates compile-time filtering: the factory inspects the message type and returns +/// next for unrelated pipelines, so this middleware is only compiled into the +/// PlaceOrderCommand pipeline - zero runtime cost everywhere else. +/// +internal sealed class PlaceOrderValidationMiddleware(ILogger logger) +{ + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + // Safe to cast - compile-time filter guarantees the message type + var order = (PlaceOrderCommand)context.Message; + + logger.LogInformation( + "[PreProcessor] Validating order: {Quantity}x {Product}", + order.Quantity, order.ProductName); + + if (order.Quantity <= 0) + { + throw new ArgumentException("Quantity must be greater than zero."); + } + + await next(context); + } + + public static MediatorMiddlewareConfiguration Create() + => new( + static (factoryCtx, next) => + { + // Compile-time filter: skip every pipeline whose message is not PlaceOrderCommand + if (!factoryCtx.IsMessageAssignableTo()) { - if (ctx.Message is PlaceOrderCommand order) - { - logger.LogInformation( - "[PreProcessor] Validating order: {Quantity}x {Product}", - order.Quantity, order.ProductName); - - if (order.Quantity <= 0) - { - throw new ArgumentException("Quantity must be greater than zero."); - } - } - - return next(ctx); - }; + return next; + } + + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new PlaceOrderValidationMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "Validation"); } // ────────────────────────────────────────────────── -// Auditing Middleware (post-processing) +// Auditing Middleware (compile-time scoped to OrderResult responses) // ────────────────────────────────────────────────── /// -/// Middleware that audits PlaceOrderCommand results after the handler runs. -/// Demonstrates message-type-specific post-processing. +/// Middleware that audits any handler returning an after it runs. +/// Demonstrates response-type-based compile-time filtering with +/// . /// -public static class PlaceOrderAuditMiddleware +internal sealed class PlaceOrderAuditMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + await next(context); + + if (context.Result is OrderResult result) + { + logger.LogInformation( + "[PostProcessor] Order {OrderId} confirmed with total {Total:C}", + result.OrderId, result.Total); + } + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.Audit"); - - return ctx => + // Compile-time filter: only commands/queries that return OrderResult + if (!factoryCtx.IsResponseAssignableTo()) { - var task = next(ctx); - - if (task.IsCompletedSuccessfully) - { - LogResult(ctx, logger); - return default; - } - - return Awaited(task, ctx, logger); - - [MethodImpl(MethodImplOptions.NoInlining)] - static async ValueTask Awaited( - ValueTask t, IMediatorContext ctx, ILogger log) - { - await t.ConfigureAwait(false); - LogResult(ctx, log); - } - - static void LogResult(IMediatorContext ctx, ILogger log) - { - if (ctx.Result is OrderResult result) - { - log.LogInformation( - "[PostProcessor] Order {OrderId} confirmed with total {Total:C}", - result.OrderId, result.Total); - } - } - }; + return next; + } + + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new PlaceOrderAuditMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "Audit"); } @@ -148,61 +166,34 @@ static void LogResult(IMediatorContext ctx, ILogger log) // ────────────────────────────────────────────────── /// -/// Middleware that catches InvalidOperationException from RiskyCommand +/// Middleware that catches from /// and returns a fallback response instead of propagating the exception. /// -public static class ExceptionHandlingMiddleware +internal sealed class ExceptionHandlingMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + try + { + await next(context); + } + catch (InvalidOperationException ex) when (context.Message is RiskyCommand) + { + logger.LogWarning( + "[ExceptionHandler] Caught {ExceptionType}: {Message}", + ex.GetType().Name, ex.Message); + + context.Result = "Recovered gracefully from error."; + } + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.ExceptionHandler"); - - return ctx => - { - try - { - var task = next(ctx); - - if (task.IsCompletedSuccessfully) - { - return default; - } - - return Awaited(task, ctx, logger); - } - catch (InvalidOperationException ex) when (ctx.Message is RiskyCommand) - { - HandleException(ctx, ex, logger); - return default; - } - - [MethodImpl(MethodImplOptions.NoInlining)] - static async ValueTask Awaited( - ValueTask t, IMediatorContext ctx, ILogger log) - { - try - { - await t.ConfigureAwait(false); - } - catch (InvalidOperationException ex) when (ctx.Message is RiskyCommand) - { - HandleException(ctx, ex, log); - } - } - - static void HandleException( - IMediatorContext ctx, InvalidOperationException ex, ILogger log) - { - log.LogWarning( - "[ExceptionHandler] Caught {ExceptionType}: {Message}", - ex.GetType().Name, ex.Message); - - ctx.Result = "Recovered gracefully from error."; - } - }; + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new ExceptionHandlingMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "ExceptionHandler"); } diff --git a/website/src/docs/mocha/v16/mediator/pipeline-and-middleware.md b/website/src/docs/mocha/v16/mediator/pipeline-and-middleware.md index a515cdee82a..465720feaea 100644 --- a/website/src/docs/mocha/v16/mediator/pipeline-and-middleware.md +++ b/website/src/docs/mocha/v16/mediator/pipeline-and-middleware.md @@ -4,47 +4,34 @@ description: "Add cross-cutting concerns to the Mocha Mediator dispatch pipeline --- ```csharp -public static class LoggingMiddleware +internal sealed class LoggingMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + logger.LogInformation("Handling {MessageType}...", context.MessageType.Name); + + var sw = Stopwatch.StartNew(); + await next(context); + sw.Stop(); + + logger.LogInformation( + "Handled {MessageType} in {ElapsedMs}ms", + context.MessageType.Name, sw.ElapsedMilliseconds); + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.Logging"); - - return ctx => - { - logger.LogInformation("Handling {MessageType}...", ctx.MessageType.Name); - - var sw = Stopwatch.StartNew(); - var task = next(ctx); - - if (task.IsCompletedSuccessfully) - { - sw.Stop(); - logger.LogInformation("Handled {MessageType} in {ElapsedMs}ms", - ctx.MessageType.Name, sw.ElapsedMilliseconds); - return default; - } - - return Awaited(task, sw, logger, ctx.MessageType.Name); - - static async ValueTask Awaited( - ValueTask t, Stopwatch sw, ILogger log, string msgType) - { - await t.ConfigureAwait(false); - sw.Stop(); - log.LogInformation("Handled {MessageType} in {ElapsedMs}ms", - msgType, sw.ElapsedMilliseconds); - } - }; + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new LoggingMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "Logging"); } ``` -That is a middleware. It wraps every command, query, and notification with timing and logging. Register it with `.Use()` and it runs for every message that passes through the pipeline. +That is a middleware. It wraps every command, query, and notification with timing and logging. Register it with `.Use(LoggingMiddleware.Create())` and it runs for every message that passes through the pipeline. # How the pipeline works @@ -89,7 +76,7 @@ graph LR # Write a middleware -A middleware is a static class with a `Create()` method that returns a `MediatorMiddlewareConfiguration`. The configuration holds two things: the factory delegate and an optional string key used for [positioning](#middleware-positioning). +A middleware is a class with an `InvokeAsync(IMediatorContext, MediatorDelegate)` method and a static `Create()` method that returns a `MediatorMiddlewareConfiguration`. The configuration holds two things: a factory delegate and an optional string key used for [positioning](#middleware-positioning). The factory delegate receives two arguments: @@ -98,34 +85,38 @@ The factory delegate receives two arguments: | `MediatorMiddlewareFactoryContext` | Startup (compile time) | Resolve singleton services, inspect message/response types, opt out of the pipeline | | `MediatorDelegate next` | Startup (compile time) | The next middleware or handler in the chain | -The factory returns a `MediatorDelegate` - the runtime function that receives `IMediatorContext` for each dispatch. +The factory returns a `MediatorDelegate` - the runtime function that receives `IMediatorContext` for each dispatch. By convention, that runtime delegate forwards to an `InvokeAsync` method on a small middleware class so the dispatch logic reads top-to-bottom. -Here is a minimal timing middleware, step by step: +Here is a minimal timing middleware: ```csharp -public static class TimingMiddleware +internal sealed class TimingMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + var sw = Stopwatch.StartNew(); + + await next(context); // call the next middleware or handler + + sw.Stop(); + logger.LogInformation( + "{MessageType} handled in {ElapsedMs}ms", + context.MessageType.Name, + sw.ElapsedMilliseconds); + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { // 1. Resolve services once at startup (not per request) - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.Timing"); + var logger = factoryCtx.Services.GetRequiredService>(); - // 2. Return the runtime delegate - return async ctx => - { - var sw = Stopwatch.StartNew(); - - await next(ctx); // 3. Call the next middleware or handler + // 2. Build the middleware instance once + var middleware = new TimingMiddleware(logger); - sw.Stop(); - logger.LogInformation( - "{MessageType} handled in {ElapsedMs}ms", - ctx.MessageType.Name, - sw.ElapsedMilliseconds); - }; + // 3. Return the runtime delegate + return ctx => middleware.InvokeAsync(ctx, next); }, "Timing"); // 4. Key for positioning } @@ -155,16 +146,18 @@ The `IMediatorContext` available at runtime provides everything you need during ## Short-circuiting -To prevent the handler from executing, return without calling `next`: +To prevent the handler from executing, return from `InvokeAsync` without calling `next`: ```csharp -return ctx => +public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) { - if (ctx.Message is PlaceOrderCommand { Quantity: <= 0 }) + if (context.Message is PlaceOrderCommand { Quantity: <= 0 }) + { throw new ArgumentException("Quantity must be greater than zero."); + } - return next(ctx); // only reached if validation passes -}; + await next(context); // only reached if validation passes +} ``` ## Exception handling @@ -172,63 +165,34 @@ return ctx => Wrap `next` in a try/catch to handle exceptions: ```csharp -public static class ExceptionHandlingMiddleware +internal sealed class ExceptionHandlingMiddleware(ILogger logger) { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + try + { + await next(context); + } + catch (Exception ex) + { + logger.LogError(ex, "Error handling {MessageType}", context.MessageType.Name); + throw; // re-throw or set context.Result to recover + } + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - var logger = factoryCtx.Services.GetRequiredService() - .CreateLogger("Pipeline.ExceptionHandler"); - - return async ctx => - { - try - { - await next(ctx); - } - catch (Exception ex) - { - logger.LogError(ex, "Error handling {MessageType}", - ctx.MessageType.Name); - throw; // re-throw or set ctx.Result to recover - } - }; + var logger = factoryCtx.Services.GetRequiredService>(); + var middleware = new ExceptionHandlingMiddleware(logger); + return ctx => middleware.InvokeAsync(ctx, next); }, "ExceptionHandling"); } ``` -To recover from an exception instead of re-throwing, set `ctx.Result` to a fallback value and return normally. - -## Synchronous fast-path optimization - -When `next` completes synchronously (common for in-memory handlers), you can avoid the `async` state machine overhead by checking `IsCompletedSuccessfully`: - -```csharp -return ctx => -{ - logger.LogInformation("Before"); - - var task = next(ctx); - - if (task.IsCompletedSuccessfully) - { - logger.LogInformation("After (sync)"); - return default; - } - - return Awaited(task, logger); - - static async ValueTask Awaited(ValueTask t, ILogger log) - { - await t.ConfigureAwait(false); - log.LogInformation("After (async)"); - } -}; -``` - -This pattern avoids allocating an async state machine when the pipeline completes synchronously. Use it in performance-sensitive middleware; use plain `async`/`await` everywhere else. +To recover from an exception instead of re-throwing, set `context.Result` to a fallback value and return normally. # Compile-time filtering @@ -237,41 +201,45 @@ The `MediatorMiddlewareFactoryContext` is available during pipeline compilation To opt out, return `next` directly from the factory. The middleware is not included in that pipeline at all - zero runtime cost, no delegate wrapper, no type check on every dispatch. ```csharp -public static class TransactionMiddleware +internal sealed class TransactionMiddleware { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + // Resolve DbContext from the scoped service provider so each dispatch gets its own + var db = context.Services.GetRequiredService(); + + await using var tx = await db.Database.BeginTransactionAsync(context.CancellationToken); + try + { + await next(context); + await tx.CommitAsync(context.CancellationToken); + } + catch + { + await tx.RollbackAsync(context.CancellationToken); + throw; + } + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - // Skip notifications and queries at compile time + // Skip notifications and queries at compile time - they don't need transactions if (factoryCtx.IsNotification() || factoryCtx.IsQuery()) { return next; // not included in this pipeline } - return async ctx => - { - // Resolve DbContext from the scoped service provider - var db = ctx.Services.GetRequiredService(); - - await using var tx = await db.Database - .BeginTransactionAsync(ctx.CancellationToken); - try - { - await next(ctx); - await tx.CommitAsync(ctx.CancellationToken); - } - catch - { - await tx.RollbackAsync(ctx.CancellationToken); - throw; - } - }; + var middleware = new TransactionMiddleware(); + return ctx => middleware.InvokeAsync(ctx, next); }, "Transaction"); } ``` +Notice that `DbContext` is scoped, so it must be resolved per dispatch from `context.Services` inside `InvokeAsync` - **not** from `factoryCtx.Services` in the factory, which would capture a single startup-scope instance and share it across every message. + ## Message kind checks | Method | Returns true when | @@ -290,27 +258,37 @@ public static class TransactionMiddleware | `IsResponseAssignableTo()` | Response type is assignable to `T` (false for void commands and notifications) | | `IsResponseAssignableTo(Type)` | Response type is assignable to the given type | -Use `IsMessageAssignableTo` to scope a middleware to a specific message or base type: +Use `IsMessageAssignableTo` to scope a middleware to a specific message or base type. Once the factory has filtered, `InvokeAsync` can cast directly without re-checking: ```csharp -public static class PlaceOrderValidationMiddleware +internal sealed class PlaceOrderValidationMiddleware { + public async ValueTask InvokeAsync(IMediatorContext context, MediatorDelegate next) + { + // Safe cast - the factory's compile-time filter guarantees the message type + var order = (PlaceOrderCommand)context.Message; + + if (order.Quantity <= 0) + { + throw new ArgumentException("Quantity must be greater than zero."); + } + + await next(context); + } + public static MediatorMiddlewareConfiguration Create() => new( static (factoryCtx, next) => { - // Only compile into the PlaceOrderCommand pipeline + // Only compile this middleware into the PlaceOrderCommand pipeline. + // Every other pipeline gets `next` directly - no wrapper, no per-dispatch type check. if (!factoryCtx.IsMessageAssignableTo()) { return next; } - return ctx => - { - if (ctx.Message is PlaceOrderCommand order && order.Quantity <= 0) - throw new ArgumentException("Quantity must be greater than zero."); - return next(ctx); - }; + var middleware = new PlaceOrderValidationMiddleware(); + return ctx => middleware.InvokeAsync(ctx, next); }, "Validation"); }