diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java index 27e86e3f227a7..5e0dfc3ba780b 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.common.io.stream; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Constants; import org.elasticsearch.common.bytes.BytesArray; @@ -25,7 +26,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; -import org.elasticsearch.transport.BytesRefRecycler; +import org.junit.After; import java.io.EOFException; import java.io.IOException; @@ -34,6 +35,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -42,6 +44,7 @@ import java.util.Objects; import java.util.TreeMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -60,11 +63,65 @@ import static org.hamcrest.Matchers.nullValue; /** - * Tests for {@link StreamOutput}. + * Tests for {@link RecyclerBytesStreamOutput}. */ public class RecyclerBytesStreamOutputTests extends ESTestCase { - private final Recycler recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE); + private final AtomicInteger activePageCount = new AtomicInteger(); + + private final Recycler recycler = new Recycler<>() { + + @Override + public V obtain() { + activePageCount.incrementAndGet(); + final var bufferPool = randomByteArrayOfLength(between(pageSize(), pageSize() * 4)); + final var offset = randomBoolean() + // align to a multiple of pageSize(), which is the usual case in production + ? between(0, bufferPool.length / pageSize() - 1) * pageSize() + // no alignment, to detect cases where alignment is assumed + : between(0, bufferPool.length - pageSize()); + final var bytesRef = new BytesRef(bufferPool, offset, pageSize()); + + final var bufferPoolCopy = ArrayUtil.copyArray(bufferPool); // keep for a later check for out-of-bounds writes + final var dummyByte = randomByte(); + Arrays.fill(bufferPoolCopy, offset, offset + pageSize(), dummyByte); + + return new V<>() { + @Override + public BytesRef v() { + return bytesRef; + } + + @Override + public boolean isRecycled() { + throw new AssertionError("shouldn't matter"); + } + + @Override + public void close() { + // page must not be changed + assertSame(bufferPool, bytesRef.bytes); + assertEquals(offset, bytesRef.offset); + assertEquals(pageSize(), bytesRef.length); + + Arrays.fill(bufferPool, offset, offset + pageSize(), dummyByte); // overwrite buffer contents to detect use-after-free + assertArrayEquals(bufferPoolCopy, bufferPool); // remainder of pool must be unmodified + + activePageCount.decrementAndGet(); + } + }; + } + + @Override + public int pageSize() { + return PageCacheRecycler.BYTE_PAGE_SIZE; + } + }; + + @After + public void ensureClosed() { + assertEquals(0, activePageCount.getAndSet(0)); + } public void testEmpty() throws Exception { RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler); @@ -589,6 +646,7 @@ public void testWriteMap() throws IOException { assertThat(loaded.size(), equalTo(expected.size())); assertThat(expected, equalTo(loaded)); + out.close(); } public void testWriteImmutableMap() throws IOException { @@ -605,6 +663,7 @@ public void testWriteImmutableMap() throws IOException { final ImmutableOpenMap loaded = in.readImmutableOpenMap(StreamInput::readString, StreamInput::readString); assertThat(expected, equalTo(loaded)); + out.close(); } public void testWriteImmutableMapOfWritable() throws IOException { @@ -621,6 +680,7 @@ public void testWriteImmutableMapOfWritable() throws IOException { final ImmutableOpenMap loaded = in.readImmutableOpenMap(TestWriteable::new, TestWriteable::new); assertThat(expected, equalTo(loaded)); + out.close(); } public void testWriteMapAsList() throws IOException { @@ -638,6 +698,7 @@ public void testWriteMapAsList() throws IOException { assertThat(loaded.size(), equalTo(expected.size())); assertThat(expected, equalTo(loaded)); + out.close(); } private abstract static class BaseNamedWriteable implements NamedWriteable { @@ -728,7 +789,25 @@ public void testWriteLongToCompletePage() throws IOException { } public void testRandomWritesAndSeeks() throws IOException { - try (RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler)) { + try (RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(new Recycler<>() { + @Override + public V obtain() { + final var result = recycler.obtain(); + final var bytesRef = result.v(); + // This seems kinda trappy: a recycler doesn't guarantee anything about the contents of the buffers it supplies, and in + // practice it might contain data left there by the previous user. As used today this is all ok, we always overwrite + // everything eventually in all production usages, but it seems like it might cause problems at some point. + // TODO should we wipe these contents when extending the stream with a seek like this just to be on the safe side? + // In the meantime, for this test only, zero out the buffer contents so that it matches expectedBuffer. + Arrays.fill(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length, (byte) 0); + return result; + } + + @Override + public int pageSize() { + return recycler.pageSize(); + } + })) { final byte[] expectedBuffer = new byte[between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)]; int currentPos = 0; @@ -758,9 +837,7 @@ public void testRandomWritesAndSeeks() throws IOException { } } - final byte[] expected = new byte[currentPos]; - System.arraycopy(expectedBuffer, 0, expected, 0, currentPos); - assertArrayEquals(expected, BytesReference.toBytes(out.bytes())); + assertThat(out.bytes(), equalBytes(new BytesArray(expectedBuffer, 0, currentPos))); } } @@ -859,7 +936,7 @@ public void testWriteRandomStrings() throws IOException { output.writeString(s); } - try (StreamInput streamInput = output.bytes().streamInput()) { + try (output; StreamInput streamInput = output.bytes().streamInput()) { for (int i = 0; i < numStrings; i++) { String s = streamInput.readString(); assertEquals(strings.get(i), s); @@ -878,7 +955,7 @@ public void testWriteLargeSurrogateOnlyString() throws IOException { assertEquals("expands to 4 bytes", 4, new BytesRef(deseretLetter).length); try (RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler)) { output.writeString(largeString); - try (StreamInput streamInput = output.bytes().streamInput()) { + try (output; StreamInput streamInput = output.bytes().streamInput()) { assertEquals(largeString, streamInput.readString()); } } @@ -895,7 +972,7 @@ public void testReadTooLargeArraySize() throws IOException { for (int i = 0; i < 10; i++) { output.writeInt(i); } - try (StreamInput streamInput = output.bytes().streamInput()) { + try (output; StreamInput streamInput = output.bytes().streamInput()) { int[] ints = streamInput.readIntArray(); for (int i = 0; i < 10; i++) { assertEquals(i, ints[i]); @@ -962,16 +1039,17 @@ public void testVInt() throws IOException { } simple.writeByte((byte) i); assertEquals(simple.bytes().toBytesRef().toString(), output.bytes().toBytesRef().toString()); + simple.close(); StreamInput input = output.bytes().streamInput(); assertEquals(value, input.readVInt()); + output.close(); } public void testVLong() throws IOException { final long value = randomLong(); - { + try (RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler)) { // Read works for positive and negative numbers - RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler); output.writeVLongNoCheck(value); // Use NoCheck variant so we can write negative numbers StreamInput input = output.bytes().streamInput(); assertEquals(value, input.readVLong()); @@ -981,6 +1059,7 @@ public void testVLong() throws IOException { RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler); Exception e = expectThrows(IllegalStateException.class, () -> output.writeVLong(value)); assertEquals("Negative longs unsupported, use writeLong or writeZLong for negative numbers [" + value + "]", e.getMessage()); + output.close(); } } @@ -997,6 +1076,7 @@ public void testEnum() throws IOException { StreamInput input = output.bytes().streamInput(); assertEquals(value, input.readEnum(TestEnum.class)); assertEquals(0, input.available()); + output.close(); } public void testInvalidEnum() throws IOException { @@ -1012,6 +1092,7 @@ public void testInvalidEnum() throws IOException { assertEquals("Unknown TestEnum ordinal [" + randomNumber + "]", ex.getMessage()); } assertEquals(0, input.available()); + output.close(); } private void assertEqualityAfterSerialize(TimeValue value, int expectedSize) throws IOException { @@ -1025,6 +1106,7 @@ private void assertEqualityAfterSerialize(TimeValue value, int expectedSize) thr assertThat(inValue, equalTo(value)); assertThat(inValue.duration(), equalTo(value.duration())); assertThat(inValue.timeUnit(), equalTo(value.timeUnit())); + out.close(); } public void testTimeValueSerialize() throws Exception { @@ -1037,6 +1119,7 @@ public void testTimeValueSerialize() throws Exception { RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler); out.writeZLong(timeValue.duration()); assertEqualityAfterSerialize(timeValue, 1 + out.bytes().length()); + out.close(); } public void testOverflow() { @@ -1098,6 +1181,7 @@ public void testSeekToPageBoundary() { byte b = randomByte(); out.writeByte(b); assertEquals(b, out.bytes().get(PageCacheRecycler.BYTE_PAGE_SIZE)); + out.close(); } public void testWriteIntFallbackToSuperClass() throws IOException {