Skip to content

Commit 299022c

Browse files
benjaminpcopybara-github
authored andcommitted
remote: Proactively close the ZstdInputStream in ZstdDecompressingOutputStream.
ZstdInputStream hangs onto some native memory, which should be released as soon as ZstdDecompressingOutputStream is done being used rather than when the finalizer runs. Closes bazelbuild#15061. PiperOrigin-RevId: 438521302
1 parent 5b95286 commit 299022c

File tree

6 files changed

+83
-60
lines changed

6 files changed

+83
-60
lines changed

src/main/java/com/google/devtools/build/lib/remote/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ java_library(
9797
"//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
9898
"//src/main/java/com/google/devtools/common/options",
9999
"//src/main/protobuf:failure_details_java_proto",
100-
"//third_party:apache_commons_compress",
101100
"//third_party:auth",
102101
"//third_party:caffeine",
103102
"//third_party:flogger",

src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java

+26-23
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.common.collect.ImmutableSet;
3838
import com.google.common.collect.Iterables;
3939
import com.google.common.flogger.GoogleLogger;
40+
import com.google.common.io.CountingOutputStream;
4041
import com.google.common.util.concurrent.Futures;
4142
import com.google.common.util.concurrent.ListenableFuture;
4243
import com.google.common.util.concurrent.MoreExecutors;
@@ -67,10 +68,8 @@
6768
import java.util.List;
6869
import java.util.concurrent.TimeUnit;
6970
import java.util.concurrent.atomic.AtomicBoolean;
70-
import java.util.concurrent.atomic.AtomicLong;
7171
import java.util.function.Supplier;
7272
import javax.annotation.Nullable;
73-
import org.apache.commons.compress.utils.CountingOutputStream;
7473

7574
/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
7675
@ThreadSafe
@@ -303,7 +302,7 @@ public ListenableFuture<Void> uploadActionResult(
303302
public ListenableFuture<Void> downloadBlob(
304303
RemoteActionExecutionContext context, Digest digest, OutputStream out) {
305304
if (digest.getSizeBytes() == 0) {
306-
return Futures.immediateFuture(null);
305+
return Futures.immediateVoidFuture();
307306
}
308307

309308
@Nullable Supplier<Digest> digestSupplier = null;
@@ -313,26 +312,14 @@ public ListenableFuture<Void> downloadBlob(
313312
out = digestOut;
314313
}
315314

316-
CountingOutputStream outputStream;
317-
if (options.cacheCompression) {
318-
try {
319-
outputStream = new ZstdDecompressingOutputStream(out);
320-
} catch (IOException e) {
321-
return Futures.immediateFailedFuture(e);
322-
}
323-
} else {
324-
outputStream = new CountingOutputStream(out);
325-
}
326-
327-
return downloadBlob(context, digest, outputStream, digestSupplier);
315+
return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier);
328316
}
329317

330318
private ListenableFuture<Void> downloadBlob(
331319
RemoteActionExecutionContext context,
332320
Digest digest,
333321
CountingOutputStream out,
334322
@Nullable Supplier<Digest> digestSupplier) {
335-
AtomicLong offset = new AtomicLong(0);
336323
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
337324
ListenableFuture<Long> downloadFuture =
338325
Utils.refreshIfUnauthenticatedAsync(
@@ -343,7 +330,6 @@ private ListenableFuture<Void> downloadBlob(
343330
channel ->
344331
requestRead(
345332
context,
346-
offset,
347333
progressiveBackoff,
348334
digest,
349335
out,
@@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean
370356

371357
private ListenableFuture<Long> requestRead(
372358
RemoteActionExecutionContext context,
373-
AtomicLong offset,
374359
ProgressiveBackoff progressiveBackoff,
375360
Digest digest,
376-
CountingOutputStream out,
361+
CountingOutputStream rawOut,
377362
@Nullable Supplier<Digest> digestSupplier,
378363
Channel channel) {
379364
String resourceName =
380365
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
381366
SettableFuture<Long> future = SettableFuture.create();
367+
OutputStream out;
368+
try {
369+
out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
370+
} catch (IOException e) {
371+
return Futures.immediateFailedFuture(e);
372+
}
382373
bsAsyncStub(context, channel)
383374
.read(
384375
ReadRequest.newBuilder()
385376
.setResourceName(resourceName)
386-
.setReadOffset(offset.get())
377+
.setReadOffset(rawOut.getCount())
387378
.build(),
388379
new StreamObserver<ReadResponse>() {
389380

@@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) {
392383
ByteString data = readResponse.getData();
393384
try {
394385
data.writeTo(out);
395-
offset.set(out.getBytesWritten());
396386
} catch (IOException e) {
397387
// Cancel the call.
398388
throw new RuntimeException(e);
@@ -403,14 +393,15 @@ public void onNext(ReadResponse readResponse) {
403393

404394
@Override
405395
public void onError(Throwable t) {
406-
if (offset.get() == digest.getSizeBytes()) {
396+
if (rawOut.getCount() == digest.getSizeBytes()) {
407397
// If the file was fully downloaded, it doesn't matter if there was an error at
408398
// the end of the stream.
409399
logger.atInfo().withCause(t).log(
410400
"ignoring error because file was fully received");
411401
onCompleted();
412402
return;
413403
}
404+
releaseOut();
414405
Status status = Status.fromThrowable(t);
415406
if (status.getCode() == Status.Code.NOT_FOUND) {
416407
future.setException(new CacheNotFoundException(digest));
@@ -426,12 +417,24 @@ public void onCompleted() {
426417
Utils.verifyBlobContents(digest, digestSupplier.get());
427418
}
428419
out.flush();
429-
future.set(offset.get());
420+
future.set(rawOut.getCount());
430421
} catch (IOException e) {
431422
future.setException(e);
432423
} catch (RuntimeException e) {
433424
logger.atWarning().withCause(e).log("Unexpected exception");
434425
future.setException(e);
426+
} finally {
427+
releaseOut();
428+
}
429+
}
430+
431+
private void releaseOut() {
432+
if (out instanceof ZstdDecompressingOutputStream) {
433+
try {
434+
((ZstdDecompressingOutputStream) out).closeShallow();
435+
} catch (IOException e) {
436+
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
437+
}
435438
}
436439
}
437440
});

src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ java_library(
1616
name = "zstd",
1717
srcs = glob(["*.java"]),
1818
deps = [
19-
"//third_party:apache_commons_compress",
2019
"//third_party:guava",
2120
"//third_party/protobuf:protobuf_java",
2221
"@zstd-jni",

src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java

+33-20
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,35 @@
1313
// limitations under the License.
1414
package com.google.devtools.build.lib.remote.zstd;
1515

16-
import com.github.luben.zstd.ZstdInputStream;
16+
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
1717
import com.google.protobuf.ByteString;
1818
import java.io.ByteArrayInputStream;
1919
import java.io.IOException;
2020
import java.io.InputStream;
2121
import java.io.OutputStream;
22-
import org.apache.commons.compress.utils.CountingOutputStream;
2322

24-
/** A {@link CountingOutputStream} that use zstd to decompress the content. */
25-
public class ZstdDecompressingOutputStream extends CountingOutputStream {
23+
/** An {@link OutputStream} that use zstd to decompress the content. */
24+
public final class ZstdDecompressingOutputStream extends OutputStream {
25+
private final OutputStream out;
2626
private ByteArrayInputStream inner;
27-
private final ZstdInputStream zis;
27+
private final ZstdInputStreamNoFinalizer zis;
2828

2929
public ZstdDecompressingOutputStream(OutputStream out) throws IOException {
30-
super(out);
30+
this.out = out;
3131
zis =
32-
new ZstdInputStream(
33-
new InputStream() {
34-
@Override
35-
public int read() {
36-
return inner.read();
37-
}
38-
39-
@Override
40-
public int read(byte[] b, int off, int len) {
41-
return inner.read(b, off, len);
42-
}
43-
});
44-
zis.setContinuous(true);
32+
new ZstdInputStreamNoFinalizer(
33+
new InputStream() {
34+
@Override
35+
public int read() {
36+
return inner.read();
37+
}
38+
39+
@Override
40+
public int read(byte[] b, int off, int len) {
41+
return inner.read(b, off, len);
42+
}
43+
})
44+
.setContinuous(true);
4545
}
4646

4747
@Override
@@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException {
5858
public void write(byte[] b, int off, int len) throws IOException {
5959
inner = new ByteArrayInputStream(b, off, len);
6060
byte[] data = ByteString.readFrom(zis).toByteArray();
61-
super.write(data, 0, data.length);
61+
out.write(data, 0, data.length);
62+
}
63+
64+
@Override
65+
public void close() throws IOException {
66+
closeShallow();
67+
out.close();
68+
}
69+
70+
/**
71+
* Free resources related to decompression without closing the underlying {@link OutputStream}.
72+
*/
73+
public void closeShallow() throws IOException {
74+
zis.close();
6275
}
6376
}

src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java

+24-14
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515

1616
import static com.google.common.truth.Truth.assertThat;
1717
import static java.nio.charset.StandardCharsets.UTF_8;
18-
import static org.mockito.ArgumentMatchers.any;
1918

2019
import build.bazel.remote.execution.v2.Digest;
2120
import com.github.luben.zstd.Zstd;
2221
import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
2322
import com.google.bytestream.ByteStreamProto.ReadRequest;
2423
import com.google.bytestream.ByteStreamProto.ReadResponse;
25-
import com.google.devtools.build.lib.remote.Retrier.Backoff;
2624
import com.google.devtools.build.lib.remote.options.RemoteOptions;
2725
import com.google.devtools.common.options.Options;
2826
import com.google.protobuf.ByteString;
@@ -31,38 +29,50 @@
3129
import java.io.IOException;
3230
import java.util.Arrays;
3331
import org.junit.Test;
34-
import org.mockito.Mockito;
3532

3633
/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */
3734
public class GrpcCacheClientTestExtra extends GrpcCacheClientTest {
3835

3936
@Test
4037
public void compressedDownloadBlobIsRetriedWithProgress()
4138
throws IOException, InterruptedException {
42-
Backoff mockBackoff = Mockito.mock(Backoff.class);
4339
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
4440
options.cacheCompression = true;
45-
final GrpcCacheClient client = newClient(options, () -> mockBackoff);
41+
final GrpcCacheClient client = newClient(options);
4642
final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
47-
ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8)));
43+
ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8)));
44+
ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8)));
45+
ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8)));
4846
serviceRegistry.addService(
4947
new ByteStreamImplBase() {
48+
private boolean first = true;
49+
5050
@Override
5151
public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
5252
assertThat(request.getResourceName().contains(digest.getHash())).isTrue();
53-
int off = (int) request.getReadOffset();
54-
// Zstd header size is 9 bytes
55-
ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off);
56-
responseObserver.onNext(ReadResponse.newBuilder().setData(data).build());
57-
if (off == 0) {
53+
if (first) {
54+
first = false;
5855
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
59-
} else {
60-
responseObserver.onCompleted();
56+
return;
57+
}
58+
switch (Math.toIntExact(request.getReadOffset())) {
59+
case 0:
60+
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build());
61+
break;
62+
case 3:
63+
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build());
64+
break;
65+
case 6:
66+
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build());
67+
responseObserver.onCompleted();
68+
return;
69+
default:
70+
throw new IllegalStateException("unexpected offset " + request.getReadOffset());
6171
}
72+
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
6273
}
6374
});
6475
assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg");
65-
Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
6676
}
6777

6878
@Test

src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException {
6363
for (byte b : compressed.toByteArray()) {
6464
zdos.write(b);
6565
zdos.flush();
66-
assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length);
6766
}
6867
assertThat(decompressed.toByteArray()).isEqualTo(data);
6968
}

0 commit comments

Comments
 (0)