Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions Mockly.Specs/HttpMockSpecs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
#if NET8_0_OR_GREATER
using System.Collections.Concurrent;
using System.Net.Http.Json;
using System.Threading;
Comment thread Fixed
Comment thread
RoccoDevs marked this conversation as resolved.
#endif
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;
Expand Down Expand Up @@ -2146,4 +2151,61 @@ public async Task Can_use_custom_options_to_match_request_body_as_object()
response.StatusCode.Should().Be(HttpStatusCode.NoContent);
}
}

/// <summary>
/// Tests for concurrency issues.
/// </summary>
/// <remarks>
/// Only really proves that parallelism didn't work sometimes.
/// But doesn't necessarily prove that everything is thread safe now.
/// This test had to be run ~10 times without thread safety fixes before a race condition occurred.
Comment thread
RoccoDevs marked this conversation as resolved.
/// </remarks>
public class WhenInMultiThreadedContext
{
#if NET6_0_OR_GREATER
[Fact]
public async Task Can_handle_parallel_scenario()
{
Comment thread
RoccoDevs marked this conversation as resolved.
// Arrange
var mock = new HttpMock();

var testData = new TestData
{
Id = 123,
Name = "Test"
};

mock.ForGet().WithPath("/api/data").RespondsWithJsonContent(testData);

// Act
ConcurrentBag<TestData> responses = [];

var options = new ParallelOptions
{
CancellationToken = new CancellationToken(canceled: false),
MaxDegreeOfParallelism = 50,
};

var client = mock.GetClient();

await Parallel.ForAsync(0, 1000, options, async (_, token) =>
{
var response = await client.GetFromJsonAsync<TestData>("https://localhost/api/data", token);
responses.Add(response);
});

// Assert
responses.Count.Should().Be(1000);
responses.Should().AllBeEquivalentTo(testData);
}

private class TestData
{
public int Id { get; init; }

public string Name { get; init; }
}
#endif
}

}
14 changes: 10 additions & 4 deletions Mockly/RequestCollection.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;

namespace Mockly;
Expand All @@ -10,15 +11,20 @@ namespace Mockly;
Justification = "Collection suffix is appropriate for this class")]
public class RequestCollection : IEnumerable<CapturedRequest>
{
private readonly List<CapturedRequest> requests = new();
private readonly ConcurrentQueue<CapturedRequest> requests = new();
private readonly object addLock = new();

/// <summary>
/// Adds a captured request to the collection.
/// </summary>
internal void Add(CapturedRequest request)
{
requests.Add(request);
request.Sequence = requests.Count;
// Although the requests collection is thread-safe, using a lock ensures thread safety for sequence assignment.
lock (addLock)
{
requests.Enqueue(request);
request.Sequence = requests.Count;
}
Comment thread
RoccoDevs marked this conversation as resolved.
}

/// <summary>
Expand All @@ -29,7 +35,7 @@ internal void Add(CapturedRequest request)
/// <summary>
/// Checks if the collection is empty.
/// </summary>
public bool IsEmpty => requests.Count == 0;
public bool IsEmpty => requests.IsEmpty;

/// <summary>
/// Checks if any unexpected requests were captured.
Expand Down
25 changes: 15 additions & 10 deletions Mockly/RequestMock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Mockly;
public class RequestMock
{
private static readonly ConcurrentDictionary<string, Regex> RegexCache = new(StringComparer.OrdinalIgnoreCase);
private readonly object hostNormalizationLock = new();
private int invocationCount;

private bool hostPatternNormalized;

Expand All @@ -32,12 +34,12 @@ public class RequestMock

public Func<RequestInfo, HttpResponseMessage> Responder { get; set; } = _ => new HttpResponseMessage();

public RequestCollection? RequestCollection { get; init; } = new();
public RequestCollection? RequestCollection { get; init; } = [];

/// <summary>
/// Gets a value determining how many times this mock has been invoked.
/// </summary>
public int InvocationCount { get; private set; }
public int InvocationCount => Volatile.Read(ref invocationCount);

/// <summary>
/// Gets the maximum number of times this mock can be invoked.
Expand All @@ -49,7 +51,7 @@ public class RequestMock
/// Gets a value indicating whether this mock has been exhausted, i.e.
/// it has been invoked at least <see cref="MaxInvocations"/> times when that limit is set.
/// </summary>
internal bool IsExhausted => MaxInvocations is not null && InvocationCount >= MaxInvocations;
internal bool IsExhausted => MaxInvocations is not null && Volatile.Read(ref invocationCount) >= MaxInvocations;

/// <summary>
/// Checks if this mock matches the given request.
Expand Down Expand Up @@ -133,16 +135,19 @@ public async Task<bool> Matches(RequestInfo request)
/// </summary>
private void NormalizeHostPatternOnce()
{
if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*")
lock (hostNormalizationLock)
{
string[] segments = HostPattern.Split(':');
if (segments.Length == 1)
if (!hostPatternNormalized && HostPattern is not null && HostPattern != "*")
{
HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80";
string[] segments = HostPattern.Split(':');
if (segments.Length == 1)
{
HostPattern += Scheme!.Equals("https", StringComparison.OrdinalIgnoreCase) ? ":443" : ":80";
}
Comment thread
RoccoDevs marked this conversation as resolved.
}
}

hostPatternNormalized = true;
hostPatternNormalized = true;
}
}

/// <summary>
Expand Down Expand Up @@ -251,7 +256,7 @@ private static bool MatchesPattern(string value, string pattern)
/// </summary>
public CapturedRequest TrackRequest(RequestInfo request)
{
InvocationCount++;
Interlocked.Increment(ref invocationCount);
Comment thread
RoccoDevs marked this conversation as resolved.

CapturedRequest capturedRequest = new(request)
{
Expand Down
Loading