diff --git a/api/src/main/java/io/grpc/HasByteBuffer.java b/api/src/main/java/io/grpc/HasByteBuffer.java
new file mode 100644
index 00000000000..35abca53e6c
--- /dev/null
+++ b/api/src/main/java/io/grpc/HasByteBuffer.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc;
+
+import java.nio.ByteBuffer;
+import javax.annotation.Nullable;
+
+/**
+ * Extension to an {@link java.io.InputStream} whose content can be accessed as {@link
+ * ByteBuffer}s.
+ *
+ *
This can be used for optimizing the case for the consumer of a {@link ByteBuffer}-backed
+ * input stream supports efficient reading from {@link ByteBuffer}s directly. This turns the reader
+ * interface from an {@link java.io.InputStream} to {@link ByteBuffer}s, without copying the
+ * content to a byte array and read from it.
+ */
+@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7387")
+public interface HasByteBuffer {
+
+ /**
+ * Indicates whether or not {@link #getByteBuffer} operation is supported.
+ */
+ boolean getByteBufferSupported();
+
+ /**
+ * Gets a {@link ByteBuffer} containing some bytes of the content next to be read, or {@code
+ * null} if has reached end of the content. The number of bytes contained in the returned buffer
+ * is implementation specific. Calling this method does not change the position of the input
+ * stream. The returned buffer's content should not be modified, but the position, limit, and
+ * mark may be changed. Operations for changing the position, limit, and mark of the returned
+ * buffer does not affect the position, limit, and mark of this input stream. This is an optional
+ * method, so callers should first check {@link #getByteBufferSupported}.
+ *
+ * @throws UnsupportedOperationException if this operation is not supported.
+ */
+ @Nullable
+ ByteBuffer getByteBuffer();
+}
diff --git a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java
index e43b7a7cc0e..99812c3ffff 100644
--- a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java
+++ b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java
@@ -16,6 +16,8 @@
package io.grpc.internal;
+import java.nio.ByteBuffer;
+
/**
* Abstract base class for {@link ReadableBuffer} implementations.
*/
@@ -45,6 +47,24 @@ public int arrayOffset() {
throw new UnsupportedOperationException();
}
+ @Override
+ public void mark() {}
+
+ @Override
+ public void reset() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return false;
+ }
+
+ @Override
+ public ByteBuffer getByteBuffer() {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void close() {}
diff --git a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java
index 93dda7cdbc8..c1c2c78c5af 100644
--- a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java
+++ b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java
@@ -20,8 +20,10 @@
import java.io.OutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
+import java.nio.InvalidMarkException;
import java.util.ArrayDeque;
-import java.util.Queue;
+import java.util.Deque;
+import javax.annotation.Nullable;
/**
* A {@link ReadableBuffer} that is composed of 0 or more {@link ReadableBuffer}s. This provides a
@@ -34,7 +36,9 @@
public class CompositeReadableBuffer extends AbstractReadableBuffer {
private int readableBytes;
- private final Queue buffers = new ArrayDeque<>();
+ private final Deque readableBuffers = new ArrayDeque<>();
+ private final Deque rewindableBuffers = new ArrayDeque<>();
+ private boolean marked;
/**
* Adds a new {@link ReadableBuffer} at the end of the buffer list. After a buffer is added, it is
@@ -43,16 +47,24 @@ public class CompositeReadableBuffer extends AbstractReadableBuffer {
* this {@code CompositeBuffer}.
*/
public void addBuffer(ReadableBuffer buffer) {
+ boolean markHead = marked && readableBuffers.isEmpty();
+ enqueueBuffer(buffer);
+ if (markHead) {
+ readableBuffers.peek().mark();
+ }
+ }
+
+ private void enqueueBuffer(ReadableBuffer buffer) {
if (!(buffer instanceof CompositeReadableBuffer)) {
- buffers.add(buffer);
+ readableBuffers.add(buffer);
readableBytes += buffer.readableBytes();
return;
}
CompositeReadableBuffer compositeBuffer = (CompositeReadableBuffer) buffer;
- while (!compositeBuffer.buffers.isEmpty()) {
- ReadableBuffer subBuffer = compositeBuffer.buffers.remove();
- buffers.add(subBuffer);
+ while (!compositeBuffer.readableBuffers.isEmpty()) {
+ ReadableBuffer subBuffer = compositeBuffer.readableBuffers.remove();
+ readableBuffers.add(subBuffer);
}
readableBytes += compositeBuffer.readableBytes;
compositeBuffer.readableBytes = 0;
@@ -136,27 +148,73 @@ public int readInternal(ReadableBuffer buffer, int length) throws IOException {
@Override
public CompositeReadableBuffer readBytes(int length) {
- checkReadable(length);
- readableBytes -= length;
-
- CompositeReadableBuffer newBuffer = new CompositeReadableBuffer();
- while (length > 0) {
- ReadableBuffer buffer = buffers.peek();
- if (buffer.readableBytes() > length) {
+ final CompositeReadableBuffer newBuffer = new CompositeReadableBuffer();
+ execute(new ReadOperation() {
+ @Override
+ int readInternal(ReadableBuffer buffer, int length) {
newBuffer.addBuffer(buffer.readBytes(length));
- length = 0;
- } else {
- newBuffer.addBuffer(buffers.poll());
- length -= buffer.readableBytes();
+ return 0;
}
- }
+ }, length);
return newBuffer;
}
+ @Override
+ public void mark() {
+ while (!rewindableBuffers.isEmpty()) {
+ rewindableBuffers.remove().close();
+ }
+ marked = true;
+ ReadableBuffer buffer = readableBuffers.peek();
+ if (buffer != null) {
+ buffer.mark();
+ }
+ }
+
+ @Override
+ public void reset() {
+ if (!marked) {
+ throw new InvalidMarkException();
+ }
+ ReadableBuffer buffer;
+ if ((buffer = readableBuffers.peek()) != null) {
+ int currentRemain = buffer.readableBytes();
+ buffer.reset();
+ readableBytes += (buffer.readableBytes() - currentRemain);
+ }
+ while ((buffer = rewindableBuffers.pollLast()) != null) {
+ buffer.reset();
+ readableBuffers.addFirst(buffer);
+ readableBytes += buffer.readableBytes();
+ }
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ for (ReadableBuffer buffer : readableBuffers) {
+ if (!buffer.getByteBufferSupported()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Nullable
+ @Override
+ public ByteBuffer getByteBuffer() {
+ if (readableBuffers.isEmpty()) {
+ return null;
+ }
+ return readableBuffers.peek().getByteBuffer();
+ }
+
@Override
public void close() {
- while (!buffers.isEmpty()) {
- buffers.remove().close();
+ while (!readableBuffers.isEmpty()) {
+ readableBuffers.remove().close();
+ }
+ while (!rewindableBuffers.isEmpty()) {
+ rewindableBuffers.remove().close();
}
}
@@ -167,12 +225,12 @@ public void close() {
private void execute(ReadOperation op, int length) {
checkReadable(length);
- if (!buffers.isEmpty()) {
+ if (!readableBuffers.isEmpty()) {
advanceBufferIfNecessary();
}
- for (; length > 0 && !buffers.isEmpty(); advanceBufferIfNecessary()) {
- ReadableBuffer buffer = buffers.peek();
+ for (; length > 0 && !readableBuffers.isEmpty(); advanceBufferIfNecessary()) {
+ ReadableBuffer buffer = readableBuffers.peek();
int lengthToCopy = Math.min(length, buffer.readableBytes());
// Perform the read operation for this buffer.
@@ -195,9 +253,17 @@ private void execute(ReadOperation op, int length) {
* If the current buffer is exhausted, removes and closes it.
*/
private void advanceBufferIfNecessary() {
- ReadableBuffer buffer = buffers.peek();
+ ReadableBuffer buffer = readableBuffers.peek();
if (buffer.readableBytes() == 0) {
- buffers.remove().close();
+ if (marked) {
+ rewindableBuffers.add(readableBuffers.remove());
+ ReadableBuffer next = readableBuffers.peek();
+ if (next != null) {
+ next.mark();
+ }
+ } else {
+ readableBuffers.remove().close();
+ }
}
}
diff --git a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java
index 954d0ac5486..0103f9fb39f 100644
--- a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java
+++ b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
+import javax.annotation.Nullable;
/**
* Base class for a wrapper around another {@link ReadableBuffer}.
@@ -96,6 +97,27 @@ public int arrayOffset() {
return buf.arrayOffset();
}
+ @Override
+ public void mark() {
+ buf.mark();
+ }
+
+ @Override
+ public void reset() {
+ buf.reset();
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return buf.getByteBufferSupported();
+ }
+
+ @Nullable
+ @Override
+ public ByteBuffer getByteBuffer() {
+ return buf.getByteBuffer();
+ }
+
@Override
public void close() {
buf.close();
diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffer.java b/core/src/main/java/io/grpc/internal/ReadableBuffer.java
index 7d2ca7ebba5..684039292aa 100644
--- a/core/src/main/java/io/grpc/internal/ReadableBuffer.java
+++ b/core/src/main/java/io/grpc/internal/ReadableBuffer.java
@@ -20,6 +20,7 @@
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
+import javax.annotation.Nullable;
/**
* Interface for an abstract byte buffer. Buffers are intended to be a read-only, except for the
@@ -123,6 +124,39 @@ public interface ReadableBuffer extends Closeable {
*/
int arrayOffset();
+ /**
+ * Marks the current position in this buffer. A subsequent call to the {@link #reset} method
+ * repositions this stream at the last marked position so that subsequent reads re-read the same
+ * bytes.
+ */
+ void mark();
+
+ /**
+ * Repositions this buffer to the position at the time {@link #mark} was last called on this
+ * buffer.
+ */
+ void reset();
+
+ /**
+ * Indicates whether or not {@link #getByteBuffer} operation is supported for this buffer.
+ */
+ boolean getByteBufferSupported();
+
+ /**
+ * Gets a {@link ByteBuffer} that contains some bytes of the content next to be read, or {@code
+ * null} if this buffer has been exhausted. The number of bytes contained in the returned buffer
+ * is implementation specific. The position of this buffer is unchanged after calling this
+ * method. The returned buffer's content should not be modified, but the position, limit, and
+ * mark may be changed. Operations for changing the position, limit, and mark of the returned
+ * buffer does not affect the position, limit, and mark of this buffer. Buffers returned by this
+ * method have independent position, limit and mark. This is an optional method, so callers
+ * should first check {@link #getByteBufferSupported}.
+ *
+ * @throws UnsupportedOperationException the buffer does not support this method.
+ */
+ @Nullable
+ ByteBuffer getByteBuffer();
+
/**
* Closes this buffer and releases any resources.
*/
diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java
index cfe5542a573..f764d893fda 100644
--- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java
+++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java
@@ -19,6 +19,7 @@
import static com.google.common.base.Charsets.UTF_8;
import com.google.common.base.Preconditions;
+import io.grpc.HasByteBuffer;
import io.grpc.KnownLength;
import java.io.IOException;
import java.io.InputStream;
@@ -26,6 +27,7 @@
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
+import javax.annotation.Nullable;
/**
* Utility methods for creating {@link ReadableBuffer} instances.
@@ -128,6 +130,7 @@ private static class ByteArrayWrapper extends AbstractReadableBuffer {
int offset;
final int end;
final byte[] bytes;
+ int mark = -1;
ByteArrayWrapper(byte[] bytes) {
this(bytes, 0, bytes.length);
@@ -204,6 +207,16 @@ public byte[] array() {
public int arrayOffset() {
return offset;
}
+
+ @Override
+ public void mark() {
+ mark = offset;
+ }
+
+ @Override
+ public void reset() {
+ offset = mark;
+ }
}
/**
@@ -291,12 +304,33 @@ public byte[] array() {
public int arrayOffset() {
return bytes.arrayOffset() + bytes.position();
}
+
+ @Override
+ public void mark() {
+ bytes.mark();
+ }
+
+ @Override
+ public void reset() {
+ bytes.reset();
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return true;
+ }
+
+ @Override
+ public ByteBuffer getByteBuffer() {
+ return bytes.slice();
+ }
}
/**
* An {@link InputStream} that is backed by a {@link ReadableBuffer}.
*/
- private static final class BufferInputStream extends InputStream implements KnownLength {
+ private static final class BufferInputStream extends InputStream
+ implements KnownLength, HasByteBuffer {
final ReadableBuffer buffer;
public BufferInputStream(ReadableBuffer buffer) {
@@ -329,6 +363,39 @@ public int read(byte[] dest, int destOffset, int length) throws IOException {
return length;
}
+ @Override
+ public long skip(long n) throws IOException {
+ int length = (int) Math.min(buffer.readableBytes(), n);
+ buffer.skipBytes(length);
+ return length;
+ }
+
+ @Override
+ public void mark(int readlimit) {
+ buffer.mark();
+ }
+
+ @Override
+ public void reset() throws IOException {
+ buffer.reset();
+ }
+
+ @Override
+ public boolean markSupported() {
+ return true;
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return buffer.getByteBufferSupported();
+ }
+
+ @Nullable
+ @Override
+ public ByteBuffer getByteBuffer() {
+ return buffer.getByteBuffer();
+ }
+
@Override
public void close() throws IOException {
buffer.close();
diff --git a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java
index 660aa116317..f26c8b1ef9a 100644
--- a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java
+++ b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java
@@ -17,14 +17,20 @@
package io.grpc.internal;
import static com.google.common.base.Charsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
+import java.nio.InvalidMarkException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -154,6 +160,128 @@ public void readStreamShouldSucceed() throws IOException {
assertEquals(EXPECTED_VALUE, new String(bos.toByteArray(), UTF_8));
}
+ @Test
+ public void resetUnmarkedShouldThrow() {
+ try {
+ composite.reset();
+ fail();
+ } catch (InvalidMarkException expected) {
+ }
+ }
+
+ @Test
+ public void markAndResetWithSkipBytesShouldSucceed() {
+ composite.mark();
+ composite.skipBytes(EXPECTED_VALUE.length() / 2);
+ composite.reset();
+ assertEquals(EXPECTED_VALUE.length(), composite.readableBytes());
+ }
+
+ @Test
+ public void markAndResetWithReadUnsignedByteShouldSucceed() {
+ composite.readUnsignedByte();
+ composite.mark();
+ int b = composite.readUnsignedByte();
+ composite.reset();
+ assertEquals(EXPECTED_VALUE.length() - 1, composite.readableBytes());
+ assertEquals(b, composite.readUnsignedByte());
+ }
+
+ @Test
+ public void markAndResetWithReadByteArrayShouldSucceed() {
+ composite.mark();
+ byte[] first = new byte[EXPECTED_VALUE.length()];
+ composite.readBytes(first, 0, EXPECTED_VALUE.length());
+ composite.reset();
+ assertEquals(EXPECTED_VALUE.length(), composite.readableBytes());
+ byte[] second = new byte[EXPECTED_VALUE.length()];
+ composite.readBytes(second, 0, EXPECTED_VALUE.length());
+ assertArrayEquals(first, second);
+ }
+
+ @Test
+ public void markAndResetWithReadByteBufferShouldSucceed() {
+ byte[] first = new byte[EXPECTED_VALUE.length()];
+ composite.mark();
+ composite.readBytes(ByteBuffer.wrap(first));
+ composite.reset();
+ byte[] second = new byte[EXPECTED_VALUE.length()];
+ assertEquals(EXPECTED_VALUE.length(), composite.readableBytes());
+ composite.readBytes(ByteBuffer.wrap(second));
+ assertArrayEquals(first, second);
+ }
+
+ @Test
+ public void markAndResetWithReadStreamShouldSucceed() throws IOException {
+ ByteArrayOutputStream first = new ByteArrayOutputStream();
+ composite.mark();
+ composite.readBytes(first, EXPECTED_VALUE.length() / 2);
+ composite.reset();
+ assertEquals(EXPECTED_VALUE.length(), composite.readableBytes());
+ ByteArrayOutputStream second = new ByteArrayOutputStream();
+ composite.readBytes(second, EXPECTED_VALUE.length() / 2);
+ assertArrayEquals(first.toByteArray(), second.toByteArray());
+ }
+
+ @Test
+ public void markAndResetWithReadReadableBufferShouldSucceed() {
+ composite.readBytes(EXPECTED_VALUE.length() / 2);
+ int remaining = composite.readableBytes();
+ composite.mark();
+ ReadableBuffer first = composite.readBytes(1);
+ composite.reset();
+ assertEquals(remaining, composite.readableBytes());
+ ReadableBuffer second = composite.readBytes(1);
+ assertEquals(first.readUnsignedByte(), second.readUnsignedByte());
+ }
+
+ @Test
+ public void markAgainShouldOverwritePreviousMark() {
+ composite.mark();
+ composite.skipBytes(EXPECTED_VALUE.length() / 2);
+ int remaining = composite.readableBytes();
+ composite.mark();
+ composite.skipBytes(1);
+ composite.reset();
+ assertEquals(remaining, composite.readableBytes());
+ }
+
+ @Test
+ public void bufferAddedAfterMarkedShouldBeIncluded() {
+ composite = new CompositeReadableBuffer();
+ composite.mark();
+ splitAndAdd(EXPECTED_VALUE);
+ composite.skipBytes(EXPECTED_VALUE.length() / 2);
+ composite.reset();
+ assertEquals(EXPECTED_VALUE.length(), composite.readableBytes());
+ }
+
+ @Test
+ public void canUseByteBufferOnlyAllComponentsSupportUsingByteBuffer() {
+ composite = new CompositeReadableBuffer();
+ ReadableBuffer buffer1 = mock(ReadableBuffer.class);
+ ReadableBuffer buffer2 = mock(ReadableBuffer.class);
+ ReadableBuffer buffer3 = mock(ReadableBuffer.class);
+ when(buffer1.getByteBufferSupported()).thenReturn(true);
+ when(buffer2.getByteBufferSupported()).thenReturn(true);
+ when(buffer3.getByteBufferSupported()).thenReturn(false);
+ composite.addBuffer(buffer1);
+ assertTrue(composite.getByteBufferSupported());
+ composite.addBuffer(buffer2);
+ assertTrue(composite.getByteBufferSupported());
+ composite.addBuffer(buffer3);
+ assertFalse(composite.getByteBufferSupported());
+ }
+
+ @Test
+ public void getByteBufferDelegatesToComponents() {
+ composite = new CompositeReadableBuffer();
+ ReadableBuffer buffer = mock(ReadableBuffer.class);
+ composite.addBuffer(buffer);
+ composite.getByteBuffer();
+ verify(buffer).getByteBuffer();
+ }
+
@Test
public void closeShouldCloseBuffers() {
composite = new CompositeReadableBuffer();
diff --git a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java
index e469b807d51..86b8e2a399b 100644
--- a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java
+++ b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java
@@ -24,6 +24,7 @@
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.Arrays;
+import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -117,6 +118,58 @@ public void partialReadToReadableBufferShouldSucceed() {
assertArrayEquals(new byte[] {'h', 'e'}, Arrays.copyOfRange(array, 0, 2));
}
+ @Test
+ public void markAndResetWithReadShouldSucceed() {
+ ReadableBuffer buffer = buffer();
+ int offset = 5;
+ buffer.readBytes(new byte[offset], 0, offset);
+ buffer.mark();
+ int b = buffer.readUnsignedByte();
+ assertEquals(msg.length() - offset - 1, buffer.readableBytes());
+ buffer.reset();
+ assertEquals(msg.length() - offset, buffer.readableBytes());
+ assertEquals(b, buffer.readUnsignedByte());
+ }
+
+ @Test
+ public void markAndResetWithReadToReadableBufferShouldSucceed() {
+ ReadableBuffer buffer = buffer();
+ int offset = 5;
+ buffer.readBytes(offset);
+ int testLen = 100;
+ buffer.mark();
+ ReadableBuffer first = buffer.readBytes(testLen);
+ assertEquals(msg.length() - offset - testLen, buffer.readableBytes());
+ buffer.reset();
+ assertEquals(msg.length() - offset, buffer.readableBytes());
+ ReadableBuffer second = buffer.readBytes(testLen);
+ byte[] array1 = new byte[testLen];
+ byte[] array2 = new byte[testLen];
+ first.readBytes(array1, 0, testLen);
+ second.readBytes(array2, 0, testLen);
+ assertArrayEquals(array1, array2);
+ }
+
+ @Test
+ public void getByteBufferDoesNotAffectBufferPosition() {
+ ReadableBuffer buffer = buffer();
+ Assume.assumeTrue(buffer.getByteBufferSupported());
+ ByteBuffer byteBuffer = buffer.getByteBuffer();
+ assertEquals(msg.length(), buffer.readableBytes());
+ byteBuffer.get(new byte[byteBuffer.remaining()]);
+ assertEquals(msg.length(), buffer.readableBytes());
+ }
+
+ @Test
+ public void getByteBufferIsNotAffectedByBufferRead() {
+ ReadableBuffer buffer = buffer();
+ Assume.assumeTrue(buffer.getByteBufferSupported());
+ ByteBuffer byteBuffer = buffer.getByteBuffer();
+ int initialRemaining = byteBuffer.remaining();
+ buffer.readBytes(new byte[100], 0, 100);
+ assertEquals(initialRemaining, byteBuffer.remaining());
+ }
+
protected abstract ReadableBuffer buffer();
private static String repeatUntilLength(String toRepeat, int length) {
diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java
index ea9daeed6a3..3007c19682a 100644
--- a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java
+++ b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java
@@ -19,11 +19,15 @@
import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import io.grpc.HasByteBuffer;
+import java.io.IOException;
import java.io.InputStream;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -128,4 +132,30 @@ public void bufferInputStream_close_closesBuffer() throws Exception {
inputStream.close();
verify(buffer, times(1)).close();
}
+
+ @Test
+ public void bufferInputStream_markAndReset() throws IOException {
+ ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES);
+ InputStream inputStream = ReadableBuffers.openStream(buffer, true);
+ assertTrue(inputStream.markSupported());
+ inputStream.mark(2);
+ byte[] first = new byte[5];
+ inputStream.read(first);
+ assertEquals(0, inputStream.available());
+ inputStream.reset();
+ assertEquals(5, inputStream.available());
+ byte[] second = new byte[5];
+ inputStream.read(second);
+ assertArrayEquals(first, second);
+ }
+
+ @Test
+ public void bufferInputStream_getByteBufferDelegatesToBuffer() {
+ ReadableBuffer buffer = mock(ReadableBuffer.class);
+ when(buffer.getByteBufferSupported()).thenReturn(true);
+ InputStream inputStream = ReadableBuffers.openStream(buffer, true);
+ assertTrue(((HasByteBuffer) inputStream).getByteBufferSupported());
+ ((HasByteBuffer) inputStream).getByteBuffer();
+ verify(buffer).getByteBuffer();
+ }
}
diff --git a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java
index 37caccb0eb3..75d70978610 100644
--- a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java
+++ b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java
@@ -94,6 +94,26 @@ public int arrayOffset() {
return buffer.arrayOffset() + buffer.readerIndex();
}
+ @Override
+ public void mark() {
+ buffer.markReaderIndex();
+ }
+
+ @Override
+ public void reset() {
+ buffer.resetReaderIndex();
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return buffer.nioBufferCount() > 0;
+ }
+
+ @Override
+ public ByteBuffer getByteBuffer() {
+ return buffer.nioBufferCount() == 1 ? buffer.nioBuffer() : buffer.nioBuffers()[0];
+ }
+
/**
* If the first call to close, calls {@link ByteBuf#release} to release the internal Netty buffer.
*/
diff --git a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java
index 8090e601911..23c2f0632a9 100644
--- a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java
@@ -17,11 +17,16 @@
package io.grpc.netty;
import static com.google.common.base.Charsets.UTF_8;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import com.google.common.base.Splitter;
import io.grpc.internal.ReadableBuffer;
import io.grpc.internal.ReadableBufferTestBase;
+import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
+import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -52,6 +57,29 @@ public void closeMultipleTimesShouldReleaseBufferOnce() {
assertEquals(0, buffer.buffer().refCnt());
}
+ @Test
+ public void getByteBufferFromSingleNioBufferBackedBuffer() {
+ assertTrue(buffer.getByteBufferSupported());
+ ByteBuffer byteBuffer = buffer.getByteBuffer();
+ byte[] arr = new byte[byteBuffer.remaining()];
+ byteBuffer.get(arr);
+ assertArrayEquals(msg.getBytes(UTF_8), arr);
+ }
+
+ @Test
+ public void getByteBufferFromCompositeBufferReturnsOnlyFirstComponent() {
+ CompositeByteBuf composite = Unpooled.compositeBuffer(10);
+ int chunks = 4;
+ int chunkLen = msg.length() / chunks;
+ for (String chunk : Splitter.fixedLength(chunkLen).split(msg)) {
+ composite.addComponent(true, Unpooled.copiedBuffer(chunk.getBytes(UTF_8)));
+ }
+ buffer = new NettyReadableBuffer(composite);
+ byte[] array = new byte[chunkLen];
+ buffer.getByteBuffer().get(array);
+ assertArrayEquals(msg.substring(0, chunkLen).getBytes(UTF_8), array);
+ }
+
@Override
protected ReadableBuffer buffer() {
return buffer;
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java
index 2ece98ffb97..4aeeae2fa8b 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java
@@ -56,6 +56,18 @@ public void partialReadToByteBufferShouldSucceed() {
// Not supported.
}
+ @Override
+ @Test
+ public void markAndResetWithReadShouldSucceed() {
+ // Not supported.
+ }
+
+ @Override
+ @Test
+ public void markAndResetWithReadToReadableBufferShouldSucceed() {
+ // Not supported.
+ }
+
@Override
protected ReadableBuffer buffer() {
return buffer;
diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java
index ddba5b8d5b1..47a790f44fc 100644
--- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java
+++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java
@@ -25,6 +25,7 @@
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import io.grpc.ExperimentalApi;
+import io.grpc.HasByteBuffer;
import io.grpc.KnownLength;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor.Marshaller;
@@ -35,6 +36,9 @@
import java.io.OutputStream;
import java.lang.ref.Reference;
import java.lang.ref.WeakReference;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
/**
* Utility methods for using protobuf with grpc.
@@ -48,12 +52,34 @@ public final class ProtoLiteUtils {
private static final int BUF_SIZE = 8192;
+ /**
+ * Assume Java 9+ if it isn't Java 7 or Java 8.
+ */
+ @VisibleForTesting
+ static final boolean IS_JAVA9_OR_HIGHER;
+
/**
* The same value as {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
*/
@VisibleForTesting
static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
+ /**
+ * Threshold for passing {@link ByteBuffer}s directly into Protobuf.
+ */
+ @VisibleForTesting
+ static final int MESSAGE_ZERO_COPY_THRESHOLD = 64 * 1024;
+
+ static {
+ boolean isJava9OrHigher = true;
+ try {
+ Class.forName("java.lang.StackWalker");
+ } catch (ClassNotFoundException e) {
+ isJava9OrHigher = false;
+ }
+ IS_JAVA9_OR_HIGHER = isJava9OrHigher;
+ }
+
/**
* Sets the global registry for proto marshalling shared across all servers and clients.
*
@@ -173,7 +199,23 @@ public T parse(InputStream stream) {
try {
if (stream instanceof KnownLength) {
int size = stream.available();
- if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) {
+ if (size == 0) {
+ return defaultInstance;
+ }
+ if (IS_JAVA9_OR_HIGHER
+ && size >= MESSAGE_ZERO_COPY_THRESHOLD
+ && stream instanceof HasByteBuffer
+ && ((HasByteBuffer) stream).getByteBufferSupported()
+ && stream.markSupported()) {
+ List buffers = new ArrayList<>();
+ stream.mark(size);
+ while (stream.available() != 0) {
+ ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer();
+ stream.skip(buffer.remaining());
+ buffers.add(buffer);
+ }
+ cis = CodedInputStream.newInstance(buffers);
+ } else if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) {
Reference ref;
// buf should not be used after this method has returned.
byte[] buf;
@@ -197,8 +239,6 @@ public T parse(InputStream stream) {
throw new RuntimeException("size inaccurate: " + size + " != " + position);
}
cis = CodedInputStream.newInstance(buf, 0, size);
- } else if (size == 0) {
- return defaultInstance;
}
}
} catch (IOException e) {
diff --git a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java
index d05e884105e..1d5c1abc806 100644
--- a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java
+++ b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java
@@ -22,6 +22,7 @@
import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
+import com.google.common.base.Strings;
import com.google.common.io.ByteStreams;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
@@ -29,6 +30,7 @@
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Type;
import io.grpc.Drainable;
+import io.grpc.HasByteBuffer;
import io.grpc.KnownLength;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor.Marshaller;
@@ -40,7 +42,10 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.util.Arrays;
+import javax.annotation.Nullable;
+import org.junit.Assume;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -226,6 +231,20 @@ public void parseFromKnowLengthInputStream() throws Exception {
assertEquals(expect, result);
}
+ @Test
+ public void parseFromKnownLengthByteBufferInputStream() {
+ Assume.assumeTrue(ProtoLiteUtils.IS_JAVA9_OR_HIGHER);
+ Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
+ Type expect =
+ Type.newBuilder()
+ .setName(Strings.repeat("M", ProtoLiteUtils.MESSAGE_ZERO_COPY_THRESHOLD))
+ .build();
+
+ Type result = marshaller.parse(
+ new CustomKnownLengthByteBufferInputStream(expect.toByteString().asReadOnlyByteBuffer()));
+ assertEquals(expect, result);
+ }
+
@Test
public void defaultMaxMessageSize() {
assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE);
@@ -253,4 +272,55 @@ public int read() throws IOException {
return source[position++];
}
}
+
+ private static final class CustomKnownLengthByteBufferInputStream extends InputStream
+ implements KnownLength, HasByteBuffer {
+ private ByteBuffer source;
+
+ private CustomKnownLengthByteBufferInputStream(ByteBuffer source) {
+ this.source = source;
+ }
+
+ @Override
+ public int available() throws IOException {
+ return source.remaining();
+ }
+
+ @Override
+ public synchronized void mark(int readlimit) {
+ source.mark();
+ }
+
+ @Override
+ public synchronized void reset() throws IOException {
+ source.reset();
+ }
+
+ @Override
+ public boolean markSupported() {
+ return true;
+ }
+
+ @Override
+ public int read() throws IOException {
+ throw new UnsupportedOperationException("should not be called");
+ }
+
+ @Override
+ public long skip(long n) throws IOException {
+ source.position((int) (source.position() + n));
+ return n;
+ }
+
+ @Override
+ public boolean getByteBufferSupported() {
+ return true;
+ }
+
+ @Nullable
+ @Override
+ public ByteBuffer getByteBuffer() {
+ return source.slice();
+ }
+ }
}