diff --git a/Sources/MockingKit/Mock.swift b/Sources/MockingKit/Mock.swift index f1c42cc..1981390 100644 --- a/Sources/MockingKit/Mock.swift +++ b/Sources/MockingKit/Mock.swift @@ -16,7 +16,7 @@ import Foundation /// /// Inherit this type instead of implementing the ``Mockable`` /// protocol, to save some code for every mock you create. -open class Mock: Mockable { +open class Mock: Mockable, @unchecked Sendable { public init() {} @@ -24,4 +24,5 @@ open class Mock: Mockable { var registeredCalls: [UUID: [AnyCall]] = [:] var registeredResults: [UUID: Function] = [:] + let registeredCallsLock = NSLock() } diff --git a/Sources/MockingKit/Mockable.swift b/Sources/MockingKit/Mockable.swift index 326aa6e..a885ccc 100644 --- a/Sources/MockingKit/Mockable.swift +++ b/Sources/MockingKit/Mockable.swift @@ -22,7 +22,7 @@ import Foundation /// /// Implement this protocol instead of inheriting the ``Mock`` /// base class, to save some code for every mock you create. -public protocol Mockable { +public protocol Mockable: Sendable { typealias Function = Any @@ -38,41 +38,55 @@ extension Mockable { _ call: MockCall, for ref: MockReference ) { - let calls = mock.registeredCalls[ref.id] ?? [] - mock.registeredCalls[ref.id] = calls + [call] + mock.registeredCallsLock.withLock { + let calls = mock.registeredCalls[ref.id] ?? [] + mock.registeredCalls[ref.id] = calls + [call] + } } func registerCall( _ call: MockCall, for ref: AsyncMockReference ) { - let calls = mock.registeredCalls[ref.id] ?? [] - mock.registeredCalls[ref.id] = calls + [call] + mock.registeredCallsLock.withLock { + let calls = mock.registeredCalls[ref.id] ?? [] + mock.registeredCalls[ref.id] = calls + [call] + } } func registeredCalls( for ref: MockReference ) -> [MockCall] { - let calls = mock.registeredCalls[ref.id] - return (calls as? [MockCall]) ?? [] + mock.registeredCallsLock.withLock { + let calls = mock.registeredCalls[ref.id] + return (calls as? [MockCall]) ?? [] + } } func registeredCalls( for ref: AsyncMockReference ) -> [MockCall] { - let calls = mock.registeredCalls[ref.id] - return (calls as? [MockCall]) ?? [] + mock.registeredCallsLock.withLock { + let calls = mock.registeredCalls[ref.id] + return (calls as? [MockCall]) ?? [] + } } func registeredResult( for ref: MockReference ) -> ((Arguments) throws -> Result)? { - mock.registeredResults[ref.id] as? (Arguments) throws -> Result + mock.registeredCallsLock.withLock { + let result = mock.registeredResults[ref.id] as? (Arguments) throws -> Result + return result + } } func registeredResult( for ref: AsyncMockReference ) -> ((Arguments) async throws -> Result)? { - mock.registeredResults[ref.id] as? (Arguments) async throws -> Result + mock.registeredCallsLock.withLock { + let result = mock.registeredResults[ref.id] as? (Arguments) async throws -> Result + return result + } } } diff --git a/Sources/MockingKit/Mocks/MockPasteboard.swift b/Sources/MockingKit/Mocks/MockPasteboard.swift index 3cc54c6..eceb4b7 100644 --- a/Sources/MockingKit/Mocks/MockPasteboard.swift +++ b/Sources/MockingKit/Mocks/MockPasteboard.swift @@ -32,7 +32,7 @@ import AppKit This mock only mocks `setValue(_:forKey:)` for now, but you can subclass this class and mock more functionality. */ -public class MockPasteboard: NSPasteboard, Mockable { +public class MockPasteboard: NSPasteboard, Mockable, @unchecked Sendable { public lazy var setValueForKeyRef = MockReference(setValueForKey) diff --git a/Sources/MockingKit/Mocks/MockUserDefaults.swift b/Sources/MockingKit/Mocks/MockUserDefaults.swift index 1dfa74b..fe47335 100644 --- a/Sources/MockingKit/Mocks/MockUserDefaults.swift +++ b/Sources/MockingKit/Mocks/MockUserDefaults.swift @@ -9,7 +9,7 @@ import Foundation /// This class can be used to mock `UserDefaults`. -open class MockUserDefaults: UserDefaults, Mockable { +open class MockUserDefaults: UserDefaults, Mockable, @unchecked Sendable { public lazy var boolRef = MockReference(bool) public lazy var arrayRef = MockReference(array) diff --git a/Tests/MockingKitTests/GenericTests.swift b/Tests/MockingKitTests/GenericTests.swift index aa65493..b62d590 100644 --- a/Tests/MockingKitTests/GenericTests.swift +++ b/Tests/MockingKitTests/GenericTests.swift @@ -21,7 +21,7 @@ final class GenericTests: XCTestCase { } } -private class GenericMock: Mock { +private class GenericMock: Mock, @unchecked Sendable { lazy var doitRef = MockReference(doit) diff --git a/Tests/MockingKitTests/MockableAsyncTests.swift b/Tests/MockingKitTests/MockableAsyncTests.swift index 5c65e6f..fbb1dc0 100644 --- a/Tests/MockingKitTests/MockableAsyncTests.swift +++ b/Tests/MockingKitTests/MockableAsyncTests.swift @@ -252,11 +252,33 @@ class MockableAsyncTests: XCTestCase { XCTAssertFalse(mock.hasCalled(mock.functionWithIntResultRef)) XCTAssertTrue(mock.hasCalled(\.functionWithStringResultRef)) } + + func testMultiThreadedAccess_doesNotCorruptState() async throws { + let expectation = XCTestExpectation() + expectation.expectedFulfillmentCount = 2 + let mock = TestClass() + + Task { + for index in 0..<100 { + await mock.functionWithVoidResult(arg1: "Test", arg2: index) + } + expectation.fulfill() + } + + Task { + for _ in 0..<100 { + _ = mock.hasCalled(\.functionWithIntResultRef) + } + expectation.fulfill() + } + + await fulfillment(of: [expectation]) + } } -private class TestClass: AsyncTestProtocol, Mockable { +private final class TestClass: AsyncTestProtocol, Mockable, @unchecked Sendable { - var mock = Mock() + let mock = Mock() lazy var functionWithIntResultRef = AsyncMockReference(functionWithIntResult) lazy var functionWithStringResultRef = AsyncMockReference(functionWithStringResult) diff --git a/Tests/MockingKitTests/MockableTests.swift b/Tests/MockingKitTests/MockableTests.swift index 4730bbf..cc272cc 100644 --- a/Tests/MockingKitTests/MockableTests.swift +++ b/Tests/MockingKitTests/MockableTests.swift @@ -256,11 +256,30 @@ class MockableTests: XCTestCase { XCTAssertFalse(mock.hasCalled(mock.functionWithIntResultRef)) XCTAssertTrue(mock.hasCalled(\.functionWithStringResultRef)) } + + func testMultiThreadedAccess_doesNotCorruptState() { + let queueA = DispatchQueue(label: "QueueA") + let queueB = DispatchQueue(label: "QueueB") + + let mock = TestClass() + + queueA.async { + for index in 0..<100 { + mock.functionWithVoidResult(arg1: "Something", arg2: index) + } + } + + queueB.async { + for _ in 0..<100 { + _ = mock.hasCalled(\.functionWithIntResultRef) + } + } + } } -private class TestClass: AsyncTestProtocol, Mockable { +private final class TestClass: AsyncTestProtocol, Mockable, @unchecked Sendable { - var mock = Mock() + let mock = Mock() lazy var functionWithIntResultRef = MockReference(functionWithIntResult) lazy var functionWithStringResultRef = MockReference(functionWithStringResult)