diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index be9914a38757..a85c033cedc1 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -456,6 +456,35 @@ public static Publisher outputStreamPublisher( consumer::accept, new DataBufferMapper(bufferFactory), executor, chunkSize); } + /** + * Subscribes to given {@link Publisher} and returns subscription + * as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API. + * Given the {@link InputStream} implementation buffers messages as per configuration. + * The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the + * terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}}) + * and all the stored {@link DataBuffer} polled from the internal buffer. + * The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer} + * when {@link InputStream#close()} is called. + *

+ * Note: The implementation of the returned {@link InputStream} disallow concurrent call on + * any of the {@link InputStream#read} methods + *

+ * Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription + * and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed + * + * @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream} + * @param bufferSize the maximum amount of {@link DataBuffer} prefetched in advance and stored inside {@link InputStream} + * @return an {@link InputStream} instance representing given {@link Publisher} messages + */ + public static InputStream subscribeAsInputStream(Publisher publisher, int bufferSize) { + Assert.notNull(publisher, "Publisher must not be null"); + Assert.isTrue(bufferSize > 0, "Buffer size must be > 0"); + + InputStreamSubscriber inputStreamSubscriber = new InputStreamSubscriber(bufferSize); + publisher.subscribe(inputStreamSubscriber); + return inputStreamSubscriber; + } + //--------------------------------------------------------------------- // Various diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java b/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java new file mode 100644 index 000000000000..b364927d953d --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/InputStreamSubscriber.java @@ -0,0 +1,355 @@ +package org.springframework.core.io.buffer; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.lang.Nullable; +import reactor.core.Exceptions; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ConcurrentModificationException; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Bridges between {@link Publisher Publisher<DataBuffer>} and {@link InputStream}. + * + *

Note that this class has a near duplicate in + * {@link org.springframework.http.client.InputStreamSubscriber}. + * + * @author Oleh Dokuka + * @since 6.1 + */ +final class InputStreamSubscriber extends InputStream implements Subscriber { + + static final Object READY = new Object(); + static final DataBuffer DONE = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0)); + static final DataBuffer CLOSED = DefaultDataBuffer.fromEmptyByteBuffer(DefaultDataBufferFactory.sharedInstance, ByteBuffer.allocate(0)); + + final int prefetch; + final int limit; + final ReentrantLock lock; + final Queue queue; + + final AtomicReference parkedThread = new AtomicReference<>(); + final AtomicInteger workAmount = new AtomicInteger(); + + volatile boolean closed; + int consumed; + + @Nullable + DataBuffer available; + + @Nullable + Subscription s; + boolean done; + @Nullable + Throwable error; + + InputStreamSubscriber(int prefetch) { + this.prefetch = prefetch; + this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2); + this.queue = new ArrayBlockingQueue<>(prefetch); + this.lock = new ReentrantLock(false); + } + + @Override + public void onSubscribe(Subscription subscription) { + if (this.s != null) { + subscription.cancel(); + return; + } + + this.s = subscription; + subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch); + } + + @Override + public void onNext(DataBuffer t) { + if (this.done) { + discard(t); + return; + } + + if (!queue.offer(t)) { + discard(t); + error = new RuntimeException("Buffer overflow"); + done = true; + } + + int previousWorkState = addWork(); + if (previousWorkState == Integer.MIN_VALUE) { + DataBuffer value = queue.poll(); + if (value != null) { + discard(value); + } + return; + } + + if (previousWorkState == 0) { + resume(); + } + } + + @Override + public void onError(Throwable throwable) { + if (this.done) { + return; + } + this.error = throwable; + this.done = true; + + if (addWork() == 0) { + resume(); + } + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (addWork() == 0) { + resume(); + } + } + + int addWork() { + for (;;) { + int produced = this.workAmount.getPlain(); + + if (produced == Integer.MIN_VALUE) { + return Integer.MIN_VALUE; + } + + int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1; + + + if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) { + return produced; + } + } + } + + @Override + public int read() throws IOException { + if (!lock.tryLock()) { + if (this.closed) { + return -1; + } + throw new ConcurrentModificationException("concurrent access is disallowed"); + } + + try { + DataBuffer bytes = getBytesOrAwait(); + + if (bytes == DONE) { + this.closed = true; + cleanAndFinalize(); + if (this.error == null) { + return -1; + } + else { + throw Exceptions.propagate(error); + } + } else if (bytes == CLOSED) { + cleanAndFinalize(); + return -1; + } + + return bytes.read() & 0xFF; + } + catch (Throwable t) { + this.closed = true; + this.s.cancel(); + cleanAndFinalize(); + throw Exceptions.propagate(t); + } + finally { + lock.unlock(); + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (len == 0) { + return 0; + } + + if (!lock.tryLock()) { + if (this.closed) { + return -1; + } + throw new ConcurrentModificationException("concurrent access is disallowed"); + } + + try { + for (int j = 0; j < len;) { + DataBuffer bytes = getBytesOrAwait(); + + if (bytes == DONE) { + cleanAndFinalize(); + if (this.error == null) { + this.closed = true; + return j == 0 ? -1 : j; + } + else { + if (j == 0) { + this.closed = true; + throw Exceptions.propagate(error); + } + + return j; + } + } else if (bytes == CLOSED) { + this.s.cancel(); + cleanAndFinalize(); + return -1; + } + int initialReadPosition = bytes.readPosition(); + bytes.read(b, off + j, Math.min(len - j, bytes.readableByteCount())); + j += bytes.readPosition() - initialReadPosition; + } + + return len; + } + catch (Throwable t) { + this.closed = true; + this.s.cancel(); + cleanAndFinalize(); + throw Exceptions.propagate(t); + } + finally { + lock.unlock(); + } + } + + DataBuffer getBytesOrAwait() { + if (this.available == null || this.available.readableByteCount() == 0) { + + discard(this.available); + this.available = null; + + int actualWorkAmount = this.workAmount.getAcquire(); + for (;;) { + if (this.closed) { + return CLOSED; + } + + boolean d = this.done; + DataBuffer t = this.queue.poll(); + if (t != null) { + int consumed = ++this.consumed; + this.available = t; + if (consumed == this.limit) { + this.consumed = 0; + this.s.request(this.limit); + } + break; + } + + if (d) { + return DONE; + } + + actualWorkAmount = workAmount.addAndGet(-actualWorkAmount); + if (actualWorkAmount == 0) { + await(); + } + } + } + + return this.available; + } + + void cleanAndFinalize() { + discard(this.available); + this.available = null; + + for (;;) { + int workAmount = this.workAmount.getPlain(); + DataBuffer value; + + while((value = queue.poll()) != null) { + discard(value); + } + + if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) { + return; + } + } + } + + void discard(@Nullable DataBuffer value) { + DataBufferUtils.release(value); + } + + @Override + public void close() throws IOException { + if (this.closed) { + return; + } + + this.closed = true; + + if (!this.lock.tryLock()) { + if (addWork() == 0) { + resume(); + } + return; + } + + try { + this.s.cancel(); + cleanAndFinalize(); + } + finally { + this.lock.unlock(); + } + } + + private void await() { + Thread toUnpark = Thread.currentThread(); + + while (true) { + Object current = this.parkedThread.get(); + if (current == READY) { + break; + } + + if (current != null && current != toUnpark) { + throw new IllegalStateException("Only one (Virtual)Thread can await!"); + } + + if (parkedThread.compareAndSet( null, toUnpark)) { + LockSupport.park(); + // we don't just break here because park() can wake up spuriously + // if we got a proper resume, get() == READY and the loop will quit above + } + } + // clear the resume indicator so that the next await call will park without a resume() + this.parkedThread.lazySet(null); + } + + private void resume() { + if (this.parkedThread != READY) { + Object old = parkedThread.getAndSet(READY); + if (old != READY) { + LockSupport.unpark((Thread)old); + } + } + } + + +} diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 9ea04e339c62..d0fe3c54466e 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -17,6 +17,7 @@ package org.springframework.core.io.buffer; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.net.URI; @@ -27,15 +28,18 @@ import java.nio.channels.ReadableByteChannel; import java.nio.channels.SeekableByteChannel; import java.nio.channels.WritableByteChannel; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import io.netty.buffer.ByteBuf; import io.netty.buffer.PooledByteBufAllocator; @@ -688,6 +692,189 @@ void outputStreamPublisherClosed(DataBufferFactory bufferFactory) throws Interru latch.await(); } + + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 3, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize2(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 3, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize3(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 3, 12, 1, List.of("foo", "bar", "baz"), List.of("foobarbaz")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize4(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 3, 1, 1, List.of("foo", "bar", "baz"), List.of("f", "o", "o", "b", "a", "r", "b", "a", "z")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize5(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 3, 2, 1, List.of("foo", "bar", "baz"), List.of("fo", "ob", "ar", "ba", "z")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize6(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 1, 3, 1, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberChunkSize7(DataBufferFactory bufferFactory) { + genericInputStreamSubscriberTest(bufferFactory, 1, 3, 64, List.of("foo", "bar", "baz"), List.of("foo", "bar", "baz")); + } + + void genericInputStreamSubscriberTest(DataBufferFactory bufferFactory, int writeChunkSize, int readChunkSize, int bufferSize, List input, List expectedOutput) { + super.bufferFactory = bufferFactory; + + Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> { + try { + for (String word : input) { + outputStream.write(word.getBytes(StandardCharsets.UTF_8)); + } + } + catch (IOException ex) { + fail(ex.getMessage(), ex); + } + }, super.bufferFactory, Executors.newSingleThreadExecutor(), writeChunkSize); + + + + byte[] chunk = new byte[readChunkSize]; + ArrayList words = new ArrayList<>(); + + try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, bufferSize)) { + int read; + while((read = inputStream.read(chunk)) > -1) { + String word = new String(chunk, 0, read, StandardCharsets.UTF_8); + words.add(word); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + assertThat(words).containsExactlyElementsOf(expectedOutput); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberError(DataBufferFactory bufferFactory) throws InterruptedException { + super.bufferFactory = bufferFactory; + + var input = List.of("foo ", "bar ", "baz"); + + Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> { + try { + for (String word : input) { + outputStream.write(word.getBytes(StandardCharsets.UTF_8)); + } + throw new RuntimeException("boom"); + } + catch (IOException ex) { + fail(ex.getMessage(), ex); + } + }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1); + + + RuntimeException error = null; + byte[] chunk = new byte[4]; + ArrayList words = new ArrayList<>(); + + try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) { + int read; + while((read = inputStream.read(chunk)) > -1) { + String word = new String(chunk, 0, read, StandardCharsets.UTF_8); + words.add(word); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + catch (RuntimeException e) { + error = e; + } + assertThat(words).containsExactlyElementsOf(List.of("foo ", "bar ", "baz")); + assertThat(error).hasMessage("boom"); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberMixedReadMode(DataBufferFactory bufferFactory) throws InterruptedException { + super.bufferFactory = bufferFactory; + + var input = List.of("foo ", "bar ", "baz"); + + Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> { + try { + for (String word : input) { + outputStream.write(word.getBytes(StandardCharsets.UTF_8)); + } + } + catch (IOException ex) { + fail(ex.getMessage(), ex); + } + }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1); + + + byte[] chunk = new byte[3]; + ArrayList words = new ArrayList<>(); + + try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, 1)) { + words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8)); + assertThat(inputStream.read()).isEqualTo(' ' & 0xFF); + words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8)); + assertThat(inputStream.read()).isEqualTo(' ' & 0xFF); + words.add(new String(chunk,0, inputStream.read(chunk), StandardCharsets.UTF_8)); + assertThat(inputStream.read()).isEqualTo(-1); + } + catch (IOException e) { + throw new RuntimeException(e); + } + assertThat(words).containsExactlyElementsOf(List.of("foo", "bar", "baz")); + } + + @ParameterizedDataBufferAllocatingTest + void inputStreamSubscriberClose(DataBufferFactory bufferFactory) throws InterruptedException { + for (int i = 1; i < 100; i++) { + CountDownLatch latch = new CountDownLatch(1); + super.bufferFactory = bufferFactory; + + var input = List.of("foo", "bar", "baz"); + + Publisher publisher = DataBufferUtils.outputStreamPublisher(outputStream -> { + try { + assertThatIOException() + .isThrownBy(() -> { + for (String word : input) { + outputStream.write(word.getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + } + }) + .withMessage("Subscription has been terminated"); + } finally { + latch.countDown(); + } + }, super.bufferFactory, Executors.newSingleThreadExecutor(), 1); + + + byte[] chunk = new byte[3]; + ArrayList words = new ArrayList<>(); + + try (InputStream inputStream = DataBufferUtils.subscribeAsInputStream(publisher, ThreadLocalRandom.current().nextInt(1, 4))) { + inputStream.read(chunk); + String word = new String(chunk, StandardCharsets.UTF_8); + words.add(word); + } catch (IOException e) { + throw new RuntimeException(e); + } + assertThat(words).containsExactlyElementsOf(List.of("foo")); + latch.await(); + } + } + @ParameterizedDataBufferAllocatingTest void readAndWriteByteChannel(DataBufferFactory bufferFactory) throws Exception { super.bufferFactory = bufferFactory; diff --git a/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java b/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java new file mode 100644 index 000000000000..606527044fe0 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/InputStreamSubscriber.java @@ -0,0 +1,405 @@ +package org.springframework.http.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import reactor.core.Exceptions; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ConcurrentModificationException; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Bridges between {@link Flow.Publisher Flow.Publisher<T>} and {@link InputStream}. + * + *

Note that this class has a near duplicate in + * {@link org.springframework.core.io.buffer.InputStreamSubscriber}. + * + * @author Oleh Dokuka + * @since 6.1 + */ +final class InputStreamSubscriber extends InputStream implements Flow.Subscriber { + + private static final Log logger = LogFactory.getLog(InputStreamSubscriber.class); + + static final Object READY = new Object(); + static final byte[] DONE = new byte[0]; + static final byte[] CLOSED = new byte[0]; + + final int prefetch; + final int limit; + final Function mapper; + final Consumer onDiscardHandler; + final ReentrantLock lock; + final Queue queue; + + final AtomicReference parkedThread = new AtomicReference<>(); + final AtomicInteger workAmount = new AtomicInteger(); + + volatile boolean closed; + int consumed; + + @Nullable + byte[] available; + int position; + + @Nullable + Flow.Subscription s; + boolean done; + @Nullable + Throwable error; + + private InputStreamSubscriber(Function mapper, Consumer onDiscardHandler, int prefetch) { + this.prefetch = prefetch; + this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : prefetch - (prefetch >> 2); + this.mapper = mapper; + this.onDiscardHandler = onDiscardHandler; + this.queue = new ArrayBlockingQueue<>(prefetch); + this.lock = new ReentrantLock(false); + } + + /** + * Subscribes to given {@link Publisher} and returns subscription + * as {@link InputStream} that allows reading all propagated {@link DataBuffer} messages via its imperative API. + * Given the {@link InputStream} implementation buffers messages as per configuration. + * The returned {@link InputStream} is considered terminated when the given {@link Publisher} signaled one of the + * terminal signal ({@link Subscriber#onComplete() or {@link Subscriber#onError(Throwable)}}) + * and all the stored {@link DataBuffer} polled from the internal buffer. + * The returned {@link InputStream} will call {@link Subscription#cancel()} and release all stored {@link DataBuffer} + * when {@link InputStream#close()} is called. + *

+ * Note: The implementation of the returned {@link InputStream} disallow concurrent call on + * any of the {@link InputStream#read} methods + *

+ * Note: {@link Subscription#request(long)} happens eagerly for the first time upon subscription + * and then repeats every time {@code bufferSize - (bufferSize >> 2)} consumed + * + * @param publisher the source of {@link DataBuffer} which should be represented as an {@link InputStream} + * @param mapper function to transform <T> element to {@code byte[]}. Note, <T> should be released during the mapping if needed. + * @param onDiscardHandler <T> element consumer if returned {@link InputStream} is closed prematurely. + * @param bufferSize the maximum amount of <T> elements prefetched in advance and stored inside {@link InputStream} + * @return an {@link InputStream} instance representing given {@link Publisher} messages + */ + public static InputStream subscribeTo(Flow.Publisher publisher, Function mapper, Consumer onDiscardHandler, int bufferSize) { + + Assert.notNull(publisher, "Flow.Publisher must not be null"); + Assert.notNull(mapper, "mapper must not be null"); + Assert.notNull(onDiscardHandler, "onDiscardHandler must not be null"); + Assert.isTrue(bufferSize > 0, "bufferSize must be greater than 0"); + + InputStreamSubscriber iss = new InputStreamSubscriber<>(mapper, onDiscardHandler, bufferSize); + publisher.subscribe(iss); + return iss; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (this.s != null) { + subscription.cancel(); + return; + } + + this.s = subscription; + subscription.request(prefetch == Integer.MAX_VALUE ? Long.MAX_VALUE : prefetch); + } + + @Override + public void onNext(T t) { + Assert.notNull(t, "T value must not be null"); + + if (this.done) { + discard(t); + return; + } + + if (!queue.offer(t)) { + discard(t); + error = new RuntimeException("Buffer overflow"); + done = true; + } + + int previousWorkState = addWork(); + if (previousWorkState == Integer.MIN_VALUE) { + T value = queue.poll(); + if (value != null) { + discard(value); + } + return; + } + + if (previousWorkState == 0) { + resume(); + } + } + + @Override + public void onError(Throwable throwable) { + if (this.done) { + return; + } + this.error = throwable; + this.done = true; + + if (addWork() == 0) { + resume(); + } + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (addWork() == 0) { + resume(); + } + } + + int addWork() { + for (;;) { + int produced = this.workAmount.getPlain(); + + if (produced == Integer.MIN_VALUE) { + return Integer.MIN_VALUE; + } + + int nextProduced = produced == Integer.MAX_VALUE ? 1 : produced + 1; + + + if (workAmount.weakCompareAndSetRelease(produced, nextProduced)) { + return produced; + } + } + } + + @Override + public int read() throws IOException { + if (!lock.tryLock()) { + if (this.closed) { + return -1; + } + throw new ConcurrentModificationException("concurrent access is disallowed"); + } + + try { + byte[] bytes = getBytesOrAwait(); + + if (bytes == DONE) { + this.closed = true; + cleanAndFinalize(); + if (this.error == null) { + return -1; + } + else { + throw Exceptions.propagate(error); + } + } else if (bytes == CLOSED) { + cleanAndFinalize(); + return -1; + } + + return bytes[this.position++] & 0xFF; + } + catch (Throwable t) { + this.closed = true; + this.s.cancel(); + cleanAndFinalize(); + throw Exceptions.propagate(t); + } + finally { + lock.unlock(); + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (len == 0) { + return 0; + } + + if (!lock.tryLock()) { + if (this.closed) { + return -1; + } + throw new ConcurrentModificationException("concurrent access is disallowed"); + } + + try { + for (int j = 0; j < len;) { + byte[] bytes = getBytesOrAwait(); + + if (bytes == DONE) { + this.closed = true; + cleanAndFinalize(); + if (this.error == null) { + return j == 0 ? -1 : j; + } + else { + throw Exceptions.propagate(error); + } + } else if (bytes == CLOSED) { + this.s.cancel(); + cleanAndFinalize(); + return -1; + } + + int i = this.position; + for (; i < bytes.length && j < len; i++, j++) { + b[off + j] = bytes[i]; + } + this.position = i; + } + + return len; + } + catch (Throwable t) { + this.closed = true; + this.s.cancel(); + cleanAndFinalize(); + throw Exceptions.propagate(t); + } + finally { + lock.unlock(); + } + } + + byte[] getBytesOrAwait() { + if (this.available == null || this.available.length - this.position == 0) { + this.available = null; + + int actualWorkAmount = this.workAmount.getAcquire(); + for (;;) { + if (this.closed) { + return CLOSED; + } + + boolean d = this.done; + T t = this.queue.poll(); + if (t != null) { + int consumed = ++this.consumed; + this.position = 0; + this.available = Objects.requireNonNull(this.mapper.apply(t)); + if (consumed == this.limit) { + this.consumed = 0; + this.s.request(this.limit); + } + break; + } + + if (d) { + return DONE; + } + + actualWorkAmount = workAmount.addAndGet(-actualWorkAmount); + if (actualWorkAmount == 0) { + await(); + } + } + } + + return this.available; + } + + void cleanAndFinalize() { + this.available = null; + + for (;;) { + int workAmount = this.workAmount.getPlain(); + T value; + + while((value = queue.poll()) != null) { + discard(value); + } + + if (this.workAmount.weakCompareAndSetPlain(workAmount, Integer.MIN_VALUE)) { + return; + } + } + } + + void discard(T value) { + try { + this.onDiscardHandler.accept(value); + } catch (Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to release " + value.getClass().getSimpleName() + ": " + value, t); + } + } + } + + @Override + public void close() throws IOException { + if (this.closed) { + return; + } + + this.closed = true; + + if (!this.lock.tryLock()) { + if (addWork() == 0) { + resume(); + } + return; + } + + try { + this.s.cancel(); + cleanAndFinalize(); + } + finally { + this.lock.unlock(); + } + } + + private void await() { + Thread toUnpark = Thread.currentThread(); + + while (true) { + Object current = this.parkedThread.get(); + if (current == READY) { + break; + } + + if (current != null && current != toUnpark) { + throw new IllegalStateException("Only one (Virtual)Thread can await!"); + } + + if (parkedThread.compareAndSet( null, toUnpark)) { + LockSupport.park(); + // we don't just break here because park() can wake up spuriously + // if we got a proper resume, get() == READY and the loop will quit above + } + } + // clear the resume indicator so that the next await call will park without a resume() + this.parkedThread.lazySet(null); + } + + private void resume() { + if (this.parkedThread != READY) { + Object old = parkedThread.getAndSet(READY); + if (old != READY) { + LockSupport.unpark((Thread)old); + } + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java b/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java new file mode 100644 index 000000000000..9dd635dffba2 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/InputStreamSubscriberTests.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.FlowAdapters; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIOException; + +/** + * @author Arjen Poutsma + * @author Oleh Dokuka + */ +class InputStreamSubscriberTests { + + private static final byte[] FOO = "foo".getBytes(StandardCharsets.UTF_8); + + private static final byte[] BAR = "bar".getBytes(StandardCharsets.UTF_8); + + private static final byte[] BAZ = "baz".getBytes(StandardCharsets.UTF_8); + + + private final Executor executor = Executors.newSingleThreadExecutor(); + + private final OutputStreamPublisher.ByteMapper byteMapper = + new OutputStreamPublisher.ByteMapper<>() { + @Override + public byte[] map(int b) { + return new byte[]{(byte) b}; + } + + @Override + public byte[] map(byte[] b, int off, int len) { + byte[] result = new byte[len]; + System.arraycopy(b, off, result, 0, len); + return result; + } + }; + + + @Test + void basic() { + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + outputStream.write(FOO); + outputStream.write(BAR); + outputStream.write(BAZ); + }, this.byteMapper, this.executor); + Flux flux = toString(flowPublisher); + + StepVerifier.create(flux) + .assertNext(s -> assertThat(s).isEqualTo("foobarbaz")) + .verifyComplete(); + } + + @Test + void flush() { + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + outputStream.write(FOO); + outputStream.flush(); + outputStream.write(BAR); + outputStream.flush(); + outputStream.write(BAZ); + outputStream.flush(); + }, this.byteMapper, this.executor); + Flux flux = toString(flowPublisher); + + try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) { + byte[] chunk = new byte[3]; + + assertThat(is.read(chunk)).isEqualTo(3); + assertThat(chunk).containsExactly(FOO); + assertThat(is.read(chunk)).isEqualTo(3); + assertThat(chunk).containsExactly(BAR); + assertThat(is.read(chunk)).isEqualTo(3); + assertThat(chunk).containsExactly(BAZ); + assertThat(is.read(chunk)).isEqualTo(-1); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + void chunkSize() { + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + outputStream.write(FOO); + outputStream.write(BAR); + outputStream.write(BAZ); + }, this.byteMapper, this.executor, 2); + Flux flux = toString(flowPublisher); + + try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), (ignore) -> {}, 1)) { + StringBuilder stringBuilder = new StringBuilder(); + byte[] chunk = new byte[3]; + + + stringBuilder + .append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(3); + stringBuilder + .append(new String(chunk, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(3); + stringBuilder + .append(new String(chunk, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(2); + stringBuilder + .append(new String(chunk,0, 2, StandardCharsets.UTF_8)); + assertThat(is.read()).isEqualTo(-1); + + assertThat(stringBuilder.toString()).isEqualTo("foobarbaz"); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + void cancel() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + assertThatIOException() + .isThrownBy(() -> { + outputStream.write(FOO); + outputStream.flush(); + outputStream.write(BAR); + outputStream.flush(); + outputStream.write(BAZ); + outputStream.flush(); + }) + .withMessage("Subscription has been terminated"); + latch.countDown(); + + }, this.byteMapper, this.executor); + Flux flux = toString(flowPublisher); + List discarded = new ArrayList<>(); + + try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), discarded::add, 1)) { + byte[] chunk = new byte[3]; + + assertThat(is.read(chunk)).isEqualTo(3); + assertThat(chunk).containsExactly(FOO); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + latch.await(); + + assertThat(discarded).containsExactly("bar"); + } + + @Test + void closed() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + OutputStreamWriter writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8); + writer.write("foo"); + writer.close(); + assertThatIOException().isThrownBy(() -> writer.write("bar")) + .withMessage("Stream closed"); + latch.countDown(); + }, this.byteMapper, this.executor); + Flux flux = toString(flowPublisher); + + try (InputStream is = InputStreamSubscriber.subscribeTo(FlowAdapters.toFlowPublisher(flux), (s) -> s.getBytes(StandardCharsets.UTF_8), ig -> {}, 1)) { + byte[] chunk = new byte[3]; + + assertThat(is.read(chunk)).isEqualTo(3); + assertThat(chunk).containsExactly(FOO); + + assertThat(is.read(chunk)).isEqualTo(-1); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + latch.await(); + } + + @Test + void mapperThrowsException() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + outputStream.write(FOO); + outputStream.flush(); + assertThatIOException().isThrownBy(() -> { + outputStream.write(BAR); + outputStream.flush(); + }).withMessage("Subscription has been terminated"); + latch.countDown(); + }, this.byteMapper, this.executor); + Throwable ex = null; + + StringBuilder stringBuilder = new StringBuilder(); + try (InputStream is = InputStreamSubscriber.subscribeTo(flowPublisher, (s) -> { + throw new NullPointerException("boom"); + }, ig -> {}, 1)) { + byte[] chunk = new byte[3]; + + stringBuilder + .append(new String(new byte[]{(byte)is.read()}, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(3); + stringBuilder + .append(new String(chunk, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(3); + stringBuilder + .append(new String(chunk, StandardCharsets.UTF_8)); + assertThat(is.read(chunk)).isEqualTo(2); + stringBuilder + .append(new String(chunk,0, 2, StandardCharsets.UTF_8)); + assertThat(is.read()).isEqualTo(-1); + } + catch (Throwable e) { + ex = e; + } + + latch.await(); + + assertThat(stringBuilder.toString()).isEqualTo(""); + assertThat(ex).hasMessage("boom"); + } + + private static Flux toString(Flow.Publisher flowPublisher) { + return Flux.from(FlowAdapters.toPublisher(flowPublisher)) + .map(bytes -> new String(bytes, StandardCharsets.UTF_8)); + } + +}