From 24fff8987d00eb2f44f60861cbdc687c8ccb98db Mon Sep 17 00:00:00 2001 From: RoccoDevs Date: Mon, 30 Mar 2026 15:26:05 +0200 Subject: [PATCH] Ensure thread safety for concurrent request handling Addresses race conditions when multiple requests are processed simultaneously by adding locking to request collection and host normalization logic. Also ensures atomic updates to invocation counters and adds a parallel test scenario to verify multi-threaded behavior. --- Mockly.Specs/HttpMockSpecs.cs | 62 +++++++++++++++++++++++++++++++++++ Mockly/RequestCollection.cs | 14 +++++--- Mockly/RequestMock.cs | 25 ++++++++------ 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/Mockly.Specs/HttpMockSpecs.cs b/Mockly.Specs/HttpMockSpecs.cs index 37e2339..717f9cf 100644 --- a/Mockly.Specs/HttpMockSpecs.cs +++ b/Mockly.Specs/HttpMockSpecs.cs @@ -6,6 +6,11 @@ using System.Net.Http.Headers; using System.Text; using System.Text.Json; +#if NET8_0_OR_GREATER +using System.Collections.Concurrent; +using System.Net.Http.Json; +using System.Threading; +#endif using System.Threading.Tasks; using FluentAssertions; using Xunit; @@ -2146,4 +2151,61 @@ public async Task Can_use_custom_options_to_match_request_body_as_object() response.StatusCode.Should().Be(HttpStatusCode.NoContent); } } + + /// + /// Tests for concurrency issues. + /// + /// + /// Only really proves that parallelism didn't work sometimes. + /// But doesn't necessarily prove that everything is thread safe now. + /// This test had to be run ~10 times without thread safety fixes before a race condition occurred. + /// + public class WhenInMultiThreadedContext + { +#if NET6_0_OR_GREATER + [Fact] + public async Task Can_handle_parallel_scenario() + { + // Arrange + var mock = new HttpMock(); + + var testData = new TestData + { + Id = 123, + Name = "Test" + }; + + mock.ForGet().WithPath("/api/data").RespondsWithJsonContent(testData); + + // Act + ConcurrentBag responses = []; + + var options = new ParallelOptions + { + CancellationToken = new CancellationToken(canceled: false), + MaxDegreeOfParallelism = 50, + }; + + var client = mock.GetClient(); + + await Parallel.ForAsync(0, 1000, options, async (_, token) => + { + var response = await client.GetFromJsonAsync("https://localhost/api/data", token); + responses.Add(response); + }); + + // Assert + responses.Count.Should().Be(1000); + responses.Should().AllBeEquivalentTo(testData); + } + + private class TestData + { + public int Id { get; init; } + + public string Name { get; init; } + } +#endif + } + } diff --git a/Mockly/RequestCollection.cs b/Mockly/RequestCollection.cs index 8fb47f7..08bc20b 100644 --- a/Mockly/RequestCollection.cs +++ b/Mockly/RequestCollection.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; namespace Mockly; @@ -10,15 +11,20 @@ namespace Mockly; Justification = "Collection suffix is appropriate for this class")] public class RequestCollection : IEnumerable { - private readonly List requests = new(); + private readonly ConcurrentQueue requests = new(); + private readonly object addLock = new(); /// /// Adds a captured request to the collection. /// internal void Add(CapturedRequest request) { - requests.Add(request); - request.Sequence = requests.Count; + // Although the requests collection is thread-safe, using a lock ensures thread safety for sequence assignment. + lock (addLock) + { + requests.Enqueue(request); + request.Sequence = requests.Count; + } } /// @@ -29,7 +35,7 @@ internal void Add(CapturedRequest request) /// /// Checks if the collection is empty. /// - public bool IsEmpty => requests.Count == 0; + public bool IsEmpty => requests.IsEmpty; /// /// Checks if any unexpected requests were captured. diff --git a/Mockly/RequestMock.cs b/Mockly/RequestMock.cs index 7b05b25..fb530b0 100644 --- a/Mockly/RequestMock.cs +++ b/Mockly/RequestMock.cs @@ -14,6 +14,8 @@ namespace Mockly; public class RequestMock { private static readonly ConcurrentDictionary RegexCache = new(StringComparer.OrdinalIgnoreCase); + private readonly object hostNormalizationLock = new(); + private int invocationCount; private bool hostPatternNormalized; @@ -32,12 +34,12 @@ public class RequestMock public Func Responder { get; set; } = _ => new HttpResponseMessage(); - public RequestCollection? RequestCollection { get; init; } = new(); + public RequestCollection? RequestCollection { get; init; } = []; /// /// Gets a value determining how many times this mock has been invoked. /// - public int InvocationCount { get; private set; } + public int InvocationCount => Volatile.Read(ref invocationCount); /// /// Gets the maximum number of times this mock can be invoked. @@ -49,7 +51,7 @@ public class RequestMock /// Gets a value indicating whether this mock has been exhausted, i.e. /// it has been invoked at least times when that limit is set. /// - internal bool IsExhausted => MaxInvocations is not null && InvocationCount >= MaxInvocations; + internal bool IsExhausted => MaxInvocations is not null && Volatile.Read(ref invocationCount) >= MaxInvocations; /// /// Checks if this mock matches the given request. @@ -133,16 +135,19 @@ public async Task Matches(RequestInfo request) /// private void NormalizeHostPatternOnce() { - if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*") + lock (hostNormalizationLock) { - string[] segments = HostPattern.Split(':'); - if (segments.Length == 1) + if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*") { - HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80"; + string[] segments = HostPattern.Split(':'); + if (segments.Length == 1) + { + HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80"; + } } - } - hostPatternNormalized = true; + hostPatternNormalized = true; + } } /// @@ -251,7 +256,7 @@ private static bool MatchesPattern(string value, string pattern) /// public CapturedRequest TrackRequest(RequestInfo request) { - InvocationCount++; + Interlocked.Increment(ref invocationCount); CapturedRequest capturedRequest = new(request) {