diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java new file mode 100644 index 000000000000..892961fa7d86 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.common; + +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.notNull; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.joda.time.DateTimeUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Async handler that automatically retries unprocessed records in case of a partial success. + * + *

The handler enforces the provided upper limit of concurrent requests. Once that limit is + * reached any further call to {@link #batchWrite(String, List)} will block until another request + * completed. + * + *

The handler is fail fast and won't submit any further request after a failure. Async failures + * can be polled using {@link #checkForAsyncFailure()}. + * + * @param Record type in batch + * @param Potentially erroneous result that needs to be correlated to a record using {@link + * #failedRecords(List, List)} + */ +@NotThreadSafe +@Internal +public abstract class AsyncBatchWriteHandler { + private static final Logger LOG = LoggerFactory.getLogger(AsyncBatchWriteHandler.class); + private final FluentBackoff backoff; + private final int concurrentRequests; + private final Stats stats; + protected final BiFunction, CompletableFuture>> submitFn; + protected final Function errorCodeFn; + + private AtomicBoolean hasErrored; + private AtomicReference asyncFailure; + private Semaphore requestPermits; + + protected AsyncBatchWriteHandler( + int concurrency, + FluentBackoff backoff, + Stats stats, + Function errorCodeFn, + BiFunction, CompletableFuture>> submitFn) { + this.backoff = backoff; + this.concurrentRequests = concurrency; + this.errorCodeFn = errorCodeFn; + this.submitFn = submitFn; + this.hasErrored = new AtomicBoolean(false); + this.asyncFailure = new AtomicReference<>(); + this.requestPermits = new Semaphore(concurrentRequests); + this.stats = stats; + } + + public final int requestsInProgress() { + return concurrentRequests - requestPermits.availablePermits(); + } + + public final void reset() { + hasErrored = new AtomicBoolean(false); + asyncFailure = new AtomicReference<>(); + requestPermits = new Semaphore(concurrentRequests); + } + + /** If this handler has errored since it was last reset. */ + public final boolean hasErrored() { + return hasErrored.get(); + } + + /** + * Check if any failure happened async. + * + * @throws Throwable The last async failure, afterwards reset it. + */ + public final void checkForAsyncFailure() throws Throwable { + @SuppressWarnings("nullness") + Throwable failure = asyncFailure.getAndSet(null); + if (failure != null) { + throw failure; + } + } + + /** + * Wait for all pending requests to complete and check for failures. + * + * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()} + */ + public final void waitForCompletion() throws Throwable { + requestPermits.acquireUninterruptibly(concurrentRequests); + checkForAsyncFailure(); + } + + /** + * Asynchronously trigger a batch write request (unless already in error state). + * + *

This will respect the concurrency limit of the handler and first wait for a permit. + * + * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()} + */ + public final void batchWrite(String destination, List records) throws Throwable { + batchWrite(destination, records, true); + } + + /** + * Asynchronously trigger a batch write request (unless already in error state). + * + *

This will respect the concurrency limit of the handler and first wait for a permit. + * + * @param throwAsyncFailures If to check and throw pending async failures + * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()} + */ + public final void batchWrite(String destination, List records, boolean throwAsyncFailures) + throws Throwable { + if (!hasErrored()) { + requestPermits.acquireUninterruptibly(); + new RetryHandler(destination, records).run(); + } + if (throwAsyncFailures) { + checkForAsyncFailure(); + } + } + + protected abstract List failedRecords(List records, List results); + + protected abstract boolean hasFailedRecords(List results); + + /** Statistics on the batch request. */ + public interface Stats { + Stats NONE = new Stats() {}; + + default void addBatchWriteRequest(long latencyMillis, boolean isPartialRetry) {} + } + + /** + * AsyncBatchWriteHandler that correlates records and results by position in the respective list. + */ + public static AsyncBatchWriteHandler byPosition( + int concurrency, + int partialRetries, + @Nullable RetryConfiguration retry, + Stats stats, + BiFunction, CompletableFuture>> submitFn, + Function errorCodeFn) { + FluentBackoff backoff = retryBackoff(partialRetries, retry); + return byPosition(concurrency, backoff, stats, submitFn, errorCodeFn); + } + + /** + * AsyncBatchWriteHandler that correlates records and results by position in the respective list. + */ + public static AsyncBatchWriteHandler byPosition( + int concurrency, + FluentBackoff backoff, + Stats stats, + BiFunction, CompletableFuture>> submitFn, + Function errorCodeFn) { + return new AsyncBatchWriteHandler( + concurrency, backoff, stats, errorCodeFn, submitFn) { + + @Override + protected boolean hasFailedRecords(List results) { + for (int i = 0; i < results.size(); i++) { + if (errorCodeFn.apply(results.get(i)) != null) { + return true; + } + } + return false; + } + + @Override + protected List failedRecords(List records, List results) { + int size = Math.min(records.size(), results.size()); + List filtered = new ArrayList<>(); + for (int i = 0; i < size; i++) { + if (errorCodeFn.apply(results.get(i)) != null) { + filtered.add(records.get(i)); + } + } + return filtered; + } + }; + } + + /** + * AsyncBatchWriteHandler that correlates records and results by id, all results are erroneous. + */ + public static AsyncBatchWriteHandler byId( + int concurrency, + int partialRetries, + @Nullable RetryConfiguration retry, + Stats stats, + BiFunction, CompletableFuture>> submitFn, + Function errorCodeFn, + Function recordIdFn, + Function errorIdFn) { + FluentBackoff backoff = retryBackoff(partialRetries, retry); + return byId(concurrency, backoff, stats, submitFn, errorCodeFn, recordIdFn, errorIdFn); + } + + /** + * AsyncBatchWriteHandler that correlates records and results by id, all results are erroneous. + */ + public static AsyncBatchWriteHandler byId( + int concurrency, + FluentBackoff backoff, + Stats stats, + BiFunction, CompletableFuture>> submitFn, + Function errorCodeFn, + Function recordIdFn, + Function errorIdFn) { + return new AsyncBatchWriteHandler( + concurrency, backoff, stats, errorCodeFn, submitFn) { + @Override + protected boolean hasFailedRecords(List errors) { + return !errors.isEmpty(); + } + + @Override + protected List failedRecords(List records, List errors) { + Set ids = Sets.newHashSetWithExpectedSize(errors.size()); + errors.forEach(e -> ids.add(errorIdFn.apply(e))); + + List filtered = new ArrayList<>(errors.size()); + for (int i = 0; i < records.size(); i++) { + RecT rec = records.get(i); + if (ids.contains(recordIdFn.apply(rec))) { + filtered.add(rec); + if (filtered.size() == errors.size()) { + return filtered; + } + } + } + return filtered; + } + }; + } + + /** + * This handler coordinates retries in case of a partial success. + * + *

+ * + * The next call of {@link #checkForAsyncFailure()}, {@link #batchWrite(String, List< RecT >)}} or + * {@link #waitForCompletion()} will check for the last async failure and throw it. Afterwards the + * failure state is reset. + */ + private class RetryHandler implements BiConsumer, Throwable> { + private final String destination; + private final int totalRecords; + private final BackOff backoff; // backoff in case of throttling + + private final long handlerStartTime; + private long requestStartTime; + private int requests; + + private List records; + + RetryHandler(String destination, List records) { + this.destination = destination; + this.totalRecords = records.size(); + this.records = records; + this.backoff = AsyncBatchWriteHandler.this.backoff.backoff(); + this.handlerStartTime = DateTimeUtils.currentTimeMillis(); + this.requestStartTime = 0; + this.requests = 0; + } + + @SuppressWarnings({"FutureReturnValueIgnored"}) + void run() { + if (!hasErrored.get()) { + try { + requests++; + requestStartTime = DateTimeUtils.currentTimeMillis(); + submitFn.apply(destination, records).whenComplete(this); + } catch (Throwable e) { + setAsyncFailure(e); + } + } + } + + @Override + public void accept(List results, Throwable throwable) { + try { + long now = DateTimeUtils.currentTimeMillis(); + long latencyMillis = now - requestStartTime; + synchronized (stats) { + stats.addBatchWriteRequest(latencyMillis, requests > 1); + } + if (results != null && !hasErrored.get()) { + if (!hasFailedRecords(results)) { + // Request succeeded, release one permit + requestPermits.release(); + LOG.debug( + "Done writing {} records [{} ms, {} request(s)]", + totalRecords, + now - handlerStartTime, + requests); + } else { + try { + if (BackOffUtils.next(Sleeper.DEFAULT, backoff)) { + LOG.info(summarizeErrors("Attempting retry", results)); + records = failedRecords(records, results); + run(); + } else { + throwable = new IOException(summarizeErrors("Exceeded retries", results)); + } + } catch (Throwable e) { + throwable = new IOException(summarizeErrors("Aborted retries", results), e); + } + } + } + } catch (Throwable e) { + throwable = e; + } + if (throwable != null) { + setAsyncFailure(throwable); + } + } + + private void setAsyncFailure(Throwable throwable) { + LOG.warn("Error when writing batch.", throwable); + hasErrored.set(true); + asyncFailure.updateAndGet( + ex -> { + if (ex != null) { + throwable.addSuppressed(ex); + } + return throwable; + }); + requestPermits.release(concurrentRequests); // unblock everything to fail fast + } + + private String summarizeErrors(String prefix, List results) { + Map countsPerError = + results.stream() + .map(errorCodeFn) + .filter(notNull()) + .collect(groupingBy(identity(), counting())); + return countsPerError.entrySet().stream() + .map(kv -> String.format("code %s for %d record(s)", kv.getKey(), kv.getValue())) + .collect(joining(", ", prefix + " after partial failure: ", ".")); + } + } + + private static FluentBackoff retryBackoff(int retries, @Nullable RetryConfiguration retry) { + FluentBackoff backoff = FluentBackoff.DEFAULT.withMaxRetries(retries); + if (retry != null) { + if (retry.throttledBaseBackoff() != null) { + backoff = backoff.withInitialBackoff(retry.throttledBaseBackoff()); + } + if (retry.maxBackoff() != null) { + backoff = backoff.withMaxBackoff(retry.maxBackoff()); + } + } + return backoff; + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java index 9ee8eb277ddc..08fb595bd037 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java @@ -52,6 +52,7 @@ @JsonInclude(value = JsonInclude.Include.NON_EMPTY) @JsonDeserialize(builder = ClientConfiguration.Builder.class) public abstract class ClientConfiguration implements Serializable { + public static final ClientConfiguration EMPTY = ClientConfiguration.builder().build(); /** * Optional {@link AwsCredentialsProvider}. If set, this overwrites the default in {@link diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java deleted file mode 100644 index 78439cb1d349..000000000000 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.kinesis; - -import static java.util.function.Function.identity; -import static java.util.stream.Collectors.counting; -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.joining; -import static java.util.stream.Collectors.toList; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; -import java.util.function.Supplier; -import javax.annotation.concurrent.NotThreadSafe; -import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.util.BackOff; -import org.apache.beam.sdk.util.BackOffUtils; -import org.apache.beam.sdk.util.FluentBackoff; -import org.apache.beam.sdk.util.Sleeper; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams; -import org.joda.time.DateTimeUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; -import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; - -/** - * Async handler for {@link KinesisAsyncClient#putRecords(PutRecordsRequest)} that automatically - * retries unprocessed records in case of a partial success. - * - *

The handler enforces the provided upper limit of concurrent requests. Once that limit is - * reached any further call to {@link #putRecords(String, List)} will block until another request - * completed. - * - *

The handler is fail fast and won't submit any further request after a failure. Async failures - * can be polled using {@link #checkForAsyncFailure()}. - */ -@NotThreadSafe -@Internal -class AsyncPutRecordsHandler { - private static final Logger LOG = LoggerFactory.getLogger(AsyncPutRecordsHandler.class); - private final KinesisAsyncClient kinesis; - private final Supplier backoff; - private final int concurrentRequests; - private final Stats stats; - - private AtomicBoolean hasErrored; - private AtomicReference asyncFailure; - private Semaphore pendingRequests; - - AsyncPutRecordsHandler( - KinesisAsyncClient kinesis, int concurrency, Supplier backoff, Stats stats) { - this.kinesis = kinesis; - this.backoff = backoff; - this.concurrentRequests = concurrency; - this.hasErrored = new AtomicBoolean(false); - this.asyncFailure = new AtomicReference<>(); - this.pendingRequests = new Semaphore(concurrentRequests); - this.stats = stats; - } - - AsyncPutRecordsHandler( - KinesisAsyncClient kinesis, int concurrency, FluentBackoff backoff, Stats stats) { - this(kinesis, concurrency, () -> backoff.backoff(), stats); - } - - protected int pendingRequests() { - return concurrentRequests - pendingRequests.availablePermits(); - } - - void reset() { - hasErrored = new AtomicBoolean(false); - asyncFailure = new AtomicReference<>(); - pendingRequests = new Semaphore(concurrentRequests); - } - - /** If this handler has errored since it was last reset. */ - boolean hasErrored() { - return hasErrored.get(); - } - - /** - * Check if any failure happened async. - * - * @throws Throwable The last async failure, afterwards reset it. - */ - void checkForAsyncFailure() throws Throwable { - @SuppressWarnings("nullness") - Throwable failure = asyncFailure.getAndSet(null); - if (failure != null) { - throw failure; - } - } - - /** - * Wait for all pending requests to complete and check for failures. - * - * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()} - */ - void waitForCompletion() throws Throwable { - pendingRequests.acquireUninterruptibly(concurrentRequests); - checkForAsyncFailure(); - } - - /** - * Asynchronously trigger a put records request. - * - *

This will respect the concurrency limit of the handler and first wait for a permit. - * - * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()} - */ - void putRecords(String stream, List records) throws Throwable { - pendingRequests.acquireUninterruptibly(); - new RetryHandler(stream, records).run(); - checkForAsyncFailure(); - } - - interface Stats { - void addPutRecordsRequest(long latencyMillis, boolean isPartialRetry); - } - - /** - * This handler coordinates retries in case of a partial success. - * - *

    - *
  • Release permit if all (remaining) records are successful to allow for a new request to - * start. - *
  • Attempt retry in case of partial success for all erroneous records using backoff. Set - * async failure once retries are exceeded. - *
  • Set async failure if the entire request fails. Retries, if configured & applicable, have - * already been attempted by the AWS SDK in that case. - *
- * - * The next call of {@link #checkForAsyncFailure()}, {@link #putRecords(String, List)} or {@link - * #waitForCompletion()} will check for the last async failure and throw it. Afterwards the - * failure state is reset. - */ - private class RetryHandler implements BiConsumer { - private final int totalRecords; - private final String stream; - private final BackOff backoff; // backoff in case of throttling - - private final long handlerStartTime; - private long requestStartTime; - private int requests; - - private List records; - - RetryHandler(String stream, List records) { - this.stream = stream; - this.totalRecords = records.size(); - this.records = records; - this.backoff = AsyncPutRecordsHandler.this.backoff.get(); - this.handlerStartTime = DateTimeUtils.currentTimeMillis(); - this.requestStartTime = 0; - this.requests = 0; - } - - @SuppressWarnings({"FutureReturnValueIgnored"}) - void run() { - if (!hasErrored.get()) { - try { - requests++; - requestStartTime = DateTimeUtils.currentTimeMillis(); - PutRecordsRequest request = - PutRecordsRequest.builder().streamName(stream).records(records).build(); - kinesis.putRecords(request).whenComplete(this); - } catch (Throwable e) { - setAsyncFailure(e); - } - } - } - - @Override - public void accept(PutRecordsResponse response, Throwable throwable) { - try { - long now = DateTimeUtils.currentTimeMillis(); - long latencyMillis = now - requestStartTime; - synchronized (stats) { - stats.addPutRecordsRequest(latencyMillis, requests > 1); - } - if (response != null && !hasErrored.get()) { - if (!hasErrors(response)) { - // Request succeeded, release one permit - pendingRequests.release(); - LOG.debug( - "Done writing {} records [{} ms, {} request(s)]", - totalRecords, - now - handlerStartTime, - requests); - } else { - try { - if (BackOffUtils.next(Sleeper.DEFAULT, backoff)) { - LOG.info(summarizeErrors("Attempting retry", response)); - records = failedRecords(response); - run(); - } else { - throwable = new IOException(summarizeErrors("Exceeded retries", response)); - } - } catch (Throwable e) { - throwable = new IOException(summarizeErrors("Aborted retries", response), e); - } - } - } - } catch (Throwable e) { - throwable = e; - } - if (throwable != null) { - setAsyncFailure(throwable); - } - } - - private void setAsyncFailure(Throwable throwable) { - LOG.warn("Error when writing to Kinesis.", throwable); - hasErrored.set(true); - asyncFailure.updateAndGet( - ex -> { - if (ex != null) { - throwable.addSuppressed(ex); - } - return throwable; - }); - pendingRequests.release(concurrentRequests); // unblock everything to fail fast - } - - private boolean hasErrors(PutRecordsResponse response) { - return response.records().stream().anyMatch(e -> e.errorCode() != null); - } - - private List failedRecords(PutRecordsResponse response) { - return Streams.zip(records.stream(), response.records().stream(), Pair::of) - .filter(p -> p.getRight().errorCode() != null) - .map(p -> p.getLeft()) - .collect(toList()); - } - - private String summarizeErrors(String prefix, PutRecordsResponse response) { - Map countPerError = - response.records().stream() - .filter(e -> e.errorCode() != null) - .map(e -> e.errorCode()) - .collect(groupingBy(identity(), counting())); - return countPerError.entrySet().stream() - .map(kv -> String.format("%s for %d record(s)", kv.getKey(), kv.getValue())) - .collect(joining(", ", prefix + " after failure when writing to Kinesis: ", ".")); - } - } -} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java index cf6796198bb4..6338de1ef309 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java @@ -35,6 +35,7 @@ import java.util.Map; import java.util.NavigableSet; import java.util.TreeSet; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; @@ -42,11 +43,11 @@ import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.Read.Unbounded; +import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler; import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory; import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; import org.apache.beam.sdk.io.aws2.common.ObjectPool; import org.apache.beam.sdk.io.aws2.common.ObjectPool.ClientPool; -import org.apache.beam.sdk.io.aws2.common.RetryConfiguration; import org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.ExplicitPartitioner; import org.apache.beam.sdk.io.aws2.options.AwsOptions; import org.apache.beam.sdk.metrics.Counter; @@ -61,7 +62,6 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.sdk.util.MovingFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -84,7 +84,9 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; +import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; @@ -559,7 +561,7 @@ public Read withCustomRateLimitPolicy(RateLimitPolicyFactory rateLimitPolicyFact * corresponds to the number of in-flight shard events which itself can contain multiple, * potentially even aggregated records. * - * @see {@link #withConsumerArn(String)} + * @see #withConsumerArn(String) */ public Read withMaxCapacityPerShard(Integer maxCapacity) { checkArgument(maxCapacity > 0, "maxCapacity must be positive, but was: %s", maxCapacity); @@ -901,30 +903,32 @@ private static class Writer implements AutoCloseable { protected final Write spec; protected final Stats stats; - protected final AsyncPutRecordsHandler handler; + protected final AsyncBatchWriteHandler handler; protected final KinesisAsyncClient kinesis; - private List requestEntries; private int requestBytes = 0; Writer(PipelineOptions options, Write spec) { ClientConfiguration clientConfig = spec.clientConfiguration(); - RetryConfiguration retryConfig = clientConfig.retry(); - FluentBackoff backoff = FluentBackoff.DEFAULT.withMaxRetries(PARTIAL_RETRIES); - if (retryConfig != null) { - if (retryConfig.throttledBaseBackoff() != null) { - backoff = backoff.withInitialBackoff(retryConfig.throttledBaseBackoff()); - } - if (retryConfig.maxBackoff() != null) { - backoff = backoff.withMaxBackoff(retryConfig.maxBackoff()); - } - } this.spec = spec; this.stats = new Stats(); this.kinesis = CLIENTS.retain(options.as(AwsOptions.class), clientConfig); - this.handler = - new AsyncPutRecordsHandler(kinesis, spec.concurrentRequests(), backoff, stats); this.requestEntries = new ArrayList<>(); + this.handler = + AsyncBatchWriteHandler.byPosition( + spec.concurrentRequests(), + PARTIAL_RETRIES, + clientConfig.retry(), + stats, + (stream, records) -> putRecords(kinesis, stream, records), + r -> r.errorCode()); + } + + private static CompletableFuture> putRecords( + KinesisAsyncClient kinesis, String stream, List records) { + PutRecordsRequest req = + PutRecordsRequest.builder().streamName(stream).records(records).build(); + return kinesis.putRecords(req).thenApply(resp -> resp.records()); } public void startBundle() { @@ -998,7 +1002,7 @@ protected final void asyncFlushEntries() throws Throwable { List recordsToWrite = requestEntries; requestEntries = new ArrayList<>(); requestBytes = 0; - handler.putRecords(spec.streamName(), recordsToWrite); + handler.batchWrite(spec.streamName(), recordsToWrite); } } @@ -1115,7 +1119,7 @@ protected void write(String partitionKey, @Nullable String explicitHashKey, byte } // only check timeouts sporadically if concurrency is already maxed out - if (handler.pendingRequests() < spec.concurrentRequests() || Math.random() < 0.05) { + if (handler.requestsInProgress() < spec.concurrentRequests() || Math.random() < 0.05) { checkAggregationTimeouts(); } } @@ -1275,7 +1279,7 @@ private BigInteger lowerHashKey(Shard shard) { } } - private static class Stats implements AsyncPutRecordsHandler.Stats { + private static class Stats implements AsyncBatchWriteHandler.Stats { private static final Logger LOG = LoggerFactory.getLogger(Stats.class); private static final Duration LOG_STATS_PERIOD = Duration.standardSeconds(10); @@ -1328,7 +1332,7 @@ void addClientRecord(int recordBytes) { } @Override - public void addPutRecordsRequest(long latencyMillis, boolean isPartialRetry) { + public void addBatchWriteRequest(long latencyMillis, boolean isPartialRetry) { long timeMillis = DateTimeUtils.currentTimeMillis(); numPutRequests.add(timeMillis, 1); if (isPartialRetry) { diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java index 72befc6003fe..ac31738154a6 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java @@ -17,10 +17,26 @@ */ package org.apache.beam.sdk.io.aws2.sqs; +import static java.util.Collections.EMPTY_LIST; +import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Consumer; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler; +import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.Stats; import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory; import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; import org.apache.beam.sdk.io.aws2.options.AwsOptions; @@ -31,11 +47,24 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.dataflow.qual.Pure; import org.joda.time.Duration; +import org.joda.time.Instant; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.SendMessageRequest; /** @@ -91,21 +120,28 @@ * then opt to retry the current partition in entirety or abort if the max number of retries of the * runner is reached. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class SqsIO { public static Read read() { return new AutoValue_SqsIO_Read.Builder() - .setClientConfiguration(ClientConfiguration.builder().build()) + .setClientConfiguration(ClientConfiguration.EMPTY) .setMaxNumRecords(Long.MAX_VALUE) .build(); } public static Write write() { return new AutoValue_SqsIO_Write.Builder() - .setClientConfiguration(ClientConfiguration.builder().build()) + .setClientConfiguration(ClientConfiguration.EMPTY) + .build(); + } + + public static WriteBatches writeBatches(WriteBatches.EntryBuilder entryBuilder) { + return new AutoValue_SqsIO_WriteBatches.Builder() + .clientConfiguration(ClientConfiguration.EMPTY) + .concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY) + .batchSize(WriteBatches.MAX_BATCH_SIZE) + .batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT) + .entryBuilder(entryBuilder) .build(); } @@ -124,7 +160,7 @@ public abstract static class Read extends PTransform expand(PBegin input) { AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class); ClientBuilderFactory.validate(awsOptions, clientConfiguration()); @@ -188,7 +225,6 @@ public PCollection expand(PBegin input) { return input.getPipeline().apply(transform); } } - // TODO: Add write batch api to improve performance /** * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage * and configuration. @@ -196,7 +232,7 @@ public PCollection expand(PBegin input) { @AutoValue public abstract static class Write extends PTransform, PDone> { - abstract ClientConfiguration getClientConfiguration(); + abstract @Pure ClientConfiguration getClientConfiguration(); abstract Builder builder(); @@ -225,7 +261,7 @@ public PDone expand(PCollection input) { private static class SqsWriteFn extends DoFn { private final Write spec; - private transient SqsClient sqs; + private transient @MonotonicNonNull SqsClient sqs = null; SqsWriteFn(Write write) { this.spec = write; @@ -241,7 +277,441 @@ public void setup(PipelineOptions options) throws Exception { @ProcessElement public void processElement(ProcessContext processContext) throws Exception { + if (sqs == null) { + throw new IllegalStateException("No SQS client"); + } sqs.sendMessage(processContext.element()); } } + + /** + * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage + * and configuration. + */ + @AutoValue + public abstract static class WriteBatches + extends PTransform, WriteBatches.Result> { + private static final int DEFAULT_CONCURRENCY = 5; + private static final int MAX_BATCH_SIZE = 10; + private static final Duration DEFAULT_BATCH_TIMEOUT = Duration.standardSeconds(3); + + abstract @Pure int concurrentRequests(); + + abstract @Pure Duration batchTimeout(); + + abstract @Pure int batchSize(); + + abstract @Pure ClientConfiguration clientConfiguration(); + + abstract @Pure EntryBuilder entryBuilder(); + + abstract @Pure @Nullable DynamicDestination dynamicDestination(); + + abstract @Pure @Nullable String queueUrl(); + + abstract Builder builder(); + + public interface DynamicDestination extends Serializable { + String queueUrl(T message); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder concurrentRequests(int concurrentRequests); + + abstract Builder batchTimeout(Duration duration); + + abstract Builder batchSize(int batchSize); + + abstract Builder clientConfiguration(ClientConfiguration config); + + abstract Builder entryBuilder(EntryBuilder entryBuilder); + + abstract Builder dynamicDestination(@Nullable DynamicDestination destination); + + abstract Builder queueUrl(@Nullable String queueUrl); + + abstract WriteBatches build(); + } + + /** Configuration of SQS client. */ + public WriteBatches withClientConfiguration(ClientConfiguration config) { + checkArgument(config != null, "ClientConfiguration cannot be null"); + return builder().clientConfiguration(config).build(); + } + + /** Max number of concurrent batch write requests per bundle, default is {@code 5}. */ + public WriteBatches withConcurrentRequests(int concurrentRequests) { + checkArgument(concurrentRequests > 0, "concurrentRequests must be > 0"); + return builder().concurrentRequests(concurrentRequests).build(); + } + + /** The batch size to use, default (and AWS limit) is {@code 10}. */ + public WriteBatches withBatchSize(int batchSize) { + checkArgument( + batchSize > 0 && batchSize <= MAX_BATCH_SIZE, + "Maximum allowed batch size is " + MAX_BATCH_SIZE); + return builder().batchSize(batchSize).build(); + } + + /** + * The duration to accumulate records before timing out, default is 3 secs. + * + *

Timeouts will be checked upon arrival of new messages. + */ + public WriteBatches withBatchTimeout(Duration timeout) { + return builder().batchTimeout(timeout).build(); + } + + /** Dynamic record based destination to write to. */ + public WriteBatches to(DynamicDestination destination) { + checkArgument(destination != null, "DynamicDestination cannot be null"); + return builder().queueUrl(null).dynamicDestination(destination).build(); + } + + /** Queue url to write to. */ + public WriteBatches to(String queueUrl) { + checkArgument(queueUrl != null, "queueUrl cannot be null"); + return builder().dynamicDestination(null).queueUrl(queueUrl).build(); + } + + @Override + public Result expand(PCollection input) { + AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class); + ClientBuilderFactory.validate(awsOptions, clientConfiguration()); + + input.apply( + ParDo.of( + new DoFn() { + private @Nullable BatchHandler handler = null; + + @Setup + public void setup(PipelineOptions options) { + handler = new BatchHandler<>(WriteBatches.this, options.as(AwsOptions.class)); + } + + @StartBundle + public void startBundle() { + handler().startBundle(); + } + + @ProcessElement + public void processElement(ProcessContext cxt) throws Throwable { + handler().process(cxt.element()); + } + + @FinishBundle + public void finishBundle() throws Throwable { + handler().finishBundle(); + } + + @Teardown + public void teardown() throws Exception { + if (handler != null) { + handler.close(); + handler = null; + } + } + + private BatchHandler handler() { + return checkStateNotNull(handler, "SQS handler is null"); + } + })); + return new Result(input.getPipeline()); + } + + /** Batch entry builder. */ + public interface EntryBuilder + extends BiConsumer, Serializable {} + + /** Result of {@link #writeBatches}. */ + public static class Result implements POutput { + private final Pipeline pipeline; + + private Result(Pipeline pipeline) { + this.pipeline = pipeline; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Map, PValue> expand() { + return ImmutableMap.of(); + } + + @Override + public void finishSpecifyingOutput( + String transformName, PInput input, PTransform transform) {} + } + + private static class BatchHandler implements AutoCloseable { + private final WriteBatches spec; + private final SqsAsyncClient sqs; + private final Batches batches; + private final AsyncBatchWriteHandler + handler; + + BatchHandler(WriteBatches spec, AwsOptions options) { + this.spec = spec; + this.sqs = buildClient(options, SqsAsyncClient.builder(), spec.clientConfiguration()); + this.handler = + AsyncBatchWriteHandler.byId( + spec.concurrentRequests(), + spec.batchSize(), + spec.clientConfiguration().retry(), + Stats.NONE, + (queue, records) -> sendMessageBatch(sqs, queue, records), + error -> error.code(), + record -> record.id(), + error -> error.id()); + if (spec.queueUrl() != null) { + this.batches = new Single(spec.queueUrl()); + } else if (spec.dynamicDestination() != null) { + this.batches = new Dynamic(spec.dynamicDestination()); + } else { + throw new IllegalStateException("to(queueUrl) or to(dynamicDestination) is required"); + } + } + + private static CompletableFuture> sendMessageBatch( + SqsAsyncClient sqs, String queue, List records) { + SendMessageBatchRequest request = + SendMessageBatchRequest.builder().queueUrl(queue).entries(records).build(); + return sqs.sendMessageBatch(request).thenApply(resp -> resp.failed()); + } + + public void startBundle() { + handler.reset(); + } + + public void process(T msg) { + SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder(); + spec.entryBuilder().accept(builder, msg); + SendMessageBatchRequestEntry entry = builder.id(batches.nextId()).build(); + + Batch batch = batches.getLocked(msg); + batch.add(entry); + if (batch.size() >= spec.batchSize() || batch.isExpired()) { + writeEntries(batch, true); + } else { + checkState(batch.lock(false)); // unlock to continue writing to batch + } + + // check timeouts synchronously on arrival of new messages + batches.writeExpired(true); + } + + private void writeEntries(Batch batch, boolean throwPendingFailures) { + try { + handler.batchWrite(batch.queue, batch.getAndClear(), throwPendingFailures); + } catch (RuntimeException e) { + throw e; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + public void finishBundle() throws Throwable { + batches.writeAll(); + handler.waitForCompletion(); + } + + @Override + public void close() throws Exception { + sqs.close(); + } + + /** + * Batch(es) of a single fixed or several dynamic queues. + * + *

{@link #getLocked} is meant to support atomic writes from multiple threads if using an + * appropriate thread-safe implementation. This is necessary to later support strict timeouts + * (see below). + * + *

For simplicity, check for expired messages after appending to a batch. For strict + * enforcement of timeouts, {@link #writeExpired} would have to be periodically called using a + * scheduler and requires also a thread-safe impl of {@link Batch#lock(boolean)}. + */ + private abstract class Batches { + private int nextId = 0; // only ever used from one "runner" thread + + abstract int maxBatches(); + + /** Next batch entry id is guaranteed to be unique for all open batches. */ + String nextId() { + if (nextId >= (spec.batchSize() * maxBatches())) { + nextId = 0; + } + return Integer.toString(nextId++); + } + + /** Get existing or new locked batch that can be written to. */ + abstract Batch getLocked(T record); + + /** Write all remaining batches (that can be locked). */ + abstract void writeAll(); + + /** Write all expired batches (that can be locked). */ + abstract void writeExpired(boolean throwPendingFailures); + + /** Create a new locked batch that is ready for writing. */ + Batch createLocked(String queue) { + return new Batch(queue, spec.batchSize(), spec.batchTimeout()); + } + + /** Write a batch if it can be locked. */ + protected boolean writeLocked(Batch batch, boolean throwPendingFailures) { + if (batch.lock(true)) { + writeEntries(batch, throwPendingFailures); + return true; + } + return false; + } + } + + /** Batch of a single, fixed queue. */ + @NotThreadSafe + private class Single extends Batches { + private Batch batch; + + Single(String queue) { + this.batch = new Batch(queue, EMPTY_LIST, Batch.NEVER); // locked + } + + @Override + int maxBatches() { + return 1; + } + + @Override + Batch getLocked(T record) { + return batch.lock(true) ? batch : (batch = createLocked(batch.queue)); + } + + @Override + void writeAll() { + writeLocked(batch, true); + } + + @Override + void writeExpired(boolean throwPendingFailures) { + if (batch.isExpired()) { + writeLocked(batch, throwPendingFailures); + } + } + } + + /** Batches of one or several dynamic queues. */ + @NotThreadSafe + private class Dynamic extends Batches { + @SuppressWarnings("method.invocation") // necessary dependencies are initialized + private final BiFunction<@NonNull String, @Nullable Batch, Batch> getLocked = + (queue, batch) -> batch != null && batch.lock(true) ? batch : createLocked(queue); + + private final Map<@NonNull String, Batch> batches = new HashMap<>(); + private final DynamicDestination destination; + private Instant nextTimeout = Batch.NEVER; + + Dynamic(DynamicDestination destination) { + this.destination = destination; + } + + @Override + int maxBatches() { + return batches.size() + 1; // next record and id might belong to new batch + } + + @Override + Batch getLocked(T record) { + return batches.compute(destination.queueUrl(record), getLocked); + } + + @Override + void writeAll() { + batches.values().forEach(batch -> writeLocked(batch, true)); + batches.clear(); + nextTimeout = Batch.NEVER; + } + + private void writeExpired(Batch batch) { + if (!batch.isExpired() || !writeLocked(batch, true)) { + // find next timeout for remaining, unwritten batches + if (batch.timeout.isBefore(nextTimeout)) { + nextTimeout = batch.timeout; + } + } + } + + @Override + void writeExpired(boolean throwPendingFailures) { + if (nextTimeout.isBeforeNow()) { + nextTimeout = Batch.NEVER; + batches.values().forEach(this::writeExpired); + } + } + + @Override + Batch createLocked(String queue) { + Batch batch = super.createLocked(queue); + if (batch.timeout.isBefore(nextTimeout)) { + nextTimeout = batch.timeout; + } + return batch; + } + } + } + + /** + * Batch of entries of a queue. + * + *

Overwrite {@link #lock} with a thread-safe implementation to support concurrent usage. + */ + @NotThreadSafe + private static final class Batch { + private static final Instant NEVER = Instant.ofEpochMilli(Long.MAX_VALUE); + private final String queue; + private Instant timeout; + private List entries; + + Batch(String queue, int size, Duration bufferedTime) { + this(queue, new ArrayList<>(size), Instant.now().plus(bufferedTime)); + } + + Batch(String queue, List entries, Instant timeout) { + this.queue = queue; + this.entries = entries; + this.timeout = timeout; + } + + /** Attempt to un/lock this batch and return if successful. */ + boolean lock(boolean lock) { + // thread unsafe dummy impl that rejects locking batches after getAndClear + return !NEVER.equals(timeout) || !lock; + } + + /** Get and clear entries for writing. */ + List getAndClear() { + List res = entries; + entries = EMPTY_LIST; + timeout = NEVER; + return res; + } + + /** Add entry to this batch. */ + void add(SendMessageBatchRequestEntry entry) { + entries.add(entry); + } + + int size() { + return entries.size(); + } + + boolean isExpired() { + return timeout.isBeforeNow(); + } + } + } } diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java new file mode 100644 index 000000000000..339fa673ed17 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.common; + +import static java.util.Collections.emptyList; +import static java.util.concurrent.ForkJoinPool.commonPool; +import static java.util.function.Function.identity; +import static org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.byId; +import static org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.byPosition; +import static org.apache.beam.sdk.util.FluentBackoff.DEFAULT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.joda.time.Duration.millis; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.Stats; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.function.ThrowingRunnable; +import org.mockito.Mockito; + +public class AsyncBatchWriteHandlerTest { + private static final int CONCURRENCY = 10; + + private CompletableFuture> resultsByPos = new CompletableFuture<>(); + private CompletableFuture> errorsById = new CompletableFuture<>(); + + private AsyncBatchWriteHandler byPositionHandler(FluentBackoff backoff) { + SubmitFn submitFn = Mockito.spy(new SubmitFn<>(() -> resultsByPos)); + Function errorFn = success -> success ? null : "REASON"; + return byPosition(CONCURRENCY, backoff, Stats.NONE, submitFn, errorFn); + } + + private AsyncBatchWriteHandler byIdHandler(FluentBackoff backoff) { + SubmitFn submitFn = Mockito.spy(new SubmitFn<>(() -> errorsById)); + Function errorFn = err -> "REASON"; + return byId(CONCURRENCY, backoff, Stats.NONE, submitFn, errorFn, identity(), identity()); + } + + @Test + public void retryOnPartialSuccessByPosition() throws Throwable { + AsyncBatchWriteHandler handler = + byPositionHandler(DEFAULT.withMaxBackoff(millis(1))); + CompletableFuture> pendingResponse1 = new CompletableFuture<>(); + CompletableFuture> pendingResponse2 = new CompletableFuture<>(); + CompletableFuture> pendingResponse3 = new CompletableFuture<>(); + + resultsByPos = pendingResponse1; + handler.batchWrite("destination", ImmutableList.of(1, 2, 3, 4)); + + // 1st attempt + eventually(5, () -> verify(handler.submitFn, times(1)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of(1, 2, 3, 4)); + assertThat(handler.requestsInProgress()).isEqualTo(1); + + resultsByPos = pendingResponse2; + pendingResponse1.complete(ImmutableList.of(true, true, false, false)); + + // 2nd attempt + eventually(5, () -> verify(handler.submitFn, times(2)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of(3, 4)); + assertThat(handler.requestsInProgress()).isEqualTo(1); + + // 3rd attempt + resultsByPos = pendingResponse3; + pendingResponse2.complete(ImmutableList.of(true, false)); + + eventually(5, () -> verify(handler.submitFn, times(3)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of(4)); + + assertThat(handler.requestsInProgress()).isEqualTo(1); + + // 4th attempt + pendingResponse3.complete(ImmutableList.of(true)); // success + + eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(0)); + verify(handler.submitFn, times(3)).apply(anyString(), anyList()); + } + + @Test + public void retryOnPartialSuccessById() throws Throwable { + AsyncBatchWriteHandler handler = byIdHandler(DEFAULT.withMaxBackoff(millis(1))); + CompletableFuture> pendingResponse1 = new CompletableFuture<>(); + CompletableFuture> pendingResponse2 = new CompletableFuture<>(); + CompletableFuture> pendingResponse3 = new CompletableFuture<>(); + + errorsById = pendingResponse1; + handler.batchWrite("destination", ImmutableList.of("1", "2", "3", "4")); + + // 1st attempt + eventually(5, () -> verify(handler.submitFn, times(1)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of("1", "2", "3", "4")); + assertThat(handler.requestsInProgress()).isEqualTo(1); + + errorsById = pendingResponse2; + pendingResponse1.complete(ImmutableList.of("3", "4")); + + // 2nd attempt + eventually(5, () -> verify(handler.submitFn, times(2)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of("3", "4")); + assertThat(handler.requestsInProgress()).isEqualTo(1); + + // 3rd attempt + errorsById = pendingResponse3; + pendingResponse2.complete(ImmutableList.of("4")); + + eventually(5, () -> verify(handler.submitFn, times(3)).apply(anyString(), anyList())); + verify(handler.submitFn).apply("destination", ImmutableList.of("4")); + + assertThat(handler.requestsInProgress()).isEqualTo(1); + + // 4th attempt + pendingResponse3.complete(ImmutableList.of()); // success + + eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(0)); + verify(handler.submitFn, times(3)).apply(anyString(), anyList()); + } + + @Test + public void retryLimitOnPartialSuccessByPosition() throws Throwable { + AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT.withMaxRetries(0)); + + handler.batchWrite("destination", ImmutableList.of(1, 2, 3, 4)); + + resultsByPos.complete(ImmutableList.of(true, true, false, false)); + + assertThatThrownBy(() -> handler.waitForCompletion()) + .hasMessageContaining("Exceeded retries") + .hasMessageEndingWith("REASON for 2 record(s).") + .isInstanceOf(IOException.class); + verify(handler.submitFn).apply("destination", ImmutableList.of(1, 2, 3, 4)); + } + + @Test + public void retryLimitOnPartialSuccessById() throws Throwable { + AsyncBatchWriteHandler handler = byIdHandler(DEFAULT.withMaxRetries(0)); + + handler.batchWrite("destination", ImmutableList.of("1", "2", "3", "4")); + + errorsById.complete(ImmutableList.of("3", "4")); + + assertThatThrownBy(() -> handler.waitForCompletion()) + .hasMessageContaining("Exceeded retries") + .hasMessageEndingWith("REASON for 2 record(s).") + .isInstanceOf(IOException.class); + verify(handler.submitFn).apply("destination", ImmutableList.of("1", "2", "3", "4")); + } + + @Test + public void propagateErrorOnPutRecords() throws Throwable { + AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT); + handler.batchWrite("destination", emptyList()); + resultsByPos.completeExceptionally(new RuntimeException("Request failed")); + + assertThatThrownBy(() -> handler.batchWrite("destination", emptyList())) + .hasMessage("Request failed"); + assertThat(handler.hasErrored()).isTrue(); + verify(handler.submitFn).apply("destination", emptyList()); + } + + @Test + public void propagateErrorWhenPolling() throws Throwable { + AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT); + handler.batchWrite("destination", emptyList()); + handler.checkForAsyncFailure(); // none yet + resultsByPos.completeExceptionally(new RuntimeException("Request failed")); + + assertThatThrownBy(() -> handler.checkForAsyncFailure()).hasMessage("Request failed"); + assertThat(handler.hasErrored()).isTrue(); + handler.checkForAsyncFailure(); // already reset + } + + @Test + public void propagateErrorOnWaitForCompletion() throws Throwable { + AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT); + handler.batchWrite("destination", emptyList()); + resultsByPos.completeExceptionally(new RuntimeException("Request failed")); + + assertThatThrownBy(() -> handler.waitForCompletion()).hasMessage("Request failed"); + } + + @Test + public void correctlyLimitConcurrency() throws Throwable { + AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT); + + // exhaust concurrency limit so that batchWrite blocks + Runnable task = repeat(CONCURRENCY + 1, () -> handler.batchWrite("destination", emptyList())); + Future future = commonPool().submit(task); + + eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(CONCURRENCY)); + eventually( + 5, () -> verify(handler.submitFn, times(CONCURRENCY)).apply("destination", emptyList())); + assertThat(future).isNotDone(); + + // complete responses and unblock last request + resultsByPos.complete(emptyList()); + + eventually( + 5, + () -> verify(handler.submitFn, times(CONCURRENCY + 1)).apply("destination", emptyList())); + handler.waitForCompletion(); + assertThat(future).isDone(); + } + + static class SubmitFn implements BiFunction, CompletableFuture>> { + private final Supplier>> resp; + + SubmitFn(Supplier>> resp) { + this.resp = resp; + } + + @Override + public CompletableFuture> apply(String destination, List input) { + return resp.get(); + } + } + + private void eventually(int attempts, Runnable fun) { + for (int i = 0; i < attempts - 1; i++) { + try { + Thread.sleep(i * 100); + fun.run(); + return; + } catch (AssertionError | InterruptedException t) { + } + } + fun.run(); + } + + private Runnable repeat(int times, ThrowingRunnable fun) { + return () -> { + for (int i = 0; i < times; i++) { + try { + fun.run(); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + }; + } +} diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java deleted file mode 100644 index 6f0b92654de4..000000000000 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.kinesis; - -import static java.util.Collections.emptyList; -import static java.util.concurrent.ForkJoinPool.commonPool; -import static org.apache.beam.sdk.io.common.TestRow.getExpectedValues; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Future; -import java.util.function.Supplier; -import org.apache.beam.sdk.util.BackOff; -import org.junit.Before; -import org.junit.Test; -import org.junit.function.ThrowingRunnable; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; -import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; - -@RunWith(MockitoJUnitRunner.StrictStubs.class) -public class AsyncPutRecordsHandlerTest extends PutRecordsHelpers { - private static final String STREAM = "streamName"; - private static final int CONCURRENCY = 10; - - private CompletableFuture pendingResponse = new CompletableFuture<>(); - - @Mock private KinesisAsyncClient client; - @Mock private Supplier backoff; - private AsyncPutRecordsHandler handler; - - @Before - public void init() { - handler = - new AsyncPutRecordsHandler( - client, CONCURRENCY, backoff, mock(AsyncPutRecordsHandler.Stats.class)); - when(client.putRecords(anyRequest())).thenReturn(pendingResponse); - } - - @Test - public void retryOnPartialSuccess() throws Throwable { - when(backoff.get()).thenReturn(BackOff.ZERO_BACKOFF); - CompletableFuture pendingResponse2 = new CompletableFuture<>(); - CompletableFuture pendingResponse3 = new CompletableFuture<>(); - when(client.putRecords(anyRequest())) - .thenReturn(pendingResponse, pendingResponse2, pendingResponse3); - - List records = fromTestRows(getExpectedValues(0, 100)); - handler.putRecords(STREAM, records); - - // 1st attempt - eventually(5, () -> verify(client, times(1)).putRecords(anyRequest())); - verify(client).putRecords(request(records)); - assertThat(handler.pendingRequests()).isEqualTo(1); - - // 2nd attempt - pendingResponse.complete(partialSuccessResponse(50, 50)); - eventually(5, () -> verify(client, times(2)).putRecords(anyRequest())); - verify(client).putRecords(request(records.subList(50, 100))); - assertThat(handler.pendingRequests()).isEqualTo(1); - - // 3rd attempt - pendingResponse2.complete(partialSuccessResponse(25, 25)); - eventually(5, () -> verify(client, times(3)).putRecords(anyRequest())); - verify(client).putRecords(request(records.subList(75, 100))); - assertThat(handler.pendingRequests()).isEqualTo(1); - - // 4th attempt - pendingResponse3.complete(PutRecordsResponse.builder().build()); // success - verifyNoMoreInteractions(client); - - eventually(5, () -> assertThat(handler.pendingRequests()).isEqualTo(0)); - } - - @Test - public void retryLimitOnPartialSuccess() throws Throwable { - when(backoff.get()).thenReturn(BackOff.STOP_BACKOFF); - - List records = fromTestRows(getExpectedValues(0, 100)); - handler.putRecords(STREAM, records); - - pendingResponse.complete(partialSuccessResponse(98, 2)); - - assertThatThrownBy(() -> handler.waitForCompletion()) - .hasMessageContaining("Exceeded retries") - .hasMessageEndingWith(ERROR_CODE + " for 2 record(s).") - .isInstanceOf(IOException.class); - verify(client).putRecords(anyRequest()); - } - - @Test - public void propagateErrorOnPutRecords() throws Throwable { - handler.putRecords(STREAM, emptyList()); - pendingResponse.completeExceptionally(new RuntimeException("Request failed")); - - assertThatThrownBy(() -> handler.putRecords(STREAM, emptyList())).hasMessage("Request failed"); - assertThat(handler.hasErrored()).isTrue(); - verify(client).putRecords(anyRequest()); - } - - @Test - public void propagateErrorWhenPolling() throws Throwable { - handler.putRecords(STREAM, emptyList()); - handler.checkForAsyncFailure(); // none yet - pendingResponse.completeExceptionally(new RuntimeException("Request failed")); - - assertThatThrownBy(() -> handler.checkForAsyncFailure()).hasMessage("Request failed"); - assertThat(handler.hasErrored()).isTrue(); - handler.checkForAsyncFailure(); // already reset - } - - @Test - public void propagateErrorOnWaitForCompletion() throws Throwable { - handler.putRecords(STREAM, emptyList()); - pendingResponse.completeExceptionally(new RuntimeException("Request failed")); - - assertThatThrownBy(() -> handler.waitForCompletion()).hasMessage("Request failed"); - } - - @Test - public void correctlyLimitConcurrency() throws Throwable { - // exhaust concurrency limit so that putRecords blocks - Runnable task = repeat(CONCURRENCY + 1, () -> handler.putRecords(STREAM, emptyList())); - Future future = commonPool().submit(task); - - eventually(5, () -> assertThat(handler.pendingRequests()).isEqualTo(CONCURRENCY)); - eventually(5, () -> verify(client, times(CONCURRENCY)).putRecords(anyRequest())); - assertThat(future).isNotDone(); - - // complete responses and unblock last request - pendingResponse.complete(PutRecordsResponse.builder().build()); - - eventually(5, () -> verify(client, times(CONCURRENCY + 1)).putRecords(anyRequest())); - handler.waitForCompletion(); - assertThat(future).isDone(); - } - - private PutRecordsRequest request(List records) { - return PutRecordsRequest.builder().streamName(STREAM).records(records).build(); - } - - private void eventually(int attempts, Runnable fun) { - for (int i = 0; i < attempts - 1; i++) { - try { - Thread.sleep(i * 100); - fun.run(); - return; - } catch (AssertionError | InterruptedException t) { - } - } - fun.run(); - } - - private Runnable repeat(int times, ThrowingRunnable fun) { - return () -> { - for (int i = 0; i < times; i++) { - try { - fun.run(); - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - }; - } -} diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java new file mode 100644 index 000000000000..aeb9122df9f2 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.CompletableFuture.supplyAsync; +import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.range; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.joda.time.Duration.millis; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory; +import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler; +import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; +import org.apache.beam.sdk.io.aws2.common.RetryConfiguration; +import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches; +import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryBuilder; +import org.apache.beam.sdk.testing.ExpectedLogs; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams; +import org.joda.time.Duration; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.SqsAsyncClientBuilder; +import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; + +/** Tests for {@link WriteBatches}. */ +@RunWith(MockitoJUnitRunner.class) +public class SqsIOWriteBatchesTest { + private static final EntryBuilder SET_MESSAGE_BODY = + SendMessageBatchRequestEntry.Builder::messageBody; + private static final SendMessageBatchResponse SUCCESS = + SendMessageBatchResponse.builder().build(); + + @Rule public TestPipeline p = TestPipeline.create(); + @Mock public SqsAsyncClient sqs; + @Rule public ExpectedLogs logs = ExpectedLogs.none(AsyncBatchWriteHandler.class); + + @Before + public void configureClientBuilderFactory() { + MockClientBuilderFactory.set(p, SqsAsyncClientBuilder.class, sqs); + } + + @Test + public void testWriteBatches() { + when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS)); + + p.apply(Create.of(23)) + .apply(ParDo.of(new CreateMessages())) + .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue")); + + p.run().waitUntilFinish(); + + verify(sqs).sendMessageBatch(request("queue", range(0, 10))); + verify(sqs).sendMessageBatch(request("queue", range(10, 20))); + verify(sqs).sendMessageBatch(request("queue", range(20, 23))); + + verify(sqs).close(); + verifyNoMoreInteractions(sqs); + } + + @Test + public void testWriteBatchesFailure() { + when(sqs.sendMessageBatch(anyRequest())) + .thenReturn( + completedFuture(SUCCESS), + supplyAsync(() -> checkNotNull(null, "sendMessageBatch failed")), + completedFuture(SUCCESS)); + + p.apply(Create.of(23)) + .apply(ParDo.of(new CreateMessages())) + .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue")); + + assertThatThrownBy(() -> p.run().waitUntilFinish()) + .isInstanceOf(Pipeline.PipelineExecutionException.class) + .hasMessageContaining("sendMessageBatch failed"); + } + + @Test + public void testWriteBatchesPartialSuccess() { + SendMessageBatchRequestEntry[] entries = entries(range(0, 10)); + when(sqs.sendMessageBatch(anyRequest())) + .thenReturn( + completedFuture(partialSuccessResponse(entries[2].id(), entries[3].id())), + completedFuture(partialSuccessResponse(entries[3].id())), + completedFuture(SUCCESS)); + + p.apply(Create.of(23)) + .apply(ParDo.of(new CreateMessages())) + .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue")); + + p.run().waitUntilFinish(); + + verify(sqs).sendMessageBatch(request("queue", entries)); + verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3])); + verify(sqs).sendMessageBatch(request("queue", entries[3])); + verify(sqs).sendMessageBatch(request("queue", range(10, 20))); + verify(sqs).sendMessageBatch(request("queue", range(20, 23))); + + verify(sqs).close(); + verifyNoMoreInteractions(sqs); + + logs.verifyInfo("retry after partial failure: code REASON for 2 record(s)"); + logs.verifyInfo("retry after partial failure: code REASON for 1 record(s)"); + } + + @Test + public void testWriteCustomBatches() { + when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS)); + + p.apply(Create.of(8)) + .apply(ParDo.of(new CreateMessages())) + .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).withBatchSize(3).to("queue")); + + p.run().waitUntilFinish(); + + verify(sqs).sendMessageBatch(request("queue", range(0, 3))); + verify(sqs).sendMessageBatch(request("queue", range(3, 6))); + verify(sqs).sendMessageBatch(request("queue", range(6, 8))); + + verify(sqs).close(); + verifyNoMoreInteractions(sqs); + } + + @Test + public void testWriteBatchesToDynamic() { + when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS)); + + // minimize delay due to retries + RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(millis(1)).build(); + + p.apply(Create.of(10)) + .apply(ParDo.of(new CreateMessages())) + .apply( + SqsIO.writeBatches(SET_MESSAGE_BODY) + .withClientConfiguration(ClientConfiguration.builder().retry(retry).build()) + .withBatchSize(3) + .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven")); + + p.run().waitUntilFinish(); + + // id generator creates ids in range of [0, batch size * (queues + 1)) + SendMessageBatchRequestEntry[] entries = entries(range(0, 9), range(9, 10)); + + verify(sqs).sendMessageBatch(request("even", entries[0], entries[2], entries[4])); + verify(sqs).sendMessageBatch(request("uneven", entries[1], entries[3], entries[5])); + verify(sqs).sendMessageBatch(request("even", entries[6], entries[8])); + verify(sqs).sendMessageBatch(request("uneven", entries[7], entries[9])); + + verify(sqs).close(); + verifyNoMoreInteractions(sqs); + } + + @Test + public void testWriteBatchesWithTimeout() { + when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS)); + + p.apply(Create.of(5)) + .apply(ParDo.of(new CreateMessages())) + .apply( + // simulate delay between messages > batch timeout + SqsIO.writeBatches(withDelay(millis(200), SET_MESSAGE_BODY)) + .withBatchTimeout(millis(100)) + .to("queue")); + + p.run().waitUntilFinish(); + + SendMessageBatchRequestEntry[] entries = entries(range(0, 5)); + // due to added delay, batches are timed out on arrival of every 2nd msg + verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1])); + verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3])); + verify(sqs).sendMessageBatch(request("queue", entries[4])); + } + + private SendMessageBatchRequest anyRequest() { + return any(); + } + + private SendMessageBatchRequest request(String queue, SendMessageBatchRequestEntry... entries) { + return SendMessageBatchRequest.builder() + .queueUrl(queue) + .entries(Arrays.asList(entries)) + .build(); + } + + private SendMessageBatchRequest request(String queue, IntStream msgs) { + return request(queue, entries(msgs)); + } + + private SendMessageBatchRequestEntry[] entries(IntStream... msgStreams) { + return Arrays.stream(msgStreams) + .flatMap(msgs -> Streams.mapWithIndex(msgs, this::entry)) + .toArray(SendMessageBatchRequestEntry[]::new); + } + + private SendMessageBatchRequestEntry entry(int msg, long id) { + return SendMessageBatchRequestEntry.builder() + .id(Long.toString(id)) + .messageBody(Integer.toString(msg)) + .build(); + } + + private SendMessageBatchResponse partialSuccessResponse(String... failedIds) { + Stream errors = + Arrays.stream(failedIds) + .map(BatchResultErrorEntry.builder()::id) + .map(b -> b.code("REASON").build()); + return SendMessageBatchResponse.builder().failed(errors.collect(toList())).build(); + } + + private static class CreateMessages extends DoFn { + @ProcessElement + public void processElement(@Element Integer count, OutputReceiver out) { + for (int i = 0; i < count; i++) { + out.output(Integer.toString(i)); + } + } + } + + private static EntryBuilder withDelay(Duration delay, EntryBuilder builder) { + return (t1, t2) -> { + builder.accept(t1, t2); + try { + Thread.sleep(delay.getMillis()); + } catch (InterruptedException e) { + } + }; + } +} diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java index f1e176f572d6..2f10f3d08f1d 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric; import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS; import java.io.Serializable; @@ -103,6 +104,33 @@ public void testWriteThenRead() { pipelineRead.run(); } + @Test + public void testWriteBatchesThenRead() { + int rows = env.options().getNumberOfRows(); + + // Write test dataset to SQS. + pipelineWrite + .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) + .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) + .apply( + "Write to SQS", + SqsIO.writeBatches((b, row) -> b.messageBody(row.name())).to(sqsQueue.url)); + + // Read test dataset from SQS. + PCollection output = + pipelineRead + .apply("Read from SQS", SqsIO.read().withQueueUrl(sqsQueue.url).withMaxNumRecords(rows)) + .apply("Extract body", MapElements.into(strings()).via(SqsMessage::getBody)); + + PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); + + PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) + .containsInAnyOrder(getExpectedHashForRowCount(rows)); + + pipelineWrite.run(); + pipelineRead.run(); + } + private static class SqsQueue extends ExternalResource implements Serializable { private transient SqsClient client = env.buildClient(SqsClient.builder()); private String url; @@ -113,7 +141,8 @@ SendMessageRequest messageRequest(TestRow r) { @Override protected void before() throws Throwable { - url = client.createQueue(b -> b.queueName("beam-sqsio-it")).queueUrl(); + url = + client.createQueue(b -> b.queueName("beam-sqsio-it-" + randomAlphanumeric(4))).queueUrl(); } @Override