Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions TUnit.Mocks/MockEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ public sealed class MockEngine<T> : IMockEngineAccess where T : class
private readonly Dictionary<int, List<MethodSetup>> _setupsByMember = new();
private readonly Lock _setupLock = new();
private readonly ConcurrentQueue<CallRecord> _callHistory = new();
private readonly ConcurrentDictionary<string, object?> _autoTrackValues = new();
private readonly ConcurrentQueue<(string EventName, bool IsSubscribe)> _eventSubscriptions = new();
private readonly ConcurrentDictionary<string, Action> _onSubscribeCallbacks = new();
private readonly ConcurrentDictionary<string, Action> _onUnsubscribeCallbacks = new();
private readonly ConcurrentDictionary<string, IMock?> _autoMockCache = new();

private volatile ConcurrentDictionary<string, object?>? _autoTrackValues;
private volatile ConcurrentQueue<(string EventName, bool IsSubscribe)>? _eventSubscriptions;
private volatile ConcurrentDictionary<string, Action>? _onSubscribeCallbacks;
private volatile ConcurrentDictionary<string, Action>? _onUnsubscribeCallbacks;
private volatile ConcurrentDictionary<string, IMock?>? _autoMockCache;

/// <summary>
/// The current state name for state machine mocking. Null means no state (all setups match).
Expand Down Expand Up @@ -77,6 +78,21 @@ public MockEngine(MockBehavior behavior)
Behavior = behavior;
}

private ConcurrentDictionary<string, object?> AutoTrackValues
=> _autoTrackValues ?? Interlocked.CompareExchange(ref _autoTrackValues, new(), null) ?? _autoTrackValues;

private ConcurrentQueue<(string EventName, bool IsSubscribe)> EventSubscriptions
=> _eventSubscriptions ?? Interlocked.CompareExchange(ref _eventSubscriptions, new(), null) ?? _eventSubscriptions;

private ConcurrentDictionary<string, Action> OnSubscribeCallbacks
=> _onSubscribeCallbacks ?? Interlocked.CompareExchange(ref _onSubscribeCallbacks, new(), null) ?? _onSubscribeCallbacks;

private ConcurrentDictionary<string, Action> OnUnsubscribeCallbacks
=> _onUnsubscribeCallbacks ?? Interlocked.CompareExchange(ref _onUnsubscribeCallbacks, new(), null) ?? _onUnsubscribeCallbacks;

private ConcurrentDictionary<string, IMock?> AutoMockCache
=> _autoMockCache ?? Interlocked.CompareExchange(ref _autoMockCache, new(), null) ?? _autoMockCache;

/// <summary>
/// Transitions the engine to the specified state. Null clears the state.
/// </summary>
Expand Down Expand Up @@ -125,7 +141,7 @@ public void HandleCall(int memberId, string memberName, object?[] args)
// Auto-track property setters: store value keyed by property name
if (AutoTrackProperties && memberName.StartsWith("set_", StringComparison.Ordinal) && args.Length > 0)
{
_autoTrackValues[memberName.Substring(4)] = args[0];
AutoTrackValues[memberName[4..]] = args[0];
}

var (setupFound, behavior, matchedSetup) = FindMatchingSetup(memberId, args);
Expand Down Expand Up @@ -211,9 +227,9 @@ public TReturn HandleCallWithReturn<TReturn>(int memberId, string memberName, ob
callRecord.IsUnmatched = true;

// Auto-track property getters: return stored value if available
if (AutoTrackProperties && memberName.StartsWith("get_", StringComparison.Ordinal))
if (AutoTrackProperties && _autoTrackValues is not null && memberName.StartsWith("get_", StringComparison.Ordinal))
{
if (_autoTrackValues.TryGetValue(memberName.Substring(4), out var trackedValue))
if (_autoTrackValues.TryGetValue(memberName[4..], out var trackedValue))
{
if (trackedValue is TReturn typed) return typed;
if (trackedValue is null) return default(TReturn)!;
Expand All @@ -236,7 +252,7 @@ public TReturn HandleCallWithReturn<TReturn>(int memberId, string memberName, ob
if (Behavior == MockBehavior.Loose && typeof(TReturn).IsInterface)
{
var cacheKey = memberName + "|" + typeof(TReturn).FullName;
var autoMock = _autoMockCache.GetOrAdd(cacheKey, _ =>
var autoMock = AutoMockCache.GetOrAdd(cacheKey, _ =>
{
Mock.TryCreateAutoMock(typeof(TReturn), Behavior, out var m);
return m;
Expand Down Expand Up @@ -465,6 +481,11 @@ public Diagnostics.MockDiagnostics GetDiagnostics()
[EditorBrowsable(EditorBrowsableState.Never)]
public bool TryGetAutoMock(string cacheKey, [System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out IMock? mock)
{
if (_autoMockCache is null)
{
mock = null;
return false;
}
return _autoMockCache.TryGetValue(cacheKey, out mock);
}

Expand All @@ -483,11 +504,11 @@ public void Reset()
// Drain the queue
while (_callHistory.TryDequeue(out _)) { }

_autoTrackValues.Clear();
while (_eventSubscriptions.TryDequeue(out _)) { }
_onSubscribeCallbacks.Clear();
_onUnsubscribeCallbacks.Clear();
_autoMockCache.Clear();
_autoTrackValues = null;
_eventSubscriptions = null;
_onSubscribeCallbacks = null;
_onUnsubscribeCallbacks = null;
_autoMockCache = null;
}

/// <summary>
Expand All @@ -496,7 +517,7 @@ public void Reset()
[EditorBrowsable(EditorBrowsableState.Never)]
public void OnSubscribe(string eventName, Action callback)
{
_onSubscribeCallbacks[eventName] = callback;
OnSubscribeCallbacks[eventName] = callback;
}

/// <summary>
Expand All @@ -505,7 +526,7 @@ public void OnSubscribe(string eventName, Action callback)
[EditorBrowsable(EditorBrowsableState.Never)]
public void OnUnsubscribe(string eventName, Action callback)
{
_onUnsubscribeCallbacks[eventName] = callback;
OnUnsubscribeCallbacks[eventName] = callback;
}

/// <summary>
Expand All @@ -514,18 +535,18 @@ public void OnUnsubscribe(string eventName, Action callback)
[EditorBrowsable(EditorBrowsableState.Never)]
public void RecordEventSubscription(string eventName, bool isSubscribe)
{
_eventSubscriptions.Enqueue((eventName, isSubscribe));
EventSubscriptions.Enqueue((eventName, isSubscribe));

if (isSubscribe)
{
if (_onSubscribeCallbacks.TryGetValue(eventName, out var callback))
if (_onSubscribeCallbacks is not null && _onSubscribeCallbacks.TryGetValue(eventName, out var callback))
{
callback();
}
}
else
{
if (_onUnsubscribeCallbacks.TryGetValue(eventName, out var callback))
if (_onUnsubscribeCallbacks is not null && _onUnsubscribeCallbacks.TryGetValue(eventName, out var callback))
{
callback();
}
Expand All @@ -538,6 +559,7 @@ public void RecordEventSubscription(string eventName, bool isSubscribe)
[EditorBrowsable(EditorBrowsableState.Never)]
public int GetEventSubscriberCount(string eventName)
{
if (_eventSubscriptions is null) return 0;
int count = 0;
foreach (var (name, isSub) in _eventSubscriptions)
{
Expand All @@ -555,6 +577,7 @@ public int GetEventSubscriberCount(string eventName)
[EditorBrowsable(EditorBrowsableState.Never)]
public bool WasEventSubscribed(string eventName)
{
if (_eventSubscriptions is null) return false;
foreach (var (name, isSub) in _eventSubscriptions)
{
if (name == eventName && isSub) return true;
Expand Down
Loading