Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.common.io.stream;

import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Constants;
import org.elasticsearch.common.bytes.BytesArray;
Expand All @@ -25,7 +26,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.transport.BytesRefRecycler;
import org.junit.After;

import java.io.EOFException;
import java.io.IOException;
Expand All @@ -34,6 +35,7 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand All @@ -42,6 +44,7 @@
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -60,11 +63,65 @@
import static org.hamcrest.Matchers.nullValue;

/**
* Tests for {@link StreamOutput}.
* Tests for {@link RecyclerBytesStreamOutput}.
*/
public class RecyclerBytesStreamOutputTests extends ESTestCase {

private final Recycler<BytesRef> recycler = new BytesRefRecycler(PageCacheRecycler.NON_RECYCLING_INSTANCE);
private final AtomicInteger activePageCount = new AtomicInteger();

private final Recycler<BytesRef> recycler = new Recycler<>() {

@Override
public V<BytesRef> obtain() {
activePageCount.incrementAndGet();
final var bufferPool = randomByteArrayOfLength(between(pageSize(), pageSize() * 4));
final var offset = randomBoolean()
// align to a multiple of pageSize(), which is the usual case in production
? between(0, bufferPool.length / pageSize() - 1) * pageSize()
// no alignment, to detect cases where alignment is assumed
: between(0, bufferPool.length - pageSize());
final var bytesRef = new BytesRef(bufferPool, offset, pageSize());

final var bufferPoolCopy = ArrayUtil.copyArray(bufferPool); // keep for a later check for out-of-bounds writes
final var dummyByte = randomByte();
Arrays.fill(bufferPoolCopy, offset, offset + pageSize(), dummyByte);

return new V<>() {
@Override
public BytesRef v() {
return bytesRef;
}

@Override
public boolean isRecycled() {
throw new AssertionError("shouldn't matter");
}

@Override
public void close() {
// page must not be changed
assertSame(bufferPool, bytesRef.bytes);
assertEquals(offset, bytesRef.offset);
assertEquals(pageSize(), bytesRef.length);

Arrays.fill(bufferPool, offset, offset + pageSize(), dummyByte); // overwrite buffer contents to detect use-after-free
assertArrayEquals(bufferPoolCopy, bufferPool); // remainder of pool must be unmodified

activePageCount.decrementAndGet();
}
};
}

@Override
public int pageSize() {
return PageCacheRecycler.BYTE_PAGE_SIZE;
}
};

@After
public void ensureClosed() {
assertEquals(0, activePageCount.getAndSet(0));
}

public void testEmpty() throws Exception {
RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler);
Expand Down Expand Up @@ -589,6 +646,7 @@ public void testWriteMap() throws IOException {

assertThat(loaded.size(), equalTo(expected.size()));
assertThat(expected, equalTo(loaded));
out.close();
}

public void testWriteImmutableMap() throws IOException {
Expand All @@ -605,6 +663,7 @@ public void testWriteImmutableMap() throws IOException {
final ImmutableOpenMap<String, String> loaded = in.readImmutableOpenMap(StreamInput::readString, StreamInput::readString);

assertThat(expected, equalTo(loaded));
out.close();
}

public void testWriteImmutableMapOfWritable() throws IOException {
Expand All @@ -621,6 +680,7 @@ public void testWriteImmutableMapOfWritable() throws IOException {
final ImmutableOpenMap<TestWriteable, TestWriteable> loaded = in.readImmutableOpenMap(TestWriteable::new, TestWriteable::new);

assertThat(expected, equalTo(loaded));
out.close();
}

public void testWriteMapAsList() throws IOException {
Expand All @@ -638,6 +698,7 @@ public void testWriteMapAsList() throws IOException {

assertThat(loaded.size(), equalTo(expected.size()));
assertThat(expected, equalTo(loaded));
out.close();
}

private abstract static class BaseNamedWriteable implements NamedWriteable {
Expand Down Expand Up @@ -728,7 +789,25 @@ public void testWriteLongToCompletePage() throws IOException {
}

public void testRandomWritesAndSeeks() throws IOException {
try (RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler)) {
try (RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(new Recycler<>() {
@Override
public V<BytesRef> obtain() {
final var result = recycler.obtain();
final var bytesRef = result.v();
// This seems kinda trappy: a recycler doesn't guarantee anything about the contents of the buffers it supplies, and in
// practice it might contain data left there by the previous user. As used today this is all ok, we always overwrite
// everything eventually in all production usages, but it seems like it might cause problems at some point.
// TODO should we wipe these contents when extending the stream with a seek like this just to be on the safe side?
// In the meantime, for this test only, zero out the buffer contents so that it matches expectedBuffer.
Comment on lines +797 to +801
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB this bit, should we address this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introduce friends to seek and skip that perform filling, when in doubt? And rename current seek/skip as unsafeSeek saying there might be garbage in between.

Arrays.fill(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length, (byte) 0);
return result;
}

@Override
public int pageSize() {
return recycler.pageSize();
}
})) {

final byte[] expectedBuffer = new byte[between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)];
int currentPos = 0;
Expand Down Expand Up @@ -758,9 +837,7 @@ public void testRandomWritesAndSeeks() throws IOException {
}
}

final byte[] expected = new byte[currentPos];
System.arraycopy(expectedBuffer, 0, expected, 0, currentPos);
assertArrayEquals(expected, BytesReference.toBytes(out.bytes()));
assertThat(out.bytes(), equalBytes(new BytesArray(expectedBuffer, 0, currentPos)));
}
}

Expand Down Expand Up @@ -859,7 +936,7 @@ public void testWriteRandomStrings() throws IOException {
output.writeString(s);
}

try (StreamInput streamInput = output.bytes().streamInput()) {
try (output; StreamInput streamInput = output.bytes().streamInput()) {
for (int i = 0; i < numStrings; i++) {
String s = streamInput.readString();
assertEquals(strings.get(i), s);
Expand All @@ -878,7 +955,7 @@ public void testWriteLargeSurrogateOnlyString() throws IOException {
assertEquals("expands to 4 bytes", 4, new BytesRef(deseretLetter).length);
try (RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler)) {
output.writeString(largeString);
try (StreamInput streamInput = output.bytes().streamInput()) {
try (output; StreamInput streamInput = output.bytes().streamInput()) {
assertEquals(largeString, streamInput.readString());
}
}
Expand All @@ -895,7 +972,7 @@ public void testReadTooLargeArraySize() throws IOException {
for (int i = 0; i < 10; i++) {
output.writeInt(i);
}
try (StreamInput streamInput = output.bytes().streamInput()) {
try (output; StreamInput streamInput = output.bytes().streamInput()) {
int[] ints = streamInput.readIntArray();
for (int i = 0; i < 10; i++) {
assertEquals(i, ints[i]);
Expand Down Expand Up @@ -962,16 +1039,17 @@ public void testVInt() throws IOException {
}
simple.writeByte((byte) i);
assertEquals(simple.bytes().toBytesRef().toString(), output.bytes().toBytesRef().toString());
simple.close();

StreamInput input = output.bytes().streamInput();
assertEquals(value, input.readVInt());
output.close();
}

public void testVLong() throws IOException {
final long value = randomLong();
{
try (RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler)) {
// Read works for positive and negative numbers
RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler);
output.writeVLongNoCheck(value); // Use NoCheck variant so we can write negative numbers
StreamInput input = output.bytes().streamInput();
assertEquals(value, input.readVLong());
Expand All @@ -981,6 +1059,7 @@ public void testVLong() throws IOException {
RecyclerBytesStreamOutput output = new RecyclerBytesStreamOutput(recycler);
Exception e = expectThrows(IllegalStateException.class, () -> output.writeVLong(value));
assertEquals("Negative longs unsupported, use writeLong or writeZLong for negative numbers [" + value + "]", e.getMessage());
output.close();
}
}

Expand All @@ -997,6 +1076,7 @@ public void testEnum() throws IOException {
StreamInput input = output.bytes().streamInput();
assertEquals(value, input.readEnum(TestEnum.class));
assertEquals(0, input.available());
output.close();
}

public void testInvalidEnum() throws IOException {
Expand All @@ -1012,6 +1092,7 @@ public void testInvalidEnum() throws IOException {
assertEquals("Unknown TestEnum ordinal [" + randomNumber + "]", ex.getMessage());
}
assertEquals(0, input.available());
output.close();
}

private void assertEqualityAfterSerialize(TimeValue value, int expectedSize) throws IOException {
Expand All @@ -1025,6 +1106,7 @@ private void assertEqualityAfterSerialize(TimeValue value, int expectedSize) thr
assertThat(inValue, equalTo(value));
assertThat(inValue.duration(), equalTo(value.duration()));
assertThat(inValue.timeUnit(), equalTo(value.timeUnit()));
out.close();
}

public void testTimeValueSerialize() throws Exception {
Expand All @@ -1037,6 +1119,7 @@ public void testTimeValueSerialize() throws Exception {
RecyclerBytesStreamOutput out = new RecyclerBytesStreamOutput(recycler);
out.writeZLong(timeValue.duration());
assertEqualityAfterSerialize(timeValue, 1 + out.bytes().length());
out.close();
}

public void testOverflow() {
Expand Down Expand Up @@ -1098,6 +1181,7 @@ public void testSeekToPageBoundary() {
byte b = randomByte();
out.writeByte(b);
assertEquals(b, out.bytes().get(PageCacheRecycler.BYTE_PAGE_SIZE));
out.close();
}

public void testWriteIntFallbackToSuperClass() throws IOException {
Expand Down