Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,21 @@ extension NIOThrowingAsyncSequenceProducer {
return nil
}

case .returnCancellationError:
self._lock.unlock()
// We have deprecated the generic Failure type in the public API and Failure should
// now be `Swift.Error`. However, if users have not migrated to the new API they could
// still use a custom generic Error type and this cast might fail.
// In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the
// non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and
// this cast will fail as well.
// Everything is marked @inlinable and the Failure type is known at compile time,
// therefore this cast should be optimised away in release build.
if let error = CancellationError() as? Failure {
throw error
}
return nil

case .returnNil:
self._lock.unlock()
return nil
Expand Down Expand Up @@ -603,6 +618,9 @@ extension NIOThrowingAsyncSequenceProducer {
failure: Failure?
)

/// The state once a call to next has been cancelled. Cancel the source when entering this state.
case cancelled(iteratorInitialized: Bool)

/// The state once there can be no outstanding demand. This can happen if:
/// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited
/// 2. The underlying source finished and all buffered elements have been consumed
Expand Down Expand Up @@ -644,15 +662,17 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state {
case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _):
.sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// No iterator was created so we can transition to finished right away.
self._state = .finished(iteratorInitialized: false)

return .callDidTerminate

case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _):
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and we deinited the sequence.
// This is an expected pattern and we just continue on normal.
return .none
Expand All @@ -673,6 +693,7 @@ extension NIOThrowingAsyncSequenceProducer {
case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true),
.finished(iteratorInitialized: true):
// Our sequence is a unicast sequence and does not support multiple AsyncIterator's
fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created")
Expand All @@ -694,6 +715,10 @@ extension NIOThrowingAsyncSequenceProducer {
iteratorInitialized: true
)

case .cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before we can be cancelled.
preconditionFailure("Internal inconsistency")

case .sourceFinished(let buffer, false, let failure):
// The first and only iterator was initialized.
self._state = .sourceFinished(
Expand Down Expand Up @@ -727,13 +752,15 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state {
case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _):
.sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before it can be deinitialized.
preconditionFailure("Internal inconsistency")

case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _):
.sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and deinited. Since we only support
// a single iterator we can now transition to finish and inform the delegate.
self._state = .finished(iteratorInitialized: true)
Expand Down Expand Up @@ -861,7 +888,7 @@ extension NIOThrowingAsyncSequenceProducer {

return .init(shouldProduceMore: shouldProduceMore)

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished we are dropping the elements.
return .returnDropped

Expand Down Expand Up @@ -913,7 +940,7 @@ extension NIOThrowingAsyncSequenceProducer {

return .none

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect.
return .none

Expand Down Expand Up @@ -968,11 +995,14 @@ extension NIOThrowingAsyncSequenceProducer {
return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation)

case .streaming(_, _, continuation: .none, _, let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized)
// We may have elements in the buffer, which is why we have no continuation
// waiting. We must store the cancellation error to hand it out on the next
// next() call.
self._state = .cancelled(iteratorInitialized: iteratorInitialized)

return .callDidTerminate

case .sourceFinished, .finished:
case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect.
return .none

Expand All @@ -992,6 +1022,8 @@ extension NIOThrowingAsyncSequenceProducer {
/// Indicates that the `Failure` should be returned to the caller and
/// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called.
case returnFailureAndCallDidTerminate(Failure?)
/// Indicates that the next call to AsyncSequence got cancelled
case returnCancellationError
/// Indicates that the `nil` should be returned to the caller.
case returnNil
/// Indicates that the `Task` of the caller should be suspended.
Expand Down Expand Up @@ -1075,6 +1107,10 @@ extension NIOThrowingAsyncSequenceProducer {
return .returnFailureAndCallDidTerminate(failure)
}

case .cancelled(let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized)
return .returnCancellationError

case .finished:
return .returnNil

Expand Down Expand Up @@ -1119,7 +1155,7 @@ extension NIOThrowingAsyncSequenceProducer {
return .none
}

case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished:
case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished, .cancelled:
preconditionFailure("This should have already been handled by `next()`")

case .modifying:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,36 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase {

XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
}

func testIteratorThrows_whenCancelled() async {
_ = self.source.yield(contentsOf: Array(0..<100))
await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var counter = 0
guard let sequence = self.sequence else {
return XCTFail("Expected to have an AsyncSequence")
}

do {
for try await next in sequence {
XCTAssertEqual(next, counter)
counter += 1
}
XCTFail("Expected that this throws")
} catch is CancellationError {
// expected
} catch {
XCTFail("Unexpected error: \(error)")
}

XCTAssertLessThan(counter, 100)
}

group.cancelAll()
}

XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
}
}

// This is needed until async let is supported to be used in autoclosures
Expand Down