From c89d16abadf66d8915d44b5ac84e35bc93ff8568 Mon Sep 17 00:00:00 2001 From: Tanguy Leroux Date: Mon, 20 Jul 2020 14:57:56 +0200 Subject: [PATCH 1/2] Add CachedBlobContainer with CopyOnReadInputStream --- .../blobstore/cache/CachedBlobContainer.java | 120 ++++++++++++++++++ .../cache/CachedBlobContainerTests.java | 108 ++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java create mode 100644 x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java new file mode 100644 index 0000000000000..8e88d9eced266 --- /dev/null +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.blobstore.cache; + +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.support.FilterBlobContainer; +import org.elasticsearch.common.bytes.PagedBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.util.ByteArray; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +public class CachedBlobContainer extends FilterBlobContainer { + + protected static final int DEFAULT_BYTE_ARRAY_SIZE = 1 << 14; + + public CachedBlobContainer(BlobContainer delegate) { + super(delegate); + } + + @Override + protected BlobContainer wrapChild(BlobContainer child) { + return new CachedBlobContainer(child); + } + + /** + * A {@link FilterInputStream} that copies over all the bytes read from the original input stream to a given {@link ByteArray}. The + * number of bytes copied cannot exceed the size of the {@link ByteArray}. + */ + static class CopyOnReadInputStream extends FilterInputStream { + + private final AtomicBoolean closed; + private final ByteArray bytes; + + private long count; + private long mark; + + protected CopyOnReadInputStream(InputStream in, ByteArray byteArray) { + super(in); + this.bytes = Objects.requireNonNull(byteArray); + this.closed = new AtomicBoolean(false); + } + + long getCount() { + return count; + } + + public int read() throws IOException { + final int result = super.read(); + if (result != -1) { + if (count < bytes.size()) { + bytes.set(count, (byte) result); + } + count++; + } + return result; + } + + public int read(byte[] b, int off, int len) throws IOException { + final int result = super.read(b, off, len); + if (result != -1) { + if (count < bytes.size()) { + bytes.set(count, b, off, Math.toIntExact(Math.min(bytes.size() - count, result))); + } + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + final long skip = super.skip(n); + if (skip > 0L) { + count += skip; + } + return skip; + } + + @Override + public synchronized void mark(int readlimit) { + super.mark(readlimit); + mark = count; + } + + @Override + public synchronized void reset() throws IOException { + super.reset(); + count = mark; + } + + protected void closeInternal(final ReleasableBytesReference releasable) { + releasable.close(); + } + + @Override + public final void close() throws IOException { + if (closed.compareAndSet(false, true)) { + boolean success = false; + try { + super.close(); + final PagedBytesReference reference = new PagedBytesReference(bytes, Math.toIntExact(Math.min(count, bytes.size()))); + closeInternal(new ReleasableBytesReference(reference, bytes)); + success = true; + } finally { + if (success == false) { + bytes.close(); + } + } + } + } + } +} diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java new file mode 100644 index 0000000000000..f5c7ebf9a64e5 --- /dev/null +++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.blobstore.cache; + +import org.elasticsearch.blobstore.cache.CachedBlobContainer.CopyOnReadInputStream; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.ByteArray; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +import static org.elasticsearch.blobstore.cache.CachedBlobContainer.DEFAULT_BYTE_ARRAY_SIZE; +import static org.hamcrest.Matchers.equalTo; + +public class CachedBlobContainerTests extends ESTestCase { + + private final MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + + public void testCopyOnReadInputStreamDoesNotCopyMoreThanByteArraySize() throws Exception { + final byte[] expected = randomByteArray(); + + final ByteArray byteArray = bigArrays.newByteArray(randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE)); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(expected), byteArray) { + @Override + protected void closeInternal(ReleasableBytesReference releasable) { + assertThat(getCount(), equalTo((long) expected.length)); + assertArrayEquals( + Arrays.copyOfRange(expected, 0, Math.toIntExact(Math.min(expected.length, byteArray.size()))), + BytesReference.toBytes(releasable) + ); + super.closeInternal(releasable); + } + }; + randomReads(stream, expected.length); + stream.close(); + } + + public void testCopyOnReadInputStream() throws Exception { + final byte[] expected = randomByteArray(); + final ByteArray byteArray = bigArrays.newByteArray(DEFAULT_BYTE_ARRAY_SIZE); + + final int maxBytesToRead = randomIntBetween(0, Math.toIntExact(Math.min(expected.length, byteArray.size()))); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(expected), byteArray) { + @Override + protected void closeInternal(ReleasableBytesReference releasable) { + assertThat(getCount(), equalTo((long) maxBytesToRead)); + assertArrayEquals(Arrays.copyOfRange(expected, 0, maxBytesToRead), BytesReference.toBytes(releasable)); + super.closeInternal(releasable); + } + }; + randomReads(stream, maxBytesToRead); + stream.close(); + } + + private static byte[] randomByteArray() { + return randomByteArrayOfLength(randomIntBetween(0, frequently() ? 512 : 1 << 20)); // rarely up to 1mb; + } + + private void randomReads(final InputStream stream, final int maxBytesToRead) throws IOException { + int remaining = maxBytesToRead; + while (remaining > 0) { + int read; + switch (randomInt(3)) { + case 0: // single byte read + read = stream.read(); + if (read != -1) { + remaining--; + } + break; + case 1: // buffered read with fixed buffer offset/length + read = stream.read(new byte[randomIntBetween(1, remaining)]); + if (read != -1) { + remaining -= read; + } + break; + case 2: // buffered read with random buffer offset/length + final byte[] tmp = new byte[randomIntBetween(1, remaining)]; + final int off = randomIntBetween(0, tmp.length - 1); + read = stream.read(tmp, off, randomIntBetween(1, Math.min(1, tmp.length - off))); + if (read != -1) { + remaining -= read; + } + break; + + case 3: // mark & reset with intermediate skip() + final int toSkip = randomIntBetween(1, remaining); + stream.mark(toSkip); + assertThat(stream.skip(toSkip), equalTo((long) toSkip)); + stream.reset(); + break; + default: + fail("Unsupported test condition in " + getTestName()); + } + } + } +} From 17448881b4d00364efe5c8cbb9dc816e3b1d80cf Mon Sep 17 00:00:00 2001 From: Tanguy Leroux Date: Tue, 21 Jul 2020 17:15:33 +0200 Subject: [PATCH 2/2] handle failures + move to listener --- .../blobstore/cache/CachedBlobContainer.java | 42 +++-- .../cache/CachedBlobContainerTests.java | 148 +++++++++++++++--- 2 files changed, 152 insertions(+), 38 deletions(-) diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java index 8e88d9eced266..7fafea4990506 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java @@ -6,6 +6,8 @@ package org.elasticsearch.blobstore.cache; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.support.FilterBlobContainer; import org.elasticsearch.common.bytes.PagedBytesReference; @@ -37,24 +39,33 @@ protected BlobContainer wrapChild(BlobContainer child) { */ static class CopyOnReadInputStream extends FilterInputStream { + private final ActionListener listener; private final AtomicBoolean closed; private final ByteArray bytes; + private IOException failure; private long count; private long mark; - protected CopyOnReadInputStream(InputStream in, ByteArray byteArray) { + protected CopyOnReadInputStream(InputStream in, ByteArray byteArray, ActionListener listener) { super(in); + this.listener = Objects.requireNonNull(listener); this.bytes = Objects.requireNonNull(byteArray); this.closed = new AtomicBoolean(false); } - long getCount() { - return count; + private T handleFailure(CheckedSupplier supplier) throws IOException { + try { + return supplier.get(); + } catch (IOException e) { + assert failure == null; + failure = e; + throw e; + } } public int read() throws IOException { - final int result = super.read(); + final int result = handleFailure(super::read); if (result != -1) { if (count < bytes.size()) { bytes.set(count, (byte) result); @@ -65,7 +76,7 @@ public int read() throws IOException { } public int read(byte[] b, int off, int len) throws IOException { - final int result = super.read(b, off, len); + final int result = handleFailure(() -> super.read(b, off, len)); if (result != -1) { if (count < bytes.size()) { bytes.set(count, b, off, Math.toIntExact(Math.min(bytes.size() - count, result))); @@ -77,7 +88,7 @@ public int read(byte[] b, int off, int len) throws IOException { @Override public long skip(long n) throws IOException { - final long skip = super.skip(n); + final long skip = handleFailure(() -> super.skip(n)); if (skip > 0L) { count += skip; } @@ -92,23 +103,26 @@ public synchronized void mark(int readlimit) { @Override public synchronized void reset() throws IOException { - super.reset(); + handleFailure(() -> { + super.reset(); + return null; + }); count = mark; } - protected void closeInternal(final ReleasableBytesReference releasable) { - releasable.close(); - } - @Override public final void close() throws IOException { if (closed.compareAndSet(false, true)) { boolean success = false; try { super.close(); - final PagedBytesReference reference = new PagedBytesReference(bytes, Math.toIntExact(Math.min(count, bytes.size()))); - closeInternal(new ReleasableBytesReference(reference, bytes)); - success = true; + if (failure == null || bytes.size() <= count) { + PagedBytesReference reference = new PagedBytesReference(bytes, Math.toIntExact(Math.min(count, bytes.size()))); + listener.onResponse(new ReleasableBytesReference(reference, bytes)); + success = true; + } else { + listener.onFailure(failure); + } } finally { if (success == false) { bytes.close(); diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java index f5c7ebf9a64e5..e049a9dc9947b 100644 --- a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java +++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java @@ -6,6 +6,8 @@ package org.elasticsearch.blobstore.cache; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.blobstore.cache.CachedBlobContainer.CopyOnReadInputStream; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -17,55 +19,153 @@ import org.elasticsearch.test.ESTestCase; import java.io.ByteArrayInputStream; +import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import static org.elasticsearch.blobstore.cache.CachedBlobContainer.DEFAULT_BYTE_ARRAY_SIZE; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; public class CachedBlobContainerTests extends ESTestCase { private final MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); public void testCopyOnReadInputStreamDoesNotCopyMoreThanByteArraySize() throws Exception { - final byte[] expected = randomByteArray(); + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = randomByteArray(); final ByteArray byteArray = bigArrays.newByteArray(randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE)); - final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(expected), byteArray) { - @Override - protected void closeInternal(ReleasableBytesReference releasable) { - assertThat(getCount(), equalTo((long) expected.length)); - assertArrayEquals( - Arrays.copyOfRange(expected, 0, Math.toIntExact(Math.min(expected.length, byteArray.size()))), - BytesReference.toBytes(releasable) - ); - super.closeInternal(releasable); - } - }; - randomReads(stream, expected.length); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener); + randomReads(stream, blobContent.length); stream.close(); + + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertThat(releasable.length(), equalTo(Math.toIntExact(Math.min(blobContent.length, byteArray.size())))); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable)); + releasable.close(); + + final Exception failure = onFailure.get(); + assertThat(failure, nullValue()); } public void testCopyOnReadInputStream() throws Exception { - final byte[] expected = randomByteArray(); + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = randomByteArray(); final ByteArray byteArray = bigArrays.newByteArray(DEFAULT_BYTE_ARRAY_SIZE); - final int maxBytesToRead = randomIntBetween(0, Math.toIntExact(Math.min(expected.length, byteArray.size()))); - final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(expected), byteArray) { + final int maxBytesToRead = randomIntBetween(0, blobContent.length); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener); + randomReads(stream, maxBytesToRead); + stream.close(); + + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertThat(releasable.length(), equalTo((int) Math.min(maxBytesToRead, byteArray.size()))); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable)); + releasable.close(); + + final Exception failure = onFailure.get(); + assertThat(failure, nullValue()); + } + + public void testCopyOnReadWithFailure() throws Exception { + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = new byte[0]; + randomByteArray(); + + // InputStream that throws an IOException once byte at position N is read/skipped + final int failAfterNBytesRead = randomIntBetween(0, Math.max(0, blobContent.length - 1)); + final InputStream erroneousStream = new FilterInputStream(new ByteArrayInputStream(blobContent)) { + + long bytesRead; + long mark; + + void canReadMoreBytes() throws IOException { + if (failAfterNBytesRead <= bytesRead) { + throw new IOException("Cannot read more bytes"); + } + } + @Override - protected void closeInternal(ReleasableBytesReference releasable) { - assertThat(getCount(), equalTo((long) maxBytesToRead)); - assertArrayEquals(Arrays.copyOfRange(expected, 0, maxBytesToRead), BytesReference.toBytes(releasable)); - super.closeInternal(releasable); + public int read() throws IOException { + canReadMoreBytes(); + final int read = super.read(); + if (read != -1) { + bytesRead++; + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + canReadMoreBytes(); + final int read = super.read(b, off, Math.min(len, Math.toIntExact(failAfterNBytesRead - bytesRead))); + if (read != -1) { + bytesRead += read; + } + return read; + } + + @Override + public long skip(long n) throws IOException { + canReadMoreBytes(); + final long skipped = super.skip(Math.min(n, Math.toIntExact(failAfterNBytesRead - bytesRead))); + if (skipped > 0L) { + bytesRead += skipped; + } + return skipped; + } + + @Override + public synchronized void reset() throws IOException { + super.reset(); + bytesRead = mark; + } + + @Override + public synchronized void mark(int readlimit) { + super.mark(readlimit); + mark = bytesRead; } }; - randomReads(stream, maxBytesToRead); - stream.close(); + + final int byteSize = randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE); + try (InputStream stream = new CopyOnReadInputStream(erroneousStream, bigArrays.newByteArray(byteSize), listener)) { + IOException exception = expectThrows(IOException.class, () -> randomReads(stream, Math.max(1, blobContent.length))); + assertThat(exception.getMessage(), containsString("Cannot read more bytes")); + } + + if (failAfterNBytesRead < byteSize) { + final Exception failure = onFailure.get(); + assertThat(failure, notNullValue()); + assertThat(failure.getMessage(), containsString("Cannot read more bytes")); + assertThat(onSuccess.get(), nullValue()); + + } else { + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, byteSize), BytesReference.toBytes(releasable)); + assertThat(onFailure.get(), nullValue()); + releasable.close(); + } } private static byte[] randomByteArray() { - return randomByteArrayOfLength(randomIntBetween(0, frequently() ? 512 : 1 << 20)); // rarely up to 1mb; + return randomByteArrayOfLength(randomIntBetween(0, frequently() ? DEFAULT_BYTE_ARRAY_SIZE : 1 << 20)); // rarely up to 1mb; } private void randomReads(final InputStream stream, final int maxBytesToRead) throws IOException { @@ -97,7 +197,7 @@ private void randomReads(final InputStream stream, final int maxBytesToRead) thr case 3: // mark & reset with intermediate skip() final int toSkip = randomIntBetween(1, remaining); stream.mark(toSkip); - assertThat(stream.skip(toSkip), equalTo((long) toSkip)); + stream.skip(toSkip); stream.reset(); break; default: