From 25c22b90f42f603212b79be6f9f50518667b0458 Mon Sep 17 00:00:00 2001 From: bf8086 Date: Tue, 5 Mar 2019 17:54:02 -0800 Subject: [PATCH 1/4] Implement write path zero copy --- .../DataMessageClientResponseObserver.java | 17 +- .../stream/DefaultBlockWorkerClient.java | 20 ++- .../stream/GrpcDataMessageBlockingStream.java | 42 ++++- .../client/block/stream/GrpcDataReader.java | 2 +- .../client/block/stream/GrpcDataWriter.java | 35 +++- .../main/java/alluxio/conf/PropertyKey.java | 19 ++ .../grpc/DataMessageMarshallerProvider.java | 34 +++- .../alluxio/grpc/GrpcSerializationUtils.java | 58 ++++++- .../java/alluxio/grpc/ReadableDataBuffer.java | 12 ++ .../alluxio/grpc/WriteRequestMarshaller.java | 120 +++++++++++++ .../databuffer/ByteArrayDataBuffer.java | 11 ++ .../protocol/databuffer/DataBuffer.java | 17 ++ .../protocol/databuffer/NettyDataBuffer.java | 12 ++ .../protocol/databuffer/NioDataBuffer.java | 12 ++ .../alluxio/worker/block/io/BlockWriter.java | 10 ++ .../worker/block/io/LocalFileBlockWriter.java | 31 ++++ .../worker/block/io/MockBlockWriter.java | 11 ++ .../worker/grpc/AbstractWriteHandler.java | 162 ++++++++++++------ .../alluxio/worker/grpc/BlockWorkerImpl.java | 18 +- .../worker/grpc/BlockWriteHandler.java | 9 +- .../DataMessageServerRequestObserver.java | 63 +++++++ .../worker/grpc/DelegationWriteHandler.java | 15 +- .../alluxio/worker/grpc/GrpcExecutors.java | 6 + .../grpc/UfsFallbackBlockWriteHandler.java | 8 +- .../worker/grpc/UfsFileWriteHandler.java | 6 +- .../worker/grpc/AbstractWriteHandlerTest.java | 43 +++++ .../worker/grpc/BlockWriteHandlerTest.java | 2 + .../UfsFallbackBlockWriteHandlerTest.java | 3 + .../worker/grpc/UfsFileWriteHandlerTest.java | 2 + 29 files changed, 696 insertions(+), 104 deletions(-) create mode 100644 core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java create mode 100644 core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerRequestObserver.java diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java b/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java index 6c01231bf3b0..16326e83967e 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/DataMessageClientResponseObserver.java @@ -30,21 +30,23 @@ */ @NotThreadSafe public class DataMessageClientResponseObserver - implements ClientResponseObserver, DataMessageMarshallerProvider { + extends DataMessageMarshallerProvider + implements ClientResponseObserver { private static final Logger LOG = LoggerFactory.getLogger(DataMessageClientResponseObserver.class); private final StreamObserver mObserver; - private final DataMessageMarshaller mMarshaller; /** * @param observer the original response observer - * @param marshaller the marshaller for the response + * @param requestMarshaller the marshaller for the request + * @param responseMarshaller the marshaller for the response */ public DataMessageClientResponseObserver(StreamObserver observer, - DataMessageMarshaller marshaller) { + DataMessageMarshaller requestMarshaller, + DataMessageMarshaller responseMarshaller) { + super(requestMarshaller, responseMarshaller); mObserver = observer; - mMarshaller = marshaller; } @Override @@ -70,9 +72,4 @@ public void beforeStart(ClientCallStreamObserver requestStream) { LOG.warn("{} does not implement ClientResponseObserver:beforeStart", mObserver); } } - - @Override - public DataMessageMarshaller getMarshaller() { - return mMarshaller; - } } diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java b/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java index 4de86fe97e76..668bc5746a5b 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/DefaultBlockWorkerClient.java @@ -35,6 +35,7 @@ import alluxio.grpc.GrpcSerializationUtils; import alluxio.util.network.NettyUtils; +import com.google.common.base.Preconditions; import com.google.common.io.Closer; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; @@ -114,14 +115,29 @@ public void close() throws IOException { @Override public StreamObserver writeBlock(StreamObserver responseObserver) { - return mStreamingAsyncStub.writeBlock(responseObserver); + if (responseObserver instanceof DataMessageMarshallerProvider) { + DataMessageMarshaller marshaller = + ((DataMessageMarshallerProvider) responseObserver) + .getRequestMarshaller(); + Preconditions.checkNotNull(marshaller); + return mStreamingAsyncStub + .withOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR, + BlockWorkerGrpc.getWriteBlockMethod().toBuilder() + .setRequestMarshaller(marshaller) + .build()) + .writeBlock(responseObserver); + } else { + return mStreamingAsyncStub.writeBlock(responseObserver); + } } @Override public StreamObserver readBlock(StreamObserver responseObserver) { if (responseObserver instanceof DataMessageMarshallerProvider) { DataMessageMarshaller marshaller = - ((DataMessageMarshallerProvider) responseObserver).getMarshaller(); + ((DataMessageMarshallerProvider) responseObserver) + .getResponseMarshaller(); + Preconditions.checkNotNull(marshaller); return mStreamingAsyncStub .withOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR, BlockWorkerGrpc.getReadBlockMethod().toBuilder() diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java index 586b64b8f5ed..ef63dc7b6d5b 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java @@ -31,32 +31,40 @@ */ @NotThreadSafe public class GrpcDataMessageBlockingStream extends GrpcBlockingStream { - private final DataMessageMarshaller mMarshaller; + private final DataMessageMarshaller mRequestMarshaller; + private final DataMessageMarshaller mResponseMarshaller; /** * @param rpcFunc the gRPC bi-directional stream stub function * @param bufferSize maximum number of incoming messages the buffer can hold * @param description description of this stream - * @param deserializer custom deserializer for the response + * @param requestMarshaller the marshaller for the request + * @param responseMarshaller the marshaller for the response */ public GrpcDataMessageBlockingStream(Function, StreamObserver> rpcFunc, - int bufferSize, String description, DataMessageMarshaller deserializer) { + int bufferSize, String description, DataMessageMarshaller requestMarshaller, + DataMessageMarshaller responseMarshaller) { super((resObserver) -> { DataMessageClientResponseObserver newObserver = - new DataMessageClientResponseObserver<>(resObserver, deserializer); + new DataMessageClientResponseObserver<>(resObserver, requestMarshaller, + responseMarshaller); StreamObserver requestObserver = rpcFunc.apply(newObserver); return requestObserver; }, bufferSize, description); - mMarshaller = deserializer; + mRequestMarshaller = requestMarshaller; + mResponseMarshaller = responseMarshaller; } @Override public ResT receive(long timeoutMs) throws IOException { + if (mResponseMarshaller == null) { + return super.receive(timeoutMs); + } DataMessage message = receiveDataMessage(timeoutMs); if (message == null) { return null; } - return mMarshaller.combineData(message); + return mResponseMarshaller.combineData(message); } /** @@ -74,12 +82,32 @@ public DataMessage receiveDataMessage(long timeoutMs) throws I if (response == null) { return null; } - DataBuffer buffer = mMarshaller.pollBuffer(response); + DataBuffer buffer = mResponseMarshaller.pollBuffer(response); return new DataMessage<>(response, buffer); } + /** + * Sends a request. Will wait until the stream is ready before sending or timeout if the + * given timeout is reached. + * + * @param message the request message with {@link DataBuffer attached} + * @param timeoutMs maximum wait time before throwing a {@link DeadlineExceededException} + * @throws IOException if any error occurs + */ + public void sendDataMessage(DataMessage message, long timeoutMs) + throws IOException { + if (mRequestMarshaller != null) { + mRequestMarshaller.offerBuffer(message.getBuffer(), message.getMessage()); + } + super.send(message.getMessage(), timeoutMs); + } + @Override public void waitForComplete(long timeoutMs) throws IOException { + if (mResponseMarshaller == null) { + super.waitForComplete(timeoutMs); + return; + } DataMessage message; while (!isCanceled() && (message = receiveDataMessage(timeoutMs)) != null) { if (message.getBuffer() != null) { diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java index 954241a424d2..d67aeed56f88 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataReader.java @@ -88,7 +88,7 @@ private GrpcDataReader(FileSystemContext context, WorkerNetAddress address, .add("request", mReadRequest) .add("address", address) .toString(), - mMarshaller); + null, mMarshaller); } else { mStream = new GrpcBlockingStream<>(mClient::readBlock, mReaderBufferSizeMessages, MoreObjects.toStringHelper(this) diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataWriter.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataWriter.java index af9792cd7664..474f69b519fc 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataWriter.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataWriter.java @@ -17,10 +17,13 @@ import alluxio.conf.PropertyKey; import alluxio.exception.status.UnavailableException; import alluxio.grpc.Chunk; +import alluxio.grpc.DataMessage; import alluxio.grpc.RequestType; import alluxio.grpc.WriteRequest; import alluxio.grpc.WriteRequestCommand; +import alluxio.grpc.WriteRequestMarshaller; import alluxio.grpc.WriteResponse; +import alluxio.network.protocol.databuffer.NettyDataBuffer; import alluxio.proto.dataserver.Protocol; import alluxio.util.proto.ProtoUtils; import alluxio.wire.WorkerNetAddress; @@ -69,6 +72,7 @@ public final class GrpcDataWriter implements DataWriter { private final WriteRequestCommand mPartialRequest; private final long mChunkSize; private final GrpcBlockingStream mStream; + private final WriteRequestMarshaller mMarshaller; /** * The next pos to queue to the buffer. @@ -152,11 +156,21 @@ private GrpcDataWriter(FileSystemContext context, final WorkerNetAddress address mPartialRequest = builder.buildPartial(); mChunkSize = chunkSize; mClient = client; - mStream = new GrpcBlockingStream<>(mClient::writeBlock, mWriterBufferSizeMessages, - MoreObjects.toStringHelper(this) - .add("request", mPartialRequest) - .add("address", address) - .toString()); + mMarshaller = new WriteRequestMarshaller(); + if (conf.getBoolean(PropertyKey.USER_NETWORK_ZEROCOPY_ENABLED)) { + mStream = new GrpcDataMessageBlockingStream<>( + mClient::writeBlock, mWriterBufferSizeMessages, + MoreObjects.toStringHelper(this) + .add("request", mPartialRequest) + .add("address", address) + .toString(), mMarshaller, null); + } else { + mStream = new GrpcBlockingStream<>(mClient::writeBlock, mWriterBufferSizeMessages, + MoreObjects.toStringHelper(this) + .add("request", mPartialRequest) + .add("address", address) + .toString()); + } mStream.send(WriteRequest.newBuilder().setCommand(mPartialRequest.toBuilder()).build(), mDataTimeoutMs); } @@ -170,11 +184,16 @@ public long pos() { public void writeChunk(final ByteBuf buf) throws IOException { mPosToQueue += buf.readableBytes(); try { - mStream.send(WriteRequest.newBuilder().setCommand(mPartialRequest).setChunk( + WriteRequest request = WriteRequest.newBuilder().setCommand(mPartialRequest).setChunk( Chunk.newBuilder() .setData(UnsafeByteOperations.unsafeWrap(buf.nioBuffer())) - .build()).build(), - mDataTimeoutMs); + .build()).build(); + if (mStream instanceof GrpcDataMessageBlockingStream) { + ((GrpcDataMessageBlockingStream) mStream) + .sendDataMessage(new DataMessage<>(request, new NettyDataBuffer(buf)), mDataTimeoutMs); + } else { + mStream.send(request, mDataTimeoutMs); + } } finally { buf.release(); } diff --git a/core/common/src/main/java/alluxio/conf/PropertyKey.java b/core/common/src/main/java/alluxio/conf/PropertyKey.java index 5424a7576a4e..6929958d7059 100644 --- a/core/common/src/main/java/alluxio/conf/PropertyKey.java +++ b/core/common/src/main/java/alluxio/conf/PropertyKey.java @@ -1989,6 +1989,21 @@ public String toString() { .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) .setScope(Scope.WORKER) .build(); + public static final PropertyKey WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX = + new Builder(Name.WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX) + .setDefaultValue(1024) + .setDescription("The maximum number of threads used to write blocks in the data server.") + .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) + .setScope(Scope.WORKER) + .build(); + public static final PropertyKey WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES = + new Builder(Name.WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES) + .setDefaultValue(8) + .setDescription("When a client writes to a remote worker, the maximum number of " + + "data messages to buffer by the server.") + .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) + .setScope(Scope.WORKER) + .build(); public static final PropertyKey WORKER_NETWORK_FLOWCONTROL_WINDOW = new Builder(Name.WORKER_NETWORK_FLOWCONTROL_WINDOW) .setDefaultValue("2MB") @@ -3821,6 +3836,10 @@ public static final class Name { "alluxio.worker.network.async.cache.manager.threads.max"; public static final String WORKER_NETWORK_BLOCK_READER_THREADS_MAX = "alluxio.worker.network.block.reader.threads.max"; + public static final String WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX = + "alluxio.worker.network.block.writer.threads.max"; + public static final String WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES = + "alluxio.worker.network.writer.buffer.size.messages"; public static final String WORKER_NETWORK_FLOWCONTROL_WINDOW = "alluxio.worker.network.flowcontrol.window"; public static final String WORKER_NETWORK_KEEPALIVE_TIME_MS = diff --git a/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java b/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java index 6397eac9c68d..6d19de68c3d8 100644 --- a/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java +++ b/core/common/src/main/java/alluxio/grpc/DataMessageMarshallerProvider.java @@ -12,10 +12,36 @@ package alluxio.grpc; /** - * A provider of {@link DataMessageMarshaller}. + * A provider of {@link DataMessageMarshaller} for a gRPC call. * - * @param type of the message + * @param type of the request message + * @param type of the response message */ -public interface DataMessageMarshallerProvider { - DataMessageMarshaller getMarshaller(); +public class DataMessageMarshallerProvider { + private final DataMessageMarshaller mRequestMarshaller; + private final DataMessageMarshaller mResponseMarshaller; + + /** + * @param requestMarshaller the marshaller for the request, or null if not provided + * @param responseMarshaller the marshaller for the response, or null if not provided + */ + public DataMessageMarshallerProvider(DataMessageMarshaller requestMarshaller, + DataMessageMarshaller responseMarshaller) { + mRequestMarshaller = requestMarshaller; + mResponseMarshaller = responseMarshaller; + } + + /** + * @return the request marshaller + */ + public DataMessageMarshaller getRequestMarshaller() { + return mRequestMarshaller; + } + + /** + * @return the response marshaller + */ + public DataMessageMarshaller getResponseMarshaller() { + return mResponseMarshaller; + } } diff --git a/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java index afbc66625d1d..091e8f8a1cbd 100644 --- a/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java +++ b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java @@ -16,8 +16,11 @@ import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; import io.grpc.ServiceDescriptor; +import io.grpc.internal.CompositeReadableBuffer; import io.grpc.internal.ReadableBuffer; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Queue; /** * Utilities for gRPC message serialization. @@ -42,21 +46,23 @@ public class GrpcSerializationUtils { private static final String BUFFER_INPUT_STREAM_CLASS_NAME = "io.grpc.internal.ReadableBuffers$BufferInputStream"; - private static final String BUFFER_FIELD_NAME = - "buffer"; + private static final String BUFFER_FIELD_NAME = "buffer"; + private static final String BUFFERS_FIELD_NAME = "buffers"; private static final String NETTY_WRITABLE_BUFFER_CLASS_NAME = "io.grpc.netty.NettyWritableBuffer"; + private static final String NETTY_READABLE_BUFFER_CLASS_NAME = + "io.grpc.netty.NettyReadableBuffer"; private static final String BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME = "io.grpc.internal.MessageFramer$BufferChainOutputStream"; - private static final String BUFFER_LIST_FIELD_NAME = - "bufferList"; - private static final String CURRENT_FIELD_NAME = - "current"; + private static final String BUFFER_LIST_FIELD_NAME = "bufferList"; + private static final String CURRENT_FIELD_NAME = "current"; private static Constructor sNettyWritableBufferConstructor; private static Field sBufferList; + private static Field sCompositeBuffers = null; private static Field sCurrent; private static Field sReadableBufferField = null; + private static Field sReadableByteBuf = null; private static boolean sZeroCopySendSupported = true; private static boolean sZeroCopyReceiveSupported = true; @@ -72,6 +78,8 @@ public class GrpcSerializationUtils { getPrivateConstructor(NETTY_WRITABLE_BUFFER_CLASS_NAME, ByteBuf.class); sBufferList = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, BUFFER_LIST_FIELD_NAME); sCurrent = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, CURRENT_FIELD_NAME); + sCompositeBuffers = getPrivateField(CompositeReadableBuffer.class.getName(), BUFFERS_FIELD_NAME); + sReadableByteBuf = getPrivateField(NETTY_READABLE_BUFFER_CLASS_NAME, BUFFER_FIELD_NAME); } catch (Exception e) { LOG.warn("Cannot get gRPC output stream buffer, zero copy receive will be disabled.", e); sZeroCopyReceiveSupported = false; @@ -125,6 +133,42 @@ public static ReadableBuffer getBufferFromStream(InputStream stream) { } } + /** + * Gets a Netty buffer directly from a gRPC ReadableBuffer. + * + * @param buffer the input buffer + * @return the raw ByteBuf + */ + public static ByteBuf getByteBufFromReadableBuffer(ReadableBuffer buffer) { + if (!sZeroCopyReceiveSupported) { + return null; + } + try { + if (buffer instanceof CompositeReadableBuffer) { + Queue buffers = (Queue)sCompositeBuffers.get(buffer); + if (buffers.size() == 1) { + return getByteBufFromReadableBuffer(buffers.peek()); + } else { + CompositeByteBuf buf = PooledByteBufAllocator.DEFAULT.compositeBuffer(); + for (ReadableBuffer readableBuffer : buffers) { + ByteBuf subBuffer = getByteBufFromReadableBuffer(readableBuffer); + if (subBuffer == null) { + return null; + } + buf.addComponent(true, subBuffer); + } + return buf; + } + } else if (buffer.getClass().equals(sReadableByteBuf.getDeclaringClass())) { + return (ByteBuf) sReadableByteBuf.get(buffer); + } + } catch (Exception e) { + LOG.warn("Failed to get data buffer from stream: {}.", e.getMessage()); + return null; + } + return null; + } + /** * Add the given buffers directly to the gRPC output stream. * @@ -149,7 +193,7 @@ public static boolean addBuffersToStream(ByteBuf[] buffers, OutputStream stream) } return true; } catch (Exception e) { - LOG.warn("Failed to add data buffer to stream.", e); + LOG.warn("Failed to add data buffer to stream: {}.", e.getMessage()); return false; } } diff --git a/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java b/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java index ae7f0a5cf775..099c425d30d5 100644 --- a/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java +++ b/core/common/src/main/java/alluxio/grpc/ReadableDataBuffer.java @@ -15,6 +15,8 @@ import io.grpc.internal.ReadableBuffer; +import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; public class ReadableDataBuffer implements DataBuffer { @@ -53,4 +55,14 @@ public int readableBytes() { public void release() { mBuffer.close(); } + + @Override + public void readBytes(OutputStream outputStream, int length) throws IOException { + mBuffer.readBytes(outputStream, length); + } + + @Override + public void readBytes(ByteBuffer outputBuf) { + mBuffer.readBytes(outputBuf); + } } diff --git a/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java b/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java new file mode 100644 index 000000000000..8c00892a23a8 --- /dev/null +++ b/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java @@ -0,0 +1,120 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.grpc; + +import alluxio.network.protocol.databuffer.DataBuffer; +import alluxio.network.protocol.databuffer.NettyDataBuffer; +import alluxio.util.proto.ProtoUtils; + +import com.google.common.base.Preconditions; +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.UnsafeByteOperations; +import com.google.protobuf.WireFormat; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import java.io.IOException; +import java.io.InputStream; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Marshaller for {@link WriteRequest}. + */ +@NotThreadSafe +public class WriteRequestMarshaller extends DataMessageMarshaller { + private static final int CHUNK_TAG = GrpcSerializationUtils.makeTag( + WriteRequest.CHUNK_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + /** + * Creates a {@link WriteRequestMarshaller}. + */ + public WriteRequestMarshaller() { + super(BlockWorkerGrpc.getWriteBlockMethod().getRequestMarshaller()); + } + + @Override + protected ByteBuf[] serialize(WriteRequest message) throws IOException { + if (message.hasCommand()) { + byte[] command = new byte[message.getSerializedSize()]; + CodedOutputStream stream = CodedOutputStream.newInstance(command); + message.writeTo(stream); + return new ByteBuf[] { Unpooled.wrappedBuffer(command) }; + } + DataBuffer chunkBuffer = pollBuffer(message); + if (chunkBuffer == null) { + if (!message.hasChunk() || !message.getChunk().hasData()) { + // nothing to serialize + return new ByteBuf[0]; + } + // attempts to fallback to read chunk from message + chunkBuffer = new NettyDataBuffer( + Unpooled.wrappedBuffer(message.getChunk().getData().asReadOnlyByteBuffer())); + } + int size = message.getSerializedSize() - chunkBuffer.readableBytes(); + byte[] header = new byte[size]; + CodedOutputStream stream = CodedOutputStream.newInstance(header); + stream.writeTag(WriteRequest.CHUNK_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + stream.writeUInt32NoTag(message.getChunk().getSerializedSize()); + stream.writeTag(Chunk.DATA_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + stream.writeUInt32NoTag(chunkBuffer.readableBytes()); + return new ByteBuf[] { Unpooled.wrappedBuffer(header), (ByteBuf) chunkBuffer.getNettyOutput() }; + } + + @Override + protected WriteRequest deserialize(ReadableBuffer buffer) throws IOException { + if (buffer.readableBytes() == 0) { + return WriteRequest.getDefaultInstance(); + } + try (InputStream is = ReadableBuffers.openStream(buffer, false)) { + int tag = ProtoUtils.readRawVarint32(is); + int messageSize = ProtoUtils.readRawVarint32(is); + if (tag != CHUNK_TAG) { + return WriteRequest.newBuilder().setCommand(WriteRequestCommand.parseFrom(is)).build(); + } + Preconditions.checkState(messageSize == buffer.readableBytes()); + Preconditions.checkState(ProtoUtils.readRawVarint32(is) == GrpcSerializationUtils.makeTag( + Chunk.DATA_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)); + int chunkSize = ProtoUtils.readRawVarint32(is); + Preconditions.checkState(chunkSize == buffer.readableBytes()); + WriteRequest request = WriteRequest.newBuilder().build(); + ByteBuf bytebuf = GrpcSerializationUtils.getByteBufFromReadableBuffer(buffer); + if (bytebuf != null) { + offerBuffer(new NettyDataBuffer(bytebuf), request); + } else { + offerBuffer(new ReadableDataBuffer(buffer), request); + } + return request; + } + } + + @Override + public WriteRequest combineData(DataMessage message) { + if (message == null) { + return null; + } + DataBuffer buffer = message.getBuffer(); + if (buffer == null) { + return message.getMessage(); + } + try { + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes, 0, bytes.length); + return message.getMessage().toBuilder() + .setChunk(Chunk.newBuilder().setData(UnsafeByteOperations.unsafeWrap(bytes)).build()) + .build(); + } finally { + message.getBuffer().release(); + } + } +} diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java b/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java index a0db6c17f793..94a5092a3344 100644 --- a/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java +++ b/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java @@ -14,6 +14,7 @@ import com.google.common.base.Preconditions; import io.netty.buffer.Unpooled; +import java.io.OutputStream; import java.nio.ByteBuffer; /** @@ -55,6 +56,16 @@ public void readBytes(byte[] dst, int dstIndex, int length) { throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); } + @Override + public void readBytes(OutputStream outputStream, int length) { + throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); + } + + @Override + public void readBytes(ByteBuffer outputBuf) { + throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); + } + @Override public int readableBytes() { throw new UnsupportedOperationException( diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/DataBuffer.java b/core/common/src/main/java/alluxio/network/protocol/databuffer/DataBuffer.java index 57d57f6f1efe..40252ce386c1 100644 --- a/core/common/src/main/java/alluxio/network/protocol/databuffer/DataBuffer.java +++ b/core/common/src/main/java/alluxio/network/protocol/databuffer/DataBuffer.java @@ -11,6 +11,8 @@ package alluxio.network.protocol.databuffer; +import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; /** @@ -50,6 +52,21 @@ public interface DataBuffer { */ void readBytes(byte[] dst, int dstIndex, int length); + /** + * Transfers this buffer's data to the given stream. + * + * @param outputStream the stream to transfer data to + * @param length length of the data to be transferred + */ + void readBytes(OutputStream outputStream, int length) throws IOException; + + /** + * Transfers this buffer's data to the given {@link ByteBuffer}. + * + * @param outputBuf the buffer to transfer data to + */ + void readBytes(ByteBuffer outputBuf); + /** * @return the number of readable bytes remaining */ diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/NettyDataBuffer.java b/core/common/src/main/java/alluxio/network/protocol/databuffer/NettyDataBuffer.java index e2e936a0dca6..ccf35a9c1dcb 100644 --- a/core/common/src/main/java/alluxio/network/protocol/databuffer/NettyDataBuffer.java +++ b/core/common/src/main/java/alluxio/network/protocol/databuffer/NettyDataBuffer.java @@ -14,6 +14,8 @@ import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; +import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; /** @@ -57,6 +59,16 @@ public void readBytes(byte[] dst, int dstIndex, int length) { mNettyBuf.readBytes(dst, dstIndex, length); } + @Override + public void readBytes(OutputStream outputStream, int length) throws IOException { + mNettyBuf.readBytes(outputStream, length); + } + + @Override + public void readBytes(ByteBuffer outputBuf) { + mNettyBuf.readBytes(outputBuf); + } + @Override public int readableBytes() { return mNettyBuf.readableBytes(); diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/NioDataBuffer.java b/core/common/src/main/java/alluxio/network/protocol/databuffer/NioDataBuffer.java index 13289e7c5c68..9b1f29b28fda 100644 --- a/core/common/src/main/java/alluxio/network/protocol/databuffer/NioDataBuffer.java +++ b/core/common/src/main/java/alluxio/network/protocol/databuffer/NioDataBuffer.java @@ -16,6 +16,8 @@ import com.google.common.base.Preconditions; import io.netty.buffer.Unpooled; +import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; /** @@ -59,6 +61,16 @@ public void readBytes(byte[] dst, int dstIndex, int length) { mBuffer.get(dst, dstIndex, length); } + @Override + public void readBytes(OutputStream outputStream, int length) throws IOException { + Unpooled.wrappedBuffer(mBuffer).readBytes(outputStream, length).release(); + } + + @Override + public void readBytes(ByteBuffer outputBuf) { + outputBuf.put(mBuffer); + } + @Override public int readableBytes() { return mBuffer.remaining(); diff --git a/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java b/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java index a3ebff39e08a..41fb9f4752ea 100644 --- a/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java +++ b/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java @@ -11,6 +11,8 @@ package alluxio.worker.block.io; +import alluxio.network.protocol.databuffer.DataBuffer; + import io.netty.buffer.ByteBuf; import java.io.Closeable; @@ -40,6 +42,14 @@ public interface BlockWriter extends Closeable { */ long append(ByteBuf buf) throws IOException; + /** + * Appends buffer.readableBytes() bytes to the end of this block writer from the given buffer. + * + * @param buffer the byte buffer to hold the data + * @return the size of data that was appended in bytes + */ + long append(DataBuffer buffer) throws IOException; + /** * @return the current write position (same as the number of bytes written) */ diff --git a/core/common/src/main/java/alluxio/worker/block/io/LocalFileBlockWriter.java b/core/common/src/main/java/alluxio/worker/block/io/LocalFileBlockWriter.java index 1b4fdf279d55..e2d9f0fa4123 100644 --- a/core/common/src/main/java/alluxio/worker/block/io/LocalFileBlockWriter.java +++ b/core/common/src/main/java/alluxio/worker/block/io/LocalFileBlockWriter.java @@ -11,11 +11,14 @@ package alluxio.worker.block.io; +import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.util.io.BufferUtils; import com.google.common.base.Preconditions; import com.google.common.io.Closer; import io.netty.buffer.ByteBuf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.RandomAccessFile; @@ -31,6 +34,8 @@ */ @NotThreadSafe public final class LocalFileBlockWriter implements BlockWriter { + private static final Logger LOG = LoggerFactory.getLogger(LocalFileBlockWriter.class); + private final String mFilePath; private final RandomAccessFile mLocalFile; private final FileChannel mLocalFileChannel; @@ -63,6 +68,22 @@ public long append(ByteBuf buf) throws IOException { return bytesWritten; } + @Override + public long append(DataBuffer buffer) throws IOException { + ByteBuf bytebuf = null; + try { + bytebuf = (ByteBuf) buffer.getNettyOutput(); + } catch (Throwable e) { + LOG.debug("Failed to get ByteBuf from DataBuffer, write performance may be degraded."); + } + if (bytebuf != null) { + return append(bytebuf); + } + long bytesWritten = write(mLocalFileChannel.size(), buffer); + mPosition += bytesWritten; + return bytesWritten; + } + @Override public long getPosition() { return mPosition; @@ -100,4 +121,14 @@ private long write(long offset, ByteBuffer inputBuf) throws IOException { BufferUtils.cleanDirectBuffer(outputBuf); return bytesWritten; } + + private long write(long offset, DataBuffer inputBuf) throws IOException { + int inputBufLength = inputBuf.readableBytes(); + MappedByteBuffer outputBuf = + mLocalFileChannel.map(FileChannel.MapMode.READ_WRITE, offset, inputBufLength); + inputBuf.readBytes(outputBuf); + int bytesWritten = outputBuf.limit(); + BufferUtils.cleanDirectBuffer(outputBuf); + return bytesWritten; + } } diff --git a/core/common/src/test/java/alluxio/worker/block/io/MockBlockWriter.java b/core/common/src/test/java/alluxio/worker/block/io/MockBlockWriter.java index 7fc1fd1f5c76..1d84d9b4ee64 100644 --- a/core/common/src/test/java/alluxio/worker/block/io/MockBlockWriter.java +++ b/core/common/src/test/java/alluxio/worker/block/io/MockBlockWriter.java @@ -11,6 +11,8 @@ package alluxio.worker.block.io; +import alluxio.network.protocol.databuffer.DataBuffer; + import io.netty.buffer.ByteBuf; import java.io.ByteArrayOutputStream; @@ -56,6 +58,15 @@ public long append(ByteBuf buf) throws IOException { return bytesWritten; } + @Override + public long append(DataBuffer buffer) throws IOException { + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes, 0, bytes.length); + mOutputStream.write(bytes); + mPosition += bytes.length; + return bytes.length; + } + @Override public GatheringByteChannel getChannel() { return new GatheringByteChannel() { diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java index b4e6fa7fa2eb..dc473f328015 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java @@ -12,11 +12,15 @@ package alluxio.worker.grpc; import alluxio.client.block.stream.GrpcDataWriter; +import alluxio.conf.PropertyKey; +import alluxio.conf.ServerConfiguration; import alluxio.exception.status.AlluxioStatusException; import alluxio.exception.status.InvalidArgumentException; import alluxio.grpc.WriteRequest; import alluxio.grpc.WriteRequestCommand; import alluxio.grpc.WriteResponse; +import alluxio.network.protocol.databuffer.DataBuffer; +import alluxio.network.protocol.databuffer.NioDataBuffer; import alluxio.util.LogUtils; import com.codahale.metrics.Counter; @@ -26,10 +30,13 @@ import com.google.protobuf.ByteString; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import io.grpc.internal.SerializingExecutor; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.Semaphore; + import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.NotThreadSafe; @@ -53,6 +60,9 @@ abstract class AbstractWriteHandler> { private static final Logger LOG = LoggerFactory.getLogger(AbstractWriteHandler.class); private final StreamObserver mResponseObserver; + private final SerializingExecutor mSerializingExecutor; + private final Semaphore mSemaphore = new Semaphore( + ServerConfiguration.getInt(PropertyKey.WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES), true); /** * This is initialized only once for a whole file or block in @@ -71,6 +81,7 @@ abstract class AbstractWriteHandler> { */ AbstractWriteHandler(StreamObserver responseObserver) { mResponseObserver = responseObserver; + mSerializingExecutor = new SerializingExecutor(GrpcExecutors.BLOCK_WRITER_EXECUTOR); } /** @@ -80,65 +91,113 @@ abstract class AbstractWriteHandler> { */ public void write(WriteRequest writeRequest) { try { - if (mContext == null) { - LOG.debug("Received write request {}.", writeRequest); - mContext = createRequestContext(writeRequest); - } else { - Preconditions.checkState(!mContext.isDoneUnsafe(), - "invalid request after write request is completed."); - } - validateWriteRequest(writeRequest); - if (writeRequest.hasCommand()) { - WriteRequestCommand command = writeRequest.getCommand(); - if (command.getFlush()) { - flush(); + mSemaphore.acquire(); + } catch (InterruptedException e) { + LOG.warn("write data request {} is interrupted: {}", writeRequest, e.getMessage()); + abort(new Error(AlluxioStatusException.fromThrowable(e), true)); + Thread.currentThread().interrupt(); + return; + } + mSerializingExecutor.execute(() -> { + try { + if (mContext == null) { + LOG.debug("Received write request {}.", writeRequest); + mContext = createRequestContext(writeRequest); + } else { + Preconditions.checkState(!mContext.isDoneUnsafe(), + "invalid request after write request is completed."); + } + if (mContext.isDoneUnsafe() || mContext.getError() != null) { + return; + } + validateWriteRequest(writeRequest); + if (writeRequest.hasCommand()) { + WriteRequestCommand command = writeRequest.getCommand(); + if (command.getFlush()) { + flush(); + } else { + handleCommand(command, mContext); + } } else { - handleCommand(command, mContext); + Preconditions.checkState(writeRequest.hasChunk(), + "write request is missing data chunk in non-command message"); + ByteString data = writeRequest.getChunk().getData(); + Preconditions.checkState(data != null && data.size() > 0, + "invalid data size from write request message"); + writeData(new NioDataBuffer(data.asReadOnlyByteBuffer(), data.size())); } - } else { - Preconditions.checkState(writeRequest.hasChunk(), - "write request is missing data chunk in non-command message"); - ByteString data = writeRequest.getChunk().getData(); - Preconditions.checkState(data != null && data.size() > 0, - "invalid data size from write request message"); - writeData(data); + } catch (Exception e) { + LogUtils.warnWithException(LOG, "Exception occurred while processing write request {}.", + writeRequest, e); + abort(new Error(AlluxioStatusException.fromThrowable(e), true)); + } finally { + mSemaphore.release(); } - } catch (Exception e) { - LogUtils.warnWithException(LOG, "Exception occurred while processing write request {}.", - writeRequest, e); + }); + } + + /** + * Handles write request with data message. + * + * @param request the request from the client + * @param buffer the data associated with the request + */ + public void writeDataMessage(WriteRequest request, DataBuffer buffer) { + if (buffer == null) { + write(request); + return; + } + Preconditions.checkState(!request.hasCommand(), + "write request command should not come with data buffer"); + Preconditions.checkState(buffer.readableBytes() > 0, + "invalid data size from write request message"); + try { + mSemaphore.acquire(); + } catch (InterruptedException e) { + LOG.warn("write data request {} is interrupted: {}", request, e.getMessage()); abort(new Error(AlluxioStatusException.fromThrowable(e), true)); + Thread.currentThread().interrupt(); + return; } + mSerializingExecutor.execute(() -> { + writeData(buffer); + mSemaphore.release(); + }); } /** * Handles request complete event. */ public void onCompleted() { - Preconditions.checkState(mContext != null); - try { - completeRequest(mContext); - replySuccess(); - } catch (Exception e) { - LogUtils.warnWithException(LOG, "Exception occurred while completing write request {}.", - mContext.getRequest(), e); - Throwables.throwIfUnchecked(e); - abort(new Error(AlluxioStatusException.fromCheckedException(e), true)); - } + mSerializingExecutor.execute(() -> { + Preconditions.checkState(mContext != null); + try { + completeRequest(mContext); + replySuccess(); + } catch (Exception e) { + LogUtils.warnWithException(LOG, "Exception occurred while completing write request {}.", + mContext.getRequest(), e); + Throwables.throwIfUnchecked(e); + abort(new Error(AlluxioStatusException.fromCheckedException(e), true)); + } + }); } /** * Handles request cancellation event. */ public void onCancel() { - try { - cancelRequest(mContext); - replyCancel(); - } catch (Exception e) { - LogUtils.warnWithException(LOG, "Exception occurred while cancelling write request {}.", - mContext.getRequest(), e); - Throwables.throwIfUnchecked(e); - abort(new Error(AlluxioStatusException.fromCheckedException(e), true)); - } + mSerializingExecutor.execute(() -> { + try { + cancelRequest(mContext); + replyCancel(); + } catch (Exception e) { + LogUtils.warnWithException(LOG, "Exception occurred while cancelling write request {}.", + mContext.getRequest(), e); + Throwables.throwIfUnchecked(e); + abort(new Error(AlluxioStatusException.fromCheckedException(e), true)); + } + }); } /** @@ -152,9 +211,11 @@ public void onError(Throwable cause) { // Cancellation is already handled. return; } - LogUtils.warnWithException(LOG, "Exception thrown while handling write request {}", - mContext == null ? "unknown" : mContext.getRequest(), cause); - abort(new Error(AlluxioStatusException.fromThrowable(cause), false)); + mSerializingExecutor.execute(() -> { + LogUtils.warnWithException(LOG, "Exception thrown while handling write request {}", + mContext == null ? "unknown" : mContext.getRequest(), cause); + abort(new Error(AlluxioStatusException.fromThrowable(cause), false)); + }); } /** @@ -174,9 +235,12 @@ private void validateWriteRequest(alluxio.grpc.WriteRequest request) } } - private void writeData(ByteString buf) { + private void writeData(DataBuffer buf) { try { - int readableBytes = buf.size(); + if (mContext.isDoneUnsafe() || mContext.getError() != null) { + return; + } + int readableBytes = buf.readableBytes(); mContext.setPos(mContext.getPos() + readableBytes); writeBuf(mContext, mResponseObserver, buf, mContext.getPos()); incrementMetrics(readableBytes); @@ -184,6 +248,8 @@ private void writeData(ByteString buf) { LOG.error("Failed to write data for request {}", mContext.getRequest(), e); Throwables.throwIfUnchecked(e); abort(new Error(AlluxioStatusException.fromCheckedException(e), true)); + } finally { + buf.release(); } } @@ -261,7 +327,7 @@ private void abort(Error error) { * @param pos the pos */ protected abstract void writeBuf(T context, StreamObserver responseObserver, - ByteString buf, long pos) throws Exception; + DataBuffer buf, long pos) throws Exception; /** * Handles a command in the write request. diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java index e88cf845ad97..e8a74967e09b 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWorkerImpl.java @@ -27,6 +27,7 @@ import alluxio.grpc.ReadResponseMarshaller; import alluxio.grpc.RemoveBlockRequest; import alluxio.grpc.RemoveBlockResponse; +import alluxio.grpc.WriteRequestMarshaller; import alluxio.grpc.WriteResponse; import alluxio.util.IdUtils; import alluxio.worker.WorkerProcess; @@ -57,6 +58,7 @@ public class BlockWorkerImpl extends BlockWorkerGrpc.BlockWorkerImplBase { private WorkerProcess mWorkerProcess; private final AsyncCacheRequestManager mRequestManager; private ReadResponseMarshaller mReadResponseMarshaller = new ReadResponseMarshaller(); + private WriteRequestMarshaller mWriteRequestMarshaller = new WriteRequestMarshaller(); /** * Creates a new implementation of gRPC BlockWorker interface. @@ -76,9 +78,13 @@ public BlockWorkerImpl(WorkerProcess workerProcess, FileSystemContext fsContext) */ public Map getOverriddenMethodDescriptors() { if (ZERO_COPY_ENABLED) { - return ImmutableMap.of(BlockWorkerGrpc.getReadBlockMethod(), + return ImmutableMap.of( + BlockWorkerGrpc.getReadBlockMethod(), BlockWorkerGrpc.getReadBlockMethod().toBuilder() - .setResponseMarshaller(mReadResponseMarshaller).build()); + .setResponseMarshaller(mReadResponseMarshaller).build(), + BlockWorkerGrpc.getWriteBlockMethod(), + BlockWorkerGrpc.getWriteBlockMethod().toBuilder() + .setRequestMarshaller(mWriteRequestMarshaller).build()); } return Collections.emptyMap(); } @@ -99,10 +105,14 @@ public StreamObserver readBlock(StreamObserver respon @Override public StreamObserver writeBlock( - final StreamObserver responseObserver) { - DelegationWriteHandler handler = new DelegationWriteHandler(mWorkerProcess, responseObserver); + StreamObserver responseObserver) { ServerCallStreamObserver serverResponseObserver = (ServerCallStreamObserver) responseObserver; + if (ZERO_COPY_ENABLED) { + responseObserver = + new DataMessageServerRequestObserver<>(responseObserver, mWriteRequestMarshaller, null); + } + DelegationWriteHandler handler = new DelegationWriteHandler(mWorkerProcess, responseObserver); serverResponseObserver.setOnCancelHandler(handler::onCancel); return handler; } diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWriteHandler.java index 36fef34293ce..7526d3c3a74a 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/BlockWriteHandler.java @@ -18,11 +18,11 @@ import alluxio.grpc.WriteResponse; import alluxio.metrics.MetricsSystem; import alluxio.metrics.WorkerMetrics; +import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.worker.block.BlockWorker; import com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; @@ -101,7 +101,7 @@ protected void flushRequest(BlockWriteRequestContext context) @Override protected void writeBuf(BlockWriteRequestContext context, - StreamObserver observer, ByteString buf, long pos) throws Exception { + StreamObserver observer, DataBuffer buf, long pos) throws Exception { Preconditions.checkState(context != null); WriteRequest request = context.getRequest(); long bytesReserved = context.getBytesReserved(); @@ -119,8 +119,7 @@ protected void writeBuf(BlockWriteRequestContext context, context.setMeter(MetricsSystem.meter(WorkerMetrics.BYTES_WRITTEN_ALLUXIO_THROUGHPUT)); } Preconditions.checkState(context.getBlockWriter() != null); - int sz = buf.size(); - Preconditions.checkState( - context.getBlockWriter().append(buf.asReadOnlyByteBuffer()) == sz); + int sz = buf.readableBytes(); + Preconditions.checkState(context.getBlockWriter().append(buf) == sz); } } diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerRequestObserver.java b/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerRequestObserver.java new file mode 100644 index 000000000000..220a5daf90f2 --- /dev/null +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/DataMessageServerRequestObserver.java @@ -0,0 +1,63 @@ +/* + * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 + * (the "License"). You may not use this work except in compliance with the License, which is + * available at www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied, as more fully set forth in the License. + * + * See the NOTICE file distributed with this work for information regarding copyright ownership. + */ + +package alluxio.worker.grpc; + +import alluxio.grpc.DataMessageMarshaller; +import alluxio.grpc.DataMessageMarshallerProvider; + +import io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A {@link StreamObserver} that handles raw data buffers. + * + * @param type of the response message + * @param type of the request message + */ +@NotThreadSafe +public class DataMessageServerRequestObserver + extends DataMessageMarshallerProvider implements StreamObserver { + private static final Logger LOG = + LoggerFactory.getLogger(DataMessageServerRequestObserver.class); + + private final StreamObserver mObserver; + + /** + * @param observer the original response observer + * @param requestMarshaller the marshaller for the request + * @param responseMarshaller the marshaller for the response + */ + public DataMessageServerRequestObserver(StreamObserver observer, + DataMessageMarshaller requestMarshaller, + DataMessageMarshaller responseMarshaller) { + super(requestMarshaller, responseMarshaller); + mObserver = observer; + } + + @Override + public void onNext(ResT value) { + mObserver.onNext(value); + } + + @Override + public void onError(Throwable t) { + mObserver.onError(t); + } + + @Override + public void onCompleted() { + mObserver.onCompleted(); + } +} diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/DelegationWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/DelegationWriteHandler.java index 68c10f59b008..0d9c8dd3cd2e 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/DelegationWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/DelegationWriteHandler.java @@ -11,6 +11,8 @@ package alluxio.worker.grpc; +import alluxio.grpc.DataMessageMarshaller; +import alluxio.grpc.DataMessageMarshallerProvider; import alluxio.grpc.WriteRequest; import alluxio.grpc.WriteResponse; import alluxio.worker.WorkerProcess; @@ -25,6 +27,7 @@ public class DelegationWriteHandler implements StreamObserver { private final StreamObserver mResponseObserver; private final WorkerProcess mWorkerProcess; + private final DataMessageMarshaller mMarshaller; private AbstractWriteHandler mWriteHandler; /** @@ -35,6 +38,12 @@ public DelegationWriteHandler(WorkerProcess workerProcess, StreamObserver responseObserver) { mWorkerProcess = workerProcess; mResponseObserver = responseObserver; + if (mResponseObserver instanceof DataMessageMarshallerProvider) { + mMarshaller = ((DataMessageMarshallerProvider) mResponseObserver) + .getRequestMarshaller(); + } else { + mMarshaller = null; + } } private AbstractWriteHandler createWriterHandler(alluxio.grpc.WriteRequest request) { @@ -60,7 +69,11 @@ public void onNext(WriteRequest request) { if (mWriteHandler == null) { mWriteHandler = createWriterHandler(request); } - mWriteHandler.write(request); + if (mMarshaller != null) { + mWriteHandler.writeDataMessage(request, mMarshaller.pollBuffer(request)); + } else { + mWriteHandler.write(request); + } } @Override diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcExecutors.java b/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcExecutors.java index f01abf63be00..83a5871bf22b 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcExecutors.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/GrpcExecutors.java @@ -44,6 +44,12 @@ final class GrpcExecutors { THREAD_STOP_MS, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), ThreadFactoryUtils.build("BlockDataReaderExecutor-%d", true)); + public static final ExecutorService BLOCK_WRITER_EXECUTOR = + new ThreadPoolExecutor(THREADS_MIN, + ServerConfiguration.getInt(PropertyKey.WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX), + THREAD_STOP_MS, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), + ThreadFactoryUtils.build("BlockDataWriterExecutor-%d", true)); + /** * Private constructor. */ diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandler.java index 898ef6df1afb..4d3f332f3e80 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandler.java @@ -22,6 +22,7 @@ import alluxio.metrics.Metric; import alluxio.metrics.MetricsSystem; import alluxio.metrics.WorkerMetrics; +import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.proto.dataserver.Protocol; import alluxio.underfs.UfsManager; import alluxio.underfs.UnderFileSystem; @@ -31,7 +32,6 @@ import alluxio.worker.block.meta.TempBlockMeta; import com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -146,11 +146,11 @@ protected void flushRequest(BlockWriteRequestContext context) throws Exception { @Override protected void writeBuf(BlockWriteRequestContext context, - StreamObserver responseObserver, ByteString buf, long pos) throws Exception { + StreamObserver responseObserver, DataBuffer buf, long pos) throws Exception { if (context.isWritingToLocal()) { // TODO(binfan): change signature of writeBuf to pass current offset and length of buffer. // Currently pos is the calculated offset after writeBuf succeeds. - long posBeforeWrite = pos - buf.size(); + long posBeforeWrite = pos - buf.readableBytes(); try { mBlockWriteHandler.writeBuf(context, responseObserver, buf, pos); return; @@ -175,7 +175,7 @@ protected void writeBuf(BlockWriteRequestContext context, if (context.getOutputStream() == null) { createUfsBlock(context); } - buf.writeTo(context.getOutputStream()); + buf.readBytes(context.getOutputStream(), buf.readableBytes()); } @Override diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFileWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFileWriteHandler.java index 0b146f6b77da..34f13abe0dd6 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFileWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/UfsFileWriteHandler.java @@ -17,6 +17,7 @@ import alluxio.metrics.Metric; import alluxio.metrics.MetricsSystem; import alluxio.metrics.WorkerMetrics; +import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.proto.dataserver.Protocol; import alluxio.resource.CloseableResource; import alluxio.security.authorization.Mode; @@ -27,7 +28,6 @@ import com.codahale.metrics.Counter; import com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -132,12 +132,12 @@ protected void flushRequest(UfsFileWriteRequestContext context) @Override protected void writeBuf(UfsFileWriteRequestContext context, - StreamObserver observer, ByteString buf, long pos) throws Exception { + StreamObserver observer, DataBuffer buf, long pos) throws Exception { Preconditions.checkState(context != null); if (context.getOutputStream() == null) { createUfsFile(context); } - buf.writeTo(context.getOutputStream()); + buf.readBytes(context.getOutputStream(), buf.readableBytes()); } private void createUfsFile(UfsFileWriteRequestContext context) diff --git a/core/server/worker/src/test/java/alluxio/worker/grpc/AbstractWriteHandlerTest.java b/core/server/worker/src/test/java/alluxio/worker/grpc/AbstractWriteHandlerTest.java index dc83b7085b7d..e1e9873b147c 100644 --- a/core/server/worker/src/test/java/alluxio/worker/grpc/AbstractWriteHandlerTest.java +++ b/core/server/worker/src/test/java/alluxio/worker/grpc/AbstractWriteHandlerTest.java @@ -14,6 +14,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -24,6 +25,8 @@ import alluxio.grpc.WriteResponse; import alluxio.network.protocol.databuffer.DataBuffer; import alluxio.network.protocol.databuffer.ByteArrayDataBuffer; +import alluxio.util.CommonUtils; +import alluxio.util.WaitForOptions; import alluxio.util.io.BufferUtils; import com.google.protobuf.ByteString; @@ -39,7 +42,10 @@ import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.TimeoutException; /** * Unit tests for {@link AbstractWriteHandler}. @@ -51,6 +57,10 @@ public abstract class AbstractWriteHandlerTest { protected static final long TEST_MOUNT_ID = 10L; protected AbstractWriteHandler mWriteHandler; protected StreamObserver mResponseObserver; + protected List mResponses = new ArrayList<>(); + protected boolean mResponseCompleted; + protected Throwable mError; + @Rule public ExpectedException mExpectedException = ExpectedException.none(); @@ -61,6 +71,7 @@ public abstract class AbstractWriteHandlerTest { public void writeEmptyFile() throws Exception { mWriteHandler.write(newWriteRequestCommand(0)); mWriteHandler.onCompleted(); + waitForResponses(); checkComplete(mResponseObserver); } @@ -77,6 +88,7 @@ public void writeNonEmptyFile() throws Exception { } // EOF. mWriteHandler.onCompleted(); + waitForResponses(); checkComplete(mResponseObserver); checkWriteData(checksum, len); } @@ -95,6 +107,7 @@ public void cancel() throws Exception { // Cancel. mWriteHandler.onCancel(); + waitForResponses(); checkComplete(mResponseObserver); // Our current implementation does not really abort the file when the write is cancelled. // The client issues another request to block worker to abort it. @@ -116,6 +129,7 @@ public void cancelIgnoreError() throws Exception { mWriteHandler.onCancel(); mWriteHandler.onError(Status.CANCELLED.asRuntimeException()); + waitForResponses(); checkComplete(mResponseObserver); checkWriteData(checksum, len); verify(mResponseObserver, never()).onError(any(Throwable.class)); @@ -125,6 +139,7 @@ public void cancelIgnoreError() throws Exception { public void writeInvalidOffsetFirstRequest() throws Exception { // The write request contains an invalid offset mWriteHandler.write(newWriteRequestCommand(1)); + waitForResponses(); checkErrorCode(mResponseObserver, Status.Code.INVALID_ARGUMENT); } @@ -133,6 +148,7 @@ public void writeInvalidOffsetLaterRequest() throws Exception { mWriteHandler.write(newWriteRequestCommand(0)); // The write request contains an invalid offset mWriteHandler.write(newWriteRequestCommand(1)); + waitForResponses(); checkErrorCode(mResponseObserver, Status.Code.INVALID_ARGUMENT); } @@ -264,4 +280,31 @@ public static long getChecksum(ByteBuf buffer) { } return ret; } + + /** + * Waits for response messages. + */ + protected void waitForResponses() + throws TimeoutException, InterruptedException { + CommonUtils.waitFor("response", () -> mResponseCompleted || mError != null, + WaitForOptions.defaults().setTimeoutMs(Constants.MINUTE_MS)); + } + + protected void setupResponseTrigger() { + doAnswer(args -> { + mResponseCompleted = true; + return null; + }).when(mResponseObserver).onCompleted(); + doAnswer(args -> { + mResponseCompleted = true; + mError = args.getArgumentAt(0, Throwable.class); + return null; + }).when(mResponseObserver).onError(any(Throwable.class)); + doAnswer((args) -> { + // make a copy of response data before it is released + mResponses.add(WriteResponse.parseFrom( + args.getArgumentAt(0, WriteResponse.class).toByteString())); + return null; + }).when(mResponseObserver).onNext(any(WriteResponse.class)); + } } diff --git a/core/server/worker/src/test/java/alluxio/worker/grpc/BlockWriteHandlerTest.java b/core/server/worker/src/test/java/alluxio/worker/grpc/BlockWriteHandlerTest.java index 0fcd0ccf4bb2..86de916d3475 100644 --- a/core/server/worker/src/test/java/alluxio/worker/grpc/BlockWriteHandlerTest.java +++ b/core/server/worker/src/test/java/alluxio/worker/grpc/BlockWriteHandlerTest.java @@ -52,6 +52,7 @@ public void before() throws Exception { .thenReturn(new LocalFileBlockWriter(mTestFolder.newFile().getPath())); mResponseObserver = Mockito.mock(StreamObserver.class); mWriteHandler = new BlockWriteHandler(mBlockWorker, mResponseObserver); + setupResponseTrigger(); } @Test @@ -59,6 +60,7 @@ public void writeFailure() throws Exception { mWriteHandler.write(newWriteRequestCommand(0)); mBlockWriter.close(); mWriteHandler.write(newWriteRequest(newDataBuffer(CHUNK_SIZE))); + waitForResponses(); checkErrorCode(mResponseObserver, Status.Code.FAILED_PRECONDITION); } diff --git a/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandlerTest.java b/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandlerTest.java index 072e16b5d2bd..91292406fc7d 100644 --- a/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandlerTest.java +++ b/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFallbackBlockWriteHandlerTest.java @@ -93,6 +93,7 @@ public void before() throws Exception { mResponseObserver = Mockito.mock(StreamObserver.class); mWriteHandler = new UfsFallbackBlockWriteHandler(mBlockWorker, ufsManager, mResponseObserver); + setupResponseTrigger(); // create a partial block in block store first mBlockStore.createBlock(TEST_SESSION_ID, TEST_BLOCK_ID, @@ -114,6 +115,7 @@ public void noTempBlockFound() throws Exception { // remove the block partially created mBlockStore.abortBlock(TEST_SESSION_ID, TEST_BLOCK_ID); mWriteHandler.write(newFallbackInitRequest(PARTIAL_WRITTEN)); + waitForResponses(); checkErrorCode(mResponseObserver, Status.Code.NOT_FOUND); } @@ -124,6 +126,7 @@ public void tempBlockWritten() throws Exception { mWriteHandler.write(newFallbackInitRequest(PARTIAL_WRITTEN)); mWriteHandler.write(newWriteRequest(buffer)); mWriteHandler.onCompleted(); + waitForResponses(); checkComplete(mResponseObserver); checkWriteData(checksum, PARTIAL_WRITTEN + CHUNK_SIZE); } diff --git a/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFileWriteHandlerTest.java b/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFileWriteHandlerTest.java index e33188c0592a..a070b0ebe0d9 100644 --- a/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFileWriteHandlerTest.java +++ b/core/server/worker/src/test/java/alluxio/worker/grpc/UfsFileWriteHandlerTest.java @@ -54,6 +54,7 @@ public void before() throws Exception { .thenReturn(new FileOutputStream(mFile, true)); mResponseObserver = Mockito.mock(StreamObserver.class); mWriteHandler = new UfsFileWriteHandler(ufsManager, mResponseObserver); + setupResponseTrigger(); } @After @@ -67,6 +68,7 @@ public void writeFailure() throws Exception { mWriteHandler.write(newWriteRequest(newDataBuffer(CHUNK_SIZE))); mOutputStream.close(); mWriteHandler.write(newWriteRequest(newDataBuffer(CHUNK_SIZE))); + waitForResponses(); checkErrorCode(mResponseObserver, Status.Code.UNKNOWN); } From 6c0c9a3609ef63d707fc2c58ebe73cf8117adf96 Mon Sep 17 00:00:00 2001 From: bf8086 Date: Wed, 6 Mar 2019 17:29:35 -0800 Subject: [PATCH 2/4] Address comments --- .../block/stream/GrpcDataMessageBlockingStream.java | 3 +++ .../src/main/java/alluxio/conf/PropertyKey.java | 2 +- .../java/alluxio/grpc/GrpcSerializationUtils.java | 2 +- .../java/alluxio/grpc/WriteRequestMarshaller.java | 4 ++-- .../protocol/databuffer/ByteArrayDataBuffer.java | 11 +++++++---- .../java/alluxio/worker/block/io/BlockWriter.java | 4 ++-- .../alluxio/worker/grpc/AbstractWriteHandler.java | 3 +++ 7 files changed, 19 insertions(+), 10 deletions(-) diff --git a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java index ef63dc7b6d5b..fc310de49ab8 100644 --- a/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java +++ b/core/client/fs/src/main/java/alluxio/client/block/stream/GrpcDataMessageBlockingStream.java @@ -16,6 +16,7 @@ import alluxio.grpc.DataMessageMarshaller; import alluxio.network.protocol.databuffer.DataBuffer; +import com.google.common.base.Preconditions; import io.grpc.stub.StreamObserver; import java.io.IOException; @@ -78,6 +79,8 @@ public ResT receive(long timeoutMs) throws IOException { * @throws IOException if any error occurs */ public DataMessage receiveDataMessage(long timeoutMs) throws IOException { + Preconditions.checkNotNull(mResponseMarshaller, + "Cannot retrieve data message without a response marshaller."); ResT response = super.receive(timeoutMs); if (response == null) { return null; diff --git a/core/common/src/main/java/alluxio/conf/PropertyKey.java b/core/common/src/main/java/alluxio/conf/PropertyKey.java index 6929958d7059..996a33143f34 100644 --- a/core/common/src/main/java/alluxio/conf/PropertyKey.java +++ b/core/common/src/main/java/alluxio/conf/PropertyKey.java @@ -2000,7 +2000,7 @@ public String toString() { new Builder(Name.WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES) .setDefaultValue(8) .setDescription("When a client writes to a remote worker, the maximum number of " - + "data messages to buffer by the server.") + + "data messages to buffer by the server for each request.") .setConsistencyCheckLevel(ConsistencyCheckLevel.WARN) .setScope(Scope.WORKER) .build(); diff --git a/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java index 091e8f8a1cbd..5d761e2b9800 100644 --- a/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java +++ b/core/common/src/main/java/alluxio/grpc/GrpcSerializationUtils.java @@ -137,7 +137,7 @@ public static ReadableBuffer getBufferFromStream(InputStream stream) { * Gets a Netty buffer directly from a gRPC ReadableBuffer. * * @param buffer the input buffer - * @return the raw ByteBuf + * @return the raw ByteBuf, or null if the ByteBuf cannot be extracted */ public static ByteBuf getByteBufFromReadableBuffer(ReadableBuffer buffer) { if (!sZeroCopyReceiveSupported) { diff --git a/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java b/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java index 8c00892a23a8..2ef3ec2e641e 100644 --- a/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java +++ b/core/common/src/main/java/alluxio/grpc/WriteRequestMarshaller.java @@ -61,8 +61,8 @@ protected ByteBuf[] serialize(WriteRequest message) throws IOException { chunkBuffer = new NettyDataBuffer( Unpooled.wrappedBuffer(message.getChunk().getData().asReadOnlyByteBuffer())); } - int size = message.getSerializedSize() - chunkBuffer.readableBytes(); - byte[] header = new byte[size]; + int headerSize = message.getSerializedSize() - chunkBuffer.readableBytes(); + byte[] header = new byte[headerSize]; CodedOutputStream stream = CodedOutputStream.newInstance(header); stream.writeTag(WriteRequest.CHUNK_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); stream.writeUInt32NoTag(message.getChunk().getSerializedSize()); diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java b/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java index 94a5092a3344..59ba7620f85f 100644 --- a/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java +++ b/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java @@ -53,23 +53,26 @@ public ByteBuffer getReadOnlyByteBuffer() { @Override public void readBytes(byte[] dst, int dstIndex, int length) { - throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); + throw new UnsupportedOperationException( + "ByteArrayDataBuffer#readBytes(byte[] dst, int dstIndex, int length) is not implemented."); } @Override public void readBytes(OutputStream outputStream, int length) { - throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); + throw new UnsupportedOperationException( + "ByteArrayDataBuffer#readBytes(OutputStream outputStream, int length) is not implemented."); } @Override public void readBytes(ByteBuffer outputBuf) { - throw new UnsupportedOperationException("ByteArrayDataBuffer#readBytes is not implemented."); + throw new UnsupportedOperationException( + "ByteArrayDataBuffer#readBytes(ByteBuffer outputBuf) is not implemented."); } @Override public int readableBytes() { throw new UnsupportedOperationException( - "ByteArrayDataBuffer#readableBytes is not implemented."); + "ByteArrayDataBuffer#readableBytes() is not implemented."); } @Override diff --git a/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java b/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java index 41fb9f4752ea..5d1d1213d114 100644 --- a/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java +++ b/core/common/src/main/java/alluxio/worker/block/io/BlockWriter.java @@ -37,7 +37,7 @@ public interface BlockWriter extends Closeable { /** * Appends buf.readableBytes() bytes to the end of this block writer from the given buf. * - * @param buf the byte buffer to hold the data + * @param buf the byte buffer that holds the data * @return the size of data that was appended in bytes */ long append(ByteBuf buf) throws IOException; @@ -45,7 +45,7 @@ public interface BlockWriter extends Closeable { /** * Appends buffer.readableBytes() bytes to the end of this block writer from the given buffer. * - * @param buffer the byte buffer to hold the data + * @param buffer the byte buffer that holds the data * @return the size of data that was appended in bytes */ long append(DataBuffer buffer) throws IOException; diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java index dc473f328015..11148262614c 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java @@ -59,8 +59,11 @@ abstract class AbstractWriteHandler> { private static final Logger LOG = LoggerFactory.getLogger(AbstractWriteHandler.class); + /** The observer for sending response messages. */ private final StreamObserver mResponseObserver; + /** The executor for running write tasks asynchronously in the submission order. */ private final SerializingExecutor mSerializingExecutor; + /** The semaphore to control the number of write tasks queued up in the executor.*/ private final Semaphore mSemaphore = new Semaphore( ServerConfiguration.getInt(PropertyKey.WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES), true); From f21b56c5c8c95c127170e89c8225c547713cd08a Mon Sep 17 00:00:00 2001 From: bf8086 Date: Thu, 7 Mar 2019 11:08:46 -0800 Subject: [PATCH 3/4] Addressed comments --- .../worker/grpc/AbstractWriteHandler.java | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java index 11148262614c..86c207cc5b34 100644 --- a/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java +++ b/core/server/worker/src/main/java/alluxio/worker/grpc/AbstractWriteHandler.java @@ -93,12 +93,7 @@ abstract class AbstractWriteHandler> { * @param writeRequest the request from the client */ public void write(WriteRequest writeRequest) { - try { - mSemaphore.acquire(); - } catch (InterruptedException e) { - LOG.warn("write data request {} is interrupted: {}", writeRequest, e.getMessage()); - abort(new Error(AlluxioStatusException.fromThrowable(e), true)); - Thread.currentThread().interrupt(); + if (!tryAcquireSemaphore()) { return; } mSerializingExecutor.execute(() -> { @@ -154,17 +149,15 @@ public void writeDataMessage(WriteRequest request, DataBuffer buffer) { "write request command should not come with data buffer"); Preconditions.checkState(buffer.readableBytes() > 0, "invalid data size from write request message"); - try { - mSemaphore.acquire(); - } catch (InterruptedException e) { - LOG.warn("write data request {} is interrupted: {}", request, e.getMessage()); - abort(new Error(AlluxioStatusException.fromThrowable(e), true)); - Thread.currentThread().interrupt(); + if (!tryAcquireSemaphore()) { return; } mSerializingExecutor.execute(() -> { - writeData(buffer); - mSemaphore.release(); + try { + writeData(buffer); + } finally { + mSemaphore.release(); + } }); } @@ -221,6 +214,19 @@ public void onError(Throwable cause) { }); } + private boolean tryAcquireSemaphore() { + try { + mSemaphore.acquire(); + } catch (InterruptedException e) { + LOG.warn("write data request {} is interrupted: {}", + mContext == null ? "unknown" : mContext.getRequest(), e.getMessage()); + abort(new Error(AlluxioStatusException.fromThrowable(e), true)); + Thread.currentThread().interrupt(); + return false; + } + return true; + } + /** * Validates a block write request. * From 111d5a34b9db271f6efe993d708767d933f55436 Mon Sep 17 00:00:00 2001 From: bf8086 Date: Fri, 8 Mar 2019 11:16:24 -0800 Subject: [PATCH 4/4] move ByteArrayDataBuffer --- .../alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename core/{common/src/main => server/worker/src/test}/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java (100%) diff --git a/core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java b/core/server/worker/src/test/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java similarity index 100% rename from core/common/src/main/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java rename to core/server/worker/src/test/java/alluxio/network/protocol/databuffer/ByteArrayDataBuffer.java