diff --git a/api/src/main/java/io/grpc/HasByteBuffer.java b/api/src/main/java/io/grpc/HasByteBuffer.java new file mode 100644 index 00000000000..35abca53e6c --- /dev/null +++ b/api/src/main/java/io/grpc/HasByteBuffer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.nio.ByteBuffer; +import javax.annotation.Nullable; + +/** + * Extension to an {@link java.io.InputStream} whose content can be accessed as {@link + * ByteBuffer}s. + * + *

This can be used for optimizing the case for the consumer of a {@link ByteBuffer}-backed + * input stream supports efficient reading from {@link ByteBuffer}s directly. This turns the reader + * interface from an {@link java.io.InputStream} to {@link ByteBuffer}s, without copying the + * content to a byte array and read from it. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7387") +public interface HasByteBuffer { + + /** + * Indicates whether or not {@link #getByteBuffer} operation is supported. + */ + boolean getByteBufferSupported(); + + /** + * Gets a {@link ByteBuffer} containing some bytes of the content next to be read, or {@code + * null} if has reached end of the content. The number of bytes contained in the returned buffer + * is implementation specific. Calling this method does not change the position of the input + * stream. The returned buffer's content should not be modified, but the position, limit, and + * mark may be changed. Operations for changing the position, limit, and mark of the returned + * buffer does not affect the position, limit, and mark of this input stream. This is an optional + * method, so callers should first check {@link #getByteBufferSupported}. + * + * @throws UnsupportedOperationException if this operation is not supported. + */ + @Nullable + ByteBuffer getByteBuffer(); +} diff --git a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java index e43b7a7cc0e..99812c3ffff 100644 --- a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java @@ -16,6 +16,8 @@ package io.grpc.internal; +import java.nio.ByteBuffer; + /** * Abstract base class for {@link ReadableBuffer} implementations. */ @@ -45,6 +47,24 @@ public int arrayOffset() { throw new UnsupportedOperationException(); } + @Override + public void mark() {} + + @Override + public void reset() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getByteBufferSupported() { + return false; + } + + @Override + public ByteBuffer getByteBuffer() { + throw new UnsupportedOperationException(); + } + @Override public void close() {} diff --git a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java index 93dda7cdbc8..c1c2c78c5af 100644 --- a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java @@ -20,8 +20,10 @@ import java.io.OutputStream; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.InvalidMarkException; import java.util.ArrayDeque; -import java.util.Queue; +import java.util.Deque; +import javax.annotation.Nullable; /** * A {@link ReadableBuffer} that is composed of 0 or more {@link ReadableBuffer}s. This provides a @@ -34,7 +36,9 @@ public class CompositeReadableBuffer extends AbstractReadableBuffer { private int readableBytes; - private final Queue buffers = new ArrayDeque<>(); + private final Deque readableBuffers = new ArrayDeque<>(); + private final Deque rewindableBuffers = new ArrayDeque<>(); + private boolean marked; /** * Adds a new {@link ReadableBuffer} at the end of the buffer list. After a buffer is added, it is @@ -43,16 +47,24 @@ public class CompositeReadableBuffer extends AbstractReadableBuffer { * this {@code CompositeBuffer}. */ public void addBuffer(ReadableBuffer buffer) { + boolean markHead = marked && readableBuffers.isEmpty(); + enqueueBuffer(buffer); + if (markHead) { + readableBuffers.peek().mark(); + } + } + + private void enqueueBuffer(ReadableBuffer buffer) { if (!(buffer instanceof CompositeReadableBuffer)) { - buffers.add(buffer); + readableBuffers.add(buffer); readableBytes += buffer.readableBytes(); return; } CompositeReadableBuffer compositeBuffer = (CompositeReadableBuffer) buffer; - while (!compositeBuffer.buffers.isEmpty()) { - ReadableBuffer subBuffer = compositeBuffer.buffers.remove(); - buffers.add(subBuffer); + while (!compositeBuffer.readableBuffers.isEmpty()) { + ReadableBuffer subBuffer = compositeBuffer.readableBuffers.remove(); + readableBuffers.add(subBuffer); } readableBytes += compositeBuffer.readableBytes; compositeBuffer.readableBytes = 0; @@ -136,27 +148,73 @@ public int readInternal(ReadableBuffer buffer, int length) throws IOException { @Override public CompositeReadableBuffer readBytes(int length) { - checkReadable(length); - readableBytes -= length; - - CompositeReadableBuffer newBuffer = new CompositeReadableBuffer(); - while (length > 0) { - ReadableBuffer buffer = buffers.peek(); - if (buffer.readableBytes() > length) { + final CompositeReadableBuffer newBuffer = new CompositeReadableBuffer(); + execute(new ReadOperation() { + @Override + int readInternal(ReadableBuffer buffer, int length) { newBuffer.addBuffer(buffer.readBytes(length)); - length = 0; - } else { - newBuffer.addBuffer(buffers.poll()); - length -= buffer.readableBytes(); + return 0; } - } + }, length); return newBuffer; } + @Override + public void mark() { + while (!rewindableBuffers.isEmpty()) { + rewindableBuffers.remove().close(); + } + marked = true; + ReadableBuffer buffer = readableBuffers.peek(); + if (buffer != null) { + buffer.mark(); + } + } + + @Override + public void reset() { + if (!marked) { + throw new InvalidMarkException(); + } + ReadableBuffer buffer; + if ((buffer = readableBuffers.peek()) != null) { + int currentRemain = buffer.readableBytes(); + buffer.reset(); + readableBytes += (buffer.readableBytes() - currentRemain); + } + while ((buffer = rewindableBuffers.pollLast()) != null) { + buffer.reset(); + readableBuffers.addFirst(buffer); + readableBytes += buffer.readableBytes(); + } + } + + @Override + public boolean getByteBufferSupported() { + for (ReadableBuffer buffer : readableBuffers) { + if (!buffer.getByteBufferSupported()) { + return false; + } + } + return true; + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + if (readableBuffers.isEmpty()) { + return null; + } + return readableBuffers.peek().getByteBuffer(); + } + @Override public void close() { - while (!buffers.isEmpty()) { - buffers.remove().close(); + while (!readableBuffers.isEmpty()) { + readableBuffers.remove().close(); + } + while (!rewindableBuffers.isEmpty()) { + rewindableBuffers.remove().close(); } } @@ -167,12 +225,12 @@ public void close() { private void execute(ReadOperation op, int length) { checkReadable(length); - if (!buffers.isEmpty()) { + if (!readableBuffers.isEmpty()) { advanceBufferIfNecessary(); } - for (; length > 0 && !buffers.isEmpty(); advanceBufferIfNecessary()) { - ReadableBuffer buffer = buffers.peek(); + for (; length > 0 && !readableBuffers.isEmpty(); advanceBufferIfNecessary()) { + ReadableBuffer buffer = readableBuffers.peek(); int lengthToCopy = Math.min(length, buffer.readableBytes()); // Perform the read operation for this buffer. @@ -195,9 +253,17 @@ private void execute(ReadOperation op, int length) { * If the current buffer is exhausted, removes and closes it. */ private void advanceBufferIfNecessary() { - ReadableBuffer buffer = buffers.peek(); + ReadableBuffer buffer = readableBuffers.peek(); if (buffer.readableBytes() == 0) { - buffers.remove().close(); + if (marked) { + rewindableBuffers.add(readableBuffers.remove()); + ReadableBuffer next = readableBuffers.peek(); + if (next != null) { + next.mark(); + } + } else { + readableBuffers.remove().close(); + } } } diff --git a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java index 954d0ac5486..0103f9fb39f 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; +import javax.annotation.Nullable; /** * Base class for a wrapper around another {@link ReadableBuffer}. @@ -96,6 +97,27 @@ public int arrayOffset() { return buf.arrayOffset(); } + @Override + public void mark() { + buf.mark(); + } + + @Override + public void reset() { + buf.reset(); + } + + @Override + public boolean getByteBufferSupported() { + return buf.getByteBufferSupported(); + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + return buf.getByteBuffer(); + } + @Override public void close() { buf.close(); diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffer.java b/core/src/main/java/io/grpc/internal/ReadableBuffer.java index 7d2ca7ebba5..684039292aa 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffer.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; +import javax.annotation.Nullable; /** * Interface for an abstract byte buffer. Buffers are intended to be a read-only, except for the @@ -123,6 +124,39 @@ public interface ReadableBuffer extends Closeable { */ int arrayOffset(); + /** + * Marks the current position in this buffer. A subsequent call to the {@link #reset} method + * repositions this stream at the last marked position so that subsequent reads re-read the same + * bytes. + */ + void mark(); + + /** + * Repositions this buffer to the position at the time {@link #mark} was last called on this + * buffer. + */ + void reset(); + + /** + * Indicates whether or not {@link #getByteBuffer} operation is supported for this buffer. + */ + boolean getByteBufferSupported(); + + /** + * Gets a {@link ByteBuffer} that contains some bytes of the content next to be read, or {@code + * null} if this buffer has been exhausted. The number of bytes contained in the returned buffer + * is implementation specific. The position of this buffer is unchanged after calling this + * method. The returned buffer's content should not be modified, but the position, limit, and + * mark may be changed. Operations for changing the position, limit, and mark of the returned + * buffer does not affect the position, limit, and mark of this buffer. Buffers returned by this + * method have independent position, limit and mark. This is an optional method, so callers + * should first check {@link #getByteBufferSupported}. + * + * @throws UnsupportedOperationException the buffer does not support this method. + */ + @Nullable + ByteBuffer getByteBuffer(); + /** * Closes this buffer and releases any resources. */ diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index cfe5542a573..f764d893fda 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -19,6 +19,7 @@ import static com.google.common.base.Charsets.UTF_8; import com.google.common.base.Preconditions; +import io.grpc.HasByteBuffer; import io.grpc.KnownLength; import java.io.IOException; import java.io.InputStream; @@ -26,6 +27,7 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import javax.annotation.Nullable; /** * Utility methods for creating {@link ReadableBuffer} instances. @@ -128,6 +130,7 @@ private static class ByteArrayWrapper extends AbstractReadableBuffer { int offset; final int end; final byte[] bytes; + int mark = -1; ByteArrayWrapper(byte[] bytes) { this(bytes, 0, bytes.length); @@ -204,6 +207,16 @@ public byte[] array() { public int arrayOffset() { return offset; } + + @Override + public void mark() { + mark = offset; + } + + @Override + public void reset() { + offset = mark; + } } /** @@ -291,12 +304,33 @@ public byte[] array() { public int arrayOffset() { return bytes.arrayOffset() + bytes.position(); } + + @Override + public void mark() { + bytes.mark(); + } + + @Override + public void reset() { + bytes.reset(); + } + + @Override + public boolean getByteBufferSupported() { + return true; + } + + @Override + public ByteBuffer getByteBuffer() { + return bytes.slice(); + } } /** * An {@link InputStream} that is backed by a {@link ReadableBuffer}. */ - private static final class BufferInputStream extends InputStream implements KnownLength { + private static final class BufferInputStream extends InputStream + implements KnownLength, HasByteBuffer { final ReadableBuffer buffer; public BufferInputStream(ReadableBuffer buffer) { @@ -329,6 +363,39 @@ public int read(byte[] dest, int destOffset, int length) throws IOException { return length; } + @Override + public long skip(long n) throws IOException { + int length = (int) Math.min(buffer.readableBytes(), n); + buffer.skipBytes(length); + return length; + } + + @Override + public void mark(int readlimit) { + buffer.mark(); + } + + @Override + public void reset() throws IOException { + buffer.reset(); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public boolean getByteBufferSupported() { + return buffer.getByteBufferSupported(); + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + return buffer.getByteBuffer(); + } + @Override public void close() throws IOException { buffer.close(); diff --git a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java index 660aa116317..f26c8b1ef9a 100644 --- a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java @@ -17,14 +17,20 @@ package io.grpc.internal; import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.InvalidMarkException; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -154,6 +160,128 @@ public void readStreamShouldSucceed() throws IOException { assertEquals(EXPECTED_VALUE, new String(bos.toByteArray(), UTF_8)); } + @Test + public void resetUnmarkedShouldThrow() { + try { + composite.reset(); + fail(); + } catch (InvalidMarkException expected) { + } + } + + @Test + public void markAndResetWithSkipBytesShouldSucceed() { + composite.mark(); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + } + + @Test + public void markAndResetWithReadUnsignedByteShouldSucceed() { + composite.readUnsignedByte(); + composite.mark(); + int b = composite.readUnsignedByte(); + composite.reset(); + assertEquals(EXPECTED_VALUE.length() - 1, composite.readableBytes()); + assertEquals(b, composite.readUnsignedByte()); + } + + @Test + public void markAndResetWithReadByteArrayShouldSucceed() { + composite.mark(); + byte[] first = new byte[EXPECTED_VALUE.length()]; + composite.readBytes(first, 0, EXPECTED_VALUE.length()); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + byte[] second = new byte[EXPECTED_VALUE.length()]; + composite.readBytes(second, 0, EXPECTED_VALUE.length()); + assertArrayEquals(first, second); + } + + @Test + public void markAndResetWithReadByteBufferShouldSucceed() { + byte[] first = new byte[EXPECTED_VALUE.length()]; + composite.mark(); + composite.readBytes(ByteBuffer.wrap(first)); + composite.reset(); + byte[] second = new byte[EXPECTED_VALUE.length()]; + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + composite.readBytes(ByteBuffer.wrap(second)); + assertArrayEquals(first, second); + } + + @Test + public void markAndResetWithReadStreamShouldSucceed() throws IOException { + ByteArrayOutputStream first = new ByteArrayOutputStream(); + composite.mark(); + composite.readBytes(first, EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + ByteArrayOutputStream second = new ByteArrayOutputStream(); + composite.readBytes(second, EXPECTED_VALUE.length() / 2); + assertArrayEquals(first.toByteArray(), second.toByteArray()); + } + + @Test + public void markAndResetWithReadReadableBufferShouldSucceed() { + composite.readBytes(EXPECTED_VALUE.length() / 2); + int remaining = composite.readableBytes(); + composite.mark(); + ReadableBuffer first = composite.readBytes(1); + composite.reset(); + assertEquals(remaining, composite.readableBytes()); + ReadableBuffer second = composite.readBytes(1); + assertEquals(first.readUnsignedByte(), second.readUnsignedByte()); + } + + @Test + public void markAgainShouldOverwritePreviousMark() { + composite.mark(); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + int remaining = composite.readableBytes(); + composite.mark(); + composite.skipBytes(1); + composite.reset(); + assertEquals(remaining, composite.readableBytes()); + } + + @Test + public void bufferAddedAfterMarkedShouldBeIncluded() { + composite = new CompositeReadableBuffer(); + composite.mark(); + splitAndAdd(EXPECTED_VALUE); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + } + + @Test + public void canUseByteBufferOnlyAllComponentsSupportUsingByteBuffer() { + composite = new CompositeReadableBuffer(); + ReadableBuffer buffer1 = mock(ReadableBuffer.class); + ReadableBuffer buffer2 = mock(ReadableBuffer.class); + ReadableBuffer buffer3 = mock(ReadableBuffer.class); + when(buffer1.getByteBufferSupported()).thenReturn(true); + when(buffer2.getByteBufferSupported()).thenReturn(true); + when(buffer3.getByteBufferSupported()).thenReturn(false); + composite.addBuffer(buffer1); + assertTrue(composite.getByteBufferSupported()); + composite.addBuffer(buffer2); + assertTrue(composite.getByteBufferSupported()); + composite.addBuffer(buffer3); + assertFalse(composite.getByteBufferSupported()); + } + + @Test + public void getByteBufferDelegatesToComponents() { + composite = new CompositeReadableBuffer(); + ReadableBuffer buffer = mock(ReadableBuffer.class); + composite.addBuffer(buffer); + composite.getByteBuffer(); + verify(buffer).getByteBuffer(); + } + @Test public void closeShouldCloseBuffers() { composite = new CompositeReadableBuffer(); diff --git a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java index e469b807d51..86b8e2a399b 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java +++ b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java @@ -24,6 +24,7 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.Arrays; +import org.junit.Assume; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -117,6 +118,58 @@ public void partialReadToReadableBufferShouldSucceed() { assertArrayEquals(new byte[] {'h', 'e'}, Arrays.copyOfRange(array, 0, 2)); } + @Test + public void markAndResetWithReadShouldSucceed() { + ReadableBuffer buffer = buffer(); + int offset = 5; + buffer.readBytes(new byte[offset], 0, offset); + buffer.mark(); + int b = buffer.readUnsignedByte(); + assertEquals(msg.length() - offset - 1, buffer.readableBytes()); + buffer.reset(); + assertEquals(msg.length() - offset, buffer.readableBytes()); + assertEquals(b, buffer.readUnsignedByte()); + } + + @Test + public void markAndResetWithReadToReadableBufferShouldSucceed() { + ReadableBuffer buffer = buffer(); + int offset = 5; + buffer.readBytes(offset); + int testLen = 100; + buffer.mark(); + ReadableBuffer first = buffer.readBytes(testLen); + assertEquals(msg.length() - offset - testLen, buffer.readableBytes()); + buffer.reset(); + assertEquals(msg.length() - offset, buffer.readableBytes()); + ReadableBuffer second = buffer.readBytes(testLen); + byte[] array1 = new byte[testLen]; + byte[] array2 = new byte[testLen]; + first.readBytes(array1, 0, testLen); + second.readBytes(array2, 0, testLen); + assertArrayEquals(array1, array2); + } + + @Test + public void getByteBufferDoesNotAffectBufferPosition() { + ReadableBuffer buffer = buffer(); + Assume.assumeTrue(buffer.getByteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + assertEquals(msg.length(), buffer.readableBytes()); + byteBuffer.get(new byte[byteBuffer.remaining()]); + assertEquals(msg.length(), buffer.readableBytes()); + } + + @Test + public void getByteBufferIsNotAffectedByBufferRead() { + ReadableBuffer buffer = buffer(); + Assume.assumeTrue(buffer.getByteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + int initialRemaining = byteBuffer.remaining(); + buffer.readBytes(new byte[100], 0, 100); + assertEquals(initialRemaining, byteBuffer.remaining()); + } + protected abstract ReadableBuffer buffer(); private static String repeatUntilLength(String toRepeat, int length) { diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java index ea9daeed6a3..3007c19682a 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java @@ -19,11 +19,15 @@ import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import io.grpc.HasByteBuffer; +import java.io.IOException; import java.io.InputStream; import org.junit.Test; import org.junit.runner.RunWith; @@ -128,4 +132,30 @@ public void bufferInputStream_close_closesBuffer() throws Exception { inputStream.close(); verify(buffer, times(1)).close(); } + + @Test + public void bufferInputStream_markAndReset() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertTrue(inputStream.markSupported()); + inputStream.mark(2); + byte[] first = new byte[5]; + inputStream.read(first); + assertEquals(0, inputStream.available()); + inputStream.reset(); + assertEquals(5, inputStream.available()); + byte[] second = new byte[5]; + inputStream.read(second); + assertArrayEquals(first, second); + } + + @Test + public void bufferInputStream_getByteBufferDelegatesToBuffer() { + ReadableBuffer buffer = mock(ReadableBuffer.class); + when(buffer.getByteBufferSupported()).thenReturn(true); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertTrue(((HasByteBuffer) inputStream).getByteBufferSupported()); + ((HasByteBuffer) inputStream).getByteBuffer(); + verify(buffer).getByteBuffer(); + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java index 37caccb0eb3..75d70978610 100644 --- a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java +++ b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java @@ -94,6 +94,26 @@ public int arrayOffset() { return buffer.arrayOffset() + buffer.readerIndex(); } + @Override + public void mark() { + buffer.markReaderIndex(); + } + + @Override + public void reset() { + buffer.resetReaderIndex(); + } + + @Override + public boolean getByteBufferSupported() { + return buffer.nioBufferCount() > 0; + } + + @Override + public ByteBuffer getByteBuffer() { + return buffer.nioBufferCount() == 1 ? buffer.nioBuffer() : buffer.nioBuffers()[0]; + } + /** * If the first call to close, calls {@link ByteBuf#release} to release the internal Netty buffer. */ diff --git a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java index 8090e601911..23c2f0632a9 100644 --- a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java @@ -17,11 +17,16 @@ package io.grpc.netty; import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import com.google.common.base.Splitter; import io.grpc.internal.ReadableBuffer; import io.grpc.internal.ReadableBufferTestBase; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; +import java.nio.ByteBuffer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,6 +57,29 @@ public void closeMultipleTimesShouldReleaseBufferOnce() { assertEquals(0, buffer.buffer().refCnt()); } + @Test + public void getByteBufferFromSingleNioBufferBackedBuffer() { + assertTrue(buffer.getByteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + byte[] arr = new byte[byteBuffer.remaining()]; + byteBuffer.get(arr); + assertArrayEquals(msg.getBytes(UTF_8), arr); + } + + @Test + public void getByteBufferFromCompositeBufferReturnsOnlyFirstComponent() { + CompositeByteBuf composite = Unpooled.compositeBuffer(10); + int chunks = 4; + int chunkLen = msg.length() / chunks; + for (String chunk : Splitter.fixedLength(chunkLen).split(msg)) { + composite.addComponent(true, Unpooled.copiedBuffer(chunk.getBytes(UTF_8))); + } + buffer = new NettyReadableBuffer(composite); + byte[] array = new byte[chunkLen]; + buffer.getByteBuffer().get(array); + assertArrayEquals(msg.substring(0, chunkLen).getBytes(UTF_8), array); + } + @Override protected ReadableBuffer buffer() { return buffer; diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java index 2ece98ffb97..4aeeae2fa8b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java @@ -56,6 +56,18 @@ public void partialReadToByteBufferShouldSucceed() { // Not supported. } + @Override + @Test + public void markAndResetWithReadShouldSucceed() { + // Not supported. + } + + @Override + @Test + public void markAndResetWithReadToReadableBufferShouldSucceed() { + // Not supported. + } + @Override protected ReadableBuffer buffer() { return buffer; diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index ddba5b8d5b1..47a790f44fc 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -25,6 +25,7 @@ import com.google.protobuf.MessageLite; import com.google.protobuf.Parser; import io.grpc.ExperimentalApi; +import io.grpc.HasByteBuffer; import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; @@ -35,6 +36,9 @@ import java.io.OutputStream; import java.lang.ref.Reference; import java.lang.ref.WeakReference; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; /** * Utility methods for using protobuf with grpc. @@ -48,12 +52,34 @@ public final class ProtoLiteUtils { private static final int BUF_SIZE = 8192; + /** + * Assume Java 9+ if it isn't Java 7 or Java 8. + */ + @VisibleForTesting + static final boolean IS_JAVA9_OR_HIGHER; + /** * The same value as {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}. */ @VisibleForTesting static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024; + /** + * Threshold for passing {@link ByteBuffer}s directly into Protobuf. + */ + @VisibleForTesting + static final int MESSAGE_ZERO_COPY_THRESHOLD = 64 * 1024; + + static { + boolean isJava9OrHigher = true; + try { + Class.forName("java.lang.StackWalker"); + } catch (ClassNotFoundException e) { + isJava9OrHigher = false; + } + IS_JAVA9_OR_HIGHER = isJava9OrHigher; + } + /** * Sets the global registry for proto marshalling shared across all servers and clients. * @@ -173,7 +199,23 @@ public T parse(InputStream stream) { try { if (stream instanceof KnownLength) { int size = stream.available(); - if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) { + if (size == 0) { + return defaultInstance; + } + if (IS_JAVA9_OR_HIGHER + && size >= MESSAGE_ZERO_COPY_THRESHOLD + && stream instanceof HasByteBuffer + && ((HasByteBuffer) stream).getByteBufferSupported() + && stream.markSupported()) { + List buffers = new ArrayList<>(); + stream.mark(size); + while (stream.available() != 0) { + ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer(); + stream.skip(buffer.remaining()); + buffers.add(buffer); + } + cis = CodedInputStream.newInstance(buffers); + } else if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) { Reference ref; // buf should not be used after this method has returned. byte[] buf; @@ -197,8 +239,6 @@ public T parse(InputStream stream) { throw new RuntimeException("size inaccurate: " + size + " != " + position); } cis = CodedInputStream.newInstance(buf, 0, size); - } else if (size == 0) { - return defaultInstance; } } } catch (IOException e) { diff --git a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java index d05e884105e..1d5c1abc806 100644 --- a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java +++ b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.fail; +import com.google.common.base.Strings; import com.google.common.io.ByteStreams; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; @@ -29,6 +30,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Type; import io.grpc.Drainable; +import io.grpc.HasByteBuffer; import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; @@ -40,7 +42,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Arrays; +import javax.annotation.Nullable; +import org.junit.Assume; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -226,6 +231,20 @@ public void parseFromKnowLengthInputStream() throws Exception { assertEquals(expect, result); } + @Test + public void parseFromKnownLengthByteBufferInputStream() { + Assume.assumeTrue(ProtoLiteUtils.IS_JAVA9_OR_HIGHER); + Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); + Type expect = + Type.newBuilder() + .setName(Strings.repeat("M", ProtoLiteUtils.MESSAGE_ZERO_COPY_THRESHOLD)) + .build(); + + Type result = marshaller.parse( + new CustomKnownLengthByteBufferInputStream(expect.toByteString().asReadOnlyByteBuffer())); + assertEquals(expect, result); + } + @Test public void defaultMaxMessageSize() { assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE); @@ -253,4 +272,55 @@ public int read() throws IOException { return source[position++]; } } + + private static final class CustomKnownLengthByteBufferInputStream extends InputStream + implements KnownLength, HasByteBuffer { + private ByteBuffer source; + + private CustomKnownLengthByteBufferInputStream(ByteBuffer source) { + this.source = source; + } + + @Override + public int available() throws IOException { + return source.remaining(); + } + + @Override + public synchronized void mark(int readlimit) { + source.mark(); + } + + @Override + public synchronized void reset() throws IOException { + source.reset(); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public int read() throws IOException { + throw new UnsupportedOperationException("should not be called"); + } + + @Override + public long skip(long n) throws IOException { + source.position((int) (source.position() + n)); + return n; + } + + @Override + public boolean getByteBufferSupported() { + return true; + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + return source.slice(); + } + } }