diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java b/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java new file mode 100644 index 000000000000..6c01231bf3b0 --- /dev/null +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java @@ -0,0 +1,78 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.client.block.stream; + +import alluxio.grpc.DataMessageMarshaller; +import alluxio.grpc.DataMessageMarshallerProvider; + +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; +import io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A {@link StreamObserver} that handles raw data buffers. + * + * @param type of the response message + * @param type of the request message + */ +@NotThreadSafe +public class DataMessageClientResponseObserver + implements ClientResponseObserver, DataMessageMarshallerProvider { + private static final Logger LOG = + LoggerFactory.getLogger(DataMessageClientResponseObserver.class); + + private final StreamObserver mObserver; + private final DataMessageMarshaller mMarshaller; + + /** + * @param observer the original response observer + * @param marshaller the marshaller for the response + */ + public DataMessageClientResponseObserver(StreamObserver observer, + DataMessageMarshaller marshaller) { + mObserver = observer; + mMarshaller = marshaller; + } + + @Override + public void onNext(RespT value) { + mObserver.onNext(value); + } + + @Override + public void onError(Throwable t) { + mObserver.onError(t); + } + + @Override + public void onCompleted() { + mObserver.onCompleted(); + } + + @Override + public void beforeStart(ClientCallStreamObserver requestStream) { + if (mObserver instanceof ClientResponseObserver) { + ((ClientResponseObserver) mObserver).beforeStart(requestStream); + } else { + LOG.warn("{} does not implement ClientResponseObserver:beforeStart", mObserver); + } + } + + @Override + public DataMessageMarshaller getMarshaller() { + return mMarshaller; + } +} diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java b/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java index 40faf1e670c1..b0404c04aefd 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java @@ -20,10 +20,12 @@ import alluxio.grpc.BlockWorkerGrpc; import alluxio.grpc.CreateLocalBlockRequest; import alluxio.grpc.CreateLocalBlockResponse; +import alluxio.grpc.DataMessageMarshallerProvider; import alluxio.grpc.GrpcChannel; import alluxio.grpc.GrpcChannelBuilder; import alluxio.grpc.GrpcExceptionUtils; import alluxio.grpc.GrpcManagedChannelPool; +import alluxio.grpc.DataMessageMarshaller; import alluxio.grpc.OpenLocalBlockRequest; import alluxio.grpc.OpenLocalBlockResponse; import alluxio.grpc.ReadRequest; @@ -32,6 +34,7 @@ import alluxio.grpc.RemoveBlockResponse; import alluxio.grpc.WriteRequest; import alluxio.grpc.WriteResponse; +import alluxio.grpc.GrpcSerializationUtils; import alluxio.util.network.NettyUtils; import com.google.common.io.Closer; @@ -79,6 +82,7 @@ public DefaultBlockWorkerClient(Subject subject, SocketAddress address, // Channel is still reused due to client pooling. mStreamingChannel = buildChannel(subject, address, GrpcManagedChannelPool.PoolingStrategy.DISABLED, alluxioConf, workerGroup); + mStreamingChannel.intercept(new StreamSerializationClientInterceptor()); // Uses default pooling strategy for RPC calls for better scalability. mRpcChannel = buildChannel(subject, address, GrpcManagedChannelPool.PoolingStrategy.DEFAULT, alluxioConf, workerGroup); @@ -117,7 +121,18 @@ public StreamObserver writeBlock(StreamObserver res @Override public StreamObserver readBlock(StreamObserver responseObserver) { - return mStreamingAsyncStub.readBlock(responseObserver); + if (responseObserver instanceof DataMessageMarshallerProvider) { + DataMessageMarshaller marshaller = + ((DataMessageMarshallerProvider) responseObserver).getMarshaller(); + return mStreamingAsyncStub + .withOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR, + BlockWorkerGrpc.getReadBlockMethod().toBuilder() + .setResponseMarshaller(marshaller) + .build()) + .readBlock(responseObserver); + } else { + return mStreamingAsyncStub.readBlock(responseObserver); + } } @Override diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java new file mode 100644 index 000000000000..586b64b8f5ed --- /dev/null +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java @@ -0,0 +1,91 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.client.block.stream; + +import alluxio.exception.status.DeadlineExceededException; +import alluxio.grpc.DataMessage; +import alluxio.grpc.DataMessageMarshaller; +import alluxio.network.protocol.databuffer.DataBuffer; + +import io.grpc.stub.StreamObserver; + +import java.io.IOException; +import java.util.function.Function; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A helper class for accessing gRPC bi-directional stream synchronously. + * + * @param type of the request + * @param type of the response + */ +@NotThreadSafe +public class GrpcDataMessageBlockingStream extends GrpcBlockingStream { + private final DataMessageMarshaller mMarshaller; + + /** + * @param rpcFunc the gRPC bi-directional stream stub function + * @param bufferSize maximum number of incoming messages the buffer can hold + * @param description description of this stream + * @param deserializer custom deserializer for the response + */ + public GrpcDataMessageBlockingStream(Function, StreamObserver> rpcFunc, + int bufferSize, String description, DataMessageMarshaller deserializer) { + super((resObserver) -> { + DataMessageClientResponseObserver newObserver = + new DataMessageClientResponseObserver<>(resObserver, deserializer); + StreamObserver requestObserver = rpcFunc.apply(newObserver); + return requestObserver; + }, bufferSize, description); + mMarshaller = deserializer; + } + + @Override + public ResT receive(long timeoutMs) throws IOException { + DataMessage message = receiveDataMessage(timeoutMs); + if (message == null) { + return null; + } + return mMarshaller.combineData(message); + } + + /** + * Receives a response with data buffer from the server. Will wait until a response is received, + * or throw an exception if times out. Caller of this method must release the buffer after reading + * the data. + * + * @param timeoutMs maximum time to wait before giving up and throwing + * a {@link DeadlineExceededException} + * @return the response message with data buffer, or null if the inbound stream is completed + * @throws IOException if any error occurs + */ + public DataMessage receiveDataMessage(long timeoutMs) throws IOException { + ResT response = super.receive(timeoutMs); + if (response == null) { + return null; + } + DataBuffer buffer = mMarshaller.pollBuffer(response); + return new DataMessage<>(response, buffer); + } + + @Override + public void waitForComplete(long timeoutMs) throws IOException { + DataMessage message; + while (!isCanceled() && (message = receiveDataMessage(timeoutMs)) != null) { + if (message.getBuffer() != null) { + message.getBuffer().release(); + } + } + super.waitForComplete(timeoutMs); + } +} diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java index bdf7993dd845..d6cdf7b2c65f 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java @@ -14,18 +14,20 @@ import alluxio.client.file.FileSystemContext; import alluxio.conf.AlluxioConfiguration; import alluxio.conf.PropertyKey; +import alluxio.grpc.DataMessage; import alluxio.grpc.ReadRequest; import alluxio.grpc.ReadResponse; +import alluxio.grpc.ReadResponseMarshaller; import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.network.protocol.databuffer.NioDataBuffer; import alluxio.wire.WorkerNetAddress; import com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.nio.ByteBuffer; import javax.annotation.concurrent.NotThreadSafe; @@ -53,6 +55,7 @@ public final class GrpcDataReader implements DataReader { private final WorkerNetAddress mAddress; private final GrpcBlockingStream mStream; + private final ReadResponseMarshaller mMarshaller; /** The next pos to read. */ private long mPosToRead; @@ -76,9 +79,15 @@ private GrpcDataReader(FileSystemContext context, WorkerNetAddress address, mDataTimeoutMs = alluxioConf.getMs(PropertyKey.USER_NETWORK_DATA_TIMEOUT_MS); mClient = mContext.acquireBlockWorkerClient(address); + mMarshaller = new ReadResponseMarshaller(); try { - mStream = new GrpcBlockingStream<>(mClient::readBlock, mReaderBufferSizeMessages, - address.toString()); + if (alluxioConf.getBoolean(PropertyKey.USER_NETWORK_ZEROCOPY_ENABLED)) { + mStream = new GrpcDataMessageBlockingStream<>(mClient::readBlock, mReaderBufferSizeMessages, + address.toString(), mMarshaller); + } else { + mStream = new GrpcBlockingStream<>(mClient::readBlock, mReaderBufferSizeMessages, + address.toString()); + } mStream.send(mReadRequest, mDataTimeoutMs); } catch (Exception e) { mContext.releaseBlockWorkerClient(address, mClient); @@ -95,16 +104,32 @@ public long pos() { public DataBuffer readChunk() throws IOException { Preconditions.checkState(!mClient.isShutdown(), "Data reader is closed while reading data chunks."); - ByteString buf; - ReadResponse response = mStream.receive(mDataTimeoutMs); + DataBuffer buffer = null; + ReadResponse response = null; + if (mStream instanceof GrpcDataMessageBlockingStream) { + DataMessage message = + ((GrpcDataMessageBlockingStream) mStream) + .receiveDataMessage(mDataTimeoutMs); + if (message != null) { + response = message.getMessage(); + buffer = message.getBuffer(); + Preconditions.checkState(buffer != null, "response should always contain chunk"); + } + } else { + response = mStream.receive(mDataTimeoutMs); + if (response != null) { + Preconditions.checkState(response.hasChunk() && response.getChunk().hasData(), + "response should always contain chunk"); + ByteBuffer byteBuffer = response.getChunk().getData().asReadOnlyByteBuffer(); + buffer = new NioDataBuffer(byteBuffer, byteBuffer.remaining()); + } + } if (response == null) { return null; } - Preconditions.checkState(response.hasChunk(), "response should always contain chunk"); - buf = response.getChunk().getData(); - mPosToRead += buf.size(); + mPosToRead += buffer.readableBytes(); Preconditions.checkState(mPosToRead - mReadRequest.getOffset() <= mReadRequest.getLength()); - return new NioDataBuffer(buf.asReadOnlyByteBuffer(), buf.size()); + return buffer; } @Override @@ -116,6 +141,7 @@ public void close() throws IOException { mStream.close(); mStream.waitForComplete(mDataTimeoutMs); } finally { + mMarshaller.close(); mContext.releaseBlockWorkerClient(mAddress, mClient); } } diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/StreamSerializationClientInterceptor.java b/core/client/fs/src/main/java/alluxio/client/block/stream/StreamSerializationClientInterceptor.java new file mode 100644 index 000000000000..9ec191140686 --- /dev/null +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/StreamSerializationClientInterceptor.java @@ -0,0 +1,36 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.client.block.stream; + +import alluxio.grpc.GrpcSerializationUtils; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; + +/** + * Serialization interceptor for client. + */ +public class StreamSerializationClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel channel) { + MethodDescriptor overriddenMethodDescriptor = + callOptions.getOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR); + if (overriddenMethodDescriptor != null) { + method = overriddenMethodDescriptor; + } + return channel.newCall(method, callOptions); + } +} diff --git a/core/common/src/main/java/alluxio/conf/PropertyKey.java b/core/common/src/main/java/alluxio/conf/PropertyKey.java index b5d2a2c77e25..2239b20489d2 100644 --- a/core/common/src/main/java/alluxio/conf/PropertyKey.java +++ b/core/common/src/main/java/alluxio/conf/PropertyKey.java @@ -2020,6 +2020,13 @@ public String toString() { .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) .setScope(Scope.WORKER) .build(); + public static final PropertyKey WORKER_NETWORK_MAX_INBOUND_MESSAGE_SIZE = + new Builder(Name.WORKER_NETWORK_MAX_INBOUND_MESSAGE_SIZE) + .setDefaultValue("4MB") + .setDescription("The max inbound message size used by worker gRPC connections.") + .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) + .setScope(Scope.WORKER) + .build(); public static final PropertyKey WORKER_NETWORK_NETTY_BOSS_THREADS = new Builder(Name.WORKER_NETWORK_NETTY_BOSS_THREADS) .setDefaultValue(1) @@ -2084,6 +2091,13 @@ public String toString() { .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) .setScope(Scope.WORKER) .build(); + public static final PropertyKey WORKER_NETWORK_ZEROCOPY_ENABLED = + new Builder(Name.WORKER_NETWORK_ZEROCOPY_ENABLED) + .setDefaultValue(true) + .setDescription("Whether zero copy is enabled on worker when processing data streams.") + .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) + .setScope(Scope.WORKER) + .build(); // The default is set to 11. One client is reserved for some light weight operations such as // heartbeat. The other 10 clients are used by commitBlock issued from the worker to the block // master. @@ -2922,6 +2936,13 @@ public String toString() { .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) .setScope(Scope.CLIENT) .build(); + public static final PropertyKey USER_NETWORK_ZEROCOPY_ENABLED = + new Builder(Name.USER_NETWORK_ZEROCOPY_ENABLED) + .setDefaultValue(true) + .setDescription("Whether zero copy is enabled on client when processing data streams.") + .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) + .setScope(Scope.CLIENT) + .build(); public static final PropertyKey USER_RPC_RETRY_BASE_SLEEP_MS = new Builder(Name.USER_RPC_RETRY_BASE_SLEEP_MS) .setAlias(new String[]{"alluxio.user.rpc.retry.base.sleep.ms"}) @@ -3810,6 +3831,8 @@ public static final class Name { "alluxio.worker.network.keepalive.time"; public static final String WORKER_NETWORK_KEEPALIVE_TIMEOUT_MS = "alluxio.worker.network.keepalive.timeout"; + public static final String WORKER_NETWORK_MAX_INBOUND_MESSAGE_SIZE = + "alluxio.worker.network.max.inbound.message.size"; public static final String WORKER_NETWORK_NETTY_BOSS_THREADS = "alluxio.worker.network.netty.boss.threads"; public static final String WORKER_NETWORK_NETTY_CHANNEL = @@ -3826,6 +3849,8 @@ public static final class Name { "alluxio.worker.network.reader.max.chunk.size.bytes"; public static final String WORKER_NETWORK_SHUTDOWN_TIMEOUT = "alluxio.worker.network.shutdown.timeout"; + public static final String WORKER_NETWORK_ZEROCOPY_ENABLED = + "alluxio.worker.network.zerocopy.enabled"; public static final String WORKER_BLOCK_MASTER_CLIENT_POOL_SIZE = "alluxio.worker.block.master.client.pool.size"; public static final String WORKER_PRINCIPAL = "alluxio.worker.principal"; @@ -3982,6 +4007,8 @@ public static final class Name { "alluxio.user.network.writer.close.timeout"; public static final String USER_NETWORK_WRITER_FLUSH_TIMEOUT = "alluxio.user.network.writer.flush.timeout"; + public static final String USER_NETWORK_ZEROCOPY_ENABLED = + "alluxio.user.network.zerocopy.enabled"; public static final String USER_RPC_RETRY_BASE_SLEEP_MS = "alluxio.user.rpc.retry.base.sleep"; public static final String USER_RPC_RETRY_MAX_DURATION = diff --git a/core/common/src/main/java/alluxio/grpc/BufferRepository.java b/core/common/src/main/java/alluxio/grpc/BufferRepository.java new file mode 100644 index 000000000000..342a680af7a0 --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/BufferRepository.java @@ -0,0 +1,43 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import java.io.Closeable; + +/** + * A repository of buffers. + * + * @param type of the message + * @param type of the buffer + */ +public interface BufferRepository extends Closeable { + /** + * Stores a buffer in the repository. + * + * @param buffer the buffer to store + * @param message the associated message + */ + void offerBuffer(TBuf buffer, TMesg message); + + /** + * Retrieves and removes a buffer from the store. + * + * @param message the message that associated with the buffer + * @return the buffer, or null if the buffer is not found + */ + TBuf pollBuffer(TMesg message); + + /** + * Closes the repository and all its buffers. + */ + void close(); +} diff --git a/core/common/src/main/java/alluxio/grpc/DataMessage.java b/core/common/src/main/java/alluxio/grpc/DataMessage.java new file mode 100644 index 000000000000..d3c42735e779 --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/DataMessage.java @@ -0,0 +1,46 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +/** + * A struct that carries a message with a data buffer. + * + * @param the type of the message + * @param the type of the data buffer + */ +public class DataMessage { + private final T mMessage; + private final R mBuffer; + + /** + * @param message the message + * @param buffer the data buffer + */ + public DataMessage(T message, R buffer) { + mMessage = message; + mBuffer = buffer; + } + + /** + * @return the message + */ + public T getMessage() { + return mMessage; + } + + /** + * @return the data buffer + */ + public R getBuffer() { + return mBuffer; + } +} diff --git a/core/common/src/main/java/alluxio/grpc/DataMessageMarshaller.java b/core/common/src/main/java/alluxio/grpc/DataMessageMarshaller.java new file mode 100644 index 000000000000..e2651225e43c --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/DataMessageMarshaller.java @@ -0,0 +1,156 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 (the + * "License"). You may not use this work except in compliance with the License, which is available + * at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import static alluxio.grpc.GrpcSerializationUtils.addBuffersToStream; + +import alluxio.network.protocol.databuffer.DataBuffer; + +import io.grpc.Drainable; +import io.grpc.MethodDescriptor; +import io.grpc.internal.CompositeReadableBuffer; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import io.netty.buffer.ByteBuf; +import org.jboss.netty.util.internal.ConcurrentIdentityHashMap; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Map; + +/** + * Marshaller for data messages. + * + * @param type of the message + */ +public abstract class DataMessageMarshaller implements MethodDescriptor.Marshaller, + BufferRepository { + private final MethodDescriptor.Marshaller mOriginalMarshaller; + private final Map mBufferMap = new ConcurrentIdentityHashMap<>(); + + /** + * Creates a data marshaller. + * + * @param originalMarshaller the original marshaller for the message + */ + public DataMessageMarshaller(MethodDescriptor.Marshaller originalMarshaller) { + mOriginalMarshaller = originalMarshaller; + } + + @Override + public InputStream stream(T message) { + return new DataBufferInputStream(message); + } + + @Override + public T parse(InputStream message) { + ReadableBuffer rawBuffer = GrpcSerializationUtils.getBufferFromStream(message); + try { + if (rawBuffer != null) { + CompositeReadableBuffer readableBuffer = new CompositeReadableBuffer(); + readableBuffer.addBuffer(rawBuffer); + return deserialize(readableBuffer); + } else { + // falls back to buffer copy + byte[] byteBuffer = new byte[message.available()]; + message.read(byteBuffer); + return deserialize(ReadableBuffers.wrap(byteBuffer)); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + for (DataBuffer buffer : mBufferMap.values()) { + buffer.release(); + } + } + + @Override + public void offerBuffer(DataBuffer buffer, T message) { + mBufferMap.put(message, buffer); + } + + @Override + public DataBuffer pollBuffer(T message) { + return mBufferMap.remove(message); + } + + /** + * Combines the data buffer into the message. + * + * @param message the message to be combined + * @return the message with the combined buffer + */ + public abstract T combineData(DataMessage message); + + /** + * Serialize the message to buffers. + * @param message the message to be serialized + * @return an array of {@link ByteBuf}s containing the serialized message + * @throws IOException if the marshaller fails to serialize the message + */ + protected abstract ByteBuf[] serialize(T message) throws IOException; + + /** + * Deserialize data buffer to the message. + * + * @param buffer the buffer that contains the message data + * @return the deserialized message + * @throws IOException if the marshaller fails to deserialize the data + */ + protected abstract T deserialize(ReadableBuffer buffer) throws IOException; + + /** + * A {@link InputStream} for writing a message into a gRPC output stream. It will attempt to + * insert raw data buffer directly to the target stream if possible, or fallback to buffer copy if + * the insertion fails. + */ + private class DataBufferInputStream extends InputStream implements Drainable { + private final InputStream mStream; + private final T mMessage; + + DataBufferInputStream(T message) { + mMessage = message; + mStream = mOriginalMarshaller.stream(message); + } + + @Override + public int read() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int drainTo(OutputStream target) throws IOException { + int bytesWritten = 0; + ByteBuf[] buffers = serialize(mMessage); + for (ByteBuf buffer : buffers) { + bytesWritten += buffer.readableBytes(); + } + if (!addBuffersToStream(buffers, target)) { + // falls back to buffer copy + for (ByteBuf buffer : buffers) { + buffer.readBytes(target, buffer.readableBytes()); + } + } + return bytesWritten; + } + + @Override + public void close() throws IOException { + mStream.close(); + } + } +} diff --git a/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java b/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java new file mode 100644 index 000000000000..6397eac9c68d --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java @@ -0,0 +1,21 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 (the + * "License"). You may not use this work except in compliance with the License, which is available + * at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +/** + * A provider of {@link DataMessageMarshaller}. + * + * @param type of the message + */ +public interface DataMessageMarshallerProvider { + DataMessageMarshaller getMarshaller(); +} diff --git a/core/common/src/main/java/alluxio/grpc/GrpcChannel.java b/core/common/src/main/java/alluxio/grpc/GrpcChannel.java index 93e32371510b..8770353972bb 100644 --- a/core/common/src/main/java/alluxio/grpc/GrpcChannel.java +++ b/core/common/src/main/java/alluxio/grpc/GrpcChannel.java @@ -28,7 +28,7 @@ */ public final class GrpcChannel extends Channel { private final GrpcManagedChannelPool.ChannelKey mChannelKey; - private final Channel mChannel; + private Channel mChannel; private boolean mChannelReleased; private boolean mChannelHealthy = true; @@ -54,6 +54,9 @@ public String authority() { return mChannel.authority(); } + public void intercept(ClientInterceptor interceptor) { + mChannel = ClientInterceptors.intercept(mChannel, interceptor); + } /** * Shuts down the channel. */ diff --git a/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java new file mode 100644 index 000000000000..afbc66625d1d --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java @@ -0,0 +1,195 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 (the + * "License"). You may not use this work except in compliance with the License, which is available + * at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import io.grpc.CallOptions; +import io.grpc.MethodDescriptor; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServiceDescriptor; +import io.grpc.internal.ReadableBuffer; +import io.netty.buffer.ByteBuf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Utilities for gRPC message serialization. + */ +public class GrpcSerializationUtils { + public static final CallOptions.Key OVERRIDDEN_METHOD_DESCRIPTOR = + CallOptions.Key.create("overridden method descriptor"); + + private static final Logger LOG = LoggerFactory.getLogger(GrpcSerializationUtils.class); + + private static final int TAG_TYPE_BITS = 3; + + private static final String BUFFER_INPUT_STREAM_CLASS_NAME = + "io.grpc.internal.ReadableBuffers$BufferInputStream"; + private static final String BUFFER_FIELD_NAME = + "buffer"; + private static final String NETTY_WRITABLE_BUFFER_CLASS_NAME = + "io.grpc.netty.NettyWritableBuffer"; + private static final String BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME = + "io.grpc.internal.MessageFramer$BufferChainOutputStream"; + private static final String BUFFER_LIST_FIELD_NAME = + "bufferList"; + private static final String CURRENT_FIELD_NAME = + "current"; + + private static Constructor sNettyWritableBufferConstructor; + private static Field sBufferList; + private static Field sCurrent; + private static Field sReadableBufferField = null; + private static boolean sZeroCopySendSupported = true; + private static boolean sZeroCopyReceiveSupported = true; + + static { + try { + sReadableBufferField = getPrivateField(BUFFER_INPUT_STREAM_CLASS_NAME, BUFFER_FIELD_NAME); + } catch (Exception e) { + LOG.warn("Cannot get gRPC input stream buffer, zero copy send will be disabled.", e); + sZeroCopySendSupported = false; + } + try { + sNettyWritableBufferConstructor = + getPrivateConstructor(NETTY_WRITABLE_BUFFER_CLASS_NAME, ByteBuf.class); + sBufferList = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, BUFFER_LIST_FIELD_NAME); + sCurrent = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, CURRENT_FIELD_NAME); + } catch (Exception e) { + LOG.warn("Cannot get gRPC output stream buffer, zero copy receive will be disabled.", e); + sZeroCopyReceiveSupported = false; + } + } + + private static Field getPrivateField(String className, String fieldName) + throws NoSuchFieldException, ClassNotFoundException { + Class declaringClass = Class.forName(className); + Field field = declaringClass.getDeclaredField(fieldName); + field.setAccessible(true); + return field; + } + + private static Constructor getPrivateConstructor(String className, Class ...parameterTypes) + throws ClassNotFoundException, NoSuchMethodException { + Class declaringClass = Class.forName(className); + Constructor constructor = declaringClass.getDeclaredConstructor(parameterTypes); + constructor.setAccessible(true); + return constructor; + } + + /** + * Makes a gRPC tag for a field. + * + * @param fieldNumber field number + * @param wireType wire type of the field + * @return the gRPC tag + */ + public static int makeTag(final int fieldNumber, final int wireType) { + // This is a public version of WireFormat.makeTag. + return (fieldNumber << TAG_TYPE_BITS) | wireType; + } + + /** + * Gets a buffer directly from a gRPC input stream. + * + * @param stream the input stream + * @return the raw data buffer + */ + public static ReadableBuffer getBufferFromStream(InputStream stream) { + if (!sZeroCopyReceiveSupported + || !stream.getClass().equals(sReadableBufferField.getDeclaringClass())) { + return null; + } + try { + return (ReadableBuffer) sReadableBufferField.get(stream); + } catch (Exception e) { + LOG.warn("Failed to get data buffer from stream.", e); + return null; + } + } + + /** + * Add the given buffers directly to the gRPC output stream. + * + * @param buffers the buffers to be added + * @param stream the output stream + * @return whether the buffers are added successfully + */ + public static boolean addBuffersToStream(ByteBuf[] buffers, OutputStream stream) { + if (!sZeroCopySendSupported || !stream.getClass().equals(sBufferList.getDeclaringClass())) { + return false; + } + try { + if (sCurrent.get(stream) != null) { + return false; + } + for (ByteBuf buffer : buffers) { + Object nettyBuffer = sNettyWritableBufferConstructor.newInstance(buffer); + List list = (List) sBufferList.get(stream); + list.add(nettyBuffer); + buffer.retain(); + sCurrent.set(stream, nettyBuffer); + } + return true; + } catch (Exception e) { + LOG.warn("Failed to add data buffer to stream.", e); + return false; + } + } + + /** + * Creates a service definition that uses custom marshallers. + * + * @param service the service to intercept + * @param marshallers a map that specifies which marshaller to use for each method + * @return the new service definition + */ + public static ServerServiceDefinition overrideMethods( + final ServerServiceDefinition service, + final Map marshallers) { + List> newMethods = new ArrayList>(); + List> newDescriptors = new ArrayList>(); + // intercepts the descriptors + for (final ServerMethodDefinition definition : service.getMethods()) { + ServerMethodDefinition newMethod = interceptMethod(definition, marshallers); + newDescriptors.add(newMethod.getMethodDescriptor()); + newMethods.add(newMethod); + } + // builds the new service descriptor + final ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition + .builder(new ServiceDescriptor(service.getServiceDescriptor().getName(), newDescriptors)); + // creates the new service definition + for (ServerMethodDefinition definition : newMethods) { + serviceBuilder.addMethod(definition); + } + return serviceBuilder.build(); + } + + private static ServerMethodDefinition interceptMethod( + final ServerMethodDefinition definition, + final Map newMethods) { + MethodDescriptor descriptor = definition.getMethodDescriptor(); + MethodDescriptor newMethod = newMethods.get(descriptor); + if (newMethod != null) { + return ServerMethodDefinition.create(newMethod, definition.getServerCallHandler()); + } + return definition; + } +} diff --git a/core/common/src/main/java/alluxio/grpc/GrpcServerBuilder.java b/core/common/src/main/java/alluxio/grpc/GrpcServerBuilder.java index bcc44079aad9..8646059cdc2b 100644 --- a/core/common/src/main/java/alluxio/grpc/GrpcServerBuilder.java +++ b/core/common/src/main/java/alluxio/grpc/GrpcServerBuilder.java @@ -162,6 +162,16 @@ public GrpcServerBuilder workerEventLoopGroup(EventLoopGroup workerGroup) { return this; } + /** + * Sets the maximum size of inbound messages + * @param messageSize maximum size of the message + * @return an updated instance of this {@link GrpcServerBuilder} + */ + public GrpcServerBuilder maxInboundMessageSize(int messageSize) { + mNettyServerBuilder = mNettyServerBuilder.maxInboundMessageSize(messageSize); + return this; + } + /** * Add a service to this server. * diff --git a/core/common/src/main/java/alluxio/grpc/ReadResponseMarshaller.java b/core/common/src/main/java/alluxio/grpc/ReadResponseMarshaller.java new file mode 100644 index 000000000000..9b5eaaef326b --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/ReadResponseMarshaller.java @@ -0,0 +1,105 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import alluxio.network.protocol.databuffer.DataBuffer; +import alluxio.network.protocol.databuffer.NettyDataBuffer; +import alluxio.util.proto.ProtoUtils; + +import com.google.common.base.Preconditions; +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.UnsafeByteOperations; +import com.google.protobuf.WireFormat; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import java.io.IOException; +import java.io.InputStream; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Marshaller for {@link ReadResponse}. + */ +@NotThreadSafe +public class ReadResponseMarshaller extends DataMessageMarshaller { + /** + * Creates a {@link ReadResponseMarshaller}. + */ + public ReadResponseMarshaller() { + super(BlockWorkerGrpc.getReadBlockMethod().getResponseMarshaller()); + } + + @Override + protected ByteBuf[] serialize(ReadResponse message) throws IOException { + DataBuffer chunkBuffer = pollBuffer(message); + if (chunkBuffer == null) { + if (!message.hasChunk() || !message.getChunk().hasData()) { + // nothing to serialize + return new ByteBuf[0]; + } + // attempts to fallback to read chunk from message + chunkBuffer = new NettyDataBuffer( + Unpooled.wrappedBuffer(message.getChunk().getData().asReadOnlyByteBuffer())); + } + int size = message.getSerializedSize() - chunkBuffer.readableBytes(); + byte[] header = new byte[size]; + CodedOutputStream stream = CodedOutputStream.newInstance(header); + stream.writeTag(ReadResponse.CHUNK_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + stream.writeUInt32NoTag(message.getChunk().getSerializedSize()); + stream.writeTag(Chunk.DATA_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + stream.writeUInt32NoTag(chunkBuffer.readableBytes()); + return new ByteBuf[] { Unpooled.wrappedBuffer(header), (ByteBuf) chunkBuffer.getNettyOutput() }; + } + + @Override + protected ReadResponse deserialize(ReadableBuffer buffer) throws IOException { + if (buffer.readableBytes() == 0) { + return ReadResponse.getDefaultInstance(); + } + try (InputStream is = ReadableBuffers.openStream(buffer, false)) { + Preconditions.checkState(ProtoUtils.readRawVarint32(is) == GrpcSerializationUtils.makeTag( + ReadResponse.CHUNK_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)); + int messageSize = ProtoUtils.readRawVarint32(is); + Preconditions.checkState(messageSize == buffer.readableBytes()); + Preconditions.checkState(ProtoUtils.readRawVarint32(is) == GrpcSerializationUtils.makeTag( + Chunk.DATA_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)); + int chunkSize = ProtoUtils.readRawVarint32(is); + Preconditions.checkState(chunkSize == buffer.readableBytes()); + ReadResponse response = ReadResponse.newBuilder().build(); + offerBuffer(new ReadableDataBuffer(buffer), response); + return response; + } + } + + @Override + public ReadResponse combineData(DataMessage message) { + if (message == null) { + return null; + } + DataBuffer buffer = message.getBuffer(); + if (buffer == null) { + return message.getMessage(); + } + try { + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes, 0, bytes.length); + return message.getMessage().toBuilder() + .setChunk(Chunk.newBuilder().setData(UnsafeByteOperations.unsafeWrap(bytes)).build()) + .build(); + } finally { + message.getBuffer().release(); + } + } +} diff --git a/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java b/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java new file mode 100644 index 000000000000..ae7f0a5cf775 --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java @@ -0,0 +1,56 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import alluxio.network.protocol.databuffer.DataBuffer; + +import io.grpc.internal.ReadableBuffer; + +import java.nio.ByteBuffer; + +public class ReadableDataBuffer implements DataBuffer { + private final ReadableBuffer mBuffer; + + public ReadableDataBuffer(ReadableBuffer buffer) { + mBuffer = buffer; + } + + @Override + public Object getNettyOutput() { + throw new UnsupportedOperationException(); + } + + @Override + public long getLength() { + return mBuffer.readableBytes(); + } + + @Override + public ByteBuffer getReadOnlyByteBuffer() { + throw new UnsupportedOperationException(); + } + + @Override + public void readBytes(byte[] dst, int dstIndex, int length) { + mBuffer.readBytes(dst, dstIndex, length); + } + + @Override + public int readableBytes() { + return mBuffer.readableBytes(); + } + + @Override + public void release() { + mBuffer.close(); + } +} diff --git a/core/common/src/main/java/alluxio/util/proto/ProtoUtils.java b/core/common/src/main/java/alluxio/util/proto/ProtoUtils.java index edff983ce044..214da18983c5 100644 --- a/core/common/src/main/java/alluxio/util/proto/ProtoUtils.java +++ b/core/common/src/main/java/alluxio/util/proto/ProtoUtils.java @@ -52,6 +52,17 @@ public static int readRawVarint32(int firstByte, InputStream input) throws IOExc return CodedInputStream.readRawVarint32(firstByte, input); } + /** + * A wrapper of {@link CodedInputStream#readRawVarint32(InputStream)}. + * + * @param input input stream + * @return an int value read from the input stream + */ + public static int readRawVarint32(InputStream input) throws IOException { + int firstByte = input.read(); + return CodedInputStream.readRawVarint32(firstByte, input); + } + /** * A wrapper of * {@link alluxio.proto.journal.Job.TaskInfo.Builder#setResult} to take byte[] as input. diff --git a/core/common/src/test/java/alluxio/grpc/ReadResponseMarshallerTest.java b/core/common/src/test/java/alluxio/grpc/ReadResponseMarshallerTest.java new file mode 100644 index 000000000000..994946b6990c --- /dev/null +++ b/core/common/src/test/java/alluxio/grpc/ReadResponseMarshallerTest.java @@ -0,0 +1,91 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import alluxio.network.protocol.databuffer.DataBuffer; +import alluxio.network.protocol.databuffer.NettyDataBuffer; + +import com.google.protobuf.ByteString; +import io.grpc.Drainable; +import io.netty.buffer.Unpooled; +import org.apache.commons.io.output.ByteArrayOutputStream; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * Unit tests for {@link ReadResponseMarshaller}. + */ +public final class ReadResponseMarshallerTest { + + @Test + public void streamEmptyMessage() throws Exception { + validateStream(ReadResponse.getDefaultInstance()); + } + + @Test + public void streamMessage() throws Exception { + validateStream(buildResponse("test".getBytes())); + } + + @Test + public void parseEmptyMessage() throws Exception { + validateParse(ReadResponse.getDefaultInstance()); + } + + @Test + public void parseMessage() throws Exception { + validateParse(buildResponse("test".getBytes())); + } + + private void validateStream(ReadResponse message) throws IOException { + ReadResponseMarshaller marshaller = new ReadResponseMarshaller(); + byte[] expected = message.toByteArray(); + if (message.hasChunk() && message.getChunk().hasData()) { + marshaller.offerBuffer(new NettyDataBuffer( + Unpooled.wrappedBuffer(message.getChunk().getData().asReadOnlyByteBuffer())), message); + } + InputStream stream = marshaller.stream(message); + assertTrue(stream instanceof Drainable); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ((Drainable) stream).drainTo(outputStream); + assertArrayEquals(expected, outputStream.toByteArray()); + } + + private void validateParse(ReadResponse message) { + ReadResponseMarshaller marshaller = new ReadResponseMarshaller(); + byte[] data = message.toByteArray(); + ReadResponse parsedMessage = marshaller.parse(new ByteArrayInputStream(data)); + if (data.length > 0) { + DataBuffer buffer = marshaller.pollBuffer(parsedMessage); + assertNotNull(buffer); + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes, 0, bytes.length); + parsedMessage = parsedMessage.toBuilder().setChunk(Chunk.newBuilder().setData( + ByteString.copyFrom(bytes) + ).build()).build(); + } + assertEquals(message, parsedMessage); + } + + private ReadResponse buildResponse(byte[] data) { + return ReadResponse.newBuilder().setChunk(Chunk.newBuilder().setData( + ByteString.copyFrom(data))).build(); + } +} diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractReadHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractReadHandler.java index b07045bf61fc..c13b23677e6a 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractReadHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractReadHandler.java @@ -11,6 +11,7 @@ package alluxio.worker.grpc; +import alluxio.grpc.DataMessage; import alluxio.conf.PropertyKey; import alluxio.conf.ServerConfiguration; import alluxio.exception.status.AlluxioStatusException; @@ -27,11 +28,13 @@ import com.google.protobuf.UnsafeByteOperations; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.stub.ServerCallStreamObserver; +import io.grpc.internal.SerializingExecutor; +import io.grpc.stub.CallStreamObserver; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.locks.ReentrantLock; @@ -69,6 +72,8 @@ abstract class AbstractReadHandler> /** The executor to run {@link DataReader}. */ private final ExecutorService mDataReaderExecutor; + /** A serializing executor for sending responses. */ + private Executor mSerializingExecutor; private final ReentrantLock mLock = new ReentrantLock(); @@ -89,6 +94,7 @@ abstract class AbstractReadHandler> AbstractReadHandler(ExecutorService executorService, StreamObserver responseObserver) { mDataReaderExecutor = executorService; + mSerializingExecutor = new SerializingExecutor(executorService); mResponseObserver = responseObserver; } @@ -105,7 +111,8 @@ public void onNext(alluxio.grpc.ReadRequest request) { mDataReaderExecutor.submit(createDataReader(mContext, mResponseObserver)); mContext.setDataReaderActive(true); } catch (Exception e) { - mResponseObserver.onError(GrpcExceptionUtils.fromThrowable(e)); + mSerializingExecutor.execute(() -> + mResponseObserver.onError(GrpcExceptionUtils.fromThrowable(e))); } } @@ -235,7 +242,7 @@ private void incrementMetrics(long bytesRead) { * A runnable that reads data and writes them to the channel. */ protected abstract class DataReader implements Runnable { - private final ServerCallStreamObserver mResponse; + private final CallStreamObserver mResponse; private final T mContext; private final ReadRequest mRequest; private final long mChunkSize; @@ -250,7 +257,7 @@ protected abstract class DataReader implements Runnable { mContext = context; mRequest = context.getRequest(); mChunkSize = Math.min(mRequest.getChunkSize(), MAX_CHUNK_SIZE); - mResponse = (ServerCallStreamObserver) response; + mResponse = (CallStreamObserver) response; } @Override @@ -302,20 +309,33 @@ private void runInternal() { } if (chunk != null) { - ReadResponse response = ReadResponse.newBuilder().setChunk(Chunk.newBuilder() - .setData(UnsafeByteOperations.unsafeWrap(chunk.getReadOnlyByteBuffer())).build()) - .build(); - mResponse.onNext(response); - incrementMetrics(chunk.getLength()); + DataBuffer finalChunk = chunk; + mSerializingExecutor.execute(() -> { + try { + ReadResponse response = ReadResponse.newBuilder().setChunk(Chunk.newBuilder() + .setData(UnsafeByteOperations.unsafeWrap(finalChunk.getReadOnlyByteBuffer())) + ).build(); + if (mResponse instanceof DataMessageServerStreamObserver) { + ((DataMessageServerStreamObserver) mResponse) + .onNext(new DataMessage<>(response, finalChunk)); + } else { + mResponse.onNext(response); + } + incrementMetrics(finalChunk.getLength()); + } catch (Exception e) { + LOG.error("Failed to read data.", e); + setError(new Error(AlluxioStatusException.fromThrowable(e), true)); + } finally { + if (finalChunk != null) { + finalChunk.release(); + } + } + }); } } catch (Exception e) { LOG.error("Failed to read data.", e); setError(new Error(AlluxioStatusException.fromThrowable(e), true)); continue; - } finally { - if (chunk != null) { - chunk.release(); - } } } @@ -367,44 +387,50 @@ protected abstract DataBuffer getDataBuffer(T context, StreamObserver { + try { + mResponse.onError(GrpcExceptionUtils.toGrpcStatusException(error.getCause())); + } catch (StatusRuntimeException e) { + // Ignores the error when client already closed the stream. + if (e.getStatus().getCode() != Status.Code.CANCELLED) { + throw e; + } } - } + }); } /** * Writes a success response. */ private void replyEof() { - try { - Preconditions.checkState(!mContext.isDoneUnsafe()); - mContext.setDoneUnsafe(true); - mResponse.onCompleted(); - } catch (StatusRuntimeException e) { - if (e.getStatus().getCode() != Status.Code.CANCELLED) { - throw e; + mSerializingExecutor.execute(() -> { + try { + Preconditions.checkState(!mContext.isDoneUnsafe()); + mContext.setDoneUnsafe(true); + mResponse.onCompleted(); + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() != Status.Code.CANCELLED) { + throw e; + } } - } + }); } /** * Writes a cancel response. */ private void replyCancel() { - try { - Preconditions.checkState(!mContext.isDoneUnsafe()); - mContext.setDoneUnsafe(true); - mResponse.onCompleted(); - } catch (StatusRuntimeException e) { - if (e.getStatus().getCode() != Status.Code.CANCELLED) { - throw e; + mSerializingExecutor.execute(() -> { + try { + Preconditions.checkState(!mContext.isDoneUnsafe()); + mContext.setDoneUnsafe(true); + mResponse.onCompleted(); + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() != Status.Code.CANCELLED) { + throw e; + } } - } + }); } } } diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java index 9a2385585a7c..1a06a6c5714a 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java @@ -12,6 +12,8 @@ package alluxio.worker.grpc; import alluxio.RpcUtils; +import alluxio.conf.PropertyKey; +import alluxio.conf.ServerConfiguration; import alluxio.grpc.AsyncCacheRequest; import alluxio.grpc.AsyncCacheResponse; import alluxio.grpc.BlockWorkerGrpc; @@ -21,6 +23,7 @@ import alluxio.grpc.OpenLocalBlockResponse; import alluxio.grpc.ReadRequest; import alluxio.grpc.ReadResponse; +import alluxio.grpc.ReadResponseMarshaller; import alluxio.grpc.RemoveBlockRequest; import alluxio.grpc.RemoveBlockResponse; import alluxio.grpc.WriteResponse; @@ -29,12 +32,18 @@ import alluxio.worker.block.AsyncCacheRequestManager; import alluxio.worker.block.BlockWorker; +import com.google.common.collect.ImmutableMap; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.grpc.MethodDescriptor; +import io.grpc.stub.CallStreamObserver; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; +import java.util.Map; + /** * Server side implementation of the gRPC BlockWorker interface. */ @@ -42,8 +51,11 @@ public class BlockWorkerImpl extends BlockWorkerGrpc.BlockWorkerImplBase { private static final Logger LOG = LoggerFactory.getLogger(BlockWorkerImpl.class); + private static final boolean ZERO_COPY_ENABLED = + ServerConfiguration.getBoolean(PropertyKey.WORKER_NETWORK_ZEROCOPY_ENABLED); private WorkerProcess mWorkerProcess; private final AsyncCacheRequestManager mRequestManager; + private ReadResponseMarshaller mReadResponseMarshaller = new ReadResponseMarshaller(); /** * Creates a new implementation of gRPC BlockWorker interface. @@ -56,13 +68,29 @@ public BlockWorkerImpl(WorkerProcess workerProcess) { GrpcExecutors.ASYNC_CACHE_MANAGER_EXECUTOR, mWorkerProcess.getWorker(BlockWorker.class)); } + /** + * @return a map of gRPC methods with overridden descriptors + */ + public Map getOverriddenMethodDescriptors() { + if (ZERO_COPY_ENABLED) { + return ImmutableMap.of(BlockWorkerGrpc.getReadBlockMethod(), + BlockWorkerGrpc.getReadBlockMethod().toBuilder() + .setResponseMarshaller(mReadResponseMarshaller).build()); + } + return Collections.emptyMap(); + } + @Override public StreamObserver readBlock(StreamObserver responseObserver) { + CallStreamObserver callStreamObserver = + (CallStreamObserver) responseObserver; + if (ZERO_COPY_ENABLED) { + callStreamObserver = + new DataMessageServerStreamObserver<>(callStreamObserver, mReadResponseMarshaller); + } BlockReadHandler readHandler = new BlockReadHandler(GrpcExecutors.BLOCK_READER_EXECUTOR, - mWorkerProcess.getWorker(BlockWorker.class), responseObserver); - ServerCallStreamObserver serverCallStreamObserver = - (ServerCallStreamObserver) responseObserver; - serverCallStreamObserver.setOnReadyHandler(readHandler::onReady); + mWorkerProcess.getWorker(BlockWorker.class), callStreamObserver); + callStreamObserver.setOnReadyHandler(readHandler::onReady); return readHandler; } diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerStreamObserver.java b/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerStreamObserver.java new file mode 100644 index 000000000000..cc84e9a72563 --- /dev/null +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerStreamObserver.java @@ -0,0 +1,96 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.worker.grpc; + +import alluxio.grpc.BufferRepository; +import alluxio.grpc.DataMessage; +import alluxio.network.protocol.databuffer.DataBuffer; + +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A {@link StreamObserver} for handling raw data buffers. + * + * @param type of the message + */ +@NotThreadSafe +public class DataMessageServerStreamObserver extends CallStreamObserver { + + private final BufferRepository mBufferRepository; + private final CallStreamObserver mObserver; + + /** + * @param observer the original observer + * @param bufferRepository the repository of the buffers + */ + public DataMessageServerStreamObserver(CallStreamObserver observer, + BufferRepository bufferRepository) { + mObserver = observer; + mBufferRepository = bufferRepository; + } + + /** + * Receives a message with data buffer from the stream. + * + * @param value the value passed to the stream + */ + public void onNext(DataMessage value) { + DataBuffer buffer = value.getBuffer(); + if (buffer != null) { + mBufferRepository.offerBuffer(buffer, value.getMessage()); + } + mObserver.onNext(value.getMessage()); + } + + @Override + public void onNext(T value) { + mObserver.onNext(value); + } + + @Override + public void onError(Throwable t) { + mObserver.onError(t); + } + + @Override + public void onCompleted() { + mObserver.onCompleted(); + } + + @Override + public boolean isReady() { + return mObserver.isReady(); + } + + @Override + public void setOnReadyHandler(Runnable onReadyHandler) { + mObserver.setOnReadyHandler(onReadyHandler); + } + + @Override + public void disableAutoInboundFlowControl() { + mObserver.disableAutoInboundFlowControl(); + } + + @Override + public void request(int count) { + mObserver.request(count); + } + + @Override + public void setMessageCompression(boolean enable) { + mObserver.setMessageCompression(enable); + } +} diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcDataServer.java b/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcDataServer.java index dc4fdd5df114..6f9e57fff9e0 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcDataServer.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcDataServer.java @@ -16,6 +16,7 @@ import alluxio.grpc.GrpcServer; import alluxio.grpc.GrpcServerBuilder; import alluxio.grpc.GrpcService; +import alluxio.grpc.GrpcSerializationUtils; import alluxio.network.ChannelType; import alluxio.util.network.NettyUtils; import alluxio.worker.DataServer; @@ -52,9 +53,10 @@ public final class GrpcDataServer implements DataServer { ServerConfiguration.getMs(PropertyKey.WORKER_NETWORK_KEEPALIVE_TIME_MS); private final long mKeepAliveTimeoutMs = ServerConfiguration.getMs(PropertyKey.WORKER_NETWORK_KEEPALIVE_TIMEOUT_MS); - private final long mFlowControlWindow = ServerConfiguration.getBytes(PropertyKey.WORKER_NETWORK_FLOWCONTROL_WINDOW); + private final long mMaxInboundMessageSize = + ServerConfiguration.getBytes(PropertyKey.WORKER_NETWORK_MAX_INBOUND_MESSAGE_SIZE); private final long mQuietPeriodMs = ServerConfiguration.getMs(PropertyKey.WORKER_NETWORK_NETTY_SHUTDOWN_QUIET_PERIOD); @@ -73,12 +75,17 @@ public final class GrpcDataServer implements DataServer { public GrpcDataServer(final SocketAddress address, final WorkerProcess workerProcess) { mSocketAddress = address; try { + BlockWorkerImpl blockWorkerService = new BlockWorkerImpl(workerProcess); mServer = createServerBuilder(address, NettyUtils.getWorkerChannel( ServerConfiguration.global())) - .addService(new GrpcService(new BlockWorkerImpl(workerProcess))) + .addService(new GrpcService( + GrpcSerializationUtils.overrideMethods(blockWorkerService.bindService(), + blockWorkerService.getOverriddenMethodDescriptors()) + )) .flowControlWindow((int) mFlowControlWindow) .keepAliveTime(mKeepAliveTimeMs, TimeUnit.MILLISECONDS) .keepAliveTimeout(mKeepAliveTimeoutMs, TimeUnit.MILLISECONDS) + .maxInboundMessageSize((int) mMaxInboundMessageSize) .build() .start(); // There is no way to query domain socket address afterwards.