diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java new file mode 100644 index 0000000000000..7fafea4990506 --- /dev/null +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/blobstore/cache/CachedBlobContainer.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.blobstore.cache; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.support.FilterBlobContainer; +import org.elasticsearch.common.bytes.PagedBytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.util.ByteArray; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +public class CachedBlobContainer extends FilterBlobContainer { + + protected static final int DEFAULT_BYTE_ARRAY_SIZE = 1 << 14; + + public CachedBlobContainer(BlobContainer delegate) { + super(delegate); + } + + @Override + protected BlobContainer wrapChild(BlobContainer child) { + return new CachedBlobContainer(child); + } + + /** + * A {@link FilterInputStream} that copies over all the bytes read from the original input stream to a given {@link ByteArray}. The + * number of bytes copied cannot exceed the size of the {@link ByteArray}. + */ + static class CopyOnReadInputStream extends FilterInputStream { + + private final ActionListener listener; + private final AtomicBoolean closed; + private final ByteArray bytes; + + private IOException failure; + private long count; + private long mark; + + protected CopyOnReadInputStream(InputStream in, ByteArray byteArray, ActionListener listener) { + super(in); + this.listener = Objects.requireNonNull(listener); + this.bytes = Objects.requireNonNull(byteArray); + this.closed = new AtomicBoolean(false); + } + + private T handleFailure(CheckedSupplier supplier) throws IOException { + try { + return supplier.get(); + } catch (IOException e) { + assert failure == null; + failure = e; + throw e; + } + } + + public int read() throws IOException { + final int result = handleFailure(super::read); + if (result != -1) { + if (count < bytes.size()) { + bytes.set(count, (byte) result); + } + count++; + } + return result; + } + + public int read(byte[] b, int off, int len) throws IOException { + final int result = handleFailure(() -> super.read(b, off, len)); + if (result != -1) { + if (count < bytes.size()) { + bytes.set(count, b, off, Math.toIntExact(Math.min(bytes.size() - count, result))); + } + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + final long skip = handleFailure(() -> super.skip(n)); + if (skip > 0L) { + count += skip; + } + return skip; + } + + @Override + public synchronized void mark(int readlimit) { + super.mark(readlimit); + mark = count; + } + + @Override + public synchronized void reset() throws IOException { + handleFailure(() -> { + super.reset(); + return null; + }); + count = mark; + } + + @Override + public final void close() throws IOException { + if (closed.compareAndSet(false, true)) { + boolean success = false; + try { + super.close(); + if (failure == null || bytes.size() <= count) { + PagedBytesReference reference = new PagedBytesReference(bytes, Math.toIntExact(Math.min(count, bytes.size()))); + listener.onResponse(new ReleasableBytesReference(reference, bytes)); + success = true; + } else { + listener.onFailure(failure); + } + } finally { + if (success == false) { + bytes.close(); + } + } + } + } + } +} diff --git a/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java new file mode 100644 index 0000000000000..e049a9dc9947b --- /dev/null +++ b/x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/blobstore/cache/CachedBlobContainerTests.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.blobstore.cache; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.blobstore.cache.CachedBlobContainer.CopyOnReadInputStream; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.ByteArray; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.test.ESTestCase; + +import java.io.ByteArrayInputStream; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; + +import static org.elasticsearch.blobstore.cache.CachedBlobContainer.DEFAULT_BYTE_ARRAY_SIZE; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class CachedBlobContainerTests extends ESTestCase { + + private final MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + + public void testCopyOnReadInputStreamDoesNotCopyMoreThanByteArraySize() throws Exception { + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = randomByteArray(); + + final ByteArray byteArray = bigArrays.newByteArray(randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE)); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener); + randomReads(stream, blobContent.length); + stream.close(); + + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertThat(releasable.length(), equalTo(Math.toIntExact(Math.min(blobContent.length, byteArray.size())))); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable)); + releasable.close(); + + final Exception failure = onFailure.get(); + assertThat(failure, nullValue()); + } + + public void testCopyOnReadInputStream() throws Exception { + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = randomByteArray(); + final ByteArray byteArray = bigArrays.newByteArray(DEFAULT_BYTE_ARRAY_SIZE); + + final int maxBytesToRead = randomIntBetween(0, blobContent.length); + final InputStream stream = new CopyOnReadInputStream(new ByteArrayInputStream(blobContent), byteArray, listener); + randomReads(stream, maxBytesToRead); + stream.close(); + + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertThat(releasable.length(), equalTo((int) Math.min(maxBytesToRead, byteArray.size()))); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, releasable.length()), BytesReference.toBytes(releasable)); + releasable.close(); + + final Exception failure = onFailure.get(); + assertThat(failure, nullValue()); + } + + public void testCopyOnReadWithFailure() throws Exception { + final SetOnce onSuccess = new SetOnce<>(); + final SetOnce onFailure = new SetOnce<>(); + final ActionListener listener = ActionListener.wrap(onSuccess::set, onFailure::set); + + final byte[] blobContent = new byte[0]; + randomByteArray(); + + // InputStream that throws an IOException once byte at position N is read/skipped + final int failAfterNBytesRead = randomIntBetween(0, Math.max(0, blobContent.length - 1)); + final InputStream erroneousStream = new FilterInputStream(new ByteArrayInputStream(blobContent)) { + + long bytesRead; + long mark; + + void canReadMoreBytes() throws IOException { + if (failAfterNBytesRead <= bytesRead) { + throw new IOException("Cannot read more bytes"); + } + } + + @Override + public int read() throws IOException { + canReadMoreBytes(); + final int read = super.read(); + if (read != -1) { + bytesRead++; + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + canReadMoreBytes(); + final int read = super.read(b, off, Math.min(len, Math.toIntExact(failAfterNBytesRead - bytesRead))); + if (read != -1) { + bytesRead += read; + } + return read; + } + + @Override + public long skip(long n) throws IOException { + canReadMoreBytes(); + final long skipped = super.skip(Math.min(n, Math.toIntExact(failAfterNBytesRead - bytesRead))); + if (skipped > 0L) { + bytesRead += skipped; + } + return skipped; + } + + @Override + public synchronized void reset() throws IOException { + super.reset(); + bytesRead = mark; + } + + @Override + public synchronized void mark(int readlimit) { + super.mark(readlimit); + mark = bytesRead; + } + }; + + final int byteSize = randomIntBetween(0, DEFAULT_BYTE_ARRAY_SIZE); + try (InputStream stream = new CopyOnReadInputStream(erroneousStream, bigArrays.newByteArray(byteSize), listener)) { + IOException exception = expectThrows(IOException.class, () -> randomReads(stream, Math.max(1, blobContent.length))); + assertThat(exception.getMessage(), containsString("Cannot read more bytes")); + } + + if (failAfterNBytesRead < byteSize) { + final Exception failure = onFailure.get(); + assertThat(failure, notNullValue()); + assertThat(failure.getMessage(), containsString("Cannot read more bytes")); + assertThat(onSuccess.get(), nullValue()); + + } else { + final ReleasableBytesReference releasable = onSuccess.get(); + assertThat(releasable, notNullValue()); + assertArrayEquals(Arrays.copyOfRange(blobContent, 0, byteSize), BytesReference.toBytes(releasable)); + assertThat(onFailure.get(), nullValue()); + releasable.close(); + } + } + + private static byte[] randomByteArray() { + return randomByteArrayOfLength(randomIntBetween(0, frequently() ? DEFAULT_BYTE_ARRAY_SIZE : 1 << 20)); // rarely up to 1mb; + } + + private void randomReads(final InputStream stream, final int maxBytesToRead) throws IOException { + int remaining = maxBytesToRead; + while (remaining > 0) { + int read; + switch (randomInt(3)) { + case 0: // single byte read + read = stream.read(); + if (read != -1) { + remaining--; + } + break; + case 1: // buffered read with fixed buffer offset/length + read = stream.read(new byte[randomIntBetween(1, remaining)]); + if (read != -1) { + remaining -= read; + } + break; + case 2: // buffered read with random buffer offset/length + final byte[] tmp = new byte[randomIntBetween(1, remaining)]; + final int off = randomIntBetween(0, tmp.length - 1); + read = stream.read(tmp, off, randomIntBetween(1, Math.min(1, tmp.length - off))); + if (read != -1) { + remaining -= read; + } + break; + + case 3: // mark & reset with intermediate skip() + final int toSkip = randomIntBetween(1, remaining); + stream.mark(toSkip); + stream.skip(toSkip); + stream.reset(); + break; + default: + fail("Unsupported test condition in " + getTestName()); + } + } + } +}