37
37
import com .google .common .collect .ImmutableSet ;
38
38
import com .google .common .collect .Iterables ;
39
39
import com .google .common .flogger .GoogleLogger ;
40
+ import com .google .common .io .CountingOutputStream ;
40
41
import com .google .common .util .concurrent .Futures ;
41
42
import com .google .common .util .concurrent .ListenableFuture ;
42
43
import com .google .common .util .concurrent .MoreExecutors ;
67
68
import java .util .List ;
68
69
import java .util .concurrent .TimeUnit ;
69
70
import java .util .concurrent .atomic .AtomicBoolean ;
70
- import java .util .concurrent .atomic .AtomicLong ;
71
71
import java .util .function .Supplier ;
72
72
import javax .annotation .Nullable ;
73
- import org .apache .commons .compress .utils .CountingOutputStream ;
74
73
75
74
/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
76
75
@ ThreadSafe
@@ -303,7 +302,7 @@ public ListenableFuture<Void> uploadActionResult(
303
302
public ListenableFuture <Void > downloadBlob (
304
303
RemoteActionExecutionContext context , Digest digest , OutputStream out ) {
305
304
if (digest .getSizeBytes () == 0 ) {
306
- return Futures .immediateFuture ( null );
305
+ return Futures .immediateVoidFuture ( );
307
306
}
308
307
309
308
@ Nullable Supplier <Digest > digestSupplier = null ;
@@ -313,26 +312,14 @@ public ListenableFuture<Void> downloadBlob(
313
312
out = digestOut ;
314
313
}
315
314
316
- CountingOutputStream outputStream ;
317
- if (options .cacheCompression ) {
318
- try {
319
- outputStream = new ZstdDecompressingOutputStream (out );
320
- } catch (IOException e ) {
321
- return Futures .immediateFailedFuture (e );
322
- }
323
- } else {
324
- outputStream = new CountingOutputStream (out );
325
- }
326
-
327
- return downloadBlob (context , digest , outputStream , digestSupplier );
315
+ return downloadBlob (context , digest , new CountingOutputStream (out ), digestSupplier );
328
316
}
329
317
330
318
private ListenableFuture <Void > downloadBlob (
331
319
RemoteActionExecutionContext context ,
332
320
Digest digest ,
333
321
CountingOutputStream out ,
334
322
@ Nullable Supplier <Digest > digestSupplier ) {
335
- AtomicLong offset = new AtomicLong (0 );
336
323
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff (retrier ::newBackoff );
337
324
ListenableFuture <Long > downloadFuture =
338
325
Utils .refreshIfUnauthenticatedAsync (
@@ -343,7 +330,6 @@ private ListenableFuture<Void> downloadBlob(
343
330
channel ->
344
331
requestRead (
345
332
context ,
346
- offset ,
347
333
progressiveBackoff ,
348
334
digest ,
349
335
out ,
@@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean
370
356
371
357
private ListenableFuture <Long > requestRead (
372
358
RemoteActionExecutionContext context ,
373
- AtomicLong offset ,
374
359
ProgressiveBackoff progressiveBackoff ,
375
360
Digest digest ,
376
- CountingOutputStream out ,
361
+ CountingOutputStream rawOut ,
377
362
@ Nullable Supplier <Digest > digestSupplier ,
378
363
Channel channel ) {
379
364
String resourceName =
380
365
getResourceName (options .remoteInstanceName , digest , options .cacheCompression );
381
366
SettableFuture <Long > future = SettableFuture .create ();
367
+ OutputStream out ;
368
+ try {
369
+ out = options .cacheCompression ? new ZstdDecompressingOutputStream (rawOut ) : rawOut ;
370
+ } catch (IOException e ) {
371
+ return Futures .immediateFailedFuture (e );
372
+ }
382
373
bsAsyncStub (context , channel )
383
374
.read (
384
375
ReadRequest .newBuilder ()
385
376
.setResourceName (resourceName )
386
- .setReadOffset (offset . get ())
377
+ .setReadOffset (rawOut . getCount ())
387
378
.build (),
388
379
new StreamObserver <ReadResponse >() {
389
380
@@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) {
392
383
ByteString data = readResponse .getData ();
393
384
try {
394
385
data .writeTo (out );
395
- offset .set (out .getBytesWritten ());
396
386
} catch (IOException e ) {
397
387
// Cancel the call.
398
388
throw new RuntimeException (e );
@@ -403,14 +393,15 @@ public void onNext(ReadResponse readResponse) {
403
393
404
394
@ Override
405
395
public void onError (Throwable t ) {
406
- if (offset . get () == digest .getSizeBytes ()) {
396
+ if (rawOut . getCount () == digest .getSizeBytes ()) {
407
397
// If the file was fully downloaded, it doesn't matter if there was an error at
408
398
// the end of the stream.
409
399
logger .atInfo ().withCause (t ).log (
410
400
"ignoring error because file was fully received" );
411
401
onCompleted ();
412
402
return ;
413
403
}
404
+ releaseOut ();
414
405
Status status = Status .fromThrowable (t );
415
406
if (status .getCode () == Status .Code .NOT_FOUND ) {
416
407
future .setException (new CacheNotFoundException (digest ));
@@ -426,12 +417,24 @@ public void onCompleted() {
426
417
Utils .verifyBlobContents (digest , digestSupplier .get ());
427
418
}
428
419
out .flush ();
429
- future .set (offset . get ());
420
+ future .set (rawOut . getCount ());
430
421
} catch (IOException e ) {
431
422
future .setException (e );
432
423
} catch (RuntimeException e ) {
433
424
logger .atWarning ().withCause (e ).log ("Unexpected exception" );
434
425
future .setException (e );
426
+ } finally {
427
+ releaseOut ();
428
+ }
429
+ }
430
+
431
+ private void releaseOut () {
432
+ if (out instanceof ZstdDecompressingOutputStream ) {
433
+ try {
434
+ ((ZstdDecompressingOutputStream ) out ).closeShallow ();
435
+ } catch (IOException e ) {
436
+ logger .atWarning ().withCause (e ).log ("failed to cleanly close output stream" );
437
+ }
435
438
}
436
439
}
437
440
});
0 commit comments