diff --git a/TUnit.Mocks/MockEngine.cs b/TUnit.Mocks/MockEngine.cs index bdeefe6f50..b431cd82ab 100644 --- a/TUnit.Mocks/MockEngine.cs +++ b/TUnit.Mocks/MockEngine.cs @@ -4,6 +4,7 @@ using TUnit.Mocks.Setup.Behaviors; using TUnit.Mocks.Verification; using System.Collections.Concurrent; +using System.Threading; using System.ComponentModel; namespace TUnit.Mocks; @@ -24,11 +25,12 @@ public sealed class MockEngine : IMockEngineAccess where T : class private readonly Dictionary> _setupsByMember = new(); private readonly Lock _setupLock = new(); private readonly ConcurrentQueue _callHistory = new(); - private readonly ConcurrentDictionary _autoTrackValues = new(); - private readonly ConcurrentQueue<(string EventName, bool IsSubscribe)> _eventSubscriptions = new(); - private readonly ConcurrentDictionary _onSubscribeCallbacks = new(); - private readonly ConcurrentDictionary _onUnsubscribeCallbacks = new(); - private readonly ConcurrentDictionary _autoMockCache = new(); + + private ConcurrentDictionary? _autoTrackValues; + private ConcurrentQueue<(string EventName, bool IsSubscribe)>? _eventSubscriptions; + private ConcurrentDictionary? _onSubscribeCallbacks; + private ConcurrentDictionary? _onUnsubscribeCallbacks; + private ConcurrentDictionary? _autoMockCache; /// /// The current state name for state machine mocking. Null means no state (all setups match). @@ -77,6 +79,21 @@ public MockEngine(MockBehavior behavior) Behavior = behavior; } + private ConcurrentDictionary AutoTrackValues + => LazyInitializer.EnsureInitialized(ref _autoTrackValues)!; + + private ConcurrentQueue<(string EventName, bool IsSubscribe)> EventSubscriptions + => LazyInitializer.EnsureInitialized(ref _eventSubscriptions)!; + + private ConcurrentDictionary OnSubscribeCallbacks + => LazyInitializer.EnsureInitialized(ref _onSubscribeCallbacks)!; + + private ConcurrentDictionary OnUnsubscribeCallbacks + => LazyInitializer.EnsureInitialized(ref _onUnsubscribeCallbacks)!; + + private ConcurrentDictionary AutoMockCache + => LazyInitializer.EnsureInitialized(ref _autoMockCache)!; + /// /// Transitions the engine to the specified state. Null clears the state. /// @@ -125,7 +142,7 @@ public void HandleCall(int memberId, string memberName, object?[] args) // Auto-track property setters: store value keyed by property name if (AutoTrackProperties && memberName.StartsWith("set_", StringComparison.Ordinal) && args.Length > 0) { - _autoTrackValues[memberName.Substring(4)] = args[0]; + AutoTrackValues[memberName[4..]] = args[0]; } var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args); @@ -211,9 +228,9 @@ public TReturn HandleCallWithReturn(int memberId, string memberName, ob callRecord.IsUnmatched = true; // Auto-track property getters: return stored value if available - if (AutoTrackProperties && memberName.StartsWith("get_", StringComparison.Ordinal)) + if (AutoTrackProperties && Volatile.Read(ref _autoTrackValues) is { } trackValues && memberName.StartsWith("get_", StringComparison.Ordinal)) { - if (_autoTrackValues.TryGetValue(memberName.Substring(4), out var trackedValue)) + if (trackValues.TryGetValue(memberName[4..], out var trackedValue)) { if (trackedValue is TReturn typed) return typed; if (trackedValue is null) return default(TReturn)!; @@ -236,7 +253,7 @@ public TReturn HandleCallWithReturn(int memberId, string memberName, ob if (Behavior == MockBehavior.Loose && typeof(TReturn).IsInterface) { var cacheKey = memberName + "|" + typeof(TReturn).FullName; - var autoMock = _autoMockCache.GetOrAdd(cacheKey, _ => + var autoMock = AutoMockCache.GetOrAdd(cacheKey, _ => { Mock.TryCreateAutoMock(typeof(TReturn), Behavior, out var m); return m; @@ -465,7 +482,12 @@ public Diagnostics.MockDiagnostics GetDiagnostics() [EditorBrowsable(EditorBrowsableState.Never)] public bool TryGetAutoMock(string cacheKey, [System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out IMock? mock) { - return _autoMockCache.TryGetValue(cacheKey, out mock); + if (Volatile.Read(ref _autoMockCache) is not { } cache) + { + mock = null; + return false; + } + return cache.TryGetValue(cacheKey, out mock); } /// @@ -483,11 +505,11 @@ public void Reset() // Drain the queue while (_callHistory.TryDequeue(out _)) { } - _autoTrackValues.Clear(); - while (_eventSubscriptions.TryDequeue(out _)) { } - _onSubscribeCallbacks.Clear(); - _onUnsubscribeCallbacks.Clear(); - _autoMockCache.Clear(); + Volatile.Write(ref _autoTrackValues, null); + Volatile.Write(ref _eventSubscriptions, null); + Volatile.Write(ref _onSubscribeCallbacks, null); + Volatile.Write(ref _onUnsubscribeCallbacks, null); + Volatile.Write(ref _autoMockCache, null); } /// @@ -496,7 +518,7 @@ public void Reset() [EditorBrowsable(EditorBrowsableState.Never)] public void OnSubscribe(string eventName, Action callback) { - _onSubscribeCallbacks[eventName] = callback; + OnSubscribeCallbacks[eventName] = callback; } /// @@ -505,7 +527,7 @@ public void OnSubscribe(string eventName, Action callback) [EditorBrowsable(EditorBrowsableState.Never)] public void OnUnsubscribe(string eventName, Action callback) { - _onUnsubscribeCallbacks[eventName] = callback; + OnUnsubscribeCallbacks[eventName] = callback; } /// @@ -514,18 +536,18 @@ public void OnUnsubscribe(string eventName, Action callback) [EditorBrowsable(EditorBrowsableState.Never)] public void RecordEventSubscription(string eventName, bool isSubscribe) { - _eventSubscriptions.Enqueue((eventName, isSubscribe)); + EventSubscriptions.Enqueue((eventName, isSubscribe)); if (isSubscribe) { - if (_onSubscribeCallbacks.TryGetValue(eventName, out var callback)) + if (Volatile.Read(ref _onSubscribeCallbacks) is { } subCallbacks && subCallbacks.TryGetValue(eventName, out var callback)) { callback(); } } else { - if (_onUnsubscribeCallbacks.TryGetValue(eventName, out var callback)) + if (Volatile.Read(ref _onUnsubscribeCallbacks) is { } unsubCallbacks && unsubCallbacks.TryGetValue(eventName, out var callback)) { callback(); } @@ -538,8 +560,9 @@ public void RecordEventSubscription(string eventName, bool isSubscribe) [EditorBrowsable(EditorBrowsableState.Never)] public int GetEventSubscriberCount(string eventName) { + if (Volatile.Read(ref _eventSubscriptions) is not { } subs) return 0; int count = 0; - foreach (var (name, isSub) in _eventSubscriptions) + foreach (var (name, isSub) in subs) { if (name == eventName) { @@ -555,7 +578,8 @@ public int GetEventSubscriberCount(string eventName) [EditorBrowsable(EditorBrowsableState.Never)] public bool WasEventSubscribed(string eventName) { - foreach (var (name, isSub) in _eventSubscriptions) + if (Volatile.Read(ref _eventSubscriptions) is not { } subs) return false; + foreach (var (name, isSub) in subs) { if (name == eventName && isSub) return true; }