diff --git a/Mockly.Specs/HttpMockSpecs.cs b/Mockly.Specs/HttpMockSpecs.cs index 37e2339..717f9cf 100644 --- a/Mockly.Specs/HttpMockSpecs.cs +++ b/Mockly.Specs/HttpMockSpecs.cs @@ -6,6 +6,11 @@ using System.Net.Http.Headers; using System.Text; using System.Text.Json; +#if NET8_0_OR_GREATER +using System.Collections.Concurrent; +using System.Net.Http.Json; +using System.Threading; +#endif using System.Threading.Tasks; using FluentAssertions; using Xunit; @@ -2146,4 +2151,61 @@ public async Task Can_use_custom_options_to_match_request_body_as_object() response.StatusCode.Should().Be(HttpStatusCode.NoContent); } } + + /// + /// Tests for concurrency issues. + /// + /// + /// Only really proves that parallelism didn't work sometimes. + /// But doesn't necessarily prove that everything is thread safe now. + /// This test had to be run ~10 times without thread safety fixes before a race condition occurred. + /// + public class WhenInMultiThreadedContext + { +#if NET6_0_OR_GREATER + [Fact] + public async Task Can_handle_parallel_scenario() + { + // Arrange + var mock = new HttpMock(); + + var testData = new TestData + { + Id = 123, + Name = "Test" + }; + + mock.ForGet().WithPath("/api/data").RespondsWithJsonContent(testData); + + // Act + ConcurrentBag responses = []; + + var options = new ParallelOptions + { + CancellationToken = new CancellationToken(canceled: false), + MaxDegreeOfParallelism = 50, + }; + + var client = mock.GetClient(); + + await Parallel.ForAsync(0, 1000, options, async (_, token) => + { + var response = await client.GetFromJsonAsync("https://localhost/api/data", token); + responses.Add(response); + }); + + // Assert + responses.Count.Should().Be(1000); + responses.Should().AllBeEquivalentTo(testData); + } + + private class TestData + { + public int Id { get; init; } + + public string Name { get; init; } + } +#endif + } + } diff --git a/Mockly/RequestCollection.cs b/Mockly/RequestCollection.cs index 8fb47f7..08bc20b 100644 --- a/Mockly/RequestCollection.cs +++ b/Mockly/RequestCollection.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; namespace Mockly; @@ -10,15 +11,20 @@ namespace Mockly; Justification = "Collection suffix is appropriate for this class")] public class RequestCollection : IEnumerable { - private readonly List requests = new(); + private readonly ConcurrentQueue requests = new(); + private readonly object addLock = new(); /// /// Adds a captured request to the collection. /// internal void Add(CapturedRequest request) { - requests.Add(request); - request.Sequence = requests.Count; + // Although the requests collection is thread-safe, using a lock ensures thread safety for sequence assignment. + lock (addLock) + { + requests.Enqueue(request); + request.Sequence = requests.Count; + } } /// @@ -29,7 +35,7 @@ internal void Add(CapturedRequest request) /// /// Checks if the collection is empty. /// - public bool IsEmpty => requests.Count == 0; + public bool IsEmpty => requests.IsEmpty; /// /// Checks if any unexpected requests were captured. diff --git a/Mockly/RequestMock.cs b/Mockly/RequestMock.cs index 7b05b25..fb530b0 100644 --- a/Mockly/RequestMock.cs +++ b/Mockly/RequestMock.cs @@ -14,6 +14,8 @@ namespace Mockly; public class RequestMock { private static readonly ConcurrentDictionary RegexCache = new(StringComparer.OrdinalIgnoreCase); + private readonly object hostNormalizationLock = new(); + private int invocationCount; private bool hostPatternNormalized; @@ -32,12 +34,12 @@ public class RequestMock public Func Responder { get; set; } = _ => new HttpResponseMessage(); - public RequestCollection? RequestCollection { get; init; } = new(); + public RequestCollection? RequestCollection { get; init; } = []; /// /// Gets a value determining how many times this mock has been invoked. /// - public int InvocationCount { get; private set; } + public int InvocationCount => Volatile.Read(ref invocationCount); /// /// Gets the maximum number of times this mock can be invoked. @@ -49,7 +51,7 @@ public class RequestMock /// Gets a value indicating whether this mock has been exhausted, i.e. /// it has been invoked at least times when that limit is set. /// - internal bool IsExhausted => MaxInvocations is not null && InvocationCount >= MaxInvocations; + internal bool IsExhausted => MaxInvocations is not null && Volatile.Read(ref invocationCount) >= MaxInvocations; /// /// Checks if this mock matches the given request. @@ -133,16 +135,19 @@ public async Task Matches(RequestInfo request) /// private void NormalizeHostPatternOnce() { - if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*") + lock (hostNormalizationLock) { - string[] segments = HostPattern.Split(':'); - if (segments.Length == 1) + if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*") { - HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80"; + string[] segments = HostPattern.Split(':'); + if (segments.Length == 1) + { + HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80"; + } } - } - hostPatternNormalized = true; + hostPatternNormalized = true; + } } /// @@ -251,7 +256,7 @@ private static bool MatchesPattern(string value, string pattern) /// public CapturedRequest TrackRequest(RequestInfo request) { - InvocationCount++; + Interlocked.Increment(ref invocationCount); CapturedRequest capturedRequest = new(request) {