diff --git a/aws/src/integration/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java b/aws/src/integration/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java index f98d1a3d4471..222c4e623710 100644 --- a/aws/src/integration/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java +++ b/aws/src/integration/java/org/apache/iceberg/aws/s3/TestFlakyS3InputStream.java @@ -30,6 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import javax.net.ssl.SSLException; +import org.apache.http.ConnectionClosedException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -52,6 +53,7 @@ public class TestFlakyS3InputStream extends TestS3InputStream { private AtomicInteger resetForRetryCounter; + private static final long CONTENT_LENGTH = 1024 * 1024; // An arbitrary content length. @BeforeEach public void setupTest() { @@ -60,7 +62,7 @@ public void setupTest() { @Override S3InputStream newInputStream(S3Client s3Client, S3URI uri) { - return new S3InputStream(s3Client, uri) { + return new S3InputStream(s3Client, uri, CONTENT_LENGTH) { @Override void resetForRetry() throws IOException { resetForRetryCounter.incrementAndGet(); @@ -122,6 +124,7 @@ public void testSeekWithFlakyStreamNonRetryableException(IOException exception) private static Stream retryableExceptions() { return Stream.of( Arguments.of( + new ConnectionClosedException("connection closed exception"), new SocketTimeoutException("socket timeout exception"), new SSLException("some ssl exception"))); } diff --git a/aws/src/integration/java/org/apache/iceberg/aws/s3/TestS3InputStream.java b/aws/src/integration/java/org/apache/iceberg/aws/s3/TestS3InputStream.java index f8903842df37..e6934027dcbe 100644 --- a/aws/src/integration/java/org/apache/iceberg/aws/s3/TestS3InputStream.java +++ b/aws/src/integration/java/org/apache/iceberg/aws/s3/TestS3InputStream.java @@ -46,6 +46,8 @@ public class TestS3InputStream { private final S3Client s3 = MinioUtil.createS3Client(MINIO); private final Random random = new Random(1); + private static final int CONTENT_LENGTH = 1024 * 1024 * 10; // 10MB + @BeforeEach public void before() { createBucket("bucket"); @@ -57,13 +59,13 @@ public void testRead() throws Exception { } S3InputStream newInputStream(S3Client s3Client, S3URI uri) { - return new S3InputStream(s3Client, uri); + return new S3InputStream(s3Client, uri, CONTENT_LENGTH); } protected void testRead(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/read.dat"); - int dataSize = 1024 * 1024 * 10; - byte[] data = randomData(dataSize); + + byte[] data = randomData(CONTENT_LENGTH); writeS3Data(uri, data); @@ -121,9 +123,9 @@ public void testRangeRead() throws Exception { protected void testRangeRead(S3Client s3Client) throws Exception { S3URI uri = new S3URI("s3://bucket/path/to/range-read.dat"); - int dataSize = 1024 * 1024 * 10; - byte[] expected = randomData(dataSize); - byte[] actual = new byte[dataSize]; + + byte[] expected = randomData(CONTENT_LENGTH); + byte[] actual = new byte[CONTENT_LENGTH]; long position; int offset; @@ -139,13 +141,13 @@ protected void testRangeRead(S3Client s3Client) throws Exception { readAndCheckRanges(in, expected, position, actual, offset, length); // last 1k - position = dataSize - 1024; - offset = dataSize - 1024; + position = CONTENT_LENGTH - 1024; + offset = CONTENT_LENGTH - 1024; readAndCheckRanges(in, expected, position, actual, offset, length); // middle 2k - position = dataSize / 2 - 1024; - offset = dataSize / 2 - 1024; + position = CONTENT_LENGTH / 2 - 1024; + offset = CONTENT_LENGTH / 2 - 1024; length = 1024 * 2; readAndCheckRanges(in, expected, position, actual, offset, length); } diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputFile.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputFile.java index 5e4346fe9f9f..4cfb349b021d 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputFile.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputFile.java @@ -76,7 +76,7 @@ public SeekableInputStream newStream() { if (s3FileIOProperties().isS3AnalyticsAcceleratorEnabled()) { return AnalyticsAcceleratorUtil.newStream(this); } - return new S3InputStream(client(), uri(), s3FileIOProperties(), metrics()); + return new S3InputStream(client(), uri(), s3FileIOProperties(), metrics(), getLength()); } @Override diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java index 4d37ac333030..d3a049ff1374 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java @@ -28,6 +28,7 @@ import java.util.Arrays; import java.util.List; import javax.net.ssl.SSLException; +import org.apache.http.ConnectionClosedException; import org.apache.iceberg.exceptions.NotFoundException; import org.apache.iceberg.io.FileIOMetricsContext; import org.apache.iceberg.io.IOUtil; @@ -53,7 +54,11 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { private static final Logger LOG = LoggerFactory.getLogger(S3InputStream.class); private static final List> RETRYABLE_EXCEPTIONS = - ImmutableList.of(SSLException.class, SocketTimeoutException.class, SocketException.class); + ImmutableList.of( + SSLException.class, + SocketTimeoutException.class, + SocketException.class, + ConnectionClosedException.class); private final StackTraceElement[] createStack; private final S3Client s3; @@ -63,6 +68,7 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { private InputStream stream; private long pos = 0; private long next = 0; + private long contentLength = 0; private boolean closed = false; private final Counter readBytes; @@ -86,15 +92,20 @@ class S3InputStream extends SeekableInputStream implements RangeReadable { .withMaxRetries(3) .build(); - S3InputStream(S3Client s3, S3URI location) { - this(s3, location, new S3FileIOProperties(), MetricsContext.nullMetrics()); + S3InputStream(S3Client s3, S3URI location, long contentLength) { + this(s3, location, new S3FileIOProperties(), MetricsContext.nullMetrics(), contentLength); } S3InputStream( - S3Client s3, S3URI location, S3FileIOProperties s3FileIOProperties, MetricsContext metrics) { + S3Client s3, + S3URI location, + S3FileIOProperties s3FileIOProperties, + MetricsContext metrics, + long contentLength) { this.s3 = s3; this.location = location; this.s3FileIOProperties = s3FileIOProperties; + this.contentLength = contentLength; this.readBytes = metrics.counter(FileIOMetricsContext.READ_BYTES, Unit.BYTES); this.readOperations = metrics.counter(FileIOMetricsContext.READ_OPERATIONS); @@ -278,7 +289,7 @@ private void closeStream(boolean closeQuietly) throws IOException { private void abortStream() { try { - if (stream instanceof Abortable && stream.read() != -1) { + if (stream instanceof Abortable && remainingInCurrentRequest() > 0) { ((Abortable) stream).abort(); } } catch (Exception e) { @@ -286,6 +297,10 @@ private void abortStream() { } } + private long remainingInCurrentRequest() { + return this.contentLength - this.pos; + } + public void setSkipSize(int skipSize) { this.skipSize = skipSize; } diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java index d8f415cfd221..f59a0d6b0339 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestS3InputStream.java @@ -20,17 +20,21 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; +import java.util.Arrays; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.AbortableInputStream; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectRequest; @@ -38,15 +42,22 @@ public final class TestS3InputStream { @Mock private S3Client s3Client; - @Mock private InputStream inputStream; + private AbortableInputStream inputStream; private S3InputStream s3InputStream; + private static final int CONTENT_LENGTH = 1024; + @BeforeEach void before() { + byte[] writeValue = new byte[CONTENT_LENGTH]; + Arrays.fill(writeValue, (byte) 1); + + inputStream = + spy(AbortableInputStream.create(new ByteArrayInputStream(new byte[CONTENT_LENGTH]))); when(s3Client.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class))) .thenReturn(inputStream); - s3InputStream = new S3InputStream(s3Client, mock()); + s3InputStream = new S3InputStream(s3Client, mock(), CONTENT_LENGTH); } @Test @@ -62,4 +73,24 @@ void testReadTailClosesTheStream() throws IOException { verify(inputStream).close(); } + + @Test + void testAbortIsCalledAfterPartialRead() throws IOException { + byte[] buff = new byte[500]; + s3InputStream.read(buff); + + // close after reading partial object, should call abort + s3InputStream.close(); + verify(inputStream).abort(); + } + + @Test + void testAbortIsCalledAfterFullRead() throws IOException { + byte[] buff = new byte[CONTENT_LENGTH]; + s3InputStream.read(buff); + + // If we're at EoF, this should not call abort. + s3InputStream.close(); + verify(inputStream, never()).abort(); + } }