diff --git a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/AbstractTempFileS3OutputStream.java b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/AbstractTempFileS3OutputStream.java index 16df44693..19553b501 100644 --- a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/AbstractTempFileS3OutputStream.java +++ b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/AbstractTempFileS3OutputStream.java @@ -121,6 +121,17 @@ public void flush() throws IOException { localOutputStream.flush(); } + @Override + public void abort() throws IOException { + if (closed) { + throw new IllegalStateException("Stream is already closed. Too late to abort."); + } + + localOutputStream.close(); + closed = true; + deleteTempFile(); + } + @Override public void close() throws IOException { if (closed) { @@ -145,19 +156,22 @@ public void close() throws IOException { } } this.upload(builder.build()); - boolean result = file.delete(); - - if (!result) { - getLogger().warn(String.format("Temporary file %s could not be deleted", file.getPath())); - } + deleteTempFile(); } catch (Exception se) { - getLogger().error( - String.format("Failed to upload %s. Temporary file @%s", location.getObject(), file.getPath())); + getLogger().error("Failed to upload {}. Temporary file @{}", location.getObject(), file.getPath()); throw new UploadFailedException(file.getPath(), se); } } + private void deleteTempFile() { + boolean result = file.delete(); + + if (!result) { + getLogger().warn("Temporary file {} could not be deleted", file.getPath()); + } + } + protected abstract void upload(PutObjectRequest putObjectRequest); protected Logger getLogger() { diff --git a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStream.java b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStream.java index 461487359..a83ec8b30 100644 --- a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStream.java +++ b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStream.java @@ -117,6 +117,19 @@ public void write(int b) { } } + @Override + public void abort() { + synchronized (this.monitor) { + if (isClosed()) { + throw new IllegalStateException("Stream is already closed. Too late to abort."); + } + if (isMultiPartUpload()) { + abortMultiPartUpload(multipartUploadResponse); + } + outputStream = null; + } + } + @Override public void close() { synchronized (this.monitor) { diff --git a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/S3OutputStream.java b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/S3OutputStream.java index bb104b7b1..82fbe5663 100644 --- a/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/S3OutputStream.java +++ b/spring-cloud-aws-s3/src/main/java/io/awspring/cloud/s3/S3OutputStream.java @@ -15,6 +15,7 @@ */ package io.awspring.cloud.s3; +import java.io.IOException; import java.io.OutputStream; /** @@ -25,4 +26,9 @@ */ public abstract class S3OutputStream extends OutputStream { + /** + * Cancels the upload and cleans up temporal resources (temp files, partial multipart upload). + */ + public void abort() throws IOException { + } } diff --git a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/DiskBufferingS3OutputStreamTests.java b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/DiskBufferingS3OutputStreamTests.java index 704a622af..4e746f367 100644 --- a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/DiskBufferingS3OutputStreamTests.java +++ b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/DiskBufferingS3OutputStreamTests.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -70,4 +71,16 @@ void throwsExceptionWhenUploadFails() throws IOException { } } + @Test + void abortsWhenExplicitlyInvoked() throws IOException { + S3Client s3Client = mock(S3Client.class); + + try (DiskBufferingS3OutputStream diskBufferingS3OutputStream = new DiskBufferingS3OutputStream( + new Location("bucket", "key"), s3Client, null)) { + diskBufferingS3OutputStream.write("hello".getBytes(StandardCharsets.UTF_8)); + diskBufferingS3OutputStream.abort(); + } + + verify(s3Client, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + } } diff --git a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStreamTests.java b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStreamTests.java index 65888bb05..098b2ec22 100644 --- a/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStreamTests.java +++ b/spring-cloud-aws-s3/src/test/java/io/awspring/cloud/s3/InMemoryBufferingS3OutputStreamTests.java @@ -168,4 +168,74 @@ void abortsWhenCompletingMultipartUploadFails() throws IOException { assertThat(requestCaptor.getValue().uploadId()).isEqualTo("uploadId"); } } + + @Test + void abortsWhenExplicitlyInvoked() throws IOException { + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(CreateMultipartUploadResponse.builder().uploadId("uploadId").build()); + + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(UploadPartResponse.builder().build()); + + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenThrow(SdkException.builder().build()); + + final byte[] content = new byte[DEFAULT_BUFFER_CAPACITY_IN_BYTES + 1]; + + try (InMemoryBufferingS3OutputStream outputStream = new InMemoryBufferingS3OutputStream( + new Location("bucket", "key", null), s3Client, null, null, DEFAULT_BUFFER_CAPACITY)) { + new Random().nextBytes(content); + outputStream.write(content); + outputStream.abort(); + } + final ArgumentCaptor requestCaptor = ArgumentCaptor + .forClass(AbortMultipartUploadRequest.class); + + verify(s3Client, times(1)).abortMultipartUpload(requestCaptor.capture()); + assertThat(requestCaptor.getValue().bucket()).isEqualTo("bucket"); + assertThat(requestCaptor.getValue().key()).isEqualTo("key"); + assertThat(requestCaptor.getValue().uploadId()).isEqualTo("uploadId"); + } + + @Test + void abortsWhenInvokedBeforeWriting() { + try (InMemoryBufferingS3OutputStream outputStream = new InMemoryBufferingS3OutputStream( + new Location("bucket", "key", null), s3Client, null, null, DEFAULT_BUFFER_CAPACITY)) { + outputStream.abort(); + } + + verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3Client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + @Test + void failsWhenAbortingAfterClosing() { + InMemoryBufferingS3OutputStream outputStream = null; + try { + outputStream = new InMemoryBufferingS3OutputStream(new Location("bucket", "key", null), s3Client, null, + null, DEFAULT_BUFFER_CAPACITY); + } + finally { + assertThat(outputStream).isNotNull(); + outputStream.close(); + try { + outputStream.abort(); + fail("IllegalStateException should be thrown."); + } + catch (IllegalStateException e) { + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + final ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(RequestBody.class); + + verify(s3Client, times(1)).putObject(requestCaptor.capture(), bodyCaptor.capture()); + + assertThat(requestCaptor.getValue().bucket()).isEqualTo("bucket"); + assertThat(requestCaptor.getValue().key()).isEqualTo("key"); + assertThat(requestCaptor.getValue().contentLength()).isEqualTo(0); + assertThat(requestCaptor.getValue().contentMD5()).isNotNull(); + + verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3Client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + } + } }