diff --git a/src/Transports/Kafka/Wolverine.Kafka.Tests/global_partitioned_aggregate_concurrency.cs b/src/Transports/Kafka/Wolverine.Kafka.Tests/global_partitioned_aggregate_concurrency.cs new file mode 100644 index 000000000..94f27b9ba --- /dev/null +++ b/src/Transports/Kafka/Wolverine.Kafka.Tests/global_partitioned_aggregate_concurrency.cs @@ -0,0 +1,332 @@ +using System.Collections.Concurrent; +using IntegrationTests; +using JasperFx.Core; +using JasperFx.Resources; +using Marten; +using Marten.Metadata; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Shouldly; +using Wolverine.Configuration; +using Wolverine.Marten; +using Wolverine.Runtime.Partitioning; +using Wolverine.Tracking; +using Xunit; +using Xunit.Abstractions; + +namespace Wolverine.Kafka.Tests; + +/// +/// Reproduces the concurrency issue reported with Global Partitioning + Kafka: +/// When multiple message types target the same Marten event stream and are processed +/// via global partitioning with sharded Kafka topics across multiple nodes, concurrent +/// processing of messages for the same stream ID causes EventStreamUnexpectedMaxEventIdException. +/// +/// This test simulates the sample app's 2-replica Aspire setup by running 2 Wolverine hosts +/// with a separate publisher host that pumps messages to Kafka input topics. +/// +public class global_partitioned_aggregate_concurrency : IAsyncLifetime +{ + private readonly ITestOutputHelper _output; + private IHost _replica1 = null!; + private IHost _replica2 = null!; + private IHost _publisher = null!; + + public global_partitioned_aggregate_concurrency(ITestOutputHelper output) + { + _output = output; + } + + private void ConfigureReplica(WolverineOptions opts, string replicaName) + { + opts.ServiceName = replicaName; + + opts.Discovery.DisableConventionalDiscovery() + .IncludeType(typeof(GpStreamCommandAHandler)) + .IncludeType(typeof(GpStreamCommandBHandler)) + .IncludeType(typeof(GpStreamCascadedHandler)); + + opts.Services.AddMarten(m => + { + m.Connection(Servers.PostgresConnectionString); + m.DatabaseSchemaName = "gp_kafka_concurrency"; + m.DisableNpgsqlLogging = true; + }).IntegrateWithWolverine(); + + var kafka = opts.UseKafka(KafkaContainerFixture.ConnectionString) + .ConfigureConsumers(c => + { + c.GroupId = "gp-concurrency-test-group"; + // Critical for Kafka co-partitioning: unique ClientId per replica + c.ClientId = replicaName; + }) + .AutoProvision(); + + opts.Policies.PropagateGroupIdToPartitionKey(); + opts.Policies.AutoApplyTransactions(); + opts.Policies.UseDurableLocalQueues(); + opts.Policies.UseDurableInboxOnAllListeners(); + opts.Policies.UseDurableOutboxOnAllSendingEndpoints(); + + // Global partitioning with sharded Kafka topics matching the sample app + // 2 shards to match the 2-replica count + opts.MessagePartitioning.UseInferredMessageGrouping() + .ByPropertyNamed("Id") + .GlobalPartitioned(topology => + { + var sharded = topology.UseShardedKafkaTopics("gp-concurrency-test", 3); + sharded.Message(); + sharded.Message(); + sharded.Message(); + }); + + // Listen to external Kafka topics (like sample's topic-one, topic-two) + opts.ListenToKafkaTopic("gp-concurrency-input-a") + .DisableConsumerGroupIdStamping() + .PartitionProcessingByGroupId(PartitionSlots.Five); + + opts.ListenToKafkaTopic("gp-concurrency-input-b") + .DisableConsumerGroupIdStamping() + .PartitionProcessingByGroupId(PartitionSlots.Five); + + opts.Services.AddResourceSetupOnStartup(); + } + + public async Task InitializeAsync() + { + ConcurrencyTracker.Reset(); + + // Start replica 1 first to provision topics and database + _replica1 = await Host.CreateDefaultBuilder() + .UseWolverine(opts => ConfigureReplica(opts, "replica-1")) + .StartAsync(); + + // Start replica 2 + _replica2 = await Host.CreateDefaultBuilder() + .UseWolverine(opts => ConfigureReplica(opts, "replica-2")) + .StartAsync(); + + // Start a publisher host that only publishes to Kafka input topics + _publisher = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.ServiceName = "publisher"; + opts.Durability.Mode = DurabilityMode.Solo; + opts.Discovery.DisableConventionalDiscovery(); + + opts.UseKafka(KafkaContainerFixture.ConnectionString) + .AutoProvision(); + + opts.PublishMessage() + .ToKafkaTopic("gp-concurrency-input-a"); + + opts.PublishMessage() + .ToKafkaTopic("gp-concurrency-input-b"); + + opts.Services.AddResourceSetupOnStartup(); + }).StartAsync(); + + // Allow Kafka consumer group rebalancing to settle + await Task.Delay(5.Seconds()); + } + + public async Task DisposeAsync() + { + if (_publisher != null) { await _publisher.StopAsync(); _publisher.Dispose(); } + if (_replica2 != null) { await _replica2.StopAsync(); _replica2.Dispose(); } + if (_replica1 != null) { await _replica1.StopAsync(); _replica1.Dispose(); } + } + + /// + /// This test pumps messages from an external publisher to Kafka input topics. + /// Two replicas each have global partitioning configured with sharded Kafka topics. + /// Messages for the same stream ID should never be processed concurrently across + /// any replica, regardless of which input topic they arrive on. + /// + /// The bug: with 2 replicas, messages for the same stream arriving on different + /// input topics can be processed concurrently, causing EventStreamUnexpectedMaxEventIdException. + /// + [Fact] + public async Task should_not_have_concurrency_exceptions_for_same_stream() + { + var store = _replica1.Services.GetRequiredService(); + await store.Advanced.Clean.DeleteAllEventDataAsync(); + + var bus = _publisher.Services.GetRequiredService(); + + // Use a small number of stream IDs but many messages per stream + var streamIds = Enumerable.Range(1, 4).Select(_ => Guid.NewGuid()).ToArray(); + var messageCount = 0; + + // Pump messages concurrently for the same stream IDs + var tasks = new List(); + foreach (var streamId in streamIds) + { + for (int i = 0; i < 8; i++) + { + var id = streamId; + var iteration = i; + tasks.Add(Task.Run(async () => + { + if (iteration % 2 == 0) + { + await bus.PublishAsync(new GpStreamCommandA(id, $"name-{iteration}")); + } + else + { + await bus.PublishAsync(new GpStreamCommandB(id, $"data-{iteration}")); + } + + Interlocked.Increment(ref messageCount); + })); + } + } + + await Task.WhenAll(tasks); + _output.WriteLine($"Published {messageCount} messages across {streamIds.Length} streams"); + + // Wait for processing to complete across both replicas + await Task.Delay(45.Seconds()); + + var errors = ConcurrencyTracker.Errors.ToList(); + var concurrentAccessCount = ConcurrencyTracker.ConcurrentAccessDetected; + + _output.WriteLine($"Total handled: {ConcurrencyTracker.TotalHandled}"); + _output.WriteLine($"Concurrent access detected: {concurrentAccessCount} times"); + + foreach (var error in errors) + { + _output.WriteLine($"ERROR: {error}"); + } + + // Verify all messages were processed (32 original + 16 cascaded from CommandA = 48) + _output.WriteLine($"Expected at least 32 handled messages, got {ConcurrencyTracker.TotalHandled}"); + + // The key assertion: no concurrent access to the same stream should occur + concurrentAccessCount.ShouldBe(0, + $"Detected {concurrentAccessCount} instances of concurrent access to the same stream. " + + "Global partitioning should prevent this. Errors:\n" + + string.Join("\n", errors.Take(10))); + } +} + +// --- Message types with Id property for partitioning --- + +public record GpStreamCommandA(Guid Id, string Name); +public record GpStreamCommandB(Guid Id, string Data); +public record GpStreamCascaded(Guid Id, string Source); + +// --- Events for the Marten stream --- +public record GpStreamEventA(string Name); +public record GpStreamEventB(string Data); +public record GpStreamEventCascaded(string Source); + +// --- Aggregate --- +public class GpStreamAggregate : IRevisioned +{ + public Guid Id { get; set; } + public int Version { get; set; } + public int ACount { get; set; } + public int BCount { get; set; } + public int CascadedCount { get; set; } + + public void Apply(GpStreamEventA _) => ACount++; + public void Apply(GpStreamEventB _) => BCount++; + public void Apply(GpStreamEventCascaded _) => CascadedCount++; +} + +// --- Concurrency tracking utility --- +public static class ConcurrencyTracker +{ + private static readonly ConcurrentDictionary _activeStreams = new(); + private static readonly ConcurrentBag _errors = new(); + private static int _totalHandled; + private static int _concurrentAccessDetected; + + public static IReadOnlyCollection Errors => _errors; + public static int TotalHandled => _totalHandled; + public static int ConcurrentAccessDetected => _concurrentAccessDetected; + + public static void Reset() + { + _activeStreams.Clear(); + while (_errors.TryTake(out _)) { } + _totalHandled = 0; + _concurrentAccessDetected = 0; + } + + public static IDisposable TrackStream(string streamId, string handlerName) + { + var count = _activeStreams.AddOrUpdate(streamId, 1, (_, existing) => existing + 1); + if (count > 1) + { + Interlocked.Increment(ref _concurrentAccessDetected); + _errors.Add($"Concurrent access to stream '{streamId}' by {handlerName} (active count: {count})"); + } + Interlocked.Increment(ref _totalHandled); + return new StreamTracker(streamId); + } + + private class StreamTracker : IDisposable + { + private readonly string _streamId; + public StreamTracker(string streamId) => _streamId = streamId; + public void Dispose() => _activeStreams.AddOrUpdate(_streamId, 0, (_, existing) => existing - 1); + } +} + +// --- Handlers that use [AggregateHandler] to target the same stream --- + +[AggregateHandler] +public static class GpStreamCommandAHandler +{ + public static async Task<(Events, OutgoingMessages)> Handle( + GpStreamCommandA command, + GpStreamAggregate aggregate) + { + using var tracker = ConcurrencyTracker.TrackStream(command.Id.ToString(), nameof(GpStreamCommandAHandler)); + + // Simulate some work (like the sample app's random delay) + await Task.Delay(TimeSpan.FromMilliseconds(Random.Shared.Next(10, 100))); + + var events = new Events { new GpStreamEventA(command.Name) }; + + // Cascade a message that also targets the same stream (like the sample app) + var outgoing = new OutgoingMessages + { + new GpStreamCascaded(command.Id, $"from-a-{command.Name}") + }; + + return (events, outgoing); + } +} + +[AggregateHandler] +public static class GpStreamCommandBHandler +{ + public static async Task Handle( + GpStreamCommandB command, + GpStreamAggregate aggregate) + { + using var tracker = ConcurrencyTracker.TrackStream(command.Id.ToString(), nameof(GpStreamCommandBHandler)); + + await Task.Delay(TimeSpan.FromMilliseconds(Random.Shared.Next(10, 100))); + + return [new GpStreamEventB(command.Data)]; + } +} + +[AggregateHandler] +public static class GpStreamCascadedHandler +{ + public static async Task Handle( + GpStreamCascaded command, + GpStreamAggregate aggregate) + { + using var tracker = ConcurrencyTracker.TrackStream(command.Id.ToString(), nameof(GpStreamCascadedHandler)); + + await Task.Delay(TimeSpan.FromMilliseconds(Random.Shared.Next(10, 50))); + + return [new GpStreamEventCascaded(command.Source)]; + } +} diff --git a/src/Wolverine/Runtime/Partitioning/GlobalPartitionedInterceptor.cs b/src/Wolverine/Runtime/Partitioning/GlobalPartitionedInterceptor.cs index 2c04de4e7..a541cdabb 100644 --- a/src/Wolverine/Runtime/Partitioning/GlobalPartitionedInterceptor.cs +++ b/src/Wolverine/Runtime/Partitioning/GlobalPartitionedInterceptor.cs @@ -33,22 +33,11 @@ public async ValueTask ReceivedAsync(IListener listener, Envelope[] messages) foreach (var envelope in messages) { - if (envelope.Message != null && ShouldIntercept(envelope.Message.GetType())) + if (ShouldIntercept(envelope)) { - // Re-route through Wolverine's routing which will hit GlobalPartitionedRoute - try + if (!await TryReRouteAsync(listener, envelope)) { - await _messageBus.PublishAsync(envelope.Message, new DeliveryOptions - { - GroupId = envelope.GroupId, - TenantId = envelope.TenantId - }); - await listener.CompleteAsync(envelope); - } - catch (Exception e) - { - _logger.LogError(e, "Error re-routing globally partitioned message {MessageType}", envelope.Message.GetType().Name); - await listener.DeferAsync(envelope); + passThrough.Add(envelope); } } else @@ -65,33 +54,69 @@ public async ValueTask ReceivedAsync(IListener listener, Envelope[] messages) public async ValueTask ReceivedAsync(IListener listener, Envelope envelope) { - if (envelope.Message != null && ShouldIntercept(envelope.Message.GetType())) + if (ShouldIntercept(envelope)) { - try - { - await _messageBus.PublishAsync(envelope.Message, new DeliveryOptions - { - GroupId = envelope.GroupId, - TenantId = envelope.TenantId - }); - await listener.CompleteAsync(envelope); - } - catch (Exception e) + if (await TryReRouteAsync(listener, envelope)) { - _logger.LogError(e, "Error re-routing globally partitioned message {MessageType}", envelope.Message.GetType().Name); - await listener.DeferAsync(envelope); + return; } - return; } await _inner.ReceivedAsync(listener, envelope); } + private async Task TryReRouteAsync(IListener listener, Envelope envelope) + { + try + { + // Ensure message is deserialized before re-publishing + if (envelope.Message == null) + { + var result = await Pipeline.TryDeserializeEnvelope(envelope); + if (result is not NullContinuation) + { + // Deserialization failed, let the inner receiver handle it + // (it will apply normal error handling) + return false; + } + } + + // Re-route through Wolverine's routing which will hit GlobalPartitionedRoute + await _messageBus.PublishAsync(envelope.Message!, new DeliveryOptions + { + GroupId = envelope.GroupId, + TenantId = envelope.TenantId + }); + await listener.CompleteAsync(envelope); + return true; + } + catch (Exception e) + { + _logger.LogError(e, "Error re-routing globally partitioned message {MessageType}", + envelope.MessageType ?? envelope.Message?.GetType().Name ?? "unknown"); + await listener.DeferAsync(envelope); + return true; + } + } + public ValueTask DrainAsync() => _inner.DrainAsync(); public void Dispose() => _inner.Dispose(); - private bool ShouldIntercept(Type messageType) + private bool ShouldIntercept(Envelope envelope) { - return _topologies.Any(t => t.Matches(messageType)); + // If message is already deserialized, check the Type directly + if (envelope.Message != null) + { + return _topologies.Any(t => t.Matches(envelope.Message.GetType())); + } + + // For transports that haven't deserialized yet (e.g. Kafka), + // check by message type name from envelope metadata/headers + if (envelope.MessageType != null) + { + return _topologies.Any(t => t.MatchesByMessageTypeName(envelope.MessageType)); + } + + return false; } } diff --git a/src/Wolverine/Runtime/Partitioning/GlobalPartitionedMessageTopology.cs b/src/Wolverine/Runtime/Partitioning/GlobalPartitionedMessageTopology.cs index 3345aa4ba..86e2add5e 100644 --- a/src/Wolverine/Runtime/Partitioning/GlobalPartitionedMessageTopology.cs +++ b/src/Wolverine/Runtime/Partitioning/GlobalPartitionedMessageTopology.cs @@ -1,6 +1,7 @@ using System.Reflection; using Wolverine.Configuration; using Wolverine.Runtime.Routing; +using Wolverine.Util; namespace Wolverine.Runtime.Partitioning; @@ -8,6 +9,7 @@ public class GlobalPartitionedMessageTopology { private readonly WolverineOptions _options; private readonly List _subscriptions = new(); + private readonly HashSet _messageTypeNames = new(StringComparer.OrdinalIgnoreCase); private PartitionedMessageTopology? _externalTopology; private LocalPartitionedMessageTopology? _localTopology; @@ -79,6 +81,7 @@ public void Message() public void Message(Type type) { _subscriptions.Add(Subscription.ForType(type)); + _messageTypeNames.Add(type.ToMessageTypeName()); } /// @@ -165,6 +168,31 @@ internal bool Matches(Type messageType) return _subscriptions.Any(x => x.Matches(messageType)); } + /// + /// Check if a message type name (from envelope metadata) matches this topology's subscriptions. + /// This is used by the interceptor when the message hasn't been deserialized yet (e.g. Kafka). + /// + internal bool MatchesByMessageTypeName(string? messageTypeName) + { + return messageTypeName != null && _messageTypeNames.Contains(messageTypeName); + } + + /// + /// Pre-compute message type names for subscription scopes that can't be resolved from + /// a string alone (e.g. MessagesImplementing, namespace, assembly). + /// Called during startup with the set of known handler message types. + /// + internal void ResolveMessageTypeNames(IEnumerable knownMessageTypes) + { + foreach (var type in knownMessageTypes) + { + if (Matches(type)) + { + _messageTypeNames.Add(type.ToMessageTypeName()); + } + } + } + internal bool TryMatch(Type messageType, IWolverineRuntime runtime, out IMessageRoute? route) { route = default; diff --git a/src/Wolverine/Runtime/WolverineRuntime.HostService.cs b/src/Wolverine/Runtime/WolverineRuntime.HostService.cs index 1bbb7da8b..7f6948b62 100644 --- a/src/Wolverine/Runtime/WolverineRuntime.HostService.cs +++ b/src/Wolverine/Runtime/WolverineRuntime.HostService.cs @@ -298,6 +298,18 @@ private async Task startMessagingTransportsAsync() discoverListenersFromConventions(); + // Pre-compute message type names for global partitioning interceptor + // This handles MessagesImplementing(), namespace, and assembly scopes + // that can't be resolved from a string alone + if (Options.MessagePartitioning.GlobalPartitionedTopologies.Count > 0) + { + var knownMessageTypes = Handlers.Chains.Select(x => x.MessageType).ToList(); + foreach (var topology in Options.MessagePartitioning.GlobalPartitionedTopologies) + { + topology.ResolveMessageTypeNames(knownMessageTypes); + } + } + // No local queues if running in Serverless if (Options.Durability.Mode == DurabilityMode.Serverless) {