diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD index 7f8d359926902a..06fb72dffa432b 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD @@ -97,7 +97,6 @@ java_library( "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment", "//src/main/java/com/google/devtools/common/options", "//src/main/protobuf:failure_details_java_proto", - "//third_party:apache_commons_compress", "//third_party:auth", "//third_party:caffeine", "//third_party:flogger", diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index 293562d0043e6d..5dd7dc03428ca6 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.flogger.GoogleLogger; +import com.google.common.io.CountingOutputStream; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -67,10 +68,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import javax.annotation.Nullable; -import org.apache.commons.compress.utils.CountingOutputStream; /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */ @ThreadSafe @@ -303,7 +302,7 @@ public ListenableFuture uploadActionResult( public ListenableFuture downloadBlob( RemoteActionExecutionContext context, Digest digest, OutputStream out) { if (digest.getSizeBytes() == 0) { - return Futures.immediateFuture(null); + return Futures.immediateVoidFuture(); } @Nullable Supplier digestSupplier = null; @@ -313,18 +312,7 @@ public ListenableFuture downloadBlob( out = digestOut; } - CountingOutputStream outputStream; - if (options.cacheCompression) { - try { - outputStream = new ZstdDecompressingOutputStream(out); - } catch (IOException e) { - return Futures.immediateFailedFuture(e); - } - } else { - outputStream = new CountingOutputStream(out); - } - - return downloadBlob(context, digest, outputStream, digestSupplier); + return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier); } private ListenableFuture downloadBlob( @@ -332,7 +320,6 @@ private ListenableFuture downloadBlob( Digest digest, CountingOutputStream out, @Nullable Supplier digestSupplier) { - AtomicLong offset = new AtomicLong(0); ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff); ListenableFuture downloadFuture = Utils.refreshIfUnauthenticatedAsync( @@ -343,7 +330,6 @@ private ListenableFuture downloadBlob( channel -> requestRead( context, - offset, progressiveBackoff, digest, out, @@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean private ListenableFuture requestRead( RemoteActionExecutionContext context, - AtomicLong offset, ProgressiveBackoff progressiveBackoff, Digest digest, - CountingOutputStream out, + CountingOutputStream rawOut, @Nullable Supplier digestSupplier, Channel channel) { String resourceName = getResourceName(options.remoteInstanceName, digest, options.cacheCompression); SettableFuture future = SettableFuture.create(); + OutputStream out; + try { + out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut; + } catch (IOException e) { + return Futures.immediateFailedFuture(e); + } bsAsyncStub(context, channel) .read( ReadRequest.newBuilder() .setResourceName(resourceName) - .setReadOffset(offset.get()) + .setReadOffset(rawOut.getCount()) .build(), new StreamObserver() { @@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) { ByteString data = readResponse.getData(); try { data.writeTo(out); - offset.set(out.getBytesWritten()); } catch (IOException e) { // Cancel the call. throw new RuntimeException(e); @@ -403,7 +393,7 @@ public void onNext(ReadResponse readResponse) { @Override public void onError(Throwable t) { - if (offset.get() == digest.getSizeBytes()) { + if (rawOut.getCount() == digest.getSizeBytes()) { // If the file was fully downloaded, it doesn't matter if there was an error at // the end of the stream. logger.atInfo().withCause(t).log( @@ -411,6 +401,7 @@ public void onError(Throwable t) { onCompleted(); return; } + releaseOut(); Status status = Status.fromThrowable(t); if (status.getCode() == Status.Code.NOT_FOUND) { future.setException(new CacheNotFoundException(digest)); @@ -426,12 +417,24 @@ public void onCompleted() { Utils.verifyBlobContents(digest, digestSupplier.get()); } out.flush(); - future.set(offset.get()); + future.set(rawOut.getCount()); } catch (IOException e) { future.setException(e); } catch (RuntimeException e) { logger.atWarning().withCause(e).log("Unexpected exception"); future.setException(e); + } finally { + releaseOut(); + } + } + + private void releaseOut() { + if (out instanceof ZstdDecompressingOutputStream) { + try { + ((ZstdDecompressingOutputStream) out).closeShallow(); + } catch (IOException e) { + logger.atWarning().withCause(e).log("failed to cleanly close output stream"); + } } } }); diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD index 6108cddc569f03..75691a65473044 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD @@ -16,7 +16,6 @@ java_library( name = "zstd", srcs = glob(["*.java"]), deps = [ - "//third_party:apache_commons_compress", "//third_party:guava", "//third_party/protobuf:protobuf_java", "@zstd-jni", diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java index ad1c333320964c..9fdb6ae4fdaa89 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java @@ -13,35 +13,35 @@ // limitations under the License. package com.google.devtools.build.lib.remote.zstd; -import com.github.luben.zstd.ZstdInputStream; +import com.github.luben.zstd.ZstdInputStreamNoFinalizer; import com.google.protobuf.ByteString; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import org.apache.commons.compress.utils.CountingOutputStream; -/** A {@link CountingOutputStream} that use zstd to decompress the content. */ -public class ZstdDecompressingOutputStream extends CountingOutputStream { +/** An {@link OutputStream} that use zstd to decompress the content. */ +public final class ZstdDecompressingOutputStream extends OutputStream { + private final OutputStream out; private ByteArrayInputStream inner; - private final ZstdInputStream zis; + private final ZstdInputStreamNoFinalizer zis; public ZstdDecompressingOutputStream(OutputStream out) throws IOException { - super(out); + this.out = out; zis = - new ZstdInputStream( - new InputStream() { - @Override - public int read() { - return inner.read(); - } - - @Override - public int read(byte[] b, int off, int len) { - return inner.read(b, off, len); - } - }); - zis.setContinuous(true); + new ZstdInputStreamNoFinalizer( + new InputStream() { + @Override + public int read() { + return inner.read(); + } + + @Override + public int read(byte[] b, int off, int len) { + return inner.read(b, off, len); + } + }) + .setContinuous(true); } @Override @@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException { public void write(byte[] b, int off, int len) throws IOException { inner = new ByteArrayInputStream(b, off, len); byte[] data = ByteString.readFrom(zis).toByteArray(); - super.write(data, 0, data.length); + out.write(data, 0, data.length); + } + + @Override + public void close() throws IOException { + closeShallow(); + out.close(); + } + + /** + * Free resources related to decompression without closing the underlying {@link OutputStream}. + */ + public void closeShallow() throws IOException { + zis.close(); } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java index 51effa08170977..80d55edc7a0677 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java @@ -15,14 +15,12 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.mockito.ArgumentMatchers.any; import build.bazel.remote.execution.v2.Digest; import com.github.luben.zstd.Zstd; import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase; import com.google.bytestream.ByteStreamProto.ReadRequest; import com.google.bytestream.ByteStreamProto.ReadResponse; -import com.google.devtools.build.lib.remote.Retrier.Backoff; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.common.options.Options; import com.google.protobuf.ByteString; @@ -31,7 +29,6 @@ import java.io.IOException; import java.util.Arrays; import org.junit.Test; -import org.mockito.Mockito; /** Extra tests for {@link GrpcCacheClient} that are not tested internally. */ public class GrpcCacheClientTestExtra extends GrpcCacheClientTest { @@ -39,30 +36,43 @@ public class GrpcCacheClientTestExtra extends GrpcCacheClientTest { @Test public void compressedDownloadBlobIsRetriedWithProgress() throws IOException, InterruptedException { - Backoff mockBackoff = Mockito.mock(Backoff.class); RemoteOptions options = Options.getDefaults(RemoteOptions.class); options.cacheCompression = true; - final GrpcCacheClient client = newClient(options, () -> mockBackoff); + final GrpcCacheClient client = newClient(options); final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg"); - ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8))); + ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8))); + ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8))); + ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8))); serviceRegistry.addService( new ByteStreamImplBase() { + private boolean first = true; + @Override public void read(ReadRequest request, StreamObserver responseObserver) { assertThat(request.getResourceName().contains(digest.getHash())).isTrue(); - int off = (int) request.getReadOffset(); - // Zstd header size is 9 bytes - ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off); - responseObserver.onNext(ReadResponse.newBuilder().setData(data).build()); - if (off == 0) { + if (first) { + first = false; responseObserver.onError(Status.DEADLINE_EXCEEDED.asException()); - } else { - responseObserver.onCompleted(); + return; + } + switch (Math.toIntExact(request.getReadOffset())) { + case 0: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build()); + break; + case 3: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build()); + break; + case 6: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build()); + responseObserver.onCompleted(); + return; + default: + throw new IllegalStateException("unexpected offset " + request.getReadOffset()); } + responseObserver.onError(Status.DEADLINE_EXCEEDED.asException()); } }); assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg"); - Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class)); } @Test diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java index 22cba85b8b6f68..62352dd5678a98 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java @@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException { for (byte b : compressed.toByteArray()) { zdos.write(b); zdos.flush(); - assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length); } assertThat(decompressed.toByteArray()).isEqualTo(data); }