From e95bb69a537504f55ae53d620c5434a72f76314d Mon Sep 17 00:00:00 2001 From: Brennan Stehling <277419+brennanMKE@users.noreply.github.com> Date: Wed, 10 Aug 2022 13:59:02 -0700 Subject: [PATCH] adds AsyncChannel and AsyncThrowingChannel with tests (#2086) * adds AsyncChannel and AsyncThrowingChannel with tests --- Amplify/Core/Support/AsyncChannel.swift | 104 ++++++ .../Core/Support/AsyncThrowingChannel.swift | 126 +++++++ .../CoreTests/AsyncChannelTests.swift | 327 ++++++++++++++++++ 3 files changed, 557 insertions(+) create mode 100644 Amplify/Core/Support/AsyncChannel.swift create mode 100644 Amplify/Core/Support/AsyncThrowingChannel.swift create mode 100644 AmplifyTests/CoreTests/AsyncChannelTests.swift diff --git a/Amplify/Core/Support/AsyncChannel.swift b/Amplify/Core/Support/AsyncChannel.swift new file mode 100644 index 0000000000..9a97f5cdf3 --- /dev/null +++ b/Amplify/Core/Support/AsyncChannel.swift @@ -0,0 +1,104 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation + +public actor AsyncChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { + private let channel: AsyncChannel + + public init(_ channel: AsyncChannel) { + self.channel = channel + } + + public mutating func next() async -> Element? { + Task.isCancelled ? nil : await channel.next() + } + } + + public enum InternalFailure: Error { + case cannotSendAfterTerminated + } + public typealias ChannelContinuation = CheckedContinuation + + private var continuations: [ChannelContinuation] = [] + private var elements: [Element] = [] + private var cancelled: Bool = false + private var terminated: Bool = false + + private var hasNext: Bool { + !continuations.isEmpty && !elements.isEmpty + } + + private var canTerminate: Bool { + terminated && elements.isEmpty && !continuations.isEmpty + } + + init() { + } + + public nonisolated func makeAsyncIterator() -> Iterator { + Iterator(self) + } + + public func next() async -> Element? { + if cancelled || terminated { + return nil + } + return await withCheckedContinuation { (continuation: ChannelContinuation) in + continuations.append(continuation) + processNext() + } + } + + public func send(_ element: Element) throws { + if Task.isCancelled { + cancelled = true + processNext() + throw CancellationError() + } + guard !terminated else { + throw InternalFailure.cannotSendAfterTerminated + } + elements.append(element) + processNext() + } + + public func finish() { + terminated = true + processNext() + } + + private func processNext() { + if cancelled && !continuations.isEmpty { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + continuation.resume(returning: nil) + return + } + + if canTerminate { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + continuation.resume(returning: nil) + return + } + + guard hasNext else { + return + } + + assert(!continuations.isEmpty) + assert(!elements.isEmpty) + + let continuation = continuations.removeFirst() + let element = elements.removeFirst() + + continuation.resume(returning: element) + } +} diff --git a/Amplify/Core/Support/AsyncThrowingChannel.swift b/Amplify/Core/Support/AsyncThrowingChannel.swift new file mode 100644 index 0000000000..30ada211d1 --- /dev/null +++ b/Amplify/Core/Support/AsyncThrowingChannel.swift @@ -0,0 +1,126 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation + +public actor AsyncThrowingChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { + private let channel: AsyncThrowingChannel + + public init(_ channel: AsyncThrowingChannel) { + self.channel = channel + } + + public mutating func next() async throws -> Element? { + try Task.checkCancellation() + return try await channel.next() + } + } + + public enum InternalFailure: Error { + case cannotSendAfterTerminated + } + public typealias ChannelContinuation = CheckedContinuation + + private var continuations: [ChannelContinuation] = [] + private var elements: [Element] = [] + private var cancelled: Bool = false + private var terminated: Bool = false + private var error: Error? = nil + + private var hasNext: Bool { + !continuations.isEmpty && !elements.isEmpty + } + + private var canFail: Bool { + error != nil && !continuations.isEmpty + } + + private var canTerminate: Bool { + terminated && elements.isEmpty && !continuations.isEmpty + } + + init() { + } + + public nonisolated func makeAsyncIterator() -> Iterator { + Iterator(self) + } + + public func next() async throws -> Element? { + if cancelled { + throw CancellationError() + } + return try await withCheckedThrowingContinuation { (continuation: ChannelContinuation) in + continuations.append(continuation) + processNext() + } + } + + public func send(_ element: Element) throws { + if Task.isCancelled { + cancelled = true + processNext() + throw CancellationError() + } + guard !terminated else { + throw InternalFailure.cannotSendAfterTerminated + } + elements.append(element) + processNext() + } + + public func fail(_ error: Error) where Failure == Error { + self.error = error + processNext() + } + + public func finish() { + terminated = true + processNext() + } + + private func processNext() { + if cancelled && !continuations.isEmpty { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + continuation.resume(throwing: CancellationError()) + return + } + + if canFail { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + assert(error != nil) + if let error = error { + continuation.resume(throwing: error) + return + } + } + + if canTerminate { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + continuation.resume(returning: nil) + return + } + + guard hasNext else { + return + } + + assert(!continuations.isEmpty) + assert(!elements.isEmpty) + + let continuation = continuations.removeFirst() + let element = elements.removeFirst() + + continuation.resume(returning: element) + } +} diff --git a/AmplifyTests/CoreTests/AsyncChannelTests.swift b/AmplifyTests/CoreTests/AsyncChannelTests.swift new file mode 100644 index 0000000000..ba02d49e75 --- /dev/null +++ b/AmplifyTests/CoreTests/AsyncChannelTests.swift @@ -0,0 +1,327 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import XCTest +@testable import Amplify + +final class AsyncChannelTests: XCTestCase { + enum Failure: Error { + case unluckyNumber + } + + actor Output { + var elements: [Element] = [] + func append(_ element: Element) { + elements.append(element) + } + } + + let sleepSeconds = 0.1 + + func testNumberSequence() async throws { + let input = [1, 2, 3, 4, 5] + let channel = AsyncChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + } + + var output: [Int] = [] + + print("-- before --") + for try await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testStringSequence() async throws { + let input = ["one", "two", "three", "four", "five"] + let channel = AsyncChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + } + + var output: [String] = [] + + print("-- before --") + for try await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSendAfterFinishing() async throws { + let input = ["a", "b", "c"] + let channel = AsyncChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + var thrown: Error? = nil + do { + try await channel.send("z") + } catch { + thrown = error + } + XCTAssertNotNil(thrown) + } + + var output: [String] = [] + + print("-- before --") + for try await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSendAfterFinishingThrowing() async throws { + let input = ["x", "y", "z"] + let channel = AsyncThrowingChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + var thrown: Error? = nil + do { + try await channel.send("a") + } catch { + thrown = error + } + XCTAssertNotNil(thrown) + } + + var output: [String] = [] + + print("-- before --") + for try await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSucceedingSequence() async throws { + let input = [3, 7, 14, 21] + let channel = AsyncThrowingChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) { element in + if element == 13 { + throw Failure.unluckyNumber + } else { + return element + } + } + } + + var output: [Int] = [] + var thrown: Error? = nil + + print("-- before --") + do { + for try await element in channel { + print(element) + output.append(element) + } + } catch { + thrown = error + } + print("-- after --") + + XCTAssertNil(thrown) + XCTAssertEqual(input, output) + } + + func testFailingSequence() async throws { + let input = [3, 7, 13, 21] + let channel = AsyncThrowingChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) { element in + if element == 13 { + throw Failure.unluckyNumber + } else { + return element + } + } + } + + var output: [Int] = [] + var thrown: Error? = nil + + print("-- before --") + do { + for try await element in channel { + print(element) + output.append(element) + } + } catch { + thrown = error + } + print("-- after --") + + XCTAssertNotNil(thrown) + let expected = Array(input[0..<2]) + XCTAssertEqual(expected, output) + } + + func testChannelCancelled() async throws { + let delay = 1.25 + let input = [1, 2, 3, 4, 5] + let channel = AsyncChannel() + let sendExp = expectation(description: "send") + let reduceExp = expectation(description: "reduce") + + let sendTask = Task { + print("send - start") + var thrown: Error? + do { + var index = 0 + while index < input.count { + try await Task.sleep(seconds: delay) + try await channel.send(input[index]) + index += 1 + } + } catch { + thrown = error + } + print("send - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + sendExp.fulfill() + } + + let reduceTask = Task { + print("reduce - start") + let result = await channel.reduce(0, +) + print(result) + print("reduce - end") + + reduceExp.fulfill() + } + + Task { + try await Task.sleep(seconds: delay * 2) + sendTask.cancel() + } + + await waitForExpectations(timeout: 5.0) + + XCTAssertFalse(reduceTask.isCancelled) + } + + func testThrowingChannelCancelled() async throws { + let delay = 1.25 + let input = [1, 2, 3, 4, 5] + let channel = AsyncThrowingChannel() + let sendExp = expectation(description: "send") + let reduceExp = expectation(description: "reduce") + + let sendTask = Task { + print("send - start") + var thrown: Error? + do { + var index = 0 + while index < input.count { + try await Task.sleep(seconds: delay) + try await channel.send(input[index]) + index += 1 + } + } catch { + thrown = error + } + print("send - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + sendExp.fulfill() + } + + let reduceTask = Task { + print("reduce - start") + var thrown: Error? + do { + let result = try await channel.reduce(0, +) + print(result) + } catch { + thrown = error + } + + print("reduce - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + reduceExp.fulfill() + } + + Task { + try await Task.sleep(seconds: delay * 2) + sendTask.cancel() + } + + await waitForExpectations(timeout: 5.0) + + XCTAssertFalse(reduceTask.isCancelled) + } + + private func send(elements: [Element], channel: AsyncChannel, sleepSeconds: Double = 0.1) async throws { + var index = 0 + while index < elements.count { + try await Task.sleep(seconds: sleepSeconds) + let element = elements[index] + try await channel.send(element) + + index += 1 + } + await channel.finish() + } + + private func send(elements: [Element], channel: AsyncThrowingChannel, sleepSeconds: Double = 0.1, processor: ((Element) throws -> Element)? = nil) async throws { + var index = 0 + while index < elements.count { + try await Task.sleep(seconds: sleepSeconds) + let element = elements[index] + if let processor = processor { + do { + let processed = try processor(element) + try await channel.send(processed) + } catch { + print("throwing \(error)") + await channel.fail(error) + } + } else { + try await channel.send(element) + } + + index += 1 + } + await channel.finish() + } + +}