Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions TUnit.Mocks/MockEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using TUnit.Mocks.Setup.Behaviors;
using TUnit.Mocks.Verification;
using System.Collections.Concurrent;
using System.Threading;
using System.ComponentModel;

namespace TUnit.Mocks;
Expand All @@ -24,11 +25,12 @@ public sealed class MockEngine<T> : IMockEngineAccess where T : class
private readonly Dictionary<int, List<MethodSetup>> _setupsByMember = new();
private readonly Lock _setupLock = new();
private readonly ConcurrentQueue<CallRecord> _callHistory = new();
private readonly ConcurrentDictionary<string, object?> _autoTrackValues = new();
private readonly ConcurrentQueue<(string EventName, bool IsSubscribe)> _eventSubscriptions = new();
private readonly ConcurrentDictionary<string, Action> _onSubscribeCallbacks = new();
private readonly ConcurrentDictionary<string, Action> _onUnsubscribeCallbacks = new();
private readonly ConcurrentDictionary<string, IMock?> _autoMockCache = new();

private ConcurrentDictionary<string, object?>? _autoTrackValues;
private ConcurrentQueue<(string EventName, bool IsSubscribe)>? _eventSubscriptions;
private ConcurrentDictionary<string, Action>? _onSubscribeCallbacks;
private ConcurrentDictionary<string, Action>? _onUnsubscribeCallbacks;
private ConcurrentDictionary<string, IMock?>? _autoMockCache;

/// <summary>
/// The current state name for state machine mocking. Null means no state (all setups match).
Expand Down Expand Up @@ -77,6 +79,21 @@ public MockEngine(MockBehavior behavior)
Behavior = behavior;
}

private ConcurrentDictionary<string, object?> AutoTrackValues
=> LazyInitializer.EnsureInitialized(ref _autoTrackValues)!;

private ConcurrentQueue<(string EventName, bool IsSubscribe)> EventSubscriptions
=> LazyInitializer.EnsureInitialized(ref _eventSubscriptions)!;

private ConcurrentDictionary<string, Action> OnSubscribeCallbacks
=> LazyInitializer.EnsureInitialized(ref _onSubscribeCallbacks)!;

private ConcurrentDictionary<string, Action> OnUnsubscribeCallbacks
=> LazyInitializer.EnsureInitialized(ref _onUnsubscribeCallbacks)!;

private ConcurrentDictionary<string, IMock?> AutoMockCache
=> LazyInitializer.EnsureInitialized(ref _autoMockCache)!;

/// <summary>
/// Transitions the engine to the specified state. Null clears the state.
/// </summary>
Expand Down Expand Up @@ -125,7 +142,7 @@ public void HandleCall(int memberId, string memberName, object?[] args)
// Auto-track property setters: store value keyed by property name
if (AutoTrackProperties && memberName.StartsWith("set_", StringComparison.Ordinal) && args.Length > 0)
{
_autoTrackValues[memberName.Substring(4)] = args[0];
AutoTrackValues[memberName[4..]] = args[0];
}

var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args);
Expand Down Expand Up @@ -211,9 +228,9 @@ public TReturn HandleCallWithReturn<TReturn>(int memberId, string memberName, ob
callRecord.IsUnmatched = true;

// Auto-track property getters: return stored value if available
if (AutoTrackProperties && memberName.StartsWith("get_", StringComparison.Ordinal))
if (AutoTrackProperties && Volatile.Read(ref _autoTrackValues) is { } trackValues && memberName.StartsWith("get_", StringComparison.Ordinal))
{
if (_autoTrackValues.TryGetValue(memberName.Substring(4), out var trackedValue))
if (trackValues.TryGetValue(memberName[4..], out var trackedValue))
{
if (trackedValue is TReturn typed) return typed;
if (trackedValue is null) return default(TReturn)!;
Expand All @@ -236,7 +253,7 @@ public TReturn HandleCallWithReturn<TReturn>(int memberId, string memberName, ob
if (Behavior == MockBehavior.Loose && typeof(TReturn).IsInterface)
{
var cacheKey = memberName + "|" + typeof(TReturn).FullName;
var autoMock = _autoMockCache.GetOrAdd(cacheKey, _ =>
var autoMock = AutoMockCache.GetOrAdd(cacheKey, _ =>
{
Mock.TryCreateAutoMock(typeof(TReturn), Behavior, out var m);
return m;
Expand Down Expand Up @@ -465,7 +482,12 @@ public Diagnostics.MockDiagnostics GetDiagnostics()
[EditorBrowsable(EditorBrowsableState.Never)]
public bool TryGetAutoMock(string cacheKey, [System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out IMock? mock)
{
return _autoMockCache.TryGetValue(cacheKey, out mock);
if (Volatile.Read(ref _autoMockCache) is not { } cache)
{
mock = null;
return false;
}
return cache.TryGetValue(cacheKey, out mock);
}

/// <summary>
Expand All @@ -483,11 +505,11 @@ public void Reset()
// Drain the queue
while (_callHistory.TryDequeue(out _)) { }

_autoTrackValues.Clear();
while (_eventSubscriptions.TryDequeue(out _)) { }
_onSubscribeCallbacks.Clear();
_onUnsubscribeCallbacks.Clear();
_autoMockCache.Clear();
Volatile.Write(ref _autoTrackValues, null);
Volatile.Write(ref _eventSubscriptions, null);
Volatile.Write(ref _onSubscribeCallbacks, null);
Volatile.Write(ref _onUnsubscribeCallbacks, null);
Volatile.Write(ref _autoMockCache, null);
}

/// <summary>
Expand All @@ -496,7 +518,7 @@ public void Reset()
[EditorBrowsable(EditorBrowsableState.Never)]
public void OnSubscribe(string eventName, Action callback)
{
_onSubscribeCallbacks[eventName] = callback;
OnSubscribeCallbacks[eventName] = callback;
}

/// <summary>
Expand All @@ -505,7 +527,7 @@ public void OnSubscribe(string eventName, Action callback)
[EditorBrowsable(EditorBrowsableState.Never)]
public void OnUnsubscribe(string eventName, Action callback)
{
_onUnsubscribeCallbacks[eventName] = callback;
OnUnsubscribeCallbacks[eventName] = callback;
}

/// <summary>
Expand All @@ -514,18 +536,18 @@ public void OnUnsubscribe(string eventName, Action callback)
[EditorBrowsable(EditorBrowsableState.Never)]
public void RecordEventSubscription(string eventName, bool isSubscribe)
{
_eventSubscriptions.Enqueue((eventName, isSubscribe));
EventSubscriptions.Enqueue((eventName, isSubscribe));

if (isSubscribe)
{
if (_onSubscribeCallbacks.TryGetValue(eventName, out var callback))
if (Volatile.Read(ref _onSubscribeCallbacks) is { } subCallbacks && subCallbacks.TryGetValue(eventName, out var callback))
{
callback();
}
}
else
{
if (_onUnsubscribeCallbacks.TryGetValue(eventName, out var callback))
if (Volatile.Read(ref _onUnsubscribeCallbacks) is { } unsubCallbacks && unsubCallbacks.TryGetValue(eventName, out var callback))
{
callback();
}
Expand All @@ -538,8 +560,9 @@ public void RecordEventSubscription(string eventName, bool isSubscribe)
[EditorBrowsable(EditorBrowsableState.Never)]
public int GetEventSubscriberCount(string eventName)
{
if (Volatile.Read(ref _eventSubscriptions) is not { } subs) return 0;
int count = 0;
foreach (var (name, isSub) in _eventSubscriptions)
foreach (var (name, isSub) in subs)
{
if (name == eventName)
{
Expand All @@ -555,7 +578,8 @@ public int GetEventSubscriberCount(string eventName)
[EditorBrowsable(EditorBrowsableState.Never)]
public bool WasEventSubscribed(string eventName)
{
foreach (var (name, isSub) in _eventSubscriptions)
if (Volatile.Read(ref _eventSubscriptions) is not { } subs) return false;
foreach (var (name, isSub) in subs)
{
if (name == eventName && isSub) return true;
}
Expand Down
Loading