From ca7c8c53d921da2717426c4560f07169d69cf4f7 Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 30 Jun 2021 14:33:34 -0700 Subject: [PATCH] Added zero-copy --- end2end-test-examples/gcs/build.gradle | 5 +- .../gcs/src/main/java/io/grpc/gcs/Args.java | 3 + .../src/main/java/io/grpc/gcs/GrpcClient.java | 59 +++++++-- .../grpc/gcs/ZeroCopyMessageMarshaller.java | 112 ++++++++++++++++++ .../io/grpc/gcs/ZeroCopyReadinessChecker.java | 44 +++++++ 5 files changed, 212 insertions(+), 11 deletions(-) create mode 100644 end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyMessageMarshaller.java create mode 100644 end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyReadinessChecker.java diff --git a/end2end-test-examples/gcs/build.gradle b/end2end-test-examples/gcs/build.gradle index 5392a3f7..ef69f5d4 100644 --- a/end2end-test-examples/gcs/build.gradle +++ b/end2end-test-examples/gcs/build.gradle @@ -10,12 +10,11 @@ version '1.0-SNAPSHOT' sourceCompatibility = 1.8 def gcsioVersion = '2.2.3-SNAPSHOT' -def grpcVersion = '1.38.0' -def protobufVersion = '3.17.0' +def grpcVersion = '1.39.0' +def protobufVersion = '3.17.3' def protocVersion = protobufVersion def conscryptVersion = '2.5.1' - repositories { mavenLocal() mavenCentral() diff --git a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/Args.java b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/Args.java index bf359db5..ef4a7262 100644 --- a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/Args.java +++ b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/Args.java @@ -36,6 +36,7 @@ public class Args { final boolean checksum; final boolean verboseLog; final boolean verboseResult; + final int zeroCopy; // 0=auto, 1=on, -1=off Args(String[] args) throws ArgumentParserException { ArgumentParser parser = @@ -63,6 +64,7 @@ public class Args { parser.addArgument("--checksum").type(Boolean.class).setDefault(false); parser.addArgument("--verboseLog").type(Boolean.class).setDefault(false); parser.addArgument("--verboseResult").type(Boolean.class).setDefault(false); + parser.addArgument("--zeroCopy").type(Integer.class).setDefault(0); Namespace ns = parser.parseArgs(args); @@ -86,5 +88,6 @@ public class Args { checksum = ns.getBoolean("checksum"); verboseLog = ns.getBoolean("verboseLog"); verboseResult = ns.getBoolean("verboseResult"); + zeroCopy = ns.getInt("zeroCopy"); } } diff --git a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/GrpcClient.java b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/GrpcClient.java index 06be414f..6cc55fbe 100644 --- a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/GrpcClient.java +++ b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/GrpcClient.java @@ -24,20 +24,32 @@ import io.grpc.auth.MoreCallCredentials; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.StreamObserver; +import io.grpc.MethodDescriptor; import java.io.IOException; import java.lang.reflect.Field; -import java.util.Iterator; -import java.util.List; -import java.util.Random; +import java.io.InputStream; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.Iterator; +import java.util.List; import java.util.logging.Logger; +import java.util.NoSuchElementException; +import java.util.Random; public class GrpcClient { private static final Logger logger = Logger.getLogger(GrpcClient.class.getName()); + // ZeroCopy version of GetObjectMedia Method + private static final ZeroCopyMessageMarshaller getObjectMediaResponseMarshaller = + new ZeroCopyMessageMarshaller(GetObjectMediaResponse.getDefaultInstance()); + private static final MethodDescriptor getObjectMediaMethod = + StorageGrpc.getGetObjectMediaMethod() + .toBuilder().setResponseMarshaller(getObjectMediaResponseMarshaller) + .build(); + private final boolean useZeroCopy; + private ManagedChannel[] channels; private Args args; private GoogleCredentials creds; @@ -50,7 +62,7 @@ public GrpcClient(Args args) { this.creds = GoogleCredentials.getApplicationDefault(); } catch (IOException e) { e.printStackTrace(); - return; + throw new RuntimeException(e); } ManagedChannelBuilder channelBuilder; @@ -95,6 +107,13 @@ public GrpcClient(Args args) { for (int i = 0; i < args.threads; i++) { channels[i] = channelBuilder.build(); } + + if (args.zeroCopy == 0) { + useZeroCopy = ZeroCopyReadinessChecker.isReady(); + } else { + useZeroCopy = args.zeroCopy > 0; + } + logger.info("useZeroCopy: " + useZeroCopy); } public void startCalls(ResultTable results) throws InterruptedException { @@ -162,6 +181,7 @@ public void startCalls(ResultTable results) throws InterruptedException { } private void makeMediaRequest(ManagedChannel channel, ResultTable results) { + StorageGrpc.StorageBlockingStub blockingStub = StorageGrpc.newBlockingStub(channel).withCallCredentials( MoreCallCredentials.from(creds.createScoped(SCOPE))); @@ -172,12 +192,35 @@ private void makeMediaRequest(ManagedChannel channel, ResultTable results) { GetObjectMediaRequest mediaRequest = GetObjectMediaRequest.newBuilder().setBucket(args.bkt).setObject(args.obj).build(); - + byte[] scratch = new byte[4*1024*1024]; for (int i = 0; i < args.calls; i++) { long start = System.currentTimeMillis(); - Iterator resIterator = blockingStub.getObjectMedia(mediaRequest); - while (resIterator.hasNext()) { - GetObjectMediaResponse res = resIterator.next(); + Iterator resIterator; + if (useZeroCopy) { + resIterator = io.grpc.stub.ClientCalls.blockingServerStreamingCall( + blockingStub.getChannel(), getObjectMediaMethod, blockingStub.getCallOptions(), mediaRequest); + } else { + resIterator = blockingStub.getObjectMedia(mediaRequest); + } + try { + while (true) { + GetObjectMediaResponse res = resIterator.next(); + InputStream stream = getObjectMediaResponseMarshaller.popStream(res); + // Just copy to scratch memory to ensure its data is consumed. + ByteString content = res.getChecksummedData().getContent(); + content.copyTo(scratch, 0); + // When zero-copy mashaller is used, the stream that backs GetObjectMediaResponse + // should be closed when the mssage is no longed needed so that all buffers in the + // stream can be reclaimed. + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } catch (NoSuchElementException e) { } long dur = System.currentTimeMillis() - start; results.reportResult(dur); diff --git a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyMessageMarshaller.java b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyMessageMarshaller.java new file mode 100644 index 00000000..6ddfea1c --- /dev/null +++ b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyMessageMarshaller.java @@ -0,0 +1,112 @@ +package io.grpc.gcs; + +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.MessageLite; +import com.google.protobuf.Parser; +import com.google.protobuf.UnsafeByteOperations; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.KnownLength; +import io.grpc.MethodDescriptor.PrototypeMarshaller; +import io.grpc.protobuf.lite.ProtoLiteUtils; +import io.grpc.Status; +import io.grpc.MethodDescriptor; +import java.io.InputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; + +// Custom gRPC marshaller to use zero memory copy feature of gRPC when deserializing messages. +// This achieves zero-copy by deserializing proto messages pointing to the buffers in the input +// stream to avoid memory copy so stream should live as long as the message can be referenced. +// Hence, it exposes the input stream to applications (through popStream) and applications are +// responsible to close it when it's no longer needed. Otherwise, it'd cause memory leak. +class ZeroCopyMessageMarshaller implements PrototypeMarshaller { + private Map unclosedStreams = Collections.synchronizedMap(new IdentityHashMap<>()); + private final Parser parser; + private final PrototypeMarshaller baseMarshaller; + + ZeroCopyMessageMarshaller(T defaultInstance) { + parser = (Parser) defaultInstance.getParserForType(); + baseMarshaller = (PrototypeMarshaller) ProtoLiteUtils.marshaller(defaultInstance); + } + + @Override + public Class getMessageClass() { + return baseMarshaller.getMessageClass(); + } + + @Override + public T getMessagePrototype() { + return baseMarshaller.getMessagePrototype(); + } + + @Override + public InputStream stream(T value) { + return baseMarshaller.stream(value); + } + + @Override + public T parse(InputStream stream) { + CodedInputStream cis = null; + try { + if (stream instanceof KnownLength) { + int size = stream.available(); + if (stream instanceof Detachable && ((HasByteBuffer) stream).byteBufferSupported()) { + // Stream is now detached here and should be closed later. + stream = ((Detachable) stream).detach(); + // This mark call is to keep buffer while traversing buffers using skip. + stream.mark(size); + List byteStrings = new ArrayList<>(); + while (stream.available() != 0) { + ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer(); + byteStrings.add(UnsafeByteOperations.unsafeWrap(buffer)); + stream.skip(buffer.remaining()); + } + stream.reset(); + cis = ByteString.copyFrom(byteStrings).newCodedInput(); + cis.enableAliasing(true); + cis.setSizeLimit(Integer.MAX_VALUE); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + if (cis != null) { + // fast path (no memory copy) + T message; + try { + message = parseFrom(cis); + } catch (InvalidProtocolBufferException ipbe) { + throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence").withCause(ipbe).asRuntimeException(); + }edStreams.put(message, stream); + return message; + } else { + // slow path + return baseMarshaller.parse(stream); + } + } + + private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException { + T message = parser.parseFrom(stream); + try { + stream.checkLastTagWas(0); + return message; + } catch (InvalidProtocolBufferException e) { + e.setUnfinishedMessage(message); + throw e; + } + } + + // Application needs to call this function to get the stream for the message and + // call stream.close() function to return it to the pool. + public InputStream popStream(T message) { + return unclosedStreams.remove(message); + } +} diff --git a/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyReadinessChecker.java b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyReadinessChecker.java new file mode 100644 index 00000000..825545a8 --- /dev/null +++ b/end2end-test-examples/gcs/src/main/java/io/grpc/gcs/ZeroCopyReadinessChecker.java @@ -0,0 +1,44 @@ +package io.grpc.gcs; + +import com.google.protobuf.MessageLite; +import io.grpc.KnownLength; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.Provider; + +public class ZeroCopyReadinessChecker { + private static final boolean isZeroCopyReady; + + static { + // Check whether io.grpc.Detachable exists? + boolean detachableClassExists = false; + try { + // Try to load Detachable interface in the package where KnownLength is in. + // This can be done directly by looking up io.grpc.Detachable but rather + // done indirectly to handle the case where gRPC is being shaded in a + // different package. + String knownLengthClassName = KnownLength.class.getName(); + String detachableClassName = knownLengthClassName.substring(0, knownLengthClassName.lastIndexOf('.') + 1) + + "Detachable"; + Class detachableClass = Class.forName(detachableClassName); + detachableClassExists = (detachableClass != null); + } catch (ClassNotFoundException ex) { + } + // Check whether com.google.protobuf.UnsafeByteOperations exists? + boolean unsafeByteOperationsClassExists = false; + try { + // Same above + String messageLiteClassName = MessageLite.class.getName(); + String unsafeByteOperationsClassName = messageLiteClassName.substring(0, + messageLiteClassName.lastIndexOf('.') + 1) + "UnsafeByteOperations"; + Class unsafeByteOperationsClass = Class.forName(unsafeByteOperationsClassName); + unsafeByteOperationsClassExists = (unsafeByteOperationsClass != null); + } catch (ClassNotFoundException ex) { + } + isZeroCopyReady = detachableClassExists && unsafeByteOperationsClassExists; + } + + public static boolean isReady() { + return isZeroCopyReady; + } +}