diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java index 13fd49faf41b4..4ce1157a034ff 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutput.java @@ -10,6 +10,7 @@ package org.elasticsearch.common.io.stream; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.CompositeBytesReference; @@ -62,8 +63,12 @@ */ public class RecyclerBytesStreamOutput extends BytesStream implements Releasable { - private ArrayList> pages = new ArrayList<>(8); private final Recycler recycler; + + @Nullable // if no circuit breaker in use + private final CircuitBreaker circuitBreaker; + + private ArrayList> pages = new ArrayList<>(8); private final int pageSize; private int pageIndex = -1; private int currentCapacity = 0; @@ -89,7 +94,12 @@ public class RecyclerBytesStreamOutput extends BytesStream implements Releasable private long positionOffset; public RecyclerBytesStreamOutput(Recycler recycler) { + this(recycler, null); + } + + public RecyclerBytesStreamOutput(Recycler recycler, @Nullable CircuitBreaker circuitBreaker) { this.recycler = recycler; + this.circuitBreaker = circuitBreaker; this.pageSize = recycler.pageSize(); this.currentOffset = this.maxOffset = pageSize; // Always start with a page. This is because if we don't have a page, one of the hot write paths would be forced to go through @@ -280,7 +290,7 @@ public void legacyWriteWithSizePrefix(Writeable writeable) throws IOException { // TODO: do this without copying the bytes from tmp by calling writeBytes and just use the pages in tmp directly through // manipulation of the offsets on the pages after writing to tmp. This will require adjustments to the places in this class // that make assumptions about the page size - try (RecyclerBytesStreamOutput tmp = new RecyclerBytesStreamOutput(recycler)) { + try (RecyclerBytesStreamOutput tmp = new RecyclerBytesStreamOutput(recycler, circuitBreaker)) { tmp.setTransportVersion(getTransportVersion()); writeable.writeTo(tmp); int size = tmp.size(); @@ -415,6 +425,9 @@ public void close() { if (pages != null) { closeFields(); Releasables.close(pages); + if (circuitBreaker != null) { + circuitBreaker.addWithoutBreaking(-(long) pageSize * pages.size()); + } } } @@ -429,7 +442,24 @@ public ReleasableBytesReference moveToBytesReference() { var pages = this.pages; closeFields(); - return new ReleasableBytesReference(bytes, pages.size() == 1 ? pages.getFirst() : Releasables.wrap(pages)); + final Releasable releasable; + if (pages.size() == 1) { + if (circuitBreaker == null) { + releasable = pages.getFirst(); + } else { + final var pageSize = this.pageSize; + releasable = Releasables.wrap(pages.getFirst(), () -> circuitBreaker.addWithoutBreaking(-pageSize)); + } + } else { + if (circuitBreaker == null) { + releasable = Releasables.wrap(pages); + } else { + final long releaseSize = (long) this.pageSize * pages.size(); + releasable = Releasables.wrap(Releasables.wrap(pages), () -> circuitBreaker.addWithoutBreaking(-releaseSize)); + } + } + + return new ReleasableBytesReference(bytes, releasable); } /** @@ -594,10 +624,22 @@ private void ensureCapacityFromPosition(long newPosition) { // Calculate number of additional pages needed int additionalPagesNeeded = (int) ((additionalCapacityNeeded + pageSize - 1) / pageSize); pages.ensureCapacity(pages.size() + additionalPagesNeeded); - for (int i = 0; i < additionalPagesNeeded; i++) { - Recycler.V newPage = recycler.obtain(); - assert pageSize == newPage.v().length; - pages.add(newPage); + + if (circuitBreaker != null) { + circuitBreaker.addEstimateBytesAndMaybeBreak((long) pageSize * additionalPagesNeeded, "RecyclerBytesStreamOutput"); + } + int pagesAdded = 0; + try { + while (pagesAdded < additionalPagesNeeded) { + Recycler.V newPage = recycler.obtain(); + assert pageSize == newPage.v().length; + pages.add(newPage); + pagesAdded += 1; + } + } finally { + if (circuitBreaker != null && pagesAdded < additionalPagesNeeded) { + circuitBreaker.addWithoutBreaking((long) pageSize * (pagesAdded - additionalPagesNeeded)); + } } currentCapacity += additionalPagesNeeded * pageSize; } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java index 7f283cce68f19..58082e31c3e34 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/RecyclerBytesStreamOutputTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Constants; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; @@ -19,7 +20,9 @@ import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.core.Assertions; import org.elasticsearch.core.TimeValue; @@ -1557,17 +1560,104 @@ public void testToBase64String() throws IOException { public void testWriteAllBytesFrom() throws IOException { final var bytes = randomBytesReference(between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)); try (var out = new RecyclerBytesStreamOutput(recycler)) { - if (randomBoolean()) { - out.writeAllBytesFrom(bytes.streamInput()); - } else { - var remaining = bytes; - while (remaining.length() > 0) { - var thisSlice = remaining.slice(0, between(1, remaining.length())); - remaining = remaining.slice(thisSlice.length(), remaining.length() - thisSlice.length()); - out.writeAllBytesFrom(thisSlice.streamInput()); + writeAllBytesRandomSlices(out, bytes); + assertThat(out.bytes(), equalBytes(bytes)); + } + } + + public void testCircuitBreakerTracking() throws IOException { + final var bytes = randomBytesReference(between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)); + final var expectedAllocation = getExpectedAllocation(bytes.length()); + final var circuitBreaker = new MockBigArrays.LimitedBreaker("test", ByteSizeValue.ofBytes(expectedAllocation + between(0, 100))); + try (var out = new RecyclerBytesStreamOutput(recycler, circuitBreaker)) { + writeAllBytesRandomSlices(out, bytes); + assertThat(out.bytes(), equalBytes(bytes)); + assertThat(circuitBreaker.getUsed(), equalTo(expectedAllocation)); + } + + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + } + + public void testCircuitBreakerMoveToBytesReference() throws IOException { + final var bytes = randomBytesReference(between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)); + final var expectedTracked = getExpectedAllocation(bytes.length()); + final var circuitBreaker = new MockBigArrays.LimitedBreaker("test", ByteSizeValue.ofBytes(expectedTracked + between(0, 100))); + final ReleasableBytesReference actualBytes; + try (var out = new RecyclerBytesStreamOutput(recycler, circuitBreaker)) { + writeAllBytesRandomSlices(out, bytes); + assertThat(circuitBreaker.getUsed(), equalTo(expectedTracked)); + actualBytes = out.moveToBytesReference(); + assertThat(circuitBreaker.getUsed(), equalTo(expectedTracked)); + } + + assertThat(circuitBreaker.getUsed(), equalTo(expectedTracked)); + assertThat(actualBytes, equalBytes(bytes)); + actualBytes.close(); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + } + + public void testCircuitBreakerTripping() { + final var bytes = randomBytesReference(between(0, PageCacheRecycler.BYTE_PAGE_SIZE * 4)); + final var expectedTracked = getExpectedAllocation(bytes.length()); + final var circuitBreaker = new MockBigArrays.LimitedBreaker( + "test", + ByteSizeValue.ofBytes(randomLongBetween(0L, expectedTracked - 1L)) + ); + + expectThrows(CircuitBreakingException.class, () -> { + try (var out = new RecyclerBytesStreamOutput(recycler, circuitBreaker)) { + writeAllBytesRandomSlices(out, bytes); + } + }); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + } + + public void testCircuitBreakerReleaseOnRecyclerFailure() { + final var bytes = randomBytesReference(between(PageCacheRecycler.BYTE_PAGE_SIZE * 3 + 1, PageCacheRecycler.BYTE_PAGE_SIZE * 4)); + assertEquals(PageCacheRecycler.BYTE_PAGE_SIZE * 4L, getExpectedAllocation(bytes.length())); + final var circuitBreaker = new MockBigArrays.LimitedBreaker("test", ByteSizeValue.ofBytes(PageCacheRecycler.BYTE_PAGE_SIZE * 4)); + + final var failingRecycler = new Recycler() { + int pagesLeft = between(0, 3); + + @Override + public V obtain() { + if (pagesLeft == 0) { + throw new RuntimeException("simulated recycler failure"); } + pagesLeft -= 1; + return recycler.obtain(); + } + + @Override + public int pageSize() { + return recycler.pageSize(); + } + }; + + expectThrows(RuntimeException.class, () -> { + try (var out = new RecyclerBytesStreamOutput(failingRecycler, circuitBreaker)) { + writeAllBytesRandomSlices(out, bytes); + } + }); + assertThat(circuitBreaker.getUsed(), equalTo(0L)); + } + + private static void writeAllBytesRandomSlices(RecyclerBytesStreamOutput out, BytesReference bytes) throws IOException { + if (randomBoolean()) { + out.writeAllBytesFrom(bytes.streamInput()); + } else { + var remaining = bytes; + while (remaining.length() > 0) { + var thisSlice = remaining.slice(0, between(1, remaining.length())); + remaining = remaining.slice(thisSlice.length(), remaining.length() - thisSlice.length()); + out.writeAllBytesFrom(thisSlice.streamInput()); } - assertThat(out.bytes(), equalBytes(bytes)); } } + + private static long getExpectedAllocation(int length) { + final var expectedPages = Math.max(1, (length + PageCacheRecycler.BYTE_PAGE_SIZE - 1) / PageCacheRecycler.BYTE_PAGE_SIZE); + return expectedPages * PageCacheRecycler.BYTE_PAGE_SIZE; + } }