Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,23 @@
*/
@NotThreadSafe
public class DataMessageClientResponseObserver<ReqT, RespT>
implements ClientResponseObserver<ReqT, RespT>, DataMessageMarshallerProvider<RespT> {
extends DataMessageMarshallerProvider<ReqT, RespT>
implements ClientResponseObserver<ReqT, RespT> {
private static final Logger LOG =
LoggerFactory.getLogger(DataMessageClientResponseObserver.class);

private final StreamObserver<RespT> mObserver;
private final DataMessageMarshaller<RespT> mMarshaller;

/**
* @param observer the original response observer
* @param marshaller the marshaller for the response
* @param requestMarshaller the marshaller for the request
* @param responseMarshaller the marshaller for the response
*/
public DataMessageClientResponseObserver(StreamObserver<RespT> observer,
DataMessageMarshaller<RespT> marshaller) {
DataMessageMarshaller<ReqT> requestMarshaller,
DataMessageMarshaller<RespT> responseMarshaller) {
super(requestMarshaller, responseMarshaller);
mObserver = observer;
mMarshaller = marshaller;
}

@Override
Expand All @@ -70,9 +72,4 @@ public void beforeStart(ClientCallStreamObserver<ReqT> requestStream) {
LOG.warn("{} does not implement ClientResponseObserver:beforeStart", mObserver);
}
}

@Override
public DataMessageMarshaller<RespT> getMarshaller() {
return mMarshaller;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import alluxio.grpc.GrpcSerializationUtils;
import alluxio.util.network.NettyUtils;

import com.google.common.base.Preconditions;
import com.google.common.io.Closer;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
Expand Down Expand Up @@ -114,14 +115,29 @@ public void close() throws IOException {

@Override
public StreamObserver<WriteRequest> writeBlock(StreamObserver<WriteResponse> responseObserver) {
return mStreamingAsyncStub.writeBlock(responseObserver);
if (responseObserver instanceof DataMessageMarshallerProvider) {
DataMessageMarshaller<WriteRequest> marshaller =
((DataMessageMarshallerProvider<WriteRequest, WriteResponse>) responseObserver)
.getRequestMarshaller();
Preconditions.checkNotNull(marshaller);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't a DataMessageMarshallerProvider provide null sometimes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataMessageMarshallerProvider is a generic holder of request marshaller and response marshaller. Depending on the situation, the caller most likely only provide one of the marshallers. In this writeBlock method, it is required that the user should provide the request marshaller for zero copy to work correctly, hence the check.

If a user does not want to use zero copy at all, they should not pass in a DataMessageMarshallerProvider.

return mStreamingAsyncStub
.withOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR,
BlockWorkerGrpc.getWriteBlockMethod().toBuilder()
.setRequestMarshaller(marshaller)
.build())
.writeBlock(responseObserver);
} else {
return mStreamingAsyncStub.writeBlock(responseObserver);
}
}

@Override
public StreamObserver<ReadRequest> readBlock(StreamObserver<ReadResponse> responseObserver) {
if (responseObserver instanceof DataMessageMarshallerProvider) {
DataMessageMarshaller<ReadResponse> marshaller =
((DataMessageMarshallerProvider<ReadResponse>) responseObserver).getMarshaller();
((DataMessageMarshallerProvider<ReadRequest, ReadResponse>) responseObserver)
.getResponseMarshaller();
Preconditions.checkNotNull(marshaller);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this be null sometimes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, if user does not want to use zero copy for reading blocks, they should not pass DataMessageMarshallerProvider instead of giving a null response marshaller.

return mStreamingAsyncStub
.withOption(GrpcSerializationUtils.OVERRIDDEN_METHOD_DESCRIPTOR,
BlockWorkerGrpc.getReadBlockMethod().toBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import alluxio.grpc.DataMessageMarshaller;
import alluxio.network.protocol.databuffer.DataBuffer;

import com.google.common.base.Preconditions;
import io.grpc.stub.StreamObserver;

import java.io.IOException;
Expand All @@ -31,32 +32,40 @@
*/
@NotThreadSafe
public class GrpcDataMessageBlockingStream<ReqT, ResT> extends GrpcBlockingStream<ReqT, ResT> {
private final DataMessageMarshaller<ResT> mMarshaller;
private final DataMessageMarshaller<ReqT> mRequestMarshaller;
private final DataMessageMarshaller<ResT> mResponseMarshaller;

/**
* @param rpcFunc the gRPC bi-directional stream stub function
* @param bufferSize maximum number of incoming messages the buffer can hold
* @param description description of this stream
* @param deserializer custom deserializer for the response
* @param requestMarshaller the marshaller for the request
* @param responseMarshaller the marshaller for the response
*/
public GrpcDataMessageBlockingStream(Function<StreamObserver<ResT>, StreamObserver<ReqT>> rpcFunc,
int bufferSize, String description, DataMessageMarshaller<ResT> deserializer) {
int bufferSize, String description, DataMessageMarshaller<ReqT> requestMarshaller,
DataMessageMarshaller<ResT> responseMarshaller) {
super((resObserver) -> {
DataMessageClientResponseObserver<ReqT, ResT> newObserver =
new DataMessageClientResponseObserver<>(resObserver, deserializer);
new DataMessageClientResponseObserver<>(resObserver, requestMarshaller,
responseMarshaller);
StreamObserver<ReqT> requestObserver = rpcFunc.apply(newObserver);
return requestObserver;
}, bufferSize, description);
mMarshaller = deserializer;
mRequestMarshaller = requestMarshaller;
mResponseMarshaller = responseMarshaller;
}

@Override
public ResT receive(long timeoutMs) throws IOException {
if (mResponseMarshaller == null) {
return super.receive(timeoutMs);
}
DataMessage<ResT, DataBuffer> message = receiveDataMessage(timeoutMs);
if (message == null) {
return null;
}
return mMarshaller.combineData(message);
return mResponseMarshaller.combineData(message);
}

/**
Expand All @@ -70,16 +79,38 @@ public ResT receive(long timeoutMs) throws IOException {
* @throws IOException if any error occurs
*/
public DataMessage<ResT, DataBuffer> receiveDataMessage(long timeoutMs) throws IOException {
Preconditions.checkNotNull(mResponseMarshaller,
"Cannot retrieve data message without a response marshaller.");
ResT response = super.receive(timeoutMs);
if (response == null) {
return null;
}
DataBuffer buffer = mMarshaller.pollBuffer(response);
DataBuffer buffer = mResponseMarshaller.pollBuffer(response);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mResponseMarshaller can be null?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User should not call this method if they do not have a response marshaller. I will add a check.

return new DataMessage<>(response, buffer);
}

/**
* Sends a request. Will wait until the stream is ready before sending or timeout if the
* given timeout is reached.
*
* @param message the request message with {@link DataBuffer attached}
* @param timeoutMs maximum wait time before throwing a {@link DeadlineExceededException}
* @throws IOException if any error occurs
*/
public void sendDataMessage(DataMessage<ReqT, DataBuffer> message, long timeoutMs)
throws IOException {
if (mRequestMarshaller != null) {
mRequestMarshaller.offerBuffer(message.getBuffer(), message.getMessage());
}
super.send(message.getMessage(), timeoutMs);
}

@Override
public void waitForComplete(long timeoutMs) throws IOException {
if (mResponseMarshaller == null) {
super.waitForComplete(timeoutMs);
return;
}
DataMessage<ResT, DataBuffer> message;
while (!isCanceled() && (message = receiveDataMessage(timeoutMs)) != null) {
if (message.getBuffer() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private GrpcDataReader(FileSystemContext context, WorkerNetAddress address,
.add("request", mReadRequest)
.add("address", address)
.toString(),
mMarshaller);
null, mMarshaller);
} else {
mStream = new GrpcBlockingStream<>(mClient::readBlock, mReaderBufferSizeMessages,
MoreObjects.toStringHelper(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import alluxio.conf.PropertyKey;
import alluxio.exception.status.UnavailableException;
import alluxio.grpc.Chunk;
import alluxio.grpc.DataMessage;
import alluxio.grpc.RequestType;
import alluxio.grpc.WriteRequest;
import alluxio.grpc.WriteRequestCommand;
import alluxio.grpc.WriteRequestMarshaller;
import alluxio.grpc.WriteResponse;
import alluxio.network.protocol.databuffer.NettyDataBuffer;
import alluxio.proto.dataserver.Protocol;
import alluxio.util.proto.ProtoUtils;
import alluxio.wire.WorkerNetAddress;
Expand Down Expand Up @@ -69,6 +72,7 @@ public final class GrpcDataWriter implements DataWriter {
private final WriteRequestCommand mPartialRequest;
private final long mChunkSize;
private final GrpcBlockingStream<WriteRequest, WriteResponse> mStream;
private final WriteRequestMarshaller mMarshaller;

/**
* The next pos to queue to the buffer.
Expand Down Expand Up @@ -152,11 +156,21 @@ private GrpcDataWriter(FileSystemContext context, final WorkerNetAddress address
mPartialRequest = builder.buildPartial();
mChunkSize = chunkSize;
mClient = client;
mStream = new GrpcBlockingStream<>(mClient::writeBlock, mWriterBufferSizeMessages,
MoreObjects.toStringHelper(this)
.add("request", mPartialRequest)
.add("address", address)
.toString());
mMarshaller = new WriteRequestMarshaller();
if (conf.getBoolean(PropertyKey.USER_NETWORK_ZEROCOPY_ENABLED)) {
mStream = new GrpcDataMessageBlockingStream<>(
mClient::writeBlock, mWriterBufferSizeMessages,
MoreObjects.toStringHelper(this)
.add("request", mPartialRequest)
.add("address", address)
.toString(), mMarshaller, null);
} else {
mStream = new GrpcBlockingStream<>(mClient::writeBlock, mWriterBufferSizeMessages,
MoreObjects.toStringHelper(this)
.add("request", mPartialRequest)
.add("address", address)
.toString());
}
mStream.send(WriteRequest.newBuilder().setCommand(mPartialRequest.toBuilder()).build(),
mDataTimeoutMs);
}
Expand All @@ -170,11 +184,16 @@ public long pos() {
public void writeChunk(final ByteBuf buf) throws IOException {
mPosToQueue += buf.readableBytes();
try {
mStream.send(WriteRequest.newBuilder().setCommand(mPartialRequest).setChunk(
WriteRequest request = WriteRequest.newBuilder().setCommand(mPartialRequest).setChunk(
Chunk.newBuilder()
.setData(UnsafeByteOperations.unsafeWrap(buf.nioBuffer()))
.build()).build(),
mDataTimeoutMs);
.build()).build();
if (mStream instanceof GrpcDataMessageBlockingStream) {
((GrpcDataMessageBlockingStream<WriteRequest, WriteResponse>) mStream)
.sendDataMessage(new DataMessage<>(request, new NettyDataBuffer(buf)), mDataTimeoutMs);
} else {
mStream.send(request, mDataTimeoutMs);
}
} finally {
buf.release();
}
Expand Down
19 changes: 19 additions & 0 deletions core/common/src/main/java/alluxio/conf/PropertyKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,21 @@ public String toString() {
.setConsistencyCheckLevel(ConsistencyCheckLevel.WARN)
.setScope(Scope.WORKER)
.build();
public static final PropertyKey WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX =
new Builder(Name.WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX)
.setDefaultValue(1024)
.setDescription("The maximum number of threads used to write blocks in the data server.")
.setConsistencyCheckLevel(ConsistencyCheckLevel.WARN)
.setScope(Scope.WORKER)
.build();
public static final PropertyKey WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES =
new Builder(Name.WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES)
.setDefaultValue(8)
.setDescription("When a client writes to a remote worker, the maximum number of "
+ "data messages to buffer by the server for each request.")
.setConsistencyCheckLevel(ConsistencyCheckLevel.WARN)
.setScope(Scope.WORKER)
.build();
public static final PropertyKey WORKER_NETWORK_FLOWCONTROL_WINDOW =
new Builder(Name.WORKER_NETWORK_FLOWCONTROL_WINDOW)
.setDefaultValue("2MB")
Expand Down Expand Up @@ -3821,6 +3836,10 @@ public static final class Name {
"alluxio.worker.network.async.cache.manager.threads.max";
public static final String WORKER_NETWORK_BLOCK_READER_THREADS_MAX =
"alluxio.worker.network.block.reader.threads.max";
public static final String WORKER_NETWORK_BLOCK_WRITER_THREADS_MAX =
"alluxio.worker.network.block.writer.threads.max";
public static final String WORKER_NETWORK_WRITER_BUFFER_SIZE_MESSAGES =
"alluxio.worker.network.writer.buffer.size.messages";
public static final String WORKER_NETWORK_FLOWCONTROL_WINDOW =
"alluxio.worker.network.flowcontrol.window";
public static final String WORKER_NETWORK_KEEPALIVE_TIME_MS =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,36 @@
package alluxio.grpc;

/**
* A provider of {@link DataMessageMarshaller}.
* A provider of {@link DataMessageMarshaller} for a gRPC call.
*
* @param <T> type of the message
* @param <ReqT> type of the request message
* @param <ResT> type of the response message
*/
public interface DataMessageMarshallerProvider<T> {
DataMessageMarshaller<T> getMarshaller();
public class DataMessageMarshallerProvider<ReqT, ResT> {
private final DataMessageMarshaller<ReqT> mRequestMarshaller;
private final DataMessageMarshaller<ResT> mResponseMarshaller;

/**
* @param requestMarshaller the marshaller for the request, or null if not provided
* @param responseMarshaller the marshaller for the response, or null if not provided
*/
public DataMessageMarshallerProvider(DataMessageMarshaller<ReqT> requestMarshaller,
DataMessageMarshaller<ResT> responseMarshaller) {
mRequestMarshaller = requestMarshaller;
mResponseMarshaller = responseMarshaller;
}

/**
* @return the request marshaller
*/
public DataMessageMarshaller<ReqT> getRequestMarshaller() {
return mRequestMarshaller;
}

/**
* @return the response marshaller
*/
public DataMessageMarshaller<ResT> getResponseMarshaller() {
return mResponseMarshaller;
}
}
Loading