Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/Testing/CoreTests/Runtime/Handlers/HandlerGraphTests.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
using System.Diagnostics;
using ImTools;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Module1;
using System.Diagnostics;
using System.Reflection;
using JasperFx.Core.Reflection;
using Wolverine.Attributes;
using Wolverine.Runtime;
using Wolverine.Runtime.Handlers;
using Wolverine.Util;
using Xunit;
Expand Down Expand Up @@ -65,6 +69,25 @@ await Should.ThrowAsync<InvalidOperationException>(async () =>
}).StartAsync();
});
}

[Fact]
public async Task Concurrent_Registration_No_Race_Condition()
{
// just make sure the list consists of at least a couple of messages
var typesToRegister = typeof(DummyMessage).Assembly.GetTypes().Where(x => x.Name.EndsWith("Message")).ToArray();

using var host = await Host.CreateDefaultBuilder()
.UseWolverine()
.StartAsync();

var graph = host.Services.GetRequiredService<HandlerGraph>();
var runtime = host.Services.GetRequiredService<IWolverineRuntime>();
Parallel.ForEach(typesToRegister, t => runtime.RegisterMessageType(t));

var missingTypes = typesToRegister.Select(t => t.ToMessageTypeName())
.Where(t => graph.TryFindMessageType(t, out var _) is false).ToArray();
missingTypes.ShouldBeEmpty();
}
}

public class DummyMessage { }
Expand Down
46 changes: 29 additions & 17 deletions src/Wolverine/Runtime/Handlers/HandlerGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public partial class HandlerGraph : ICodeFileCollectionWithServices, IWithFailur

private bool _hasGrouped;

private object _messageTypesLock = new();

private ImHashMap<string, Type> _messageTypes = ImHashMap<string, Type>.Empty;

private ImmutableList<Type> _replyTypes = ImmutableList<Type>.Empty;
Expand Down Expand Up @@ -334,24 +336,28 @@ private void tryApplyLocalQueueConfiguration(WolverineOptions options)

private void registerMessageTypes()
{
_messageTypes =
_messageTypes.AddOrUpdate(typeof(Acknowledgement).ToMessageTypeName(), typeof(Acknowledgement));

foreach (var chain in Chains)
lock (_messageTypesLock)
{
_messageTypes = _messageTypes.AddOrUpdate(chain.MessageType.ToMessageTypeName(), chain.MessageType);
_messageTypes =
_messageTypes.AddOrUpdate(typeof(Acknowledgement).ToMessageTypeName(), typeof(Acknowledgement));

if (chain.MessageType.TryGetAttribute<InteropMessageAttribute>(out var att))
foreach (var chain in Chains)
{
_messageTypes = _messageTypes.AddOrUpdate(att.InteropType.ToMessageTypeName(), chain.MessageType);
}
else
{
foreach (var @interface in chain.MessageType.GetInterfaces())
_messageTypes = _messageTypes.AddOrUpdate(chain.MessageType.ToMessageTypeName(), chain.MessageType);

if (chain.MessageType.TryGetAttribute<InteropMessageAttribute>(out var att))
{
if (InteropAssemblies.Contains(@interface.Assembly))
_messageTypes = _messageTypes.AddOrUpdate(att.InteropType.ToMessageTypeName(), chain.MessageType);
}
else
{
foreach (var @interface in chain.MessageType.GetInterfaces())
{
_messageTypes = _messageTypes.AddOrUpdate(@interface.ToMessageTypeName(), chain.MessageType);
if (InteropAssemblies.Contains(@interface.Assembly))
{
_messageTypes =
_messageTypes.AddOrUpdate(@interface.ToMessageTypeName(), chain.MessageType);
}
}
}
}
Expand Down Expand Up @@ -465,8 +471,11 @@ public void RegisterMessageType(Type messageType)
return;
}

_messageTypes = _messageTypes.AddOrUpdate(messageType.ToMessageTypeName(), messageType);
_replyTypes = _replyTypes.Add(messageType);
lock (_messageTypesLock)
{
_messageTypes = _messageTypes.AddOrUpdate(messageType.ToMessageTypeName(), messageType);
_replyTypes = _replyTypes.Add(messageType);
}
}

public void RegisterMessageType(Type messageType, string messageAlias)
Expand All @@ -481,8 +490,11 @@ public void RegisterMessageType(Type messageType, string messageAlias)
return;
}

_messageTypes = _messageTypes.AddOrUpdate(messageAlias, messageType);
_replyTypes = _replyTypes.Add(messageType);
lock (_messageTypesLock)
{
_messageTypes = _messageTypes.AddOrUpdate(messageAlias, messageType);
_replyTypes = _replyTypes.Add(messageType);
}
}

public IEnumerable<HandlerChain> AllChains()
Expand Down
Loading