diff --git a/src/Testing/CoreTests/Runtime/Handlers/HandlerGraphTests.cs b/src/Testing/CoreTests/Runtime/Handlers/HandlerGraphTests.cs index a3de0368f..519d815f7 100644 --- a/src/Testing/CoreTests/Runtime/Handlers/HandlerGraphTests.cs +++ b/src/Testing/CoreTests/Runtime/Handlers/HandlerGraphTests.cs @@ -1,8 +1,12 @@ -using System.Diagnostics; +using ImTools; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Module1; +using System.Diagnostics; +using System.Reflection; +using JasperFx.Core.Reflection; using Wolverine.Attributes; +using Wolverine.Runtime; using Wolverine.Runtime.Handlers; using Wolverine.Util; using Xunit; @@ -65,6 +69,25 @@ await Should.ThrowAsync(async () => }).StartAsync(); }); } + + [Fact] + public async Task Concurrent_Registration_No_Race_Condition() + { + // just make sure the list consists of at least a couple of messages + var typesToRegister = typeof(DummyMessage).Assembly.GetTypes().Where(x => x.Name.EndsWith("Message")).ToArray(); + + using var host = await Host.CreateDefaultBuilder() + .UseWolverine() + .StartAsync(); + + var graph = host.Services.GetRequiredService(); + var runtime = host.Services.GetRequiredService(); + Parallel.ForEach(typesToRegister, t => runtime.RegisterMessageType(t)); + + var missingTypes = typesToRegister.Select(t => t.ToMessageTypeName()) + .Where(t => graph.TryFindMessageType(t, out var _) is false).ToArray(); + missingTypes.ShouldBeEmpty(); + } } public class DummyMessage { } diff --git a/src/Wolverine/Runtime/Handlers/HandlerGraph.cs b/src/Wolverine/Runtime/Handlers/HandlerGraph.cs index ffda4a370..a9f926d90 100644 --- a/src/Wolverine/Runtime/Handlers/HandlerGraph.cs +++ b/src/Wolverine/Runtime/Handlers/HandlerGraph.cs @@ -46,6 +46,8 @@ public partial class HandlerGraph : ICodeFileCollectionWithServices, IWithFailur private bool _hasGrouped; + private object _messageTypesLock = new(); + private ImHashMap _messageTypes = ImHashMap.Empty; private ImmutableList _replyTypes = ImmutableList.Empty; @@ -334,24 +336,28 @@ private void tryApplyLocalQueueConfiguration(WolverineOptions options) private void registerMessageTypes() { - _messageTypes = - _messageTypes.AddOrUpdate(typeof(Acknowledgement).ToMessageTypeName(), typeof(Acknowledgement)); - - foreach (var chain in Chains) + lock (_messageTypesLock) { - _messageTypes = _messageTypes.AddOrUpdate(chain.MessageType.ToMessageTypeName(), chain.MessageType); + _messageTypes = + _messageTypes.AddOrUpdate(typeof(Acknowledgement).ToMessageTypeName(), typeof(Acknowledgement)); - if (chain.MessageType.TryGetAttribute(out var att)) + foreach (var chain in Chains) { - _messageTypes = _messageTypes.AddOrUpdate(att.InteropType.ToMessageTypeName(), chain.MessageType); - } - else - { - foreach (var @interface in chain.MessageType.GetInterfaces()) + _messageTypes = _messageTypes.AddOrUpdate(chain.MessageType.ToMessageTypeName(), chain.MessageType); + + if (chain.MessageType.TryGetAttribute(out var att)) { - if (InteropAssemblies.Contains(@interface.Assembly)) + _messageTypes = _messageTypes.AddOrUpdate(att.InteropType.ToMessageTypeName(), chain.MessageType); + } + else + { + foreach (var @interface in chain.MessageType.GetInterfaces()) { - _messageTypes = _messageTypes.AddOrUpdate(@interface.ToMessageTypeName(), chain.MessageType); + if (InteropAssemblies.Contains(@interface.Assembly)) + { + _messageTypes = + _messageTypes.AddOrUpdate(@interface.ToMessageTypeName(), chain.MessageType); + } } } } @@ -465,8 +471,11 @@ public void RegisterMessageType(Type messageType) return; } - _messageTypes = _messageTypes.AddOrUpdate(messageType.ToMessageTypeName(), messageType); - _replyTypes = _replyTypes.Add(messageType); + lock (_messageTypesLock) + { + _messageTypes = _messageTypes.AddOrUpdate(messageType.ToMessageTypeName(), messageType); + _replyTypes = _replyTypes.Add(messageType); + } } public void RegisterMessageType(Type messageType, string messageAlias) @@ -481,8 +490,11 @@ public void RegisterMessageType(Type messageType, string messageAlias) return; } - _messageTypes = _messageTypes.AddOrUpdate(messageAlias, messageType); - _replyTypes = _replyTypes.Add(messageType); + lock (_messageTypesLock) + { + _messageTypes = _messageTypes.AddOrUpdate(messageAlias, messageType); + _replyTypes = _replyTypes.Add(messageType); + } } public IEnumerable AllChains()