From 8a1460c80b43a61332554b7f36e33ca246ca7830 Mon Sep 17 00:00:00 2001 From: Martin Troup Date: Tue, 22 Mar 2022 15:30:50 +0100 Subject: [PATCH] Fix for ShareReplay operator. --- Sources/Subjects/ReplaySubject.swift | 74 ++++++++++++++++++---------- Tests/ShareReplayTests.swift | 31 ++++++++++++ 2 files changed, 79 insertions(+), 26 deletions(-) diff --git a/Sources/Subjects/ReplaySubject.swift b/Sources/Subjects/ReplaySubject.swift index a97bec3..5d836c1 100644 --- a/Sources/Subjects/ReplaySubject.swift +++ b/Sources/Subjects/ReplaySubject.swift @@ -41,17 +41,17 @@ public final class ReplaySubject: Subject { let subscriptions: [Subscription>] do { - lock.lock() - defer { lock.unlock() } + lock.lock() + defer { lock.unlock() } - guard isActive else { return } + guard isActive else { return } - buffer.append(value) - if buffer.count > bufferSize { - buffer.removeFirst() - } + buffer.append(value) + if buffer.count > bufferSize { + buffer.removeFirst() + } - subscriptions = self.subscriptions + subscriptions = self.subscriptions } subscriptions.forEach { $0.forwardValueToBuffer(value) } @@ -81,25 +81,30 @@ public final class ReplaySubject: Subject { public func receive(subscriber: Subscriber) where Failure == Subscriber.Failure, Output == Subscriber.Input { let subscriberIdentifier = subscriber.combineIdentifier - let subscription = Subscription(downstream: AnySubscriber(subscriber)) { [weak self] in - self?.completeSubscriber(withIdentifier: subscriberIdentifier) - } - - let buffer: [Output] - let completion: Subscribers.Completion? + let subscription = Subscription( + downstream: AnySubscriber(subscriber), + cancellationHandler: { [weak self] in + self?.completeSubscriber(withIdentifier: subscriberIdentifier) + }, + requestReplay: { [weak self] in + let buffer: [Output]? + let completion: Subscribers.Completion? - do { - lock.lock() - defer { lock.unlock() } + do { + self?.lock.lock() + defer { self?.lock.unlock() } - subscriptions.append(subscription) + buffer = self?.buffer + completion = self?.completion + } - buffer = self.buffer - completion = self.completion - } + return (buffer, completion) + } + ) + subscriptions.append(subscription) + // (*) It was called here. subscriber.receive(subscription: subscription) - subscription.replay(buffer, completion: completion) } private func completeSubscriber(withIdentifier subscriberIdentifier: CombineIdentifier) { @@ -115,17 +120,25 @@ extension ReplaySubject { final class Subscription: Combine.Subscription where Output == Downstream.Input, Failure == Downstream.Failure { private var demandBuffer: DemandBuffer? private var cancellationHandler: (() -> Void)? + private var requestReplay: (() -> (buffer: [Output]?, completion: Subscribers.Completion?))? + + private var isActive = false fileprivate let innerSubscriberIdentifier: CombineIdentifier - init(downstream: Downstream, cancellationHandler: (() -> Void)?) { + init( + downstream: Downstream, + cancellationHandler: (() -> Void)?, + requestReplay: (() -> ([Output]?, Subscribers.Completion?))? + ) { self.demandBuffer = DemandBuffer(subscriber: downstream) self.innerSubscriberIdentifier = downstream.combineIdentifier self.cancellationHandler = cancellationHandler + self.requestReplay = requestReplay } - func replay(_ buffer: [Output], completion: Subscribers.Completion?) { - buffer.forEach(forwardValueToBuffer) + func replay(_ buffer: [Output]?, completion: Subscribers.Completion?) { + buffer?.forEach(forwardValueToBuffer) if let completion = completion { forwardCompletionToBuffer(completion) @@ -138,11 +151,20 @@ extension ReplaySubject { func forwardCompletionToBuffer(_ completion: Subscribers.Completion) { demandBuffer?.complete(completion: completion) - cancel() + + if isActive { + cancel() + } } func request(_ demand: Subscribers.Demand) { _ = demandBuffer?.demand(demand) + + isActive = true + + if let replay = requestReplay?() { + self.replay(replay.buffer, completion: replay.completion) + } } func cancel() { diff --git a/Tests/ShareReplayTests.swift b/Tests/ShareReplayTests.swift index 76121d9..b35b34d 100644 --- a/Tests/ShareReplayTests.swift +++ b/Tests/ShareReplayTests.swift @@ -241,5 +241,36 @@ final class ShareReplayTests: XCTestCase { XCTAssertEqual(completions, [.finished]) XCTAssertNil(weakSource) } + + func testSequentialUpstreamWithShareReplay() { + let publisher = Just(1) + .eraseToAnyPublisher() + .share(replay: 1) + + var valueReceived = false + var finishedReceived = false + + Publishers.Zip(publisher, publisher) + .sink( + receiveCompletion: { completion in + switch completion { + case .finished: + finishedReceived = true + case let .failure(error): + XCTFail("Unexpected completion - failure: \(error).") + } + }, + receiveValue: { leftValue, rightValue in + XCTAssertEqual(leftValue, 1) + XCTAssertEqual(rightValue, 1) + + valueReceived = true + } + ) + .store(in: &subscriptions) + + XCTAssertTrue(valueReceived) + XCTAssertTrue(finishedReceived) + } } #endif