From f285fffd482be3b841b9d2f013965845bfbc8555 Mon Sep 17 00:00:00 2001 From: Brennan Stehling Date: Fri, 29 Jul 2022 15:15:09 -0700 Subject: [PATCH 1/4] adds AsyncChannel and AsyncThrowingChannel with tests --- Amplify/Core/Support/AsyncChannel.swift | 88 +++++++ .../Core/Support/AsyncThrowingChannel.swift | 110 +++++++++ .../CoreTests/AsyncChannelTests.swift | 215 ++++++++++++++++++ 3 files changed, 413 insertions(+) create mode 100644 Amplify/Core/Support/AsyncChannel.swift create mode 100644 Amplify/Core/Support/AsyncThrowingChannel.swift create mode 100644 AmplifyTests/CoreTests/AsyncChannelTests.swift diff --git a/Amplify/Core/Support/AsyncChannel.swift b/Amplify/Core/Support/AsyncChannel.swift new file mode 100644 index 0000000000..48ba35ca9f --- /dev/null +++ b/Amplify/Core/Support/AsyncChannel.swift @@ -0,0 +1,88 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation + +public actor AsyncChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { + private let channel: AsyncChannel + + public init(_ channel: AsyncChannel) { + self.channel = channel + } + + public mutating func next() async -> Element? { + await channel.next() + } + } + + public enum InternalFailure: Error { + case cannotSendAfterTerminated + } + public typealias ChannelContinuation = CheckedContinuation + + private var continuations: [ChannelContinuation] = [] + private var elements: [Element] = [] + private var terminated: Bool = false + + private var hasNext: Bool { + !continuations.isEmpty && !elements.isEmpty + } + + private var canTerminate: Bool { + terminated && elements.isEmpty && !continuations.isEmpty + } + + public init() { + } + + public nonisolated func makeAsyncIterator() -> Iterator { + Iterator(self) + } + + public func next() async -> Element? { + await withCheckedContinuation { (continuation: ChannelContinuation) in + continuations.append(continuation) + processNext() + } + } + + public func send(_ element: Element) throws { + guard !terminated else { + throw InternalFailure.cannotSendAfterTerminated + } + elements.append(element) + processNext() + } + + public func finish() { + terminated = true + processNext() + } + + private func processNext() { + if canTerminate { + let contination = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + contination.resume(returning: nil) + return + } + + guard hasNext else { + return + } + + assert(!continuations.isEmpty) + assert(!elements.isEmpty) + + let contination = continuations.removeFirst() + let element = elements.removeFirst() + + contination.resume(returning: element) + } +} diff --git a/Amplify/Core/Support/AsyncThrowingChannel.swift b/Amplify/Core/Support/AsyncThrowingChannel.swift new file mode 100644 index 0000000000..43eab86e21 --- /dev/null +++ b/Amplify/Core/Support/AsyncThrowingChannel.swift @@ -0,0 +1,110 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation + +public actor AsyncThrowingChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { + private let channel: AsyncThrowingChannel + + public init(_ channel: AsyncThrowingChannel) { + self.channel = channel + } + + public mutating func next() async throws -> Element? { + try await channel.next() + } + } + + public enum InternalFailure: Error { + case cannotSendAfterTerminated + } + public typealias ChannelContinuation = CheckedContinuation + + private var continuations: [ChannelContinuation] = [] + private var elements: [Element] = [] + private var terminated: Bool = false + private var error: Error? = nil + + private var hasNext: Bool { + !continuations.isEmpty && !elements.isEmpty + } + + private var canFail: Bool { + error != nil && !continuations.isEmpty + } + + private var canTerminate: Bool { + terminated && elements.isEmpty && !continuations.isEmpty + } + + public init() { + } + + public nonisolated func makeAsyncIterator() -> Iterator { + Iterator(self) + } + + public func next() async throws -> Element? { + try await withCheckedThrowingContinuation { (continuation: ChannelContinuation) in + continuations.append(continuation) + processNext() + } + } + + public func send(_ element: Element) throws { + guard !terminated else { + throw InternalFailure.cannotSendAfterTerminated + } + elements.append(element) + processNext() + } + + + public func fail(_ error: Error) where Failure == Error { + self.error = error + processNext() + } + + public func finish() { + terminated = true + processNext() + } + + private func processNext() { + if canFail { + let contination = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + assert(error != nil) + if let error = error { + contination.resume(throwing: error) + return + } + } + + if canTerminate { + let contination = continuations.removeFirst() + assert(continuations.isEmpty) + assert(elements.isEmpty) + contination.resume(returning: nil) + return + } + + guard hasNext else { + return + } + + assert(!continuations.isEmpty) + assert(!elements.isEmpty) + + let contination = continuations.removeFirst() + let element = elements.removeFirst() + + contination.resume(returning: element) + } +} diff --git a/AmplifyTests/CoreTests/AsyncChannelTests.swift b/AmplifyTests/CoreTests/AsyncChannelTests.swift new file mode 100644 index 0000000000..41fba48b61 --- /dev/null +++ b/AmplifyTests/CoreTests/AsyncChannelTests.swift @@ -0,0 +1,215 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import XCTest +@testable import Amplify + +final class AsyncChannelTests: XCTestCase { + enum Failure: Error { + case unluckyNumber + } + let sleepSeconds = 0.1 + + func testNumberSequence() async throws { + let input = [1, 2, 3, 4, 5] + let channel = AsyncChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + } + + var output: [Int] = [] + + print("-- before --") + for await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testStringSequence() async throws { + let input = ["one", "two", "three", "four", "five"] + let channel = AsyncChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + } + + var output: [String] = [] + + print("-- before --") + for await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSendAfterFinishing() async throws { + let input = ["a", "b", "c"] + let channel = AsyncChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + var thrown: Error? = nil + do { + try await channel.send("z") + } catch { + thrown = error + } + XCTAssertNotNil(thrown) + } + + var output: [String] = [] + + print("-- before --") + for await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSendAfterFinishingThrowing() async throws { + let input = ["x", "y", "z"] + let channel = AsyncThrowingChannel() + + // load all strings into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) + var thrown: Error? = nil + do { + try await channel.send("a") + } catch { + thrown = error + } + XCTAssertNotNil(thrown) + } + + var output: [String] = [] + + print("-- before --") + for try await element in channel { + print(element) + output.append(element) + } + print("-- after --") + + XCTAssertEqual(input, output) + } + + func testSucceedingSequence() async throws { + let input = [3, 7, 14, 21] + let channel = AsyncThrowingChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) { element in + if element == 13 { + throw Failure.unluckyNumber + } else { + return element + } + } + } + + var output: [Int] = [] + var thrown: Error? = nil + + print("-- before --") + do { + for try await element in channel { + print(element) + output.append(element) + } + } catch { + thrown = error + } + print("-- after --") + + XCTAssertNil(thrown) + XCTAssertEqual(input, output) + } + + func testFailingSequence() async throws { + let input = [3, 7, 13, 21] + let channel = AsyncThrowingChannel() + + // load all numbers into the channel with delays + Task { + try await send(elements: input, channel: channel, sleepSeconds: sleepSeconds) { element in + if element == 13 { + throw Failure.unluckyNumber + } else { + return element + } + } + } + + var output: [Int] = [] + var thrown: Error? = nil + + print("-- before --") + do { + for try await element in channel { + print(element) + output.append(element) + } + } catch { + thrown = error + } + print("-- after --") + + XCTAssertNotNil(thrown) + let expected = Array(input[0..<2]) + XCTAssertEqual(expected, output) + } + + private func send(elements: [Element], channel: AsyncChannel, sleepSeconds: Double = 0.1) async throws { + var index = 0 + while index < elements.count { + try await Task.sleep(seconds: sleepSeconds) + let element = elements[index] + try await channel.send(element) + + index += 1 + } + await channel.finish() + } + + private func send(elements: [Element], channel: AsyncThrowingChannel, sleepSeconds: Double = 0.1, processor: ((Element) throws -> Element)? = nil) async throws { + var index = 0 + while index < elements.count { + try await Task.sleep(seconds: sleepSeconds) + let element = elements[index] + if let processor = processor { + do { + let processed = try processor(element) + try await channel.send(processed) + } catch { + print("throwing \(error)") + await channel.fail(error) + } + } else { + try await channel.send(element) + } + + index += 1 + } + await channel.finish() + } +} From 728e63053ac80ab73451722aa8c2e1ebd9c638f8 Mon Sep 17 00:00:00 2001 From: Brennan Stehling Date: Mon, 8 Aug 2022 16:13:24 -0700 Subject: [PATCH 2/4] fixes typos --- Amplify/Core/Support/AsyncChannel.swift | 8 ++++---- Amplify/Core/Support/AsyncThrowingChannel.swift | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Amplify/Core/Support/AsyncChannel.swift b/Amplify/Core/Support/AsyncChannel.swift index 48ba35ca9f..3e50d146f1 100644 --- a/Amplify/Core/Support/AsyncChannel.swift +++ b/Amplify/Core/Support/AsyncChannel.swift @@ -66,10 +66,10 @@ public actor AsyncChannel: AsyncSequence { private func processNext() { if canTerminate { - let contination = continuations.removeFirst() + let continuation = continuations.removeFirst() assert(continuations.isEmpty) assert(elements.isEmpty) - contination.resume(returning: nil) + continuation.resume(returning: nil) return } @@ -80,9 +80,9 @@ public actor AsyncChannel: AsyncSequence { assert(!continuations.isEmpty) assert(!elements.isEmpty) - let contination = continuations.removeFirst() + let continuation = continuations.removeFirst() let element = elements.removeFirst() - contination.resume(returning: element) + continuation.resume(returning: element) } } diff --git a/Amplify/Core/Support/AsyncThrowingChannel.swift b/Amplify/Core/Support/AsyncThrowingChannel.swift index 43eab86e21..a8f2ba77af 100644 --- a/Amplify/Core/Support/AsyncThrowingChannel.swift +++ b/Amplify/Core/Support/AsyncThrowingChannel.swift @@ -77,21 +77,21 @@ public actor AsyncThrowingChannel: AsyncSeque private func processNext() { if canFail { - let contination = continuations.removeFirst() + let continuation = continuations.removeFirst() assert(continuations.isEmpty) assert(elements.isEmpty) assert(error != nil) if let error = error { - contination.resume(throwing: error) + continuation.resume(throwing: error) return } } if canTerminate { - let contination = continuations.removeFirst() + let continuation = continuations.removeFirst() assert(continuations.isEmpty) assert(elements.isEmpty) - contination.resume(returning: nil) + continuation.resume(returning: nil) return } @@ -102,9 +102,9 @@ public actor AsyncThrowingChannel: AsyncSeque assert(!continuations.isEmpty) assert(!elements.isEmpty) - let contination = continuations.removeFirst() + let continuation = continuations.removeFirst() let element = elements.removeFirst() - contination.resume(returning: element) + continuation.resume(returning: element) } } From 36bc2a84e486311129703e603be4029682b11b16 Mon Sep 17 00:00:00 2001 From: Brennan Stehling Date: Tue, 9 Aug 2022 14:26:41 -0700 Subject: [PATCH 3/4] makes new types internal (unit tests pass) --- Amplify/Core/Support/AsyncChannel.swift | 22 ++++++++--------- .../Core/Support/AsyncThrowingChannel.swift | 24 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/Amplify/Core/Support/AsyncChannel.swift b/Amplify/Core/Support/AsyncChannel.swift index 3e50d146f1..969a93cd14 100644 --- a/Amplify/Core/Support/AsyncChannel.swift +++ b/Amplify/Core/Support/AsyncChannel.swift @@ -7,23 +7,23 @@ import Foundation -public actor AsyncChannel: AsyncSequence { - public struct Iterator: AsyncIteratorProtocol, Sendable { +actor AsyncChannel: AsyncSequence { + struct Iterator: AsyncIteratorProtocol, Sendable { private let channel: AsyncChannel - public init(_ channel: AsyncChannel) { + init(_ channel: AsyncChannel) { self.channel = channel } - public mutating func next() async -> Element? { + mutating func next() async -> Element? { await channel.next() } } - public enum InternalFailure: Error { + enum InternalFailure: Error { case cannotSendAfterTerminated } - public typealias ChannelContinuation = CheckedContinuation + typealias ChannelContinuation = CheckedContinuation private var continuations: [ChannelContinuation] = [] private var elements: [Element] = [] @@ -37,21 +37,21 @@ public actor AsyncChannel: AsyncSequence { terminated && elements.isEmpty && !continuations.isEmpty } - public init() { + init() { } - public nonisolated func makeAsyncIterator() -> Iterator { + nonisolated func makeAsyncIterator() -> Iterator { Iterator(self) } - public func next() async -> Element? { + func next() async -> Element? { await withCheckedContinuation { (continuation: ChannelContinuation) in continuations.append(continuation) processNext() } } - public func send(_ element: Element) throws { + func send(_ element: Element) throws { guard !terminated else { throw InternalFailure.cannotSendAfterTerminated } @@ -59,7 +59,7 @@ public actor AsyncChannel: AsyncSequence { processNext() } - public func finish() { + func finish() { terminated = true processNext() } diff --git a/Amplify/Core/Support/AsyncThrowingChannel.swift b/Amplify/Core/Support/AsyncThrowingChannel.swift index a8f2ba77af..b15ae89036 100644 --- a/Amplify/Core/Support/AsyncThrowingChannel.swift +++ b/Amplify/Core/Support/AsyncThrowingChannel.swift @@ -7,23 +7,23 @@ import Foundation -public actor AsyncThrowingChannel: AsyncSequence { - public struct Iterator: AsyncIteratorProtocol, Sendable { +actor AsyncThrowingChannel: AsyncSequence { + struct Iterator: AsyncIteratorProtocol, Sendable { private let channel: AsyncThrowingChannel - public init(_ channel: AsyncThrowingChannel) { + init(_ channel: AsyncThrowingChannel) { self.channel = channel } - public mutating func next() async throws -> Element? { + mutating func next() async throws -> Element? { try await channel.next() } } - public enum InternalFailure: Error { + enum InternalFailure: Error { case cannotSendAfterTerminated } - public typealias ChannelContinuation = CheckedContinuation + typealias ChannelContinuation = CheckedContinuation private var continuations: [ChannelContinuation] = [] private var elements: [Element] = [] @@ -42,21 +42,21 @@ public actor AsyncThrowingChannel: AsyncSeque terminated && elements.isEmpty && !continuations.isEmpty } - public init() { + init() { } - public nonisolated func makeAsyncIterator() -> Iterator { + nonisolated func makeAsyncIterator() -> Iterator { Iterator(self) } - public func next() async throws -> Element? { + func next() async throws -> Element? { try await withCheckedThrowingContinuation { (continuation: ChannelContinuation) in continuations.append(continuation) processNext() } } - public func send(_ element: Element) throws { + func send(_ element: Element) throws { guard !terminated else { throw InternalFailure.cannotSendAfterTerminated } @@ -65,12 +65,12 @@ public actor AsyncThrowingChannel: AsyncSeque } - public func fail(_ error: Error) where Failure == Error { + func fail(_ error: Error) where Failure == Error { self.error = error processNext() } - public func finish() { + func finish() { terminated = true processNext() } From 2cdeb1166a38f220ec3df569d67cf360bfecf3d7 Mon Sep 17 00:00:00 2001 From: Brennan Stehling Date: Wed, 10 Aug 2022 08:19:41 -0700 Subject: [PATCH 4/4] adds support for cancellation --- Amplify/Core/Support/AsyncChannel.swift | 40 ++++-- .../Core/Support/AsyncThrowingChannel.swift | 44 ++++--- .../CoreTests/AsyncChannelTests.swift | 118 +++++++++++++++++- 3 files changed, 173 insertions(+), 29 deletions(-) diff --git a/Amplify/Core/Support/AsyncChannel.swift b/Amplify/Core/Support/AsyncChannel.swift index 969a93cd14..9a97f5cdf3 100644 --- a/Amplify/Core/Support/AsyncChannel.swift +++ b/Amplify/Core/Support/AsyncChannel.swift @@ -7,26 +7,27 @@ import Foundation -actor AsyncChannel: AsyncSequence { - struct Iterator: AsyncIteratorProtocol, Sendable { +public actor AsyncChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { private let channel: AsyncChannel - init(_ channel: AsyncChannel) { + public init(_ channel: AsyncChannel) { self.channel = channel } - mutating func next() async -> Element? { - await channel.next() + public mutating func next() async -> Element? { + Task.isCancelled ? nil : await channel.next() } } - enum InternalFailure: Error { + public enum InternalFailure: Error { case cannotSendAfterTerminated } - typealias ChannelContinuation = CheckedContinuation + public typealias ChannelContinuation = CheckedContinuation private var continuations: [ChannelContinuation] = [] private var elements: [Element] = [] + private var cancelled: Bool = false private var terminated: Bool = false private var hasNext: Bool { @@ -40,18 +41,26 @@ actor AsyncChannel: AsyncSequence { init() { } - nonisolated func makeAsyncIterator() -> Iterator { + public nonisolated func makeAsyncIterator() -> Iterator { Iterator(self) } - func next() async -> Element? { - await withCheckedContinuation { (continuation: ChannelContinuation) in + public func next() async -> Element? { + if cancelled || terminated { + return nil + } + return await withCheckedContinuation { (continuation: ChannelContinuation) in continuations.append(continuation) processNext() } } - func send(_ element: Element) throws { + public func send(_ element: Element) throws { + if Task.isCancelled { + cancelled = true + processNext() + throw CancellationError() + } guard !terminated else { throw InternalFailure.cannotSendAfterTerminated } @@ -59,12 +68,19 @@ actor AsyncChannel: AsyncSequence { processNext() } - func finish() { + public func finish() { terminated = true processNext() } private func processNext() { + if cancelled && !continuations.isEmpty { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + continuation.resume(returning: nil) + return + } + if canTerminate { let continuation = continuations.removeFirst() assert(continuations.isEmpty) diff --git a/Amplify/Core/Support/AsyncThrowingChannel.swift b/Amplify/Core/Support/AsyncThrowingChannel.swift index b15ae89036..30ada211d1 100644 --- a/Amplify/Core/Support/AsyncThrowingChannel.swift +++ b/Amplify/Core/Support/AsyncThrowingChannel.swift @@ -7,26 +7,28 @@ import Foundation -actor AsyncThrowingChannel: AsyncSequence { - struct Iterator: AsyncIteratorProtocol, Sendable { +public actor AsyncThrowingChannel: AsyncSequence { + public struct Iterator: AsyncIteratorProtocol, Sendable { private let channel: AsyncThrowingChannel - init(_ channel: AsyncThrowingChannel) { + public init(_ channel: AsyncThrowingChannel) { self.channel = channel } - mutating func next() async throws -> Element? { - try await channel.next() + public mutating func next() async throws -> Element? { + try Task.checkCancellation() + return try await channel.next() } } - enum InternalFailure: Error { + public enum InternalFailure: Error { case cannotSendAfterTerminated } - typealias ChannelContinuation = CheckedContinuation + public typealias ChannelContinuation = CheckedContinuation private var continuations: [ChannelContinuation] = [] private var elements: [Element] = [] + private var cancelled: Bool = false private var terminated: Bool = false private var error: Error? = nil @@ -45,18 +47,26 @@ actor AsyncThrowingChannel: AsyncSequence { init() { } - nonisolated func makeAsyncIterator() -> Iterator { + public nonisolated func makeAsyncIterator() -> Iterator { Iterator(self) } - func next() async throws -> Element? { - try await withCheckedThrowingContinuation { (continuation: ChannelContinuation) in + public func next() async throws -> Element? { + if cancelled { + throw CancellationError() + } + return try await withCheckedThrowingContinuation { (continuation: ChannelContinuation) in continuations.append(continuation) processNext() } } - func send(_ element: Element) throws { + public func send(_ element: Element) throws { + if Task.isCancelled { + cancelled = true + processNext() + throw CancellationError() + } guard !terminated else { throw InternalFailure.cannotSendAfterTerminated } @@ -64,18 +74,24 @@ actor AsyncThrowingChannel: AsyncSequence { processNext() } - - func fail(_ error: Error) where Failure == Error { + public func fail(_ error: Error) where Failure == Error { self.error = error processNext() } - func finish() { + public func finish() { terminated = true processNext() } private func processNext() { + if cancelled && !continuations.isEmpty { + let continuation = continuations.removeFirst() + assert(continuations.isEmpty) + continuation.resume(throwing: CancellationError()) + return + } + if canFail { let continuation = continuations.removeFirst() assert(continuations.isEmpty) diff --git a/AmplifyTests/CoreTests/AsyncChannelTests.swift b/AmplifyTests/CoreTests/AsyncChannelTests.swift index 41fba48b61..ba02d49e75 100644 --- a/AmplifyTests/CoreTests/AsyncChannelTests.swift +++ b/AmplifyTests/CoreTests/AsyncChannelTests.swift @@ -12,6 +12,14 @@ final class AsyncChannelTests: XCTestCase { enum Failure: Error { case unluckyNumber } + + actor Output { + var elements: [Element] = [] + func append(_ element: Element) { + elements.append(element) + } + } + let sleepSeconds = 0.1 func testNumberSequence() async throws { @@ -26,7 +34,7 @@ final class AsyncChannelTests: XCTestCase { var output: [Int] = [] print("-- before --") - for await element in channel { + for try await element in channel { print(element) output.append(element) } @@ -47,7 +55,7 @@ final class AsyncChannelTests: XCTestCase { var output: [String] = [] print("-- before --") - for await element in channel { + for try await element in channel { print(element) output.append(element) } @@ -75,7 +83,7 @@ final class AsyncChannelTests: XCTestCase { var output: [String] = [] print("-- before --") - for await element in channel { + for try await element in channel { print(element) output.append(element) } @@ -179,6 +187,109 @@ final class AsyncChannelTests: XCTestCase { XCTAssertEqual(expected, output) } + func testChannelCancelled() async throws { + let delay = 1.25 + let input = [1, 2, 3, 4, 5] + let channel = AsyncChannel() + let sendExp = expectation(description: "send") + let reduceExp = expectation(description: "reduce") + + let sendTask = Task { + print("send - start") + var thrown: Error? + do { + var index = 0 + while index < input.count { + try await Task.sleep(seconds: delay) + try await channel.send(input[index]) + index += 1 + } + } catch { + thrown = error + } + print("send - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + sendExp.fulfill() + } + + let reduceTask = Task { + print("reduce - start") + let result = await channel.reduce(0, +) + print(result) + print("reduce - end") + + reduceExp.fulfill() + } + + Task { + try await Task.sleep(seconds: delay * 2) + sendTask.cancel() + } + + await waitForExpectations(timeout: 5.0) + + XCTAssertFalse(reduceTask.isCancelled) + } + + func testThrowingChannelCancelled() async throws { + let delay = 1.25 + let input = [1, 2, 3, 4, 5] + let channel = AsyncThrowingChannel() + let sendExp = expectation(description: "send") + let reduceExp = expectation(description: "reduce") + + let sendTask = Task { + print("send - start") + var thrown: Error? + do { + var index = 0 + while index < input.count { + try await Task.sleep(seconds: delay) + try await channel.send(input[index]) + index += 1 + } + } catch { + thrown = error + } + print("send - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + sendExp.fulfill() + } + + let reduceTask = Task { + print("reduce - start") + var thrown: Error? + do { + let result = try await channel.reduce(0, +) + print(result) + } catch { + thrown = error + } + + print("reduce - end") + + XCTAssertNotNil(thrown) + XCTAssertTrue(thrown is CancellationError) + + reduceExp.fulfill() + } + + Task { + try await Task.sleep(seconds: delay * 2) + sendTask.cancel() + } + + await waitForExpectations(timeout: 5.0) + + XCTAssertFalse(reduceTask.isCancelled) + } + private func send(elements: [Element], channel: AsyncChannel, sleepSeconds: Double = 0.1) async throws { var index = 0 while index < elements.count { @@ -212,4 +323,5 @@ final class AsyncChannelTests: XCTestCase { } await channel.finish() } + }