diff --git a/servicetalk-concurrent-api/gradle/checkstyle/suppressions.xml b/servicetalk-concurrent-api/gradle/checkstyle/suppressions.xml index 679ac02eaa..8979886a5c 100644 --- a/servicetalk-concurrent-api/gradle/checkstyle/suppressions.xml +++ b/servicetalk-concurrent-api/gradle/checkstyle/suppressions.xml @@ -23,4 +23,5 @@ files="docs[\\/]modules[\\/]ROOT[\\/]assets[\\/]images[\\/].+\.svg"/> + diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java index b4e87d1d18..d71f518d9f 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Publisher.java @@ -287,11 +287,54 @@ public final Publisher scanWith(Supplier initial, BiFunction Type of the items emitted by the returned {@link Publisher}. * @return A {@link Publisher} that transforms elements emitted by this {@link Publisher} into a different type. * @see ReactiveX scan operator. + * @deprecated Use {@link #scanWithMapper(Supplier)}. */ + @Deprecated public final Publisher scanWith(Supplier> mapperSupplier) { return new ScanWithPublisher<>(this, mapperSupplier); } + /** + * Apply a function to each {@link Subscriber#onNext(Object)} emitted by this {@link Publisher} as well as + * optionally concat one {@link Subscriber#onNext(Object)} signal before the terminal signal is emitted downstream. + *

+ * This method provides a data transformation in sequential programming similar to: + *

{@code
+     *     List results = ...;
+     *     ScanWithLifetimeMapperExt mapper = mapperSupplier.get();
+     *     MappedTerminal mapped = null;
+     *     try {
+     *       for (T t : resultOfThisPublisher()) {
+     *         results.add(mapper.mapOnNext(t));
+     *       }
+     *     } catch (Throwable cause) {
+     *       mapped = mapper.mapOnError(cause);
+     *       if (mapped == null) {
+     *         throw cause;
+     *       }
+     *     }
+     *     if (mapped == null) {
+     *       mapped = mapper.mapOnComplete();
+     *     }
+     *     if (mapped.onNextValid()) {
+     *       results.add(mapped.onNext());
+     *     }
+     *     if (mapped.terminal() != null) {
+     *       throw mapped.terminal();
+     *     }
+     *     return results;
+     * }
+ * @param mapperSupplier Invoked on each {@link PublisherSource#subscribe(Subscriber)} and maintains any necessary + * state for the mapping/accumulation for each {@link Subscriber}. + * @param Type of the items emitted by the returned {@link Publisher}. + * @return A {@link Publisher} that transforms elements emitted by this {@link Publisher} into a different type. + * @see ReactiveX scan operator. + */ + public final Publisher scanWithMapper( + Supplier> mapperSupplier) { + return new ScanWithPublisher<>(mapperSupplier, this); + } + /** * Apply a function to each {@link Subscriber#onNext(Object)} emitted by this {@link Publisher} as well as * optionally concat one {@link Subscriber#onNext(Object)} signal before the terminal signal is emitted downstream. @@ -329,12 +372,63 @@ public final Publisher scanWith(Supplier Type of the items emitted by the returned {@link Publisher}. * @return A {@link Publisher} that transforms elements emitted by this {@link Publisher} into a different type. * @see ReactiveX scan operator. + * @deprecated Use {@link #scanWithLifetimeMapper(Supplier)}. */ + @Deprecated public final Publisher scanWithLifetime( Supplier> mapperSupplier) { return new ScanWithLifetimePublisher<>(this, mapperSupplier); } + /** + * Apply a function to each {@link Subscriber#onNext(Object)} emitted by this {@link Publisher} as well as + * optionally concat one {@link Subscriber#onNext(Object)} signal before the terminal signal is emitted downstream. + * Additionally the {@link ScanLifetimeMapper#afterFinally()} method will be invoked on terminal or cancel + * signals which enables cleanup of state (if required). This provides a similar lifetime management as + * {@link TerminalSignalConsumer}. + * + *

+ * This method provides a data transformation in sequential programming similar to: + *

{@code
+     *     List results = ...;
+     *     ScanWithLifetimeMapperExt mapper = mapperSupplier.get();
+     *     try {
+     *       MappedTerminal mapped = null;
+     *       try {
+     *         for (T t : resultOfThisPublisher()) {
+     *           results.add(mapper.mapOnNext(t));
+     *         }
+     *       } catch (Throwable cause) {
+     *         mapped = mapper.mapOnError(cause);
+     *         if (mapped == null) {
+     *           throw cause;
+     *         }
+     *       }
+     *       if (mapped == null) {
+     *         mapped = mapper.mapOnComplete();
+     *       }
+     *       if (mapped.onNextValid()) {
+     *         results.add(mapped.onNext());
+     *       }
+     *       if (mapped.terminal() != null) {
+     *          throw mapped.terminal();
+     *       }
+     *     } finally {
+     *       mapper.afterFinally();
+     *     }
+     *     return results;
+     * }
+ * @param mapperSupplier Invoked on each {@link PublisherSource#subscribe(Subscriber)} and maintains any necessary + * state for the mapping/accumulation for each {@link Subscriber}. + * @param Type of the items emitted by the returned {@link Publisher}. + * @return A {@link Publisher} that transforms elements emitted by this {@link Publisher} into a different type. + * @see ReactiveX scan operator. + */ + public final Publisher scanWithLifetimeMapper( + Supplier> mapperSupplier) { + return new ScanWithLifetimePublisher<>(mapperSupplier, this); + } + /** * Transform errors emitted on this {@link Publisher} into a {@link Subscriber#onComplete()} signal * (e.g. swallows the error). diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanLifetimeMapper.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanLifetimeMapper.java new file mode 100644 index 0000000000..508c2acb8b --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanLifetimeMapper.java @@ -0,0 +1,38 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; +import io.servicetalk.concurrent.PublisherSource.Subscription; + +import java.util.function.Supplier; + +/** + * Provides the ability to transform (aka map) signals emitted via + * the {@link Publisher#scanWithLifetimeMapper(Supplier)} operator, as well as the ability to cleanup state + * via {@link #afterFinally}. + * @param Type of items emitted by the {@link Publisher} this operator is applied. + * @param Type of items emitted by this operator. + */ +public interface ScanLifetimeMapper extends ScanMapper { + /** + * Invoked after a terminal signal {@link Subscriber#onError(Throwable)} or + * {@link Subscriber#onComplete()} or {@link Subscription#cancel()}. + * No further interaction will occur with the {@link ScanLifetimeMapper} to prevent use-after-free + * on internal state. + */ + void afterFinally(); +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanMapper.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanMapper.java new file mode 100644 index 0000000000..f49dea9711 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanMapper.java @@ -0,0 +1,106 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; + +import java.util.function.Supplier; +import javax.annotation.Nullable; + +/** + * Provides the ability to transform (aka map) signals emitted via the {@link Publisher#scanWithMapper(Supplier)} + * operator. + * @param Type of items emitted by the {@link Publisher} this operator is applied. + * @param Type of items emitted by this operator. + */ +public interface ScanMapper { + /** + * Invoked on each {@link Subscriber#onNext(Object)} signal and maps from type {@link T} to type {@link R}. + * @param next The next element emitted from {@link Subscriber#onNext(Object)}. + * @return The result of mapping {@code next}. + */ + @Nullable + R mapOnNext(@Nullable T next); + + /** + * Invoked when a {@link Subscriber#onError(Throwable)} signal is received and can map the current state into an + * object of type {@link R} which will be emitted downstream as {@link Subscriber#onNext(Object)}, followed by + * a terminal signal. + *

+ * If this method throws the exception will be propagated downstream via {@link Subscriber#onError(Throwable)}. + * @param cause The cause from upstream {@link Subscriber#onError(Throwable)}. + * @return + *

    + *
  • {@code null} if no mapping is required and {@code cause} is propagated to + * {@link Subscriber#onError(Throwable)}
  • + *
  • non-{@code null} will propagate {@link MappedTerminal#onNext()} to {@link Subscriber#onNext(Object)} + * then will terminate with {@link MappedTerminal#terminal()}
  • + *
+ * @throws Throwable If an exception occurs, which will be propagated downstream via + * {@link Subscriber#onError(Throwable)}. + */ + @Nullable + MappedTerminal mapOnError(Throwable cause) throws Throwable; + + /** + * Invoked when a {@link Subscriber#onComplete()} signal is received and can map the current state into an + * object of type {@link R} which will be emitted downstream as {@link Subscriber#onNext(Object)}, followed by + * a terminal signal. + *

+ * If this method throws the exception will be propagated downstream via {@link Subscriber#onError(Throwable)}. + * @return + *

    + *
  • {@code null} if no mapping is required and {@code cause} is propagated to + * {@link Subscriber#onError(Throwable)}
  • + *
  • non-{@code null} will propagate {@link MappedTerminal#onNext()} to {@link Subscriber#onNext(Object)} + * then will terminate with {@link MappedTerminal#terminal()}
  • + *
+ * @throws Throwable If an exception occurs, which will be propagated downstream via + * {@link Subscriber#onError(Throwable)}. + */ + @Nullable + MappedTerminal mapOnComplete() throws Throwable; + + /** + * Result of a mapping operation of a terminal signal. + * @param The mapped result type. + */ + interface MappedTerminal { + /** + * Get the signal to be delivered to {@link Subscriber#onNext(Object)} if {@link #onNextValid()}. + * @return the signal to be delivered to {@link Subscriber#onNext(Object)} if {@link #onNextValid()}. + */ + @Nullable + R onNext(); + + /** + * Determine if {@link #onNext()} is valid and should be propagated downstream. + * @return {@code true} to propagate {@link #onNext()}, {@code false} will only propagate {@link #terminal()}. + */ + boolean onNextValid(); + + /** + * The terminal event to propagate. + * @return + *
    + *
  • {@code null} means {@link Subscriber#onComplete()}
  • + *
  • non-{@code null} will propagate as {@link Subscriber#onError(Throwable)}
  • + *
+ */ + @Nullable + Throwable terminal(); + } +} diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimeMapper.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimeMapper.java index f9bf15be4b..9e1fb14c60 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimeMapper.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimeMapper.java @@ -25,9 +25,10 @@ * via {@link #afterFinally}. * @param Type of items emitted by the {@link Publisher} this operator is applied. * @param Type of items emitted by this operator. + * @deprecated Use {@link ScanLifetimeMapper}. */ +@Deprecated public interface ScanWithLifetimeMapper extends ScanWithMapper { - /** * Invoked after a terminal signal {@link PublisherSource.Subscriber#onError(Throwable)} or * {@link PublisherSource.Subscriber#onComplete()} or {@link PublisherSource.Subscription#cancel()}. diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimePublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimePublisher.java index d041219b4c..c0da9a8c4c 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimePublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithLifetimePublisher.java @@ -15,6 +15,8 @@ */ package io.servicetalk.concurrent.api; +import io.servicetalk.concurrent.api.ScanWithPublisher.ScanMapperAdapter; +import io.servicetalk.concurrent.api.ScanWithPublisher.ScanWithSubscriber; import io.servicetalk.context.api.ContextMap; import org.slf4j.Logger; @@ -31,10 +33,16 @@ final class ScanWithLifetimePublisher extends AbstractNoHandleSubscribePub private static final Logger LOGGER = LoggerFactory.getLogger(ScanWithLifetimePublisher.class); private final Publisher original; - private final Supplier> mapperSupplier; + private final Supplier> mapperSupplier; ScanWithLifetimePublisher(Publisher original, + @SuppressWarnings("deprecation") Supplier> mapperSupplier) { + this(new SupplierScanMapperLifetime<>(mapperSupplier), original); + } + + ScanWithLifetimePublisher(Supplier> mapperSupplier, + Publisher original) { this.mapperSupplier = requireNonNull(mapperSupplier); this.original = original; } @@ -52,10 +60,10 @@ void handleSubscribe(final Subscriber subscriber, } /** - * Wraps the {@link io.servicetalk.concurrent.api.ScanWithPublisher.ScanWithSubscriber} to provide mutual exclusion - * to the {@link ScanWithLifetimeMapper#afterFinally()} call and guarantee a 'no-use-after-free' contract. + * Wraps the {@link ScanWithSubscriber} to provide mutual exclusion to the {@link ScanLifetimeMapper#afterFinally()} + * call and guarantee a 'no-use-after-free' contract. */ - private static final class ScanWithLifetimeSubscriber extends ScanWithPublisher.ScanWithSubscriber { + private static final class ScanWithLifetimeSubscriber extends ScanWithSubscriber { private static final int STATE_UNLOCKED = 0; private static final int STATE_BUSY = 1; private static final int STATE_FINALIZED = 2; @@ -67,13 +75,13 @@ private static final class ScanWithLifetimeSubscriber extends ScanWithPubl private volatile int state = STATE_UNLOCKED; - private final ScanWithLifetimeMapper mapper; + private final ScanLifetimeMapper mapper; ScanWithLifetimeSubscriber(final Subscriber subscriber, - final ScanWithLifetimeMapper mapper, + final ScanLifetimeMapper mapper, final ContextMap contextMap, final AsyncContextProvider contextProvider) { super(subscriber, mapper, contextProvider, contextMap); - this.mapper = mapper; + this.mapper = requireNonNull(mapper); } @Override @@ -180,24 +188,12 @@ public void onComplete() { } @Override - protected void deliverOnCompleteFromSubscription(final Subscriber subscriber) { - if (shouldDeliverFromSubscription()) { - try { - super.deliverOnCompleteFromSubscription(subscriber); - } finally { - // Done, transit to FINALIZED. - // No need to CAS, we have exclusion, and any cancellations will hand-over finalization to us. - state = STATE_FINALIZED; - finalize0(); - } - } - } - - @Override - protected void deliverOnErrorFromSubscription(final Throwable t, final Subscriber subscriber) { + protected void deliverAllTerminalFromSubscription( + final ScanMapper.MappedTerminal mappedTerminal, + final Subscriber subscriber) { if (shouldDeliverFromSubscription()) { try { - super.deliverOnErrorFromSubscription(t, subscriber); + super.deliverAllTerminalFromSubscription(mappedTerminal, subscriber); } finally { // Done, transit to FINALIZED. // No need to CAS, we have exclusion, and any cancellations will hand-over finalization to us. @@ -252,4 +248,34 @@ private void finalize0() { } } } + + @SuppressWarnings("deprecation") + private static final class SupplierScanMapperLifetime implements + Supplier> { + private final Supplier> mapperSupplier; + + private SupplierScanMapperLifetime( + final Supplier> mapperSupplier) { + this.mapperSupplier = requireNonNull(mapperSupplier); + } + + @Override + public ScanLifetimeMapper get() { + return new ScanMapperLifetimeAdapter<>(mapperSupplier.get()); + } + } + + @SuppressWarnings("deprecation") + private static final class ScanMapperLifetimeAdapter + extends ScanMapperAdapter> + implements ScanLifetimeMapper { + ScanMapperLifetimeAdapter(final ScanWithLifetimeMapper mapper) { + super(mapper); + } + + @Override + public void afterFinally() { + mapper.afterFinally(); + } + } } diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithMapper.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithMapper.java index 5399f23509..036ef6f76b 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithMapper.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithMapper.java @@ -24,7 +24,9 @@ * Provides the ability to transform (aka map) signals emitted via the {@link Publisher#scanWith(Supplier)} operator. * @param Type of items emitted by the {@link Publisher} this operator is applied. * @param Type of items emitted by this operator. + * @deprecated Use {@link ScanMapper}. */ +@Deprecated public interface ScanWithMapper { /** * Invoked on each {@link Subscriber#onNext(Object)} signal and maps from type {@link T} to type {@link R}. diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java index 17d62cde31..396f07731f 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java @@ -15,6 +15,7 @@ */ package io.servicetalk.concurrent.api; +import io.servicetalk.concurrent.api.ScanMapper.MappedTerminal; import io.servicetalk.concurrent.internal.FlowControlUtils; import io.servicetalk.context.api.ContextMap; @@ -31,14 +32,20 @@ final class ScanWithPublisher extends AbstractNoHandleSubscribePublisher { private final Publisher original; - private final Supplier> mapperSupplier; + private final Supplier> mapperSupplier; ScanWithPublisher(Publisher original, Supplier initial, BiFunction accumulator) { - this(original, new SupplierScanWithMapper<>(initial, accumulator)); + this(new SupplierScanWithMapper<>(initial, accumulator), original); } ScanWithPublisher(Publisher original, + @SuppressWarnings("deprecation") Supplier> mapperSupplier) { + this(new SupplierScanMapper<>(mapperSupplier), original); + } + + ScanWithPublisher(Supplier> mapperSupplier, + Publisher original) { this.mapperSupplier = requireNonNull(mapperSupplier); this.original = original; } @@ -63,7 +70,7 @@ static class ScanWithSubscriber implements Subscriber { private static final long TERMINATED = Long.MIN_VALUE; private static final long TERMINAL_PENDING = TERMINATED + 1; /** - * We don't want to invoke {@link ScanWithMapper#mapOnError(Throwable)} for invalid demand because we may never + * We don't want to invoke {@link ScanMapper#mapOnError(Throwable)} for invalid demand because we may never * get enough demand to deliver an {@link #onNext(Object)} to the downstream subscriber. {@code -1} to avoid * {@link #demand} underflow in onNext (in case the source doesn't deliver a timely error). */ @@ -72,16 +79,17 @@ static class ScanWithSubscriber implements Subscriber { private final Subscriber subscriber; private final ContextMap contextMap; private final AsyncContextProvider contextProvider; - private final ScanWithMapper mapper; + private final ScanMapper mapper; private volatile long demand; /** - * Retains the {@link #onError(Throwable)} cause for use in the {@link Subscription}. + * Retains the {@link MappedTerminal} cause for use in the {@link Subscription}. * Happens-before relationship with {@link #demand} means no volatile or other synchronization required. */ @Nullable - private Throwable errorCause; + private MappedTerminal mappedTerminal; - ScanWithSubscriber(final Subscriber subscriber, final ScanWithMapper mapper, + ScanWithSubscriber(final Subscriber subscriber, + final ScanMapper mapper, final AsyncContextProvider contextProvider, final ContextMap contextMap) { this.subscriber = subscriber; this.contextProvider = contextProvider; @@ -103,11 +111,8 @@ public void request(final long n) { } else if (demandUpdater.getAndAccumulate(ScanWithSubscriber.this, n, FlowControlUtils::addWithOverflowProtectionIfNotNegative) == TERMINAL_PENDING) { demand = TERMINATED; - if (errorCause != null) { - deliverOnErrorFromSubscription(errorCause, newOffloadedSubscriber()); - } else { - deliverOnCompleteFromSubscription(newOffloadedSubscriber()); - } + assert mappedTerminal != null; + deliverAllTerminalFromSubscription(mappedTerminal, newOffloadedSubscriber()); } else { subscription.request(n); } @@ -157,116 +162,136 @@ public void onComplete() { /** * Executes the on-error signal and returns {@code true} if demand was sufficient to deliver the result of the - * mapped {@code Throwable} with {@link ScanWithMapper#mapOnError(Throwable)}. + * mapped {@code Throwable} and terminal signal. * * @param t The throwable to propagate * @return {@code true} if the demand was sufficient to deliver the result of the mapped {@code Throwable} with - * {@link ScanWithMapper#mapOnError(Throwable)}. + * terminal signal. */ protected boolean onError0(final Throwable t) { - errorCause = t; - final boolean doMap; try { - doMap = mapper.mapTerminal(); + mappedTerminal = mapper.mapOnError(t); } catch (Throwable cause) { subscriber.onError(cause); return true; } - if (doMap) { - for (;;) { - final long currDemand = demand; - if (currDemand > 0 && demandUpdater.compareAndSet(this, currDemand, TERMINATED)) { - deliverOnError(t, subscriber); - break; - } else if (currDemand == 0 && demandUpdater.compareAndSet(this, currDemand, TERMINAL_PENDING)) { - return false; - } else if (currDemand < 0) { - // Either we previously saw invalid request n, or upstream has sent a duplicate terminal event. - // In either circumstance we propagate the error downstream and bail. - subscriber.onError(t); - break; - } - } - } else { - demand = TERMINATED; - subscriber.onError(t); - } + if (mappedTerminal != null) { + return deliverAllTerminal(mappedTerminal, subscriber, t); + } + demand = TERMINATED; + subscriber.onError(t); return true; } /** * Executes the on-completed signal and returns {@code true} if demand was sufficient to deliver the concat item - * from {@link ScanWithMapper#mapOnComplete()} downstream. + * from {@link ScanMapper#mapOnComplete()} downstream. * * @return {@code true} if demand was sufficient to deliver the concat item from - * {@link ScanWithMapper#mapOnComplete()} downstream. + * {@link ScanMapper#mapOnComplete()} downstream. */ protected boolean onComplete0() { - final boolean doMap; try { - doMap = mapper.mapTerminal(); + mappedTerminal = mapper.mapOnComplete(); } catch (Throwable cause) { subscriber.onError(cause); return true; } - if (doMap) { + if (mappedTerminal != null) { + return deliverAllTerminal(mappedTerminal, subscriber, null); + } + demand = TERMINATED; + subscriber.onComplete(); + return true; + } + + protected void onCancel() { + //NOOP + } + + protected void deliverAllTerminalFromSubscription(final MappedTerminal mappedTerminal, + final Subscriber subscriber) { + deliverOnNextAndTerminal(mappedTerminal, subscriber); + } + + private boolean deliverAllTerminal(final MappedTerminal mappedTerminal, + final Subscriber subscriber, + @Nullable final Throwable originalCause) { + final boolean onNextValid; + try { + onNextValid = mappedTerminal.onNextValid(); + } catch (Throwable cause) { + subscriber.onError(cause); + return true; + } + if (onNextValid) { for (;;) { final long currDemand = demand; if (currDemand > 0 && demandUpdater.compareAndSet(this, currDemand, TERMINATED)) { - deliverOnComplete(subscriber); + deliverOnNextAndTerminal(mappedTerminal, subscriber); break; } else if (currDemand == 0 && demandUpdater.compareAndSet(this, currDemand, TERMINAL_PENDING)) { return false; } else if (currDemand < 0) { - // Either we previously saw invalid request n, or upstream has sent a duplicate terminal event. - // In either circumstance we propagate the error downstream and bail. - subscriber.onError(new IllegalStateException("onComplete with invalid demand: " + currDemand)); + // Either we previously saw invalid request n, or upstream has sent a duplicate terminal + // event. In either circumstance we propagate the error downstream and bail. + subscriber.onError(originalCause != null ? originalCause : + new IllegalStateException("onComplete with invalid demand: " + currDemand)); break; } } } else { demand = TERMINATED; - subscriber.onComplete(); + deliverTerminal(mappedTerminal, subscriber); } - return true; } - protected void onCancel() { - //NOOP - } - - protected void deliverOnErrorFromSubscription(Throwable t, Subscriber subscriber) { - deliverOnError(t, subscriber); - } - - protected void deliverOnCompleteFromSubscription(Subscriber subscriber) { - deliverOnComplete(subscriber); - } - - private void deliverOnError(Throwable t, Subscriber subscriber) { + private void deliverTerminal(final MappedTerminal mappedTerminal, + final Subscriber subscriber) { + final Throwable cause; try { - subscriber.onNext(mapper.mapOnError(t)); - } catch (Throwable cause) { - subscriber.onError(cause); + cause = mappedTerminal.terminal(); + } catch (Throwable cause2) { + subscriber.onError(cause2); return; } - subscriber.onComplete(); + if (cause == null) { + subscriber.onComplete(); + } else { + subscriber.onError(cause); + } } - private void deliverOnComplete(Subscriber subscriber) { + private void deliverOnNextAndTerminal(final MappedTerminal mappedTerminal, + final Subscriber subscriber) { try { - subscriber.onNext(mapper.mapOnComplete()); + assert mappedTerminal.onNextValid(); + subscriber.onNext(mappedTerminal.onNext()); } catch (Throwable cause) { subscriber.onError(cause); return; } - subscriber.onComplete(); + deliverTerminal(mappedTerminal, subscriber); + } + } + + @SuppressWarnings("deprecation") + private static final class SupplierScanMapper implements Supplier> { + private final Supplier> mapperSupplier; + + SupplierScanMapper(Supplier> mapperSupplier) { + this.mapperSupplier = requireNonNull(mapperSupplier); + } + + @Override + public ScanMapper get() { + return new ScanMapperAdapter<>(mapperSupplier.get()); } } - private static final class SupplierScanWithMapper implements Supplier> { + private static final class SupplierScanWithMapper implements Supplier> { private final BiFunction accumulator; private final Supplier initial; @@ -276,8 +301,8 @@ private static final class SupplierScanWithMapper implements Supplier get() { - return new ScanWithMapper() { + public ScanMapper get() { + return new ScanMapper() { @Nullable private R state = initial.get(); @@ -287,25 +312,72 @@ public R mapOnNext(@Nullable final T next) { return state; } + @Nullable @Override - public R mapOnError(final Throwable cause) { - throw newMapTerminalUnsupported(); - } - - @Override - public R mapOnComplete() { - throw newMapTerminalUnsupported(); + public MappedTerminal mapOnError(final Throwable cause) { + return null; } + @Nullable @Override - public boolean mapTerminal() { - return false; + public MappedTerminal mapOnComplete() { + return null; } }; } + } + + @SuppressWarnings("deprecation") + static class ScanMapperAdapter> + implements ScanMapper { + final X mapper; + + ScanMapperAdapter(final X mapper) { + this.mapper = requireNonNull(mapper); + } + + @Nullable + @Override + public R mapOnNext(@Nullable final T next) { + return mapper.mapOnNext(next); + } - private static IllegalStateException newMapTerminalUnsupported() { - throw new IllegalStateException("mapTerminal returns false, this method should never be invoked!"); + @Nullable + @Override + public MappedTerminal mapOnError(final Throwable cause) throws Throwable { + return mapper.mapTerminal() ? new FixedMappedTerminal<>(mapper.mapOnError(cause)) : null; + } + + @Nullable + @Override + public MappedTerminal mapOnComplete() { + return mapper.mapTerminal() ? new FixedMappedTerminal<>(mapper.mapOnComplete()) : null; + } + } + + private static final class FixedMappedTerminal implements MappedTerminal { + @Nullable + private final R onNext; + + private FixedMappedTerminal(@Nullable final R onNext) { + this.onNext = onNext; + } + + @Nullable + @Override + public R onNext() { + return onNext; + } + + @Override + public boolean onNextValid() { + return true; + } + + @Nullable + @Override + public Throwable terminal() { + return null; } } } diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java index 1d77eb9ece..165d4ca413 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java @@ -17,6 +17,7 @@ import io.servicetalk.concurrent.PublisherSource; import io.servicetalk.concurrent.PublisherSource.Subscription; +import io.servicetalk.concurrent.api.ScanMapper.MappedTerminal; import io.servicetalk.concurrent.internal.DeliberateException; import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; @@ -39,7 +40,7 @@ import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; import static io.servicetalk.concurrent.internal.SubscriberUtils.newExceptionForInvalidRequestN; -import static io.servicetalk.utils.internal.PlatformDependent.throwException; +import static io.servicetalk.utils.internal.ThrowableUtils.throwException; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; @@ -48,7 +49,6 @@ import static org.hamcrest.Matchers.nullValue; class ScanWithPublisherTest { - @Test void scanWithComplete() { scanWithNoTerminalMapper(true); @@ -101,8 +101,8 @@ public void request(final long n) { public void cancel() { } }); - toSource(fromSource(syncNoReentryProtectionSource).scanWithLifetime(() - -> new ScanWithLifetimeMapper() { + toSource(fromSource(syncNoReentryProtectionSource).scanWithLifetimeMapper(() + -> new ScanLifetimeMapper() { @Override public void afterFinally() { finalizations.incrementAndGet(); @@ -117,22 +117,18 @@ public Integer mapOnNext(@Nullable final Integer next) { @Nullable @Override - public Integer mapOnError(final Throwable cause) { + public MappedTerminal mapOnError(final Throwable cause) { return null; } @Nullable @Override - public Integer mapOnComplete() { + public MappedTerminal mapOnComplete() { return null; } - - @Override - public boolean mapTerminal() { - return false; - } })).subscribe(new PublisherSource.Subscriber() { - Subscription subscription; + @Nullable + private Subscription subscription; @Override public void onSubscribe(final Subscription subscription) { @@ -142,6 +138,7 @@ public void onSubscribe(final Subscription subscription) { @Override public void onNext(@Nullable final Integer integer) { + assert subscription != null; subscription.request(1); } @@ -186,7 +183,8 @@ private static void scanOnNextTerminalNoConcat(boolean onNext, boolean onComplet final AtomicInteger finalizations = new AtomicInteger(0); PublisherSource.Processor processor = newPublisherProcessor(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(fromSource(processor), withLifetime, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(fromSource(processor), withLifetime, + new ScanLifetimeMapper() { @Nullable @Override public Integer mapOnNext(@Nullable final Integer next) { @@ -195,19 +193,14 @@ public Integer mapOnNext(@Nullable final Integer next) { @Nullable @Override - public Integer mapOnError(final Throwable cause) { - throw new UnsupportedOperationException(); + public MappedTerminal mapOnError(final Throwable cause) { + return null; } @Nullable @Override - public Integer mapOnComplete() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean mapTerminal() { - return false; + public MappedTerminal mapOnComplete() { + return null; } @Override @@ -247,7 +240,7 @@ void onErrorConcatWithUpfrontDemand(boolean withLifetime) { } @ParameterizedTest(name = "{displayName} [{index}] {arguments}") - @ValueSource(booleans = {true, false}) + @ValueSource(booleans = {/*true,*/ false}) void onCompleteConcatDelayedDemand(boolean withLifetime) { terminalConcatWithDemand(false, true, withLifetime); } @@ -262,7 +255,8 @@ private static void terminalConcatWithDemand(boolean demandUpFront, boolean onCo final AtomicInteger finalizations = new AtomicInteger(0); PublisherSource.Processor processor = newPublisherProcessor(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(fromSource(processor), withLifetime, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(fromSource(processor), withLifetime, + new ScanLifetimeMapper() { private int sum; @Override public Integer mapOnNext(@Nullable final Integer next) { @@ -273,18 +267,17 @@ public Integer mapOnNext(@Nullable final Integer next) { } @Override - public Integer mapOnError(final Throwable cause) { - return ++sum; + public MappedTerminal mapOnError(final Throwable cause) { + return mapTerminal(); } @Override - public Integer mapOnComplete() { - return ++sum; + public MappedTerminal mapOnComplete() { + return mapTerminal(); } - @Override - public boolean mapTerminal() { - return true; + private MappedTerminal mapTerminal() { + return new FixedMappedTerminal<>(++sum); } @Override @@ -322,25 +315,21 @@ void scanWithFinalizationOnCancel() { final AtomicInteger finalizations = new AtomicInteger(0); PublisherSource.Processor processor = newPublisherProcessor(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(fromSource(processor), true, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(fromSource(processor), true, new ScanLifetimeMapper() { @Override public Integer mapOnNext(@Nullable final Integer next) { return next; } + @Nullable @Override - public Integer mapOnError(final Throwable cause) throws Throwable { - throw cause; - } - - @Override - public Integer mapOnComplete() { - return 5; + public MappedTerminal mapOnError(final Throwable cause) { + return null; } @Override - public boolean mapTerminal() { - return true; + public MappedTerminal mapOnComplete() { + return new FixedMappedTerminal<>(5); } @Override @@ -374,24 +363,25 @@ void scanWithFinalizationOnCancelDifferentThreads(final boolean interleaveCancel final CountDownLatch nextDeliveredResume = new CountDownLatch(1); final PublisherSource source = toSource(scanWithOperator(fromSource(processor), true, - new ScanWithLifetimeMapper() { + new ScanLifetimeMapper() { @Override public Integer mapOnNext(@Nullable final Integer next) { return next; } + @Nullable @Override - public Integer mapOnError(final Throwable cause) throws Throwable { + public MappedTerminal mapOnError(final Throwable cause) throws Throwable { throw cause; } + @Nullable @Override - public Integer mapOnComplete() { - return 5; + public MappedTerminal mapOnComplete() { + return mapTerminal() ? new FixedMappedTerminal<>(5) : null; } - @Override - public boolean mapTerminal() { + private boolean mapTerminal() { if (interleaveCancellation) { checkpoint.countDown(); try { @@ -469,7 +459,8 @@ private static void terminalThrowsHandled(boolean onComplete, boolean withLifeti final AtomicInteger finalizations = new AtomicInteger(0); PublisherSource.Processor processor = newPublisherProcessor(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(fromSource(processor), withLifetime, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(fromSource(processor), withLifetime, + new ScanLifetimeMapper() { @Nullable @Override public Integer mapOnNext(@Nullable final Integer next) { @@ -478,21 +469,16 @@ public Integer mapOnNext(@Nullable final Integer next) { @Nullable @Override - public Integer mapOnError(final Throwable cause) throws Throwable { + public MappedTerminal mapOnError(final Throwable cause) throws Throwable { throw cause; } @Nullable @Override - public Integer mapOnComplete() { + public MappedTerminal mapOnComplete() { throw DELIBERATE_EXCEPTION; } - @Override - public boolean mapTerminal() { - return true; - } - @Override public void afterFinally() { finalizations.incrementAndGet(); @@ -527,7 +513,8 @@ private static void mapTerminalSignalThrows(boolean onComplete, boolean withLife final AtomicInteger finalizations = new AtomicInteger(0); PublisherSource.Processor processor = newPublisherProcessor(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(fromSource(processor), withLifetime, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(fromSource(processor), withLifetime, + new ScanLifetimeMapper() { @Nullable @Override public Integer mapOnNext(@Nullable final Integer next) { @@ -536,18 +523,13 @@ public Integer mapOnNext(@Nullable final Integer next) { @Nullable @Override - public Integer mapOnError(final Throwable cause) { - return null; + public MappedTerminal mapOnError(final Throwable cause) { + throw DELIBERATE_EXCEPTION; } @Nullable @Override - public Integer mapOnComplete() { - return null; - } - - @Override - public boolean mapTerminal() { + public MappedTerminal mapOnComplete() { throw DELIBERATE_EXCEPTION; } @@ -619,7 +601,7 @@ void cancelStillAllowsMaps(boolean onError, boolean cancelBefore, boolean withLi final AtomicInteger finalizations = new AtomicInteger(0); TestPublisher publisher = new TestPublisher<>(); TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); - toSource(scanWithOperator(publisher, withLifetime, new ScanWithLifetimeMapper() { + toSource(scanWithOperator(publisher, withLifetime, new ScanLifetimeMapper() { private int sum; @Nullable @Override @@ -631,18 +613,13 @@ public Integer mapOnNext(@Nullable final Integer next) { } @Override - public Integer mapOnError(final Throwable cause) { - return sum; - } - - @Override - public Integer mapOnComplete() { - return sum; + public MappedTerminal mapOnError(final Throwable cause) { + return new FixedMappedTerminal<>(sum); } @Override - public boolean mapTerminal() { - return true; + public MappedTerminal mapOnComplete() { + return new FixedMappedTerminal<>(sum); } @Override @@ -693,8 +670,8 @@ private static Stream cancelStillAllowsMapsParams() { Arguments.of(true, true, true)); } - private static ScanWithLifetimeMapper noopMapper(final AtomicInteger finalizations) { - return new ScanWithLifetimeMapper() { + private static ScanLifetimeMapper noopMapper(final AtomicInteger finalizations) { + return new ScanLifetimeMapper() { @Nullable @Override public Integer mapOnNext(@Nullable final Integer next) { @@ -703,21 +680,16 @@ public Integer mapOnNext(@Nullable final Integer next) { @Nullable @Override - public Integer mapOnError(final Throwable cause) { + public MappedTerminal mapOnError(final Throwable cause) { return null; } @Nullable @Override - public Integer mapOnComplete() { + public MappedTerminal mapOnComplete() { return null; } - @Override - public boolean mapTerminal() { - return true; - } - @Override public void afterFinally() { finalizations.incrementAndGet(); @@ -726,7 +698,31 @@ public void afterFinally() { } private static Publisher scanWithOperator(final Publisher source, final boolean withLifetime, - final ScanWithLifetimeMapper mapper) { - return withLifetime ? source.scanWithLifetime(() -> mapper) : source.scanWith(() -> mapper); + final ScanLifetimeMapper mapper) { + return withLifetime ? source.scanWithLifetimeMapper(() -> mapper) : source.scanWithMapper(() -> mapper); + } + + private static final class FixedMappedTerminal implements MappedTerminal { + private final T onNext; + + private FixedMappedTerminal(final T onNext) { + this.onNext = onNext; + } + + @Override + public T onNext() { + return onNext; + } + + @Override + public boolean onNextValid() { + return true; + } + + @Nullable + @Override + public Throwable terminal() { + return null; + } } } diff --git a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithLifetimeTckTest.java b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithLifetimeTckTest.java index 333cd7f9f3..56cf95e709 100644 --- a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithLifetimeTckTest.java +++ b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithLifetimeTckTest.java @@ -16,7 +16,7 @@ package io.servicetalk.concurrent.reactivestreams.tck; import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.concurrent.api.ScanWithLifetimeMapper; +import io.servicetalk.concurrent.api.ScanLifetimeMapper; import org.testng.annotations.Test; @@ -28,11 +28,7 @@ public class PublisherScanWithLifetimeTckTest extends AbstractPublisherOperatorTckTest { @Override protected Publisher composePublisher(Publisher publisher, int elements) { - return publisher.scanWithLifetime(() -> new ScanWithLifetimeMapper() { - @Override - public void afterFinally() { - } - + return publisher.scanWithLifetimeMapper(() -> new ScanLifetimeMapper() { @Nullable @Override public String mapOnNext(@Nullable final Integer next) { @@ -41,19 +37,18 @@ public String mapOnNext(@Nullable final Integer next) { @Nullable @Override - public String mapOnError(final Throwable cause) throws Throwable { + public MappedTerminal mapOnError(final Throwable cause) { return null; } @Nullable @Override - public String mapOnComplete() { + public MappedTerminal mapOnComplete() { return null; } @Override - public boolean mapTerminal() { - return false; + public void afterFinally() { } }); } diff --git a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithMapperTckTest.java b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithMapperTckTest.java index 90b013b0b6..883a231187 100644 --- a/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithMapperTckTest.java +++ b/servicetalk-concurrent-reactivestreams/src/test/java/io/servicetalk/concurrent/reactivestreams/tck/PublisherScanWithMapperTckTest.java @@ -16,7 +16,7 @@ package io.servicetalk.concurrent.reactivestreams.tck; import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.concurrent.api.ScanWithMapper; +import io.servicetalk.concurrent.api.ScanMapper; import org.testng.annotations.Test; @@ -28,7 +28,7 @@ public class PublisherScanWithMapperTckTest extends AbstractPublisherOperatorTckTest { @Override protected Publisher composePublisher(Publisher publisher, int elements) { - return publisher.scanWith(() -> new ScanWithMapper() { + return publisher.scanWithMapper(() -> new ScanMapper() { @Nullable @Override public String mapOnNext(@Nullable final Integer next) { @@ -37,20 +37,15 @@ public String mapOnNext(@Nullable final Integer next) { @Nullable @Override - public String mapOnError(final Throwable cause) { + public MappedTerminal mapOnError(final Throwable cause) { return null; } @Nullable @Override - public String mapOnComplete() { + public MappedTerminal mapOnComplete() { return null; } - - @Override - public boolean mapTerminal() { - return false; - } }); } } diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/StreamingHttpPayloadHolder.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/StreamingHttpPayloadHolder.java index 7e458cb1eb..8693fe2a3d 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/StreamingHttpPayloadHolder.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/StreamingHttpPayloadHolder.java @@ -22,7 +22,7 @@ import io.servicetalk.concurrent.SingleSource.Processor; import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.PublisherOperator; -import io.servicetalk.concurrent.api.ScanWithMapper; +import io.servicetalk.concurrent.api.ScanMapper; import io.servicetalk.concurrent.api.Single; import io.servicetalk.http.api.HttpDataSourceTransformations.BridgeFlowControlAndDiscardOperator; import io.servicetalk.http.api.HttpDataSourceTransformations.HttpTransportBufferFilterOperator; @@ -152,7 +152,7 @@ void transform(final TrailersTransformer trailersTransformer, final Publisher transformedPayloadBody = body.liftSync( new PreserveTrailersBufferOperator(trailersProcessor)); return merge(serializer.deserialize(headers, transformedPayloadBody, allocator), - fromSource(trailersProcessor)).scanWith(() -> + fromSource(trailersProcessor)).scanWithMapper(() -> new TrailersMapper<>(trailersTransformer, headersFactory)) .shareContextOnSubscribe(); })); @@ -160,7 +160,7 @@ void transform(final TrailersTransformer trailersTransformer, void transform(final TrailersTransformer trailersTransformer) { transform(trailersTransformer, - body -> body.scanWith(() -> new TrailersMapper<>(trailersTransformer, headersFactory))); + body -> body.scanWithMapper(() -> new TrailersMapper<>(trailersTransformer, headersFactory))); } private void transform(final TrailersTransformer trailersTransformer, @@ -276,7 +276,7 @@ private static Publisher merge(Publisher p, Single s) { return from(p, s.toPublisher().filter(Objects::nonNull)).flatMapMerge(identity(), 2); } - private static final class TrailersMapper implements ScanWithMapper { + private static final class TrailersMapper implements ScanMapper { private final TrailersTransformer trailersTransformer; private final HttpHeadersFactory headersFactory; @Nullable @@ -307,19 +307,43 @@ public Object mapOnNext(@Nullable final Object next) { return trailersTransformer.accept(state, nextS); } + @Nullable @Override - public Object mapOnError(final Throwable t) throws Throwable { - return trailersTransformer.catchPayloadFailure(state, t, headersFactory.newEmptyTrailers()); + public MappedTerminal mapOnError(final Throwable t) throws Throwable { + return trailers == null ? new DefaultMappedTerminal<>( + trailersTransformer.catchPayloadFailure(state, t, headersFactory.newEmptyTrailers())) : null; } + @Nullable @Override - public Object mapOnComplete() { - return trailersTransformer.payloadComplete(state, headersFactory.newEmptyTrailers()); + public MappedTerminal mapOnComplete() { + return trailers == null ? new DefaultMappedTerminal<>( + trailersTransformer.payloadComplete(state, headersFactory.newEmptyTrailers())) : null; + } + } + + private static final class DefaultMappedTerminal implements ScanMapper.MappedTerminal { + private final T onNext; + + private DefaultMappedTerminal(final T onNext) { + this.onNext = onNext; } + @Nullable + @Override + public T onNext() { + return onNext; + } + + @Override + public boolean onNextValid() { + return true; + } + + @Nullable @Override - public boolean mapTerminal() { - return trailers == null; + public Throwable terminal() { + return null; } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java index 7d63002b12..6c22f82953 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AbstractStreamingHttpConnection.java @@ -188,7 +188,7 @@ public Single request(final StreamingHttpRequest request) // requests with non-replayable messageBody flatRequest = Single.succeeded(request).concatDeferSubscribe(messageBody); if (shouldAppendTrailers(connectionContext().protocol(), request)) { - flatRequest = flatRequest.scanWith(HeaderUtils::appendTrailersMapper); + flatRequest = flatRequest.scanWithMapper(HeaderUtils::appendTrailersMapper); } } addRequestTransferEncodingIfNecessary(request); diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/HeaderUtils.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/HeaderUtils.java index 800b122b1a..f9d6a49b24 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/HeaderUtils.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/HeaderUtils.java @@ -18,7 +18,8 @@ import io.servicetalk.buffer.api.Buffer; import io.servicetalk.buffer.api.CharSequences; import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.concurrent.api.ScanWithMapper; +import io.servicetalk.concurrent.api.ScanMapper; +import io.servicetalk.concurrent.api.ScanMapper.MappedTerminal; import io.servicetalk.http.api.EmptyHttpHeaders; import io.servicetalk.http.api.HttpHeaderNames; import io.servicetalk.http.api.HttpHeaders; @@ -171,8 +172,8 @@ static boolean responseMayHaveContent(final int statusCode, return !isEmptyResponseStatus(statusCode) && !isEmptyConnectResponse(requestMethod, statusCode); } - static ScanWithMapper appendTrailersMapper() { - return new ScanWithMapper() { + static ScanMapper appendTrailersMapper() { + return new ScanMapper() { private boolean sawHeaders; @Nullable @@ -186,22 +187,41 @@ public Object mapOnNext(@Nullable final Object next) { @Nullable @Override - public Object mapOnError(final Throwable t) throws Throwable { - throw t; - } - - @Override - public Object mapOnComplete() { - return EmptyHttpHeaders.INSTANCE; + public MappedTerminal mapOnError(final Throwable cause) throws Throwable { + throw cause; } + @Nullable @Override - public boolean mapTerminal() { - return !sawHeaders; + public MappedTerminal mapOnComplete() { + return sawHeaders ? null : EmptyHeadersComplete.INSTANCE; } }; } + private static final class EmptyHeadersComplete implements MappedTerminal { + private static final MappedTerminal INSTANCE = new EmptyHeadersComplete(); + + private EmptyHeadersComplete() { + } + + @Override + public HttpHeaders onNext() { + return EmptyHttpHeaders.INSTANCE; + } + + @Override + public boolean onNextValid() { + return true; + } + + @Nullable + @Override + public Throwable terminal() { + return null; + } + } + static boolean emptyMessageBody(final HttpMetaData metadata, final Publisher messageBody) { return messageBody == empty() || emptyMessageBody(metadata); } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyHttpServer.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyHttpServer.java index e0a1e62a83..295a1710dc 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyHttpServer.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/NettyHttpServer.java @@ -429,7 +429,7 @@ private static Publisher handleResponse(final HttpProtocolVersion protoc } else { flatResponse = Single.succeeded(response).concatPropagateCancel(messageBody); if (shouldAppendTrailers(protocolVersion, response)) { - flatResponse = flatResponse.scanWith(HeaderUtils::appendTrailersMapper); + flatResponse = flatResponse.scanWithMapper(HeaderUtils::appendTrailersMapper); } } addResponseTransferEncodingIfNecessary(response, requestMethod);