Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Arrays;
import java.util.List;
import javax.net.ssl.SSLException;
import org.apache.iceberg.exceptions.NotFoundException;
import org.apache.iceberg.io.FileIOMetricsContext;
Expand All @@ -35,6 +36,7 @@
import org.apache.iceberg.metrics.Counter;
import org.apache.iceberg.metrics.MetricsContext;
import org.apache.iceberg.metrics.MetricsContext.Unit;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
Expand All @@ -50,6 +52,9 @@
class S3InputStream extends SeekableInputStream implements RangeReadable {
private static final Logger LOG = LoggerFactory.getLogger(S3InputStream.class);

private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS =
ImmutableList.of(SSLException.class, SocketTimeoutException.class, SocketException.class);

private final StackTraceElement[] createStack;
private final S3Client s3;
private final S3URI location;
Expand All @@ -66,10 +71,18 @@ class S3InputStream extends SeekableInputStream implements RangeReadable {
private int skipSize = 1024 * 1024;
private RetryPolicy<Object> retryPolicy =
RetryPolicy.builder()
.handle(
ImmutableList.of(
SSLException.class, SocketTimeoutException.class, SocketException.class))
.onFailure(failure -> openStream(true))
.handle(RETRYABLE_EXCEPTIONS)
.onRetry(
e -> {
LOG.warn(
"Retrying read from S3, reopening stream (attempt {})", e.getAttemptCount());
resetForRetry();
})
.onFailure(
e ->
LOG.error(
"Failed to read from S3 input stream after exhausting all retries",
e.getException()))
.withMaxRetries(3)
.build();

Expand Down Expand Up @@ -230,6 +243,11 @@ private void openStream(boolean closeQuietly) throws IOException {
}
}

@VisibleForTesting
void resetForRetry() throws IOException {
openStream(true);
}

private void closeStream(boolean closeQuietly) throws IOException {
if (stream != null) {
// if we aren't at the end of the stream, and the stream is abortable, then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.iceberg.aws.s3;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
Expand All @@ -29,6 +30,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -49,10 +51,29 @@

public class TestFlakyS3InputStream extends TestS3InputStream {

private AtomicInteger resetForRetryCounter;

@BeforeEach
public void setupTest() {
resetForRetryCounter = new AtomicInteger(0);
}

@Override
S3InputStream newInputStream(S3Client s3Client, S3URI uri) {
return new S3InputStream(s3Client, uri) {
@Override
void resetForRetry() throws IOException {
resetForRetryCounter.incrementAndGet();
super.resetForRetry();
}
Comment on lines +65 to +68
Copy link
Contributor

@amogh-jahagirdar amogh-jahagirdar Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this, this was crucial for verifying the retry behavior actually resets the input stream. Seems like we were just getting lucky on the tests before since they only counted the number of attempts but not what was actually happening in the attempts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's correct.

};
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testRead(flakyStreamClient(new AtomicInteger(3), exception));
assertThat(resetForRetryCounter.get()).isEqualTo(2);
}

@ParameterizedTest
Expand All @@ -61,6 +82,7 @@ public void testReadWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(3);
}

@ParameterizedTest
Expand All @@ -69,12 +91,14 @@ public void testReadWithFlakyStreamNonRetryableException(IOException exception)
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(0);
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testSeekWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testSeek(flakyStreamClient(new AtomicInteger(3), exception));
assertThat(resetForRetryCounter.get()).isEqualTo(2);
}

@ParameterizedTest
Expand All @@ -83,6 +107,7 @@ public void testSeekWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(3);
}

@ParameterizedTest
Expand All @@ -91,6 +116,7 @@ public void testSeekWithFlakyStreamNonRetryableException(IOException exception)
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
assertThat(resetForRetryCounter.get()).isEqualTo(0);
}

private static Stream<Arguments> retryableExceptions() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ public void testRead() throws Exception {
testRead(s3);
}

S3InputStream newInputStream(S3Client s3Client, S3URI uri) {
return new S3InputStream(s3Client, uri);
}

protected void testRead(S3Client s3Client) throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/read.dat");
int dataSize = 1024 * 1024 * 10;
byte[] data = randomData(dataSize);

writeS3Data(uri, data);

try (SeekableInputStream in = new S3InputStream(s3Client, uri)) {
try (SeekableInputStream in = newInputStream(s3Client, uri)) {
int readSize = 1024;
readAndCheck(in, in.getPos(), readSize, data, false);
readAndCheck(in, in.getPos(), readSize, data, true);
Expand Down Expand Up @@ -128,7 +132,7 @@ protected void testRangeRead(S3Client s3Client) throws Exception {

writeS3Data(uri, expected);

try (RangeReadable in = new S3InputStream(s3Client, uri)) {
try (RangeReadable in = newInputStream(s3Client, uri)) {
// first 1k
position = 0;
offset = 0;
Expand Down Expand Up @@ -160,7 +164,7 @@ private void readAndCheckRanges(
@Test
public void testClose() throws Exception {
S3URI uri = new S3URI("s3://bucket/path/to/closed.dat");
SeekableInputStream closed = new S3InputStream(s3, uri);
SeekableInputStream closed = newInputStream(s3, uri);
closed.close();
assertThatThrownBy(() -> closed.seek(0))
.isInstanceOf(IllegalStateException.class)
Expand All @@ -178,7 +182,7 @@ protected void testSeek(S3Client s3Client) throws Exception {

writeS3Data(uri, expected);

try (SeekableInputStream in = new S3InputStream(s3Client, uri)) {
try (SeekableInputStream in = newInputStream(s3Client, uri)) {
in.seek(expected.length / 2);
byte[] actual = new byte[expected.length / 2];
IOUtil.readFully(in, actual, 0, expected.length / 2);
Expand Down