diff --git a/src/Mocha/src/Mocha/Builder/MessageBusOptions.cs b/src/Mocha/src/Mocha/Builder/MessageBusOptions.cs index 845d77dd15b..69e46c33384 100644 --- a/src/Mocha/src/Mocha/Builder/MessageBusOptions.cs +++ b/src/Mocha/src/Mocha/Builder/MessageBusOptions.cs @@ -23,6 +23,7 @@ public static class MessageBusServiceCollectionExtensions public static IMessageBusHostBuilder AddMessageBus(this IServiceCollection services) { services.AddLogging(); + services.AddScoped(); services.AddScoped(); services.AddSingleton(static sp => diff --git a/src/Mocha/src/Mocha/Context/ConsumeContextAccessor.cs b/src/Mocha/src/Mocha/Context/ConsumeContextAccessor.cs new file mode 100644 index 00000000000..bd88bda7834 --- /dev/null +++ b/src/Mocha/src/Mocha/Context/ConsumeContextAccessor.cs @@ -0,0 +1,11 @@ +namespace Mocha; + +/// +/// Scoped service that holds a reference to the current +/// during message consumption. Used by to automatically +/// propagate ConversationId and CausationId when publishing or sending from within a handler. +/// +public sealed class ConsumeContextAccessor +{ + public IConsumeContext? Context { get; set; } +} diff --git a/src/Mocha/src/Mocha/Endpoints/ReceiveEndpoint.cs b/src/Mocha/src/Mocha/Endpoints/ReceiveEndpoint.cs index 56d3c4d60cf..b0caa192eb7 100644 --- a/src/Mocha/src/Mocha/Endpoints/ReceiveEndpoint.cs +++ b/src/Mocha/src/Mocha/Endpoints/ReceiveEndpoint.cs @@ -131,6 +131,9 @@ public async ValueTask ExecuteAsync( configure(context, state); + var accessor = scope.ServiceProvider.GetRequiredService(); + accessor.Context = context; + await _pipeline(context); } catch (Exception ex) @@ -140,6 +143,8 @@ public async ValueTask ExecuteAsync( } finally { + var accessor = scope.ServiceProvider.GetRequiredService(); + accessor.Context = null; pools.ReceiveContext.Return(context); } } diff --git a/src/Mocha/src/Mocha/Middlewares/DefaultMessageBus.cs b/src/Mocha/src/Mocha/Middlewares/DefaultMessageBus.cs index 7e0f1d695bc..0e260f778f6 100644 --- a/src/Mocha/src/Mocha/Middlewares/DefaultMessageBus.cs +++ b/src/Mocha/src/Mocha/Middlewares/DefaultMessageBus.cs @@ -17,7 +17,12 @@ namespace Mocha; /// The messaging runtime used to resolve message types, endpoints, and transports. /// The scoped service provider injected into each dispatch context. /// Object pools providing reusable instances. -public sealed class DefaultMessageBus(IMessagingRuntime runtime, IServiceProvider services, IMessagingPools pools) +/// Accessor for the ambient consume context used to propagate correlation IDs. +public sealed class DefaultMessageBus( + IMessagingRuntime runtime, + IServiceProvider services, + IMessagingPools pools, + ConsumeContextAccessor consumeContextAccessor) : IMessageBus { private readonly ObjectPool _contextPool = pools.DispatchContext; @@ -56,6 +61,7 @@ public async ValueTask PublishAsync(T message, PublishOptions options, Cancel var context = _contextPool.Get(); try { + PropagateCorrelationIds(context); context.Initialize(services, endpoint, runtime, messageType, cancellationToken); context.Message = message; context.AddHeaders(options.Headers); @@ -105,6 +111,7 @@ public async ValueTask SendAsync(object message, SendOptions options, Cancellati var context = _contextPool.Get(); try { + PropagateCorrelationIds(context); context.Initialize(services, endpoint, runtime, messageType, cancellationToken); context.Message = message; @@ -263,6 +270,7 @@ private async ValueTask RequestAndWaitAsync( var context = _contextPool.Get(); try { + PropagateCorrelationIds(context); context.CorrelationId = correlationId; context.Initialize(services, endpoint, runtime, requestType, cancellationToken); @@ -288,6 +296,15 @@ private async ValueTask RequestAndWaitAsync( throw new InvalidOperationException("Unexpected response type."); } + + private void PropagateCorrelationIds(DispatchContext context) + { + if (consumeContextAccessor.Context is { } ambient) + { + context.ConversationId ??= ambient.ConversationId; + context.CausationId ??= ambient.MessageId; + } + } } file static class Extensions diff --git a/src/Mocha/test/Mocha.Transport.InMemory.Tests/Behaviors/CorrelationTests.cs b/src/Mocha/test/Mocha.Transport.InMemory.Tests/Behaviors/CorrelationTests.cs new file mode 100644 index 00000000000..1ed4b9fba2d --- /dev/null +++ b/src/Mocha/test/Mocha.Transport.InMemory.Tests/Behaviors/CorrelationTests.cs @@ -0,0 +1,267 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.DependencyInjection; +using Mocha; +using Mocha.Transport.InMemory.Tests.Helpers; + +namespace Mocha.Transport.InMemory.Tests.Behaviors; + +public class CorrelationTests +{ + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(10); + + [Fact] + public async Task Publish_Should_AutoGenerateIds_When_NoIdsSet() + { + // arrange + var capture = new ContextCapture(); + await using var provider = await new ServiceCollection() + .AddSingleton(capture) + .AddMessageBus() + .AddConsumer() + .AddInMemory() + .BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var bus = scope.ServiceProvider.GetRequiredService(); + + // act + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-1" }, default); + + // assert + Assert.True(await capture.WaitAsync(Timeout)); + var ctx = Assert.Single(capture.Contexts); + + Assert.NotNull(ctx.MessageId); + Assert.NotNull(ctx.CorrelationId); + Assert.NotNull(ctx.ConversationId); + Assert.True(Guid.TryParse(ctx.MessageId, out _), "MessageId should be a valid GUID"); + Assert.True(Guid.TryParse(ctx.CorrelationId, out _), "CorrelationId should be a valid GUID"); + Assert.True(Guid.TryParse(ctx.ConversationId, out _), "ConversationId should be a valid GUID"); + } + + [Fact] + public async Task Publish_Should_AssignUniqueIds_When_MultipleSeparatePublishes() + { + // arrange + var capture = new ContextCapture(); + await using var provider = await new ServiceCollection() + .AddSingleton(capture) + .AddMessageBus() + .AddConsumer() + .AddInMemory() + .BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var bus = scope.ServiceProvider.GetRequiredService(); + + // act — two independent publishes + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-A" }, default); + Assert.True(await capture.WaitAsync(Timeout)); + + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-B" }, default); + Assert.True(await capture.WaitAsync(Timeout)); + + // assert — each publish gets its own MessageId and ConversationId + Assert.Equal(2, capture.Contexts.Count); + var ids = capture.Contexts.ToArray(); + + Assert.NotEqual(ids[0].MessageId, ids[1].MessageId); + Assert.NotEqual(ids[0].ConversationId, ids[1].ConversationId); + } + + [Fact] + public async Task Consumer_Should_SeeAllCorrelationIds_When_MessageReceived() + { + // arrange — use IConsumer to inspect the full context + var capture = new ContextCapture(); + await using var provider = await new ServiceCollection() + .AddSingleton(capture) + .AddMessageBus() + .AddConsumer() + .AddInMemory() + .BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var bus = scope.ServiceProvider.GetRequiredService(); + + // act + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-CTX" }, default); + + // assert + Assert.True(await capture.WaitAsync(Timeout)); + var ctx = Assert.Single(capture.Contexts); + + Assert.NotNull(ctx.ConversationId); + Assert.NotNull(ctx.CorrelationId); + Assert.NotNull(ctx.MessageId); + } + + [Fact] + public async Task Publish_Should_HaveDistinctMessageIdButSharedCorrelationScope_When_FanOutToMultipleConsumers() + { + // arrange — two consumers receive the same published event via fan-out + var capture = new ContextCapture(); + await using var provider = await new ServiceCollection() + .AddSingleton(capture) + .AddMessageBus() + .AddConsumer() + .AddConsumer() + .AddInMemory() + .BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var bus = scope.ServiceProvider.GetRequiredService(); + + // act + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-FAN" }, default); + + // assert — both consumers received the event + Assert.True(await capture.WaitAsync(Timeout, 2)); + Assert.Equal(2, capture.Contexts.Count); + + var all = capture.Contexts.ToArray(); + + // Both see the same ConversationId (same logical conversation) + Assert.Equal(all[0].ConversationId, all[1].ConversationId); + + // Both see the same CorrelationId (same correlation scope) + Assert.Equal(all[0].CorrelationId, all[1].CorrelationId); + } + + [Fact] + public async Task Chain_Should_PropagateConversationId_When_HandlerPublishesNewMessage() + { + // arrange + // Chain: publish OrderCreated → OrderCreatedForwarder handles it and publishes ProcessPayment + // → PaymentSpy captures the ProcessPayment context + // Verify: ConversationId from message 1 should carry over to message 2, + // and CausationId on message 2 should equal MessageId of message 1. + var capture = new ContextCapture(); + await using var provider = await new ServiceCollection() + .AddSingleton(capture) + .AddMessageBus() + .AddConsumer() + .AddConsumer() + .AddInMemory() + .BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var bus = scope.ServiceProvider.GetRequiredService(); + + // act — publish the initial event + await bus.PublishAsync(new OrderCreated { OrderId = "ORD-CHAIN" }, default); + + // assert — wait for both captures (OrderCreated + ProcessPayment) + Assert.True(await capture.WaitAsync(Timeout, 2), "Both handlers should fire"); + Assert.Equal(2, capture.Contexts.Count); + + var hop1 = capture.Contexts.Single(c => c.Label == "OrderCreatedForwarder"); + var hop2 = capture.Contexts.Single(c => c.Label == "PaymentSpy"); + + // ConversationId must propagate across hops + Assert.Equal(hop1.ConversationId, hop2.ConversationId); + + // CausationId on hop2 should equal MessageId of hop1 (parent→child link) + Assert.Equal(hop1.MessageId, hop2.CausationId); + } + + // ══════════════════════════════════════════════════════════════════════ + // Test infrastructure + // ══════════════════════════════════════════════════════════════════════ + + public sealed class ContextCapture + { + private readonly SemaphoreSlim _semaphore = new(0); + + public ConcurrentBag Contexts { get; } = []; + + public void Record( + string? messageId, + string? correlationId, + string? conversationId, + string? causationId, + string? label = null) + { + Contexts.Add(new CapturedContext + { + MessageId = messageId, + CorrelationId = correlationId, + ConversationId = conversationId, + CausationId = causationId, + Label = label + }); + _semaphore.Release(); + } + + public async Task WaitAsync(TimeSpan timeout, int expectedCount = 1) + { + for (var i = 0; i < expectedCount; i++) + { + if (!await _semaphore.WaitAsync(timeout)) + return false; + } + return true; + } + } + + public sealed class CapturedContext + { + public string? MessageId { get; init; } + public string? CorrelationId { get; init; } + public string? ConversationId { get; init; } + public string? CausationId { get; init; } + public string? Label { get; init; } + } + + public sealed class OrderCreatedSpy(ContextCapture capture) : IConsumer + { + public ValueTask ConsumeAsync(IConsumeContext context) + { + capture.Record(context.MessageId, context.CorrelationId, context.ConversationId, context.CausationId); + return default; + } + } + + public sealed class OrderCreatedSpy2(ContextCapture capture) : IConsumer + { + public ValueTask ConsumeAsync(IConsumeContext context) + { + capture.Record(context.MessageId, context.CorrelationId, context.ConversationId, context.CausationId); + return default; + } + } + + /// + /// Receives OrderCreated and publishes ProcessPayment. ConversationId and + /// CausationId are propagated automatically by the framework. + /// + public sealed class OrderCreatedForwarder(ContextCapture capture) : IConsumer + { + public async ValueTask ConsumeAsync(IConsumeContext context) + { + capture.Record( + context.MessageId, context.CorrelationId, + context.ConversationId, context.CausationId, + nameof(OrderCreatedForwarder)); + + var bus = context.Services.GetRequiredService(); + + // No manual propagation needed — the framework handles it automatically + await bus.PublishAsync( + new ProcessPayment { OrderId = context.Message.OrderId, Amount = 100m }, + context.CancellationToken); + } + } + + public sealed class PaymentSpy(ContextCapture capture) : IConsumer + { + public ValueTask ConsumeAsync(IConsumeContext context) + { + capture.Record( + context.MessageId, context.CorrelationId, + context.ConversationId, context.CausationId, + nameof(PaymentSpy)); + return default; + } + } +}