diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java
new file mode 100644
index 000000000000..892961fa7d86
--- /dev/null
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandler.java
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws2.common;
+
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.counting;
+import static java.util.stream.Collectors.groupingBy;
+import static java.util.stream.Collectors.joining;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.notNull;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.BackOffUtils;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.Sleeper;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.joda.time.DateTimeUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Async handler that automatically retries unprocessed records in case of a partial success.
+ *
+ *
The handler enforces the provided upper limit of concurrent requests. Once that limit is
+ * reached any further call to {@link #batchWrite(String, List)} will block until another request
+ * completed.
+ *
+ *
The handler is fail fast and won't submit any further request after a failure. Async failures
+ * can be polled using {@link #checkForAsyncFailure()}.
+ *
+ * @param Record type in batch
+ * @param Potentially erroneous result that needs to be correlated to a record using {@link
+ * #failedRecords(List, List)}
+ */
+@NotThreadSafe
+@Internal
+public abstract class AsyncBatchWriteHandler {
+ private static final Logger LOG = LoggerFactory.getLogger(AsyncBatchWriteHandler.class);
+ private final FluentBackoff backoff;
+ private final int concurrentRequests;
+ private final Stats stats;
+ protected final BiFunction, CompletableFuture>> submitFn;
+ protected final Function errorCodeFn;
+
+ private AtomicBoolean hasErrored;
+ private AtomicReference asyncFailure;
+ private Semaphore requestPermits;
+
+ protected AsyncBatchWriteHandler(
+ int concurrency,
+ FluentBackoff backoff,
+ Stats stats,
+ Function errorCodeFn,
+ BiFunction, CompletableFuture>> submitFn) {
+ this.backoff = backoff;
+ this.concurrentRequests = concurrency;
+ this.errorCodeFn = errorCodeFn;
+ this.submitFn = submitFn;
+ this.hasErrored = new AtomicBoolean(false);
+ this.asyncFailure = new AtomicReference<>();
+ this.requestPermits = new Semaphore(concurrentRequests);
+ this.stats = stats;
+ }
+
+ public final int requestsInProgress() {
+ return concurrentRequests - requestPermits.availablePermits();
+ }
+
+ public final void reset() {
+ hasErrored = new AtomicBoolean(false);
+ asyncFailure = new AtomicReference<>();
+ requestPermits = new Semaphore(concurrentRequests);
+ }
+
+ /** If this handler has errored since it was last reset. */
+ public final boolean hasErrored() {
+ return hasErrored.get();
+ }
+
+ /**
+ * Check if any failure happened async.
+ *
+ * @throws Throwable The last async failure, afterwards reset it.
+ */
+ public final void checkForAsyncFailure() throws Throwable {
+ @SuppressWarnings("nullness")
+ Throwable failure = asyncFailure.getAndSet(null);
+ if (failure != null) {
+ throw failure;
+ }
+ }
+
+ /**
+ * Wait for all pending requests to complete and check for failures.
+ *
+ * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()}
+ */
+ public final void waitForCompletion() throws Throwable {
+ requestPermits.acquireUninterruptibly(concurrentRequests);
+ checkForAsyncFailure();
+ }
+
+ /**
+ * Asynchronously trigger a batch write request (unless already in error state).
+ *
+ * This will respect the concurrency limit of the handler and first wait for a permit.
+ *
+ * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()}
+ */
+ public final void batchWrite(String destination, List records) throws Throwable {
+ batchWrite(destination, records, true);
+ }
+
+ /**
+ * Asynchronously trigger a batch write request (unless already in error state).
+ *
+ * This will respect the concurrency limit of the handler and first wait for a permit.
+ *
+ * @param throwAsyncFailures If to check and throw pending async failures
+ * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()}
+ */
+ public final void batchWrite(String destination, List records, boolean throwAsyncFailures)
+ throws Throwable {
+ if (!hasErrored()) {
+ requestPermits.acquireUninterruptibly();
+ new RetryHandler(destination, records).run();
+ }
+ if (throwAsyncFailures) {
+ checkForAsyncFailure();
+ }
+ }
+
+ protected abstract List failedRecords(List records, List results);
+
+ protected abstract boolean hasFailedRecords(List results);
+
+ /** Statistics on the batch request. */
+ public interface Stats {
+ Stats NONE = new Stats() {};
+
+ default void addBatchWriteRequest(long latencyMillis, boolean isPartialRetry) {}
+ }
+
+ /**
+ * AsyncBatchWriteHandler that correlates records and results by position in the respective list.
+ */
+ public static AsyncBatchWriteHandler byPosition(
+ int concurrency,
+ int partialRetries,
+ @Nullable RetryConfiguration retry,
+ Stats stats,
+ BiFunction, CompletableFuture>> submitFn,
+ Function errorCodeFn) {
+ FluentBackoff backoff = retryBackoff(partialRetries, retry);
+ return byPosition(concurrency, backoff, stats, submitFn, errorCodeFn);
+ }
+
+ /**
+ * AsyncBatchWriteHandler that correlates records and results by position in the respective list.
+ */
+ public static AsyncBatchWriteHandler byPosition(
+ int concurrency,
+ FluentBackoff backoff,
+ Stats stats,
+ BiFunction, CompletableFuture>> submitFn,
+ Function errorCodeFn) {
+ return new AsyncBatchWriteHandler(
+ concurrency, backoff, stats, errorCodeFn, submitFn) {
+
+ @Override
+ protected boolean hasFailedRecords(List results) {
+ for (int i = 0; i < results.size(); i++) {
+ if (errorCodeFn.apply(results.get(i)) != null) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ protected List failedRecords(List records, List results) {
+ int size = Math.min(records.size(), results.size());
+ List filtered = new ArrayList<>();
+ for (int i = 0; i < size; i++) {
+ if (errorCodeFn.apply(results.get(i)) != null) {
+ filtered.add(records.get(i));
+ }
+ }
+ return filtered;
+ }
+ };
+ }
+
+ /**
+ * AsyncBatchWriteHandler that correlates records and results by id, all results are erroneous.
+ */
+ public static AsyncBatchWriteHandler byId(
+ int concurrency,
+ int partialRetries,
+ @Nullable RetryConfiguration retry,
+ Stats stats,
+ BiFunction, CompletableFuture>> submitFn,
+ Function errorCodeFn,
+ Function recordIdFn,
+ Function errorIdFn) {
+ FluentBackoff backoff = retryBackoff(partialRetries, retry);
+ return byId(concurrency, backoff, stats, submitFn, errorCodeFn, recordIdFn, errorIdFn);
+ }
+
+ /**
+ * AsyncBatchWriteHandler that correlates records and results by id, all results are erroneous.
+ */
+ public static AsyncBatchWriteHandler byId(
+ int concurrency,
+ FluentBackoff backoff,
+ Stats stats,
+ BiFunction, CompletableFuture>> submitFn,
+ Function errorCodeFn,
+ Function recordIdFn,
+ Function errorIdFn) {
+ return new AsyncBatchWriteHandler(
+ concurrency, backoff, stats, errorCodeFn, submitFn) {
+ @Override
+ protected boolean hasFailedRecords(List errors) {
+ return !errors.isEmpty();
+ }
+
+ @Override
+ protected List failedRecords(List records, List errors) {
+ Set ids = Sets.newHashSetWithExpectedSize(errors.size());
+ errors.forEach(e -> ids.add(errorIdFn.apply(e)));
+
+ List filtered = new ArrayList<>(errors.size());
+ for (int i = 0; i < records.size(); i++) {
+ RecT rec = records.get(i);
+ if (ids.contains(recordIdFn.apply(rec))) {
+ filtered.add(rec);
+ if (filtered.size() == errors.size()) {
+ return filtered;
+ }
+ }
+ }
+ return filtered;
+ }
+ };
+ }
+
+ /**
+ * This handler coordinates retries in case of a partial success.
+ *
+ *
+ * - Release permit if all (remaining) records are successful to allow for a new request to
+ * start.
+ *
- Attempt retry in case of partial success for all erroneous records using backoff. Set
+ * async failure once retries are exceeded.
+ *
- Set async failure if the entire request fails. Retries, if configured & applicable, have
+ * already been attempted by the AWS SDK in that case.
+ *
+ *
+ * The next call of {@link #checkForAsyncFailure()}, {@link #batchWrite(String, List< RecT >)}} or
+ * {@link #waitForCompletion()} will check for the last async failure and throw it. Afterwards the
+ * failure state is reset.
+ */
+ private class RetryHandler implements BiConsumer, Throwable> {
+ private final String destination;
+ private final int totalRecords;
+ private final BackOff backoff; // backoff in case of throttling
+
+ private final long handlerStartTime;
+ private long requestStartTime;
+ private int requests;
+
+ private List records;
+
+ RetryHandler(String destination, List records) {
+ this.destination = destination;
+ this.totalRecords = records.size();
+ this.records = records;
+ this.backoff = AsyncBatchWriteHandler.this.backoff.backoff();
+ this.handlerStartTime = DateTimeUtils.currentTimeMillis();
+ this.requestStartTime = 0;
+ this.requests = 0;
+ }
+
+ @SuppressWarnings({"FutureReturnValueIgnored"})
+ void run() {
+ if (!hasErrored.get()) {
+ try {
+ requests++;
+ requestStartTime = DateTimeUtils.currentTimeMillis();
+ submitFn.apply(destination, records).whenComplete(this);
+ } catch (Throwable e) {
+ setAsyncFailure(e);
+ }
+ }
+ }
+
+ @Override
+ public void accept(List results, Throwable throwable) {
+ try {
+ long now = DateTimeUtils.currentTimeMillis();
+ long latencyMillis = now - requestStartTime;
+ synchronized (stats) {
+ stats.addBatchWriteRequest(latencyMillis, requests > 1);
+ }
+ if (results != null && !hasErrored.get()) {
+ if (!hasFailedRecords(results)) {
+ // Request succeeded, release one permit
+ requestPermits.release();
+ LOG.debug(
+ "Done writing {} records [{} ms, {} request(s)]",
+ totalRecords,
+ now - handlerStartTime,
+ requests);
+ } else {
+ try {
+ if (BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
+ LOG.info(summarizeErrors("Attempting retry", results));
+ records = failedRecords(records, results);
+ run();
+ } else {
+ throwable = new IOException(summarizeErrors("Exceeded retries", results));
+ }
+ } catch (Throwable e) {
+ throwable = new IOException(summarizeErrors("Aborted retries", results), e);
+ }
+ }
+ }
+ } catch (Throwable e) {
+ throwable = e;
+ }
+ if (throwable != null) {
+ setAsyncFailure(throwable);
+ }
+ }
+
+ private void setAsyncFailure(Throwable throwable) {
+ LOG.warn("Error when writing batch.", throwable);
+ hasErrored.set(true);
+ asyncFailure.updateAndGet(
+ ex -> {
+ if (ex != null) {
+ throwable.addSuppressed(ex);
+ }
+ return throwable;
+ });
+ requestPermits.release(concurrentRequests); // unblock everything to fail fast
+ }
+
+ private String summarizeErrors(String prefix, List results) {
+ Map countsPerError =
+ results.stream()
+ .map(errorCodeFn)
+ .filter(notNull())
+ .collect(groupingBy(identity(), counting()));
+ return countsPerError.entrySet().stream()
+ .map(kv -> String.format("code %s for %d record(s)", kv.getKey(), kv.getValue()))
+ .collect(joining(", ", prefix + " after partial failure: ", "."));
+ }
+ }
+
+ private static FluentBackoff retryBackoff(int retries, @Nullable RetryConfiguration retry) {
+ FluentBackoff backoff = FluentBackoff.DEFAULT.withMaxRetries(retries);
+ if (retry != null) {
+ if (retry.throttledBaseBackoff() != null) {
+ backoff = backoff.withInitialBackoff(retry.throttledBaseBackoff());
+ }
+ if (retry.maxBackoff() != null) {
+ backoff = backoff.withMaxBackoff(retry.maxBackoff());
+ }
+ }
+ return backoff;
+ }
+}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java
index 9ee8eb277ddc..08fb595bd037 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientConfiguration.java
@@ -52,6 +52,7 @@
@JsonInclude(value = JsonInclude.Include.NON_EMPTY)
@JsonDeserialize(builder = ClientConfiguration.Builder.class)
public abstract class ClientConfiguration implements Serializable {
+ public static final ClientConfiguration EMPTY = ClientConfiguration.builder().build();
/**
* Optional {@link AwsCredentialsProvider}. If set, this overwrites the default in {@link
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java
deleted file mode 100644
index 78439cb1d349..000000000000
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandler.java
+++ /dev/null
@@ -1,271 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.kinesis;
-
-import static java.util.function.Function.identity;
-import static java.util.stream.Collectors.counting;
-import static java.util.stream.Collectors.groupingBy;
-import static java.util.stream.Collectors.joining;
-import static java.util.stream.Collectors.toList;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.BiConsumer;
-import java.util.function.Supplier;
-import javax.annotation.concurrent.NotThreadSafe;
-import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
-import org.apache.beam.sdk.annotations.Internal;
-import org.apache.beam.sdk.util.BackOff;
-import org.apache.beam.sdk.util.BackOffUtils;
-import org.apache.beam.sdk.util.FluentBackoff;
-import org.apache.beam.sdk.util.Sleeper;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
-import org.joda.time.DateTimeUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;
-
-/**
- * Async handler for {@link KinesisAsyncClient#putRecords(PutRecordsRequest)} that automatically
- * retries unprocessed records in case of a partial success.
- *
- * The handler enforces the provided upper limit of concurrent requests. Once that limit is
- * reached any further call to {@link #putRecords(String, List)} will block until another request
- * completed.
- *
- *
The handler is fail fast and won't submit any further request after a failure. Async failures
- * can be polled using {@link #checkForAsyncFailure()}.
- */
-@NotThreadSafe
-@Internal
-class AsyncPutRecordsHandler {
- private static final Logger LOG = LoggerFactory.getLogger(AsyncPutRecordsHandler.class);
- private final KinesisAsyncClient kinesis;
- private final Supplier backoff;
- private final int concurrentRequests;
- private final Stats stats;
-
- private AtomicBoolean hasErrored;
- private AtomicReference asyncFailure;
- private Semaphore pendingRequests;
-
- AsyncPutRecordsHandler(
- KinesisAsyncClient kinesis, int concurrency, Supplier backoff, Stats stats) {
- this.kinesis = kinesis;
- this.backoff = backoff;
- this.concurrentRequests = concurrency;
- this.hasErrored = new AtomicBoolean(false);
- this.asyncFailure = new AtomicReference<>();
- this.pendingRequests = new Semaphore(concurrentRequests);
- this.stats = stats;
- }
-
- AsyncPutRecordsHandler(
- KinesisAsyncClient kinesis, int concurrency, FluentBackoff backoff, Stats stats) {
- this(kinesis, concurrency, () -> backoff.backoff(), stats);
- }
-
- protected int pendingRequests() {
- return concurrentRequests - pendingRequests.availablePermits();
- }
-
- void reset() {
- hasErrored = new AtomicBoolean(false);
- asyncFailure = new AtomicReference<>();
- pendingRequests = new Semaphore(concurrentRequests);
- }
-
- /** If this handler has errored since it was last reset. */
- boolean hasErrored() {
- return hasErrored.get();
- }
-
- /**
- * Check if any failure happened async.
- *
- * @throws Throwable The last async failure, afterwards reset it.
- */
- void checkForAsyncFailure() throws Throwable {
- @SuppressWarnings("nullness")
- Throwable failure = asyncFailure.getAndSet(null);
- if (failure != null) {
- throw failure;
- }
- }
-
- /**
- * Wait for all pending requests to complete and check for failures.
- *
- * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()}
- */
- void waitForCompletion() throws Throwable {
- pendingRequests.acquireUninterruptibly(concurrentRequests);
- checkForAsyncFailure();
- }
-
- /**
- * Asynchronously trigger a put records request.
- *
- * This will respect the concurrency limit of the handler and first wait for a permit.
- *
- * @throws Throwable The last async failure if present using {@link #checkForAsyncFailure()}
- */
- void putRecords(String stream, List records) throws Throwable {
- pendingRequests.acquireUninterruptibly();
- new RetryHandler(stream, records).run();
- checkForAsyncFailure();
- }
-
- interface Stats {
- void addPutRecordsRequest(long latencyMillis, boolean isPartialRetry);
- }
-
- /**
- * This handler coordinates retries in case of a partial success.
- *
- *
- * - Release permit if all (remaining) records are successful to allow for a new request to
- * start.
- *
- Attempt retry in case of partial success for all erroneous records using backoff. Set
- * async failure once retries are exceeded.
- *
- Set async failure if the entire request fails. Retries, if configured & applicable, have
- * already been attempted by the AWS SDK in that case.
- *
- *
- * The next call of {@link #checkForAsyncFailure()}, {@link #putRecords(String, List)} or {@link
- * #waitForCompletion()} will check for the last async failure and throw it. Afterwards the
- * failure state is reset.
- */
- private class RetryHandler implements BiConsumer {
- private final int totalRecords;
- private final String stream;
- private final BackOff backoff; // backoff in case of throttling
-
- private final long handlerStartTime;
- private long requestStartTime;
- private int requests;
-
- private List records;
-
- RetryHandler(String stream, List records) {
- this.stream = stream;
- this.totalRecords = records.size();
- this.records = records;
- this.backoff = AsyncPutRecordsHandler.this.backoff.get();
- this.handlerStartTime = DateTimeUtils.currentTimeMillis();
- this.requestStartTime = 0;
- this.requests = 0;
- }
-
- @SuppressWarnings({"FutureReturnValueIgnored"})
- void run() {
- if (!hasErrored.get()) {
- try {
- requests++;
- requestStartTime = DateTimeUtils.currentTimeMillis();
- PutRecordsRequest request =
- PutRecordsRequest.builder().streamName(stream).records(records).build();
- kinesis.putRecords(request).whenComplete(this);
- } catch (Throwable e) {
- setAsyncFailure(e);
- }
- }
- }
-
- @Override
- public void accept(PutRecordsResponse response, Throwable throwable) {
- try {
- long now = DateTimeUtils.currentTimeMillis();
- long latencyMillis = now - requestStartTime;
- synchronized (stats) {
- stats.addPutRecordsRequest(latencyMillis, requests > 1);
- }
- if (response != null && !hasErrored.get()) {
- if (!hasErrors(response)) {
- // Request succeeded, release one permit
- pendingRequests.release();
- LOG.debug(
- "Done writing {} records [{} ms, {} request(s)]",
- totalRecords,
- now - handlerStartTime,
- requests);
- } else {
- try {
- if (BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
- LOG.info(summarizeErrors("Attempting retry", response));
- records = failedRecords(response);
- run();
- } else {
- throwable = new IOException(summarizeErrors("Exceeded retries", response));
- }
- } catch (Throwable e) {
- throwable = new IOException(summarizeErrors("Aborted retries", response), e);
- }
- }
- }
- } catch (Throwable e) {
- throwable = e;
- }
- if (throwable != null) {
- setAsyncFailure(throwable);
- }
- }
-
- private void setAsyncFailure(Throwable throwable) {
- LOG.warn("Error when writing to Kinesis.", throwable);
- hasErrored.set(true);
- asyncFailure.updateAndGet(
- ex -> {
- if (ex != null) {
- throwable.addSuppressed(ex);
- }
- return throwable;
- });
- pendingRequests.release(concurrentRequests); // unblock everything to fail fast
- }
-
- private boolean hasErrors(PutRecordsResponse response) {
- return response.records().stream().anyMatch(e -> e.errorCode() != null);
- }
-
- private List failedRecords(PutRecordsResponse response) {
- return Streams.zip(records.stream(), response.records().stream(), Pair::of)
- .filter(p -> p.getRight().errorCode() != null)
- .map(p -> p.getLeft())
- .collect(toList());
- }
-
- private String summarizeErrors(String prefix, PutRecordsResponse response) {
- Map countPerError =
- response.records().stream()
- .filter(e -> e.errorCode() != null)
- .map(e -> e.errorCode())
- .collect(groupingBy(identity(), counting()));
- return countPerError.entrySet().stream()
- .map(kv -> String.format("%s for %d record(s)", kv.getKey(), kv.getValue()))
- .collect(joining(", ", prefix + " after failure when writing to Kinesis: ", "."));
- }
- }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
index cf6796198bb4..6338de1ef309 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
@@ -35,6 +35,7 @@
import java.util.Map;
import java.util.NavigableSet;
import java.util.TreeSet;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;
@@ -42,11 +43,11 @@
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.Read.Unbounded;
+import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.common.ObjectPool;
import org.apache.beam.sdk.io.aws2.common.ObjectPool.ClientPool;
-import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.ExplicitPartitioner;
import org.apache.beam.sdk.io.aws2.options.AwsOptions;
import org.apache.beam.sdk.metrics.Counter;
@@ -61,7 +62,6 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Sum;
-import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.MovingFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@@ -84,7 +84,9 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
+import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
+import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
@@ -559,7 +561,7 @@ public Read withCustomRateLimitPolicy(RateLimitPolicyFactory rateLimitPolicyFact
* corresponds to the number of in-flight shard events which itself can contain multiple,
* potentially even aggregated records.
*
- * @see {@link #withConsumerArn(String)}
+ * @see #withConsumerArn(String)
*/
public Read withMaxCapacityPerShard(Integer maxCapacity) {
checkArgument(maxCapacity > 0, "maxCapacity must be positive, but was: %s", maxCapacity);
@@ -901,30 +903,32 @@ private static class Writer implements AutoCloseable {
protected final Write spec;
protected final Stats stats;
- protected final AsyncPutRecordsHandler handler;
+ protected final AsyncBatchWriteHandler handler;
protected final KinesisAsyncClient kinesis;
-
private List requestEntries;
private int requestBytes = 0;
Writer(PipelineOptions options, Write spec) {
ClientConfiguration clientConfig = spec.clientConfiguration();
- RetryConfiguration retryConfig = clientConfig.retry();
- FluentBackoff backoff = FluentBackoff.DEFAULT.withMaxRetries(PARTIAL_RETRIES);
- if (retryConfig != null) {
- if (retryConfig.throttledBaseBackoff() != null) {
- backoff = backoff.withInitialBackoff(retryConfig.throttledBaseBackoff());
- }
- if (retryConfig.maxBackoff() != null) {
- backoff = backoff.withMaxBackoff(retryConfig.maxBackoff());
- }
- }
this.spec = spec;
this.stats = new Stats();
this.kinesis = CLIENTS.retain(options.as(AwsOptions.class), clientConfig);
- this.handler =
- new AsyncPutRecordsHandler(kinesis, spec.concurrentRequests(), backoff, stats);
this.requestEntries = new ArrayList<>();
+ this.handler =
+ AsyncBatchWriteHandler.byPosition(
+ spec.concurrentRequests(),
+ PARTIAL_RETRIES,
+ clientConfig.retry(),
+ stats,
+ (stream, records) -> putRecords(kinesis, stream, records),
+ r -> r.errorCode());
+ }
+
+ private static CompletableFuture> putRecords(
+ KinesisAsyncClient kinesis, String stream, List records) {
+ PutRecordsRequest req =
+ PutRecordsRequest.builder().streamName(stream).records(records).build();
+ return kinesis.putRecords(req).thenApply(resp -> resp.records());
}
public void startBundle() {
@@ -998,7 +1002,7 @@ protected final void asyncFlushEntries() throws Throwable {
List recordsToWrite = requestEntries;
requestEntries = new ArrayList<>();
requestBytes = 0;
- handler.putRecords(spec.streamName(), recordsToWrite);
+ handler.batchWrite(spec.streamName(), recordsToWrite);
}
}
@@ -1115,7 +1119,7 @@ protected void write(String partitionKey, @Nullable String explicitHashKey, byte
}
// only check timeouts sporadically if concurrency is already maxed out
- if (handler.pendingRequests() < spec.concurrentRequests() || Math.random() < 0.05) {
+ if (handler.requestsInProgress() < spec.concurrentRequests() || Math.random() < 0.05) {
checkAggregationTimeouts();
}
}
@@ -1275,7 +1279,7 @@ private BigInteger lowerHashKey(Shard shard) {
}
}
- private static class Stats implements AsyncPutRecordsHandler.Stats {
+ private static class Stats implements AsyncBatchWriteHandler.Stats {
private static final Logger LOG = LoggerFactory.getLogger(Stats.class);
private static final Duration LOG_STATS_PERIOD = Duration.standardSeconds(10);
@@ -1328,7 +1332,7 @@ void addClientRecord(int recordBytes) {
}
@Override
- public void addPutRecordsRequest(long latencyMillis, boolean isPartialRetry) {
+ public void addBatchWriteRequest(long latencyMillis, boolean isPartialRetry) {
long timeMillis = DateTimeUtils.currentTimeMillis();
numPutRequests.add(timeMillis, 1);
if (isPartialRetry) {
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
index 72befc6003fe..ac31738154a6 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
@@ -17,10 +17,26 @@
*/
package org.apache.beam.sdk.io.aws2.sqs;
+import static java.util.Collections.EMPTY_LIST;
+import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import com.google.auto.value.AutoValue;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
import java.util.function.Consumer;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
+import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.Stats;
import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.options.AwsOptions;
@@ -31,11 +47,24 @@
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
+import org.checkerframework.dataflow.qual.Pure;
import org.joda.time.Duration;
+import org.joda.time.Instant;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.SqsClient;
+import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
/**
@@ -91,21 +120,28 @@
* then opt to retry the current partition in entirety or abort if the max number of retries of the
* runner is reached.
*/
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
public class SqsIO {
public static Read read() {
return new AutoValue_SqsIO_Read.Builder()
- .setClientConfiguration(ClientConfiguration.builder().build())
+ .setClientConfiguration(ClientConfiguration.EMPTY)
.setMaxNumRecords(Long.MAX_VALUE)
.build();
}
public static Write write() {
return new AutoValue_SqsIO_Write.Builder()
- .setClientConfiguration(ClientConfiguration.builder().build())
+ .setClientConfiguration(ClientConfiguration.EMPTY)
+ .build();
+ }
+
+ public static WriteBatches writeBatches(WriteBatches.EntryBuilder entryBuilder) {
+ return new AutoValue_SqsIO_WriteBatches.Builder()
+ .clientConfiguration(ClientConfiguration.EMPTY)
+ .concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY)
+ .batchSize(WriteBatches.MAX_BATCH_SIZE)
+ .batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT)
+ .entryBuilder(entryBuilder)
.build();
}
@@ -124,7 +160,7 @@ public abstract static class Read extends PTransform expand(PBegin input) {
AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
ClientBuilderFactory.validate(awsOptions, clientConfiguration());
@@ -188,7 +225,6 @@ public PCollection expand(PBegin input) {
return input.getPipeline().apply(transform);
}
}
- // TODO: Add write batch api to improve performance
/**
* A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
* and configuration.
@@ -196,7 +232,7 @@ public PCollection expand(PBegin input) {
@AutoValue
public abstract static class Write extends PTransform, PDone> {
- abstract ClientConfiguration getClientConfiguration();
+ abstract @Pure ClientConfiguration getClientConfiguration();
abstract Builder builder();
@@ -225,7 +261,7 @@ public PDone expand(PCollection input) {
private static class SqsWriteFn extends DoFn {
private final Write spec;
- private transient SqsClient sqs;
+ private transient @MonotonicNonNull SqsClient sqs = null;
SqsWriteFn(Write write) {
this.spec = write;
@@ -241,7 +277,441 @@ public void setup(PipelineOptions options) throws Exception {
@ProcessElement
public void processElement(ProcessContext processContext) throws Exception {
+ if (sqs == null) {
+ throw new IllegalStateException("No SQS client");
+ }
sqs.sendMessage(processContext.element());
}
}
+
+ /**
+ * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
+ * and configuration.
+ */
+ @AutoValue
+ public abstract static class WriteBatches
+ extends PTransform, WriteBatches.Result> {
+ private static final int DEFAULT_CONCURRENCY = 5;
+ private static final int MAX_BATCH_SIZE = 10;
+ private static final Duration DEFAULT_BATCH_TIMEOUT = Duration.standardSeconds(3);
+
+ abstract @Pure int concurrentRequests();
+
+ abstract @Pure Duration batchTimeout();
+
+ abstract @Pure int batchSize();
+
+ abstract @Pure ClientConfiguration clientConfiguration();
+
+ abstract @Pure EntryBuilder entryBuilder();
+
+ abstract @Pure @Nullable DynamicDestination dynamicDestination();
+
+ abstract @Pure @Nullable String queueUrl();
+
+ abstract Builder builder();
+
+ public interface DynamicDestination extends Serializable {
+ String queueUrl(T message);
+ }
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder concurrentRequests(int concurrentRequests);
+
+ abstract Builder batchTimeout(Duration duration);
+
+ abstract Builder batchSize(int batchSize);
+
+ abstract Builder clientConfiguration(ClientConfiguration config);
+
+ abstract Builder entryBuilder(EntryBuilder entryBuilder);
+
+ abstract Builder dynamicDestination(@Nullable DynamicDestination destination);
+
+ abstract Builder queueUrl(@Nullable String queueUrl);
+
+ abstract WriteBatches build();
+ }
+
+ /** Configuration of SQS client. */
+ public WriteBatches withClientConfiguration(ClientConfiguration config) {
+ checkArgument(config != null, "ClientConfiguration cannot be null");
+ return builder().clientConfiguration(config).build();
+ }
+
+ /** Max number of concurrent batch write requests per bundle, default is {@code 5}. */
+ public WriteBatches withConcurrentRequests(int concurrentRequests) {
+ checkArgument(concurrentRequests > 0, "concurrentRequests must be > 0");
+ return builder().concurrentRequests(concurrentRequests).build();
+ }
+
+ /** The batch size to use, default (and AWS limit) is {@code 10}. */
+ public WriteBatches withBatchSize(int batchSize) {
+ checkArgument(
+ batchSize > 0 && batchSize <= MAX_BATCH_SIZE,
+ "Maximum allowed batch size is " + MAX_BATCH_SIZE);
+ return builder().batchSize(batchSize).build();
+ }
+
+ /**
+ * The duration to accumulate records before timing out, default is 3 secs.
+ *
+ * Timeouts will be checked upon arrival of new messages.
+ */
+ public WriteBatches withBatchTimeout(Duration timeout) {
+ return builder().batchTimeout(timeout).build();
+ }
+
+ /** Dynamic record based destination to write to. */
+ public WriteBatches to(DynamicDestination destination) {
+ checkArgument(destination != null, "DynamicDestination cannot be null");
+ return builder().queueUrl(null).dynamicDestination(destination).build();
+ }
+
+ /** Queue url to write to. */
+ public WriteBatches to(String queueUrl) {
+ checkArgument(queueUrl != null, "queueUrl cannot be null");
+ return builder().dynamicDestination(null).queueUrl(queueUrl).build();
+ }
+
+ @Override
+ public Result expand(PCollection input) {
+ AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
+ ClientBuilderFactory.validate(awsOptions, clientConfiguration());
+
+ input.apply(
+ ParDo.of(
+ new DoFn() {
+ private @Nullable BatchHandler handler = null;
+
+ @Setup
+ public void setup(PipelineOptions options) {
+ handler = new BatchHandler<>(WriteBatches.this, options.as(AwsOptions.class));
+ }
+
+ @StartBundle
+ public void startBundle() {
+ handler().startBundle();
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext cxt) throws Throwable {
+ handler().process(cxt.element());
+ }
+
+ @FinishBundle
+ public void finishBundle() throws Throwable {
+ handler().finishBundle();
+ }
+
+ @Teardown
+ public void teardown() throws Exception {
+ if (handler != null) {
+ handler.close();
+ handler = null;
+ }
+ }
+
+ private BatchHandler handler() {
+ return checkStateNotNull(handler, "SQS handler is null");
+ }
+ }));
+ return new Result(input.getPipeline());
+ }
+
+ /** Batch entry builder. */
+ public interface EntryBuilder
+ extends BiConsumer, Serializable {}
+
+ /** Result of {@link #writeBatches}. */
+ public static class Result implements POutput {
+ private final Pipeline pipeline;
+
+ private Result(Pipeline pipeline) {
+ this.pipeline = pipeline;
+ }
+
+ @Override
+ public Pipeline getPipeline() {
+ return pipeline;
+ }
+
+ @Override
+ public Map, PValue> expand() {
+ return ImmutableMap.of();
+ }
+
+ @Override
+ public void finishSpecifyingOutput(
+ String transformName, PInput input, PTransform, ?> transform) {}
+ }
+
+ private static class BatchHandler implements AutoCloseable {
+ private final WriteBatches spec;
+ private final SqsAsyncClient sqs;
+ private final Batches batches;
+ private final AsyncBatchWriteHandler
+ handler;
+
+ BatchHandler(WriteBatches spec, AwsOptions options) {
+ this.spec = spec;
+ this.sqs = buildClient(options, SqsAsyncClient.builder(), spec.clientConfiguration());
+ this.handler =
+ AsyncBatchWriteHandler.byId(
+ spec.concurrentRequests(),
+ spec.batchSize(),
+ spec.clientConfiguration().retry(),
+ Stats.NONE,
+ (queue, records) -> sendMessageBatch(sqs, queue, records),
+ error -> error.code(),
+ record -> record.id(),
+ error -> error.id());
+ if (spec.queueUrl() != null) {
+ this.batches = new Single(spec.queueUrl());
+ } else if (spec.dynamicDestination() != null) {
+ this.batches = new Dynamic(spec.dynamicDestination());
+ } else {
+ throw new IllegalStateException("to(queueUrl) or to(dynamicDestination) is required");
+ }
+ }
+
+ private static CompletableFuture> sendMessageBatch(
+ SqsAsyncClient sqs, String queue, List records) {
+ SendMessageBatchRequest request =
+ SendMessageBatchRequest.builder().queueUrl(queue).entries(records).build();
+ return sqs.sendMessageBatch(request).thenApply(resp -> resp.failed());
+ }
+
+ public void startBundle() {
+ handler.reset();
+ }
+
+ public void process(T msg) {
+ SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder();
+ spec.entryBuilder().accept(builder, msg);
+ SendMessageBatchRequestEntry entry = builder.id(batches.nextId()).build();
+
+ Batch batch = batches.getLocked(msg);
+ batch.add(entry);
+ if (batch.size() >= spec.batchSize() || batch.isExpired()) {
+ writeEntries(batch, true);
+ } else {
+ checkState(batch.lock(false)); // unlock to continue writing to batch
+ }
+
+ // check timeouts synchronously on arrival of new messages
+ batches.writeExpired(true);
+ }
+
+ private void writeEntries(Batch batch, boolean throwPendingFailures) {
+ try {
+ handler.batchWrite(batch.queue, batch.getAndClear(), throwPendingFailures);
+ } catch (RuntimeException e) {
+ throw e;
+ } catch (Throwable e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void finishBundle() throws Throwable {
+ batches.writeAll();
+ handler.waitForCompletion();
+ }
+
+ @Override
+ public void close() throws Exception {
+ sqs.close();
+ }
+
+ /**
+ * Batch(es) of a single fixed or several dynamic queues.
+ *
+ * {@link #getLocked} is meant to support atomic writes from multiple threads if using an
+ * appropriate thread-safe implementation. This is necessary to later support strict timeouts
+ * (see below).
+ *
+ *
For simplicity, check for expired messages after appending to a batch. For strict
+ * enforcement of timeouts, {@link #writeExpired} would have to be periodically called using a
+ * scheduler and requires also a thread-safe impl of {@link Batch#lock(boolean)}.
+ */
+ private abstract class Batches {
+ private int nextId = 0; // only ever used from one "runner" thread
+
+ abstract int maxBatches();
+
+ /** Next batch entry id is guaranteed to be unique for all open batches. */
+ String nextId() {
+ if (nextId >= (spec.batchSize() * maxBatches())) {
+ nextId = 0;
+ }
+ return Integer.toString(nextId++);
+ }
+
+ /** Get existing or new locked batch that can be written to. */
+ abstract Batch getLocked(T record);
+
+ /** Write all remaining batches (that can be locked). */
+ abstract void writeAll();
+
+ /** Write all expired batches (that can be locked). */
+ abstract void writeExpired(boolean throwPendingFailures);
+
+ /** Create a new locked batch that is ready for writing. */
+ Batch createLocked(String queue) {
+ return new Batch(queue, spec.batchSize(), spec.batchTimeout());
+ }
+
+ /** Write a batch if it can be locked. */
+ protected boolean writeLocked(Batch batch, boolean throwPendingFailures) {
+ if (batch.lock(true)) {
+ writeEntries(batch, throwPendingFailures);
+ return true;
+ }
+ return false;
+ }
+ }
+
+ /** Batch of a single, fixed queue. */
+ @NotThreadSafe
+ private class Single extends Batches {
+ private Batch batch;
+
+ Single(String queue) {
+ this.batch = new Batch(queue, EMPTY_LIST, Batch.NEVER); // locked
+ }
+
+ @Override
+ int maxBatches() {
+ return 1;
+ }
+
+ @Override
+ Batch getLocked(T record) {
+ return batch.lock(true) ? batch : (batch = createLocked(batch.queue));
+ }
+
+ @Override
+ void writeAll() {
+ writeLocked(batch, true);
+ }
+
+ @Override
+ void writeExpired(boolean throwPendingFailures) {
+ if (batch.isExpired()) {
+ writeLocked(batch, throwPendingFailures);
+ }
+ }
+ }
+
+ /** Batches of one or several dynamic queues. */
+ @NotThreadSafe
+ private class Dynamic extends Batches {
+ @SuppressWarnings("method.invocation") // necessary dependencies are initialized
+ private final BiFunction<@NonNull String, @Nullable Batch, Batch> getLocked =
+ (queue, batch) -> batch != null && batch.lock(true) ? batch : createLocked(queue);
+
+ private final Map<@NonNull String, Batch> batches = new HashMap<>();
+ private final DynamicDestination destination;
+ private Instant nextTimeout = Batch.NEVER;
+
+ Dynamic(DynamicDestination destination) {
+ this.destination = destination;
+ }
+
+ @Override
+ int maxBatches() {
+ return batches.size() + 1; // next record and id might belong to new batch
+ }
+
+ @Override
+ Batch getLocked(T record) {
+ return batches.compute(destination.queueUrl(record), getLocked);
+ }
+
+ @Override
+ void writeAll() {
+ batches.values().forEach(batch -> writeLocked(batch, true));
+ batches.clear();
+ nextTimeout = Batch.NEVER;
+ }
+
+ private void writeExpired(Batch batch) {
+ if (!batch.isExpired() || !writeLocked(batch, true)) {
+ // find next timeout for remaining, unwritten batches
+ if (batch.timeout.isBefore(nextTimeout)) {
+ nextTimeout = batch.timeout;
+ }
+ }
+ }
+
+ @Override
+ void writeExpired(boolean throwPendingFailures) {
+ if (nextTimeout.isBeforeNow()) {
+ nextTimeout = Batch.NEVER;
+ batches.values().forEach(this::writeExpired);
+ }
+ }
+
+ @Override
+ Batch createLocked(String queue) {
+ Batch batch = super.createLocked(queue);
+ if (batch.timeout.isBefore(nextTimeout)) {
+ nextTimeout = batch.timeout;
+ }
+ return batch;
+ }
+ }
+ }
+
+ /**
+ * Batch of entries of a queue.
+ *
+ * Overwrite {@link #lock} with a thread-safe implementation to support concurrent usage.
+ */
+ @NotThreadSafe
+ private static final class Batch {
+ private static final Instant NEVER = Instant.ofEpochMilli(Long.MAX_VALUE);
+ private final String queue;
+ private Instant timeout;
+ private List entries;
+
+ Batch(String queue, int size, Duration bufferedTime) {
+ this(queue, new ArrayList<>(size), Instant.now().plus(bufferedTime));
+ }
+
+ Batch(String queue, List entries, Instant timeout) {
+ this.queue = queue;
+ this.entries = entries;
+ this.timeout = timeout;
+ }
+
+ /** Attempt to un/lock this batch and return if successful. */
+ boolean lock(boolean lock) {
+ // thread unsafe dummy impl that rejects locking batches after getAndClear
+ return !NEVER.equals(timeout) || !lock;
+ }
+
+ /** Get and clear entries for writing. */
+ List getAndClear() {
+ List res = entries;
+ entries = EMPTY_LIST;
+ timeout = NEVER;
+ return res;
+ }
+
+ /** Add entry to this batch. */
+ void add(SendMessageBatchRequestEntry entry) {
+ entries.add(entry);
+ }
+
+ int size() {
+ return entries.size();
+ }
+
+ boolean isExpired() {
+ return timeout.isBeforeNow();
+ }
+ }
+ }
}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java
new file mode 100644
index 000000000000..339fa673ed17
--- /dev/null
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/AsyncBatchWriteHandlerTest.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws2.common;
+
+import static java.util.Collections.emptyList;
+import static java.util.concurrent.ForkJoinPool.commonPool;
+import static java.util.function.Function.identity;
+import static org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.byId;
+import static org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.byPosition;
+import static org.apache.beam.sdk.util.FluentBackoff.DEFAULT;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.joda.time.Duration.millis;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler.Stats;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
+import org.mockito.Mockito;
+
+public class AsyncBatchWriteHandlerTest {
+ private static final int CONCURRENCY = 10;
+
+ private CompletableFuture> resultsByPos = new CompletableFuture<>();
+ private CompletableFuture> errorsById = new CompletableFuture<>();
+
+ private AsyncBatchWriteHandler byPositionHandler(FluentBackoff backoff) {
+ SubmitFn submitFn = Mockito.spy(new SubmitFn<>(() -> resultsByPos));
+ Function errorFn = success -> success ? null : "REASON";
+ return byPosition(CONCURRENCY, backoff, Stats.NONE, submitFn, errorFn);
+ }
+
+ private AsyncBatchWriteHandler byIdHandler(FluentBackoff backoff) {
+ SubmitFn submitFn = Mockito.spy(new SubmitFn<>(() -> errorsById));
+ Function errorFn = err -> "REASON";
+ return byId(CONCURRENCY, backoff, Stats.NONE, submitFn, errorFn, identity(), identity());
+ }
+
+ @Test
+ public void retryOnPartialSuccessByPosition() throws Throwable {
+ AsyncBatchWriteHandler handler =
+ byPositionHandler(DEFAULT.withMaxBackoff(millis(1)));
+ CompletableFuture> pendingResponse1 = new CompletableFuture<>();
+ CompletableFuture> pendingResponse2 = new CompletableFuture<>();
+ CompletableFuture> pendingResponse3 = new CompletableFuture<>();
+
+ resultsByPos = pendingResponse1;
+ handler.batchWrite("destination", ImmutableList.of(1, 2, 3, 4));
+
+ // 1st attempt
+ eventually(5, () -> verify(handler.submitFn, times(1)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of(1, 2, 3, 4));
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ resultsByPos = pendingResponse2;
+ pendingResponse1.complete(ImmutableList.of(true, true, false, false));
+
+ // 2nd attempt
+ eventually(5, () -> verify(handler.submitFn, times(2)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of(3, 4));
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ // 3rd attempt
+ resultsByPos = pendingResponse3;
+ pendingResponse2.complete(ImmutableList.of(true, false));
+
+ eventually(5, () -> verify(handler.submitFn, times(3)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of(4));
+
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ // 4th attempt
+ pendingResponse3.complete(ImmutableList.of(true)); // success
+
+ eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(0));
+ verify(handler.submitFn, times(3)).apply(anyString(), anyList());
+ }
+
+ @Test
+ public void retryOnPartialSuccessById() throws Throwable {
+ AsyncBatchWriteHandler handler = byIdHandler(DEFAULT.withMaxBackoff(millis(1)));
+ CompletableFuture> pendingResponse1 = new CompletableFuture<>();
+ CompletableFuture> pendingResponse2 = new CompletableFuture<>();
+ CompletableFuture> pendingResponse3 = new CompletableFuture<>();
+
+ errorsById = pendingResponse1;
+ handler.batchWrite("destination", ImmutableList.of("1", "2", "3", "4"));
+
+ // 1st attempt
+ eventually(5, () -> verify(handler.submitFn, times(1)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of("1", "2", "3", "4"));
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ errorsById = pendingResponse2;
+ pendingResponse1.complete(ImmutableList.of("3", "4"));
+
+ // 2nd attempt
+ eventually(5, () -> verify(handler.submitFn, times(2)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of("3", "4"));
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ // 3rd attempt
+ errorsById = pendingResponse3;
+ pendingResponse2.complete(ImmutableList.of("4"));
+
+ eventually(5, () -> verify(handler.submitFn, times(3)).apply(anyString(), anyList()));
+ verify(handler.submitFn).apply("destination", ImmutableList.of("4"));
+
+ assertThat(handler.requestsInProgress()).isEqualTo(1);
+
+ // 4th attempt
+ pendingResponse3.complete(ImmutableList.of()); // success
+
+ eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(0));
+ verify(handler.submitFn, times(3)).apply(anyString(), anyList());
+ }
+
+ @Test
+ public void retryLimitOnPartialSuccessByPosition() throws Throwable {
+ AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT.withMaxRetries(0));
+
+ handler.batchWrite("destination", ImmutableList.of(1, 2, 3, 4));
+
+ resultsByPos.complete(ImmutableList.of(true, true, false, false));
+
+ assertThatThrownBy(() -> handler.waitForCompletion())
+ .hasMessageContaining("Exceeded retries")
+ .hasMessageEndingWith("REASON for 2 record(s).")
+ .isInstanceOf(IOException.class);
+ verify(handler.submitFn).apply("destination", ImmutableList.of(1, 2, 3, 4));
+ }
+
+ @Test
+ public void retryLimitOnPartialSuccessById() throws Throwable {
+ AsyncBatchWriteHandler handler = byIdHandler(DEFAULT.withMaxRetries(0));
+
+ handler.batchWrite("destination", ImmutableList.of("1", "2", "3", "4"));
+
+ errorsById.complete(ImmutableList.of("3", "4"));
+
+ assertThatThrownBy(() -> handler.waitForCompletion())
+ .hasMessageContaining("Exceeded retries")
+ .hasMessageEndingWith("REASON for 2 record(s).")
+ .isInstanceOf(IOException.class);
+ verify(handler.submitFn).apply("destination", ImmutableList.of("1", "2", "3", "4"));
+ }
+
+ @Test
+ public void propagateErrorOnPutRecords() throws Throwable {
+ AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT);
+ handler.batchWrite("destination", emptyList());
+ resultsByPos.completeExceptionally(new RuntimeException("Request failed"));
+
+ assertThatThrownBy(() -> handler.batchWrite("destination", emptyList()))
+ .hasMessage("Request failed");
+ assertThat(handler.hasErrored()).isTrue();
+ verify(handler.submitFn).apply("destination", emptyList());
+ }
+
+ @Test
+ public void propagateErrorWhenPolling() throws Throwable {
+ AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT);
+ handler.batchWrite("destination", emptyList());
+ handler.checkForAsyncFailure(); // none yet
+ resultsByPos.completeExceptionally(new RuntimeException("Request failed"));
+
+ assertThatThrownBy(() -> handler.checkForAsyncFailure()).hasMessage("Request failed");
+ assertThat(handler.hasErrored()).isTrue();
+ handler.checkForAsyncFailure(); // already reset
+ }
+
+ @Test
+ public void propagateErrorOnWaitForCompletion() throws Throwable {
+ AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT);
+ handler.batchWrite("destination", emptyList());
+ resultsByPos.completeExceptionally(new RuntimeException("Request failed"));
+
+ assertThatThrownBy(() -> handler.waitForCompletion()).hasMessage("Request failed");
+ }
+
+ @Test
+ public void correctlyLimitConcurrency() throws Throwable {
+ AsyncBatchWriteHandler handler = byPositionHandler(DEFAULT);
+
+ // exhaust concurrency limit so that batchWrite blocks
+ Runnable task = repeat(CONCURRENCY + 1, () -> handler.batchWrite("destination", emptyList()));
+ Future> future = commonPool().submit(task);
+
+ eventually(5, () -> assertThat(handler.requestsInProgress()).isEqualTo(CONCURRENCY));
+ eventually(
+ 5, () -> verify(handler.submitFn, times(CONCURRENCY)).apply("destination", emptyList()));
+ assertThat(future).isNotDone();
+
+ // complete responses and unblock last request
+ resultsByPos.complete(emptyList());
+
+ eventually(
+ 5,
+ () -> verify(handler.submitFn, times(CONCURRENCY + 1)).apply("destination", emptyList()));
+ handler.waitForCompletion();
+ assertThat(future).isDone();
+ }
+
+ static class SubmitFn implements BiFunction, CompletableFuture>> {
+ private final Supplier>> resp;
+
+ SubmitFn(Supplier>> resp) {
+ this.resp = resp;
+ }
+
+ @Override
+ public CompletableFuture> apply(String destination, List input) {
+ return resp.get();
+ }
+ }
+
+ private void eventually(int attempts, Runnable fun) {
+ for (int i = 0; i < attempts - 1; i++) {
+ try {
+ Thread.sleep(i * 100);
+ fun.run();
+ return;
+ } catch (AssertionError | InterruptedException t) {
+ }
+ }
+ fun.run();
+ }
+
+ private Runnable repeat(int times, ThrowingRunnable fun) {
+ return () -> {
+ for (int i = 0; i < times; i++) {
+ try {
+ fun.run();
+ } catch (Throwable t) {
+ throw new RuntimeException(t);
+ }
+ }
+ };
+ }
+}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java
deleted file mode 100644
index 6f0b92654de4..000000000000
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/AsyncPutRecordsHandlerTest.java
+++ /dev/null
@@ -1,192 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.kinesis;
-
-import static java.util.Collections.emptyList;
-import static java.util.concurrent.ForkJoinPool.commonPool;
-import static org.apache.beam.sdk.io.common.TestRow.getExpectedValues;
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.when;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Future;
-import java.util.function.Supplier;
-import org.apache.beam.sdk.util.BackOff;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.function.ThrowingRunnable;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
-import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;
-
-@RunWith(MockitoJUnitRunner.StrictStubs.class)
-public class AsyncPutRecordsHandlerTest extends PutRecordsHelpers {
- private static final String STREAM = "streamName";
- private static final int CONCURRENCY = 10;
-
- private CompletableFuture pendingResponse = new CompletableFuture<>();
-
- @Mock private KinesisAsyncClient client;
- @Mock private Supplier backoff;
- private AsyncPutRecordsHandler handler;
-
- @Before
- public void init() {
- handler =
- new AsyncPutRecordsHandler(
- client, CONCURRENCY, backoff, mock(AsyncPutRecordsHandler.Stats.class));
- when(client.putRecords(anyRequest())).thenReturn(pendingResponse);
- }
-
- @Test
- public void retryOnPartialSuccess() throws Throwable {
- when(backoff.get()).thenReturn(BackOff.ZERO_BACKOFF);
- CompletableFuture pendingResponse2 = new CompletableFuture<>();
- CompletableFuture pendingResponse3 = new CompletableFuture<>();
- when(client.putRecords(anyRequest()))
- .thenReturn(pendingResponse, pendingResponse2, pendingResponse3);
-
- List records = fromTestRows(getExpectedValues(0, 100));
- handler.putRecords(STREAM, records);
-
- // 1st attempt
- eventually(5, () -> verify(client, times(1)).putRecords(anyRequest()));
- verify(client).putRecords(request(records));
- assertThat(handler.pendingRequests()).isEqualTo(1);
-
- // 2nd attempt
- pendingResponse.complete(partialSuccessResponse(50, 50));
- eventually(5, () -> verify(client, times(2)).putRecords(anyRequest()));
- verify(client).putRecords(request(records.subList(50, 100)));
- assertThat(handler.pendingRequests()).isEqualTo(1);
-
- // 3rd attempt
- pendingResponse2.complete(partialSuccessResponse(25, 25));
- eventually(5, () -> verify(client, times(3)).putRecords(anyRequest()));
- verify(client).putRecords(request(records.subList(75, 100)));
- assertThat(handler.pendingRequests()).isEqualTo(1);
-
- // 4th attempt
- pendingResponse3.complete(PutRecordsResponse.builder().build()); // success
- verifyNoMoreInteractions(client);
-
- eventually(5, () -> assertThat(handler.pendingRequests()).isEqualTo(0));
- }
-
- @Test
- public void retryLimitOnPartialSuccess() throws Throwable {
- when(backoff.get()).thenReturn(BackOff.STOP_BACKOFF);
-
- List records = fromTestRows(getExpectedValues(0, 100));
- handler.putRecords(STREAM, records);
-
- pendingResponse.complete(partialSuccessResponse(98, 2));
-
- assertThatThrownBy(() -> handler.waitForCompletion())
- .hasMessageContaining("Exceeded retries")
- .hasMessageEndingWith(ERROR_CODE + " for 2 record(s).")
- .isInstanceOf(IOException.class);
- verify(client).putRecords(anyRequest());
- }
-
- @Test
- public void propagateErrorOnPutRecords() throws Throwable {
- handler.putRecords(STREAM, emptyList());
- pendingResponse.completeExceptionally(new RuntimeException("Request failed"));
-
- assertThatThrownBy(() -> handler.putRecords(STREAM, emptyList())).hasMessage("Request failed");
- assertThat(handler.hasErrored()).isTrue();
- verify(client).putRecords(anyRequest());
- }
-
- @Test
- public void propagateErrorWhenPolling() throws Throwable {
- handler.putRecords(STREAM, emptyList());
- handler.checkForAsyncFailure(); // none yet
- pendingResponse.completeExceptionally(new RuntimeException("Request failed"));
-
- assertThatThrownBy(() -> handler.checkForAsyncFailure()).hasMessage("Request failed");
- assertThat(handler.hasErrored()).isTrue();
- handler.checkForAsyncFailure(); // already reset
- }
-
- @Test
- public void propagateErrorOnWaitForCompletion() throws Throwable {
- handler.putRecords(STREAM, emptyList());
- pendingResponse.completeExceptionally(new RuntimeException("Request failed"));
-
- assertThatThrownBy(() -> handler.waitForCompletion()).hasMessage("Request failed");
- }
-
- @Test
- public void correctlyLimitConcurrency() throws Throwable {
- // exhaust concurrency limit so that putRecords blocks
- Runnable task = repeat(CONCURRENCY + 1, () -> handler.putRecords(STREAM, emptyList()));
- Future> future = commonPool().submit(task);
-
- eventually(5, () -> assertThat(handler.pendingRequests()).isEqualTo(CONCURRENCY));
- eventually(5, () -> verify(client, times(CONCURRENCY)).putRecords(anyRequest()));
- assertThat(future).isNotDone();
-
- // complete responses and unblock last request
- pendingResponse.complete(PutRecordsResponse.builder().build());
-
- eventually(5, () -> verify(client, times(CONCURRENCY + 1)).putRecords(anyRequest()));
- handler.waitForCompletion();
- assertThat(future).isDone();
- }
-
- private PutRecordsRequest request(List records) {
- return PutRecordsRequest.builder().streamName(STREAM).records(records).build();
- }
-
- private void eventually(int attempts, Runnable fun) {
- for (int i = 0; i < attempts - 1; i++) {
- try {
- Thread.sleep(i * 100);
- fun.run();
- return;
- } catch (AssertionError | InterruptedException t) {
- }
- }
- fun.run();
- }
-
- private Runnable repeat(int times, ThrowingRunnable fun) {
- return () -> {
- for (int i = 0; i < times; i++) {
- try {
- fun.run();
- } catch (Throwable t) {
- throw new RuntimeException(t);
- }
- }
- };
- }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
new file mode 100644
index 000000000000..aeb9122df9f2
--- /dev/null
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws2.sqs;
+
+import static java.util.concurrent.CompletableFuture.completedFuture;
+import static java.util.concurrent.CompletableFuture.supplyAsync;
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.IntStream.range;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.joda.time.Duration.millis;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
+import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
+import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
+import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
+import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches;
+import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryBuilder;
+import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
+import org.joda.time.Duration;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import software.amazon.awssdk.services.sqs.SqsAsyncClient;
+import software.amazon.awssdk.services.sqs.SqsAsyncClientBuilder;
+import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+
+/** Tests for {@link WriteBatches}. */
+@RunWith(MockitoJUnitRunner.class)
+public class SqsIOWriteBatchesTest {
+ private static final EntryBuilder SET_MESSAGE_BODY =
+ SendMessageBatchRequestEntry.Builder::messageBody;
+ private static final SendMessageBatchResponse SUCCESS =
+ SendMessageBatchResponse.builder().build();
+
+ @Rule public TestPipeline p = TestPipeline.create();
+ @Mock public SqsAsyncClient sqs;
+ @Rule public ExpectedLogs logs = ExpectedLogs.none(AsyncBatchWriteHandler.class);
+
+ @Before
+ public void configureClientBuilderFactory() {
+ MockClientBuilderFactory.set(p, SqsAsyncClientBuilder.class, sqs);
+ }
+
+ @Test
+ public void testWriteBatches() {
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ p.apply(Create.of(23))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+
+ p.run().waitUntilFinish();
+
+ verify(sqs).sendMessageBatch(request("queue", range(0, 10)));
+ verify(sqs).sendMessageBatch(request("queue", range(10, 20)));
+ verify(sqs).sendMessageBatch(request("queue", range(20, 23)));
+
+ verify(sqs).close();
+ verifyNoMoreInteractions(sqs);
+ }
+
+ @Test
+ public void testWriteBatchesFailure() {
+ when(sqs.sendMessageBatch(anyRequest()))
+ .thenReturn(
+ completedFuture(SUCCESS),
+ supplyAsync(() -> checkNotNull(null, "sendMessageBatch failed")),
+ completedFuture(SUCCESS));
+
+ p.apply(Create.of(23))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+
+ assertThatThrownBy(() -> p.run().waitUntilFinish())
+ .isInstanceOf(Pipeline.PipelineExecutionException.class)
+ .hasMessageContaining("sendMessageBatch failed");
+ }
+
+ @Test
+ public void testWriteBatchesPartialSuccess() {
+ SendMessageBatchRequestEntry[] entries = entries(range(0, 10));
+ when(sqs.sendMessageBatch(anyRequest()))
+ .thenReturn(
+ completedFuture(partialSuccessResponse(entries[2].id(), entries[3].id())),
+ completedFuture(partialSuccessResponse(entries[3].id())),
+ completedFuture(SUCCESS));
+
+ p.apply(Create.of(23))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+
+ p.run().waitUntilFinish();
+
+ verify(sqs).sendMessageBatch(request("queue", entries));
+ verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3]));
+ verify(sqs).sendMessageBatch(request("queue", entries[3]));
+ verify(sqs).sendMessageBatch(request("queue", range(10, 20)));
+ verify(sqs).sendMessageBatch(request("queue", range(20, 23)));
+
+ verify(sqs).close();
+ verifyNoMoreInteractions(sqs);
+
+ logs.verifyInfo("retry after partial failure: code REASON for 2 record(s)");
+ logs.verifyInfo("retry after partial failure: code REASON for 1 record(s)");
+ }
+
+ @Test
+ public void testWriteCustomBatches() {
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ p.apply(Create.of(8))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).withBatchSize(3).to("queue"));
+
+ p.run().waitUntilFinish();
+
+ verify(sqs).sendMessageBatch(request("queue", range(0, 3)));
+ verify(sqs).sendMessageBatch(request("queue", range(3, 6)));
+ verify(sqs).sendMessageBatch(request("queue", range(6, 8)));
+
+ verify(sqs).close();
+ verifyNoMoreInteractions(sqs);
+ }
+
+ @Test
+ public void testWriteBatchesToDynamic() {
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ // minimize delay due to retries
+ RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(millis(1)).build();
+
+ p.apply(Create.of(10))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(
+ SqsIO.writeBatches(SET_MESSAGE_BODY)
+ .withClientConfiguration(ClientConfiguration.builder().retry(retry).build())
+ .withBatchSize(3)
+ .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
+
+ p.run().waitUntilFinish();
+
+ // id generator creates ids in range of [0, batch size * (queues + 1))
+ SendMessageBatchRequestEntry[] entries = entries(range(0, 9), range(9, 10));
+
+ verify(sqs).sendMessageBatch(request("even", entries[0], entries[2], entries[4]));
+ verify(sqs).sendMessageBatch(request("uneven", entries[1], entries[3], entries[5]));
+ verify(sqs).sendMessageBatch(request("even", entries[6], entries[8]));
+ verify(sqs).sendMessageBatch(request("uneven", entries[7], entries[9]));
+
+ verify(sqs).close();
+ verifyNoMoreInteractions(sqs);
+ }
+
+ @Test
+ public void testWriteBatchesWithTimeout() {
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ p.apply(Create.of(5))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(
+ // simulate delay between messages > batch timeout
+ SqsIO.writeBatches(withDelay(millis(200), SET_MESSAGE_BODY))
+ .withBatchTimeout(millis(100))
+ .to("queue"));
+
+ p.run().waitUntilFinish();
+
+ SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
+ // due to added delay, batches are timed out on arrival of every 2nd msg
+ verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1]));
+ verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3]));
+ verify(sqs).sendMessageBatch(request("queue", entries[4]));
+ }
+
+ private SendMessageBatchRequest anyRequest() {
+ return any();
+ }
+
+ private SendMessageBatchRequest request(String queue, SendMessageBatchRequestEntry... entries) {
+ return SendMessageBatchRequest.builder()
+ .queueUrl(queue)
+ .entries(Arrays.asList(entries))
+ .build();
+ }
+
+ private SendMessageBatchRequest request(String queue, IntStream msgs) {
+ return request(queue, entries(msgs));
+ }
+
+ private SendMessageBatchRequestEntry[] entries(IntStream... msgStreams) {
+ return Arrays.stream(msgStreams)
+ .flatMap(msgs -> Streams.mapWithIndex(msgs, this::entry))
+ .toArray(SendMessageBatchRequestEntry[]::new);
+ }
+
+ private SendMessageBatchRequestEntry entry(int msg, long id) {
+ return SendMessageBatchRequestEntry.builder()
+ .id(Long.toString(id))
+ .messageBody(Integer.toString(msg))
+ .build();
+ }
+
+ private SendMessageBatchResponse partialSuccessResponse(String... failedIds) {
+ Stream errors =
+ Arrays.stream(failedIds)
+ .map(BatchResultErrorEntry.builder()::id)
+ .map(b -> b.code("REASON").build());
+ return SendMessageBatchResponse.builder().failed(errors.collect(toList())).build();
+ }
+
+ private static class CreateMessages extends DoFn {
+ @ProcessElement
+ public void processElement(@Element Integer count, OutputReceiver out) {
+ for (int i = 0; i < count; i++) {
+ out.output(Integer.toString(i));
+ }
+ }
+ }
+
+ private static EntryBuilder withDelay(Duration delay, EntryBuilder builder) {
+ return (t1, t2) -> {
+ builder.accept(t1, t2);
+ try {
+ Thread.sleep(delay.getMillis());
+ } catch (InterruptedException e) {
+ }
+ };
+ }
+}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
index f1e176f572d6..2f10f3d08f1d 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
@@ -19,6 +19,7 @@
import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount;
import static org.apache.beam.sdk.values.TypeDescriptors.strings;
+import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS;
import java.io.Serializable;
@@ -103,6 +104,33 @@ public void testWriteThenRead() {
pipelineRead.run();
}
+ @Test
+ public void testWriteBatchesThenRead() {
+ int rows = env.options().getNumberOfRows();
+
+ // Write test dataset to SQS.
+ pipelineWrite
+ .apply("Generate Sequence", GenerateSequence.from(0).to(rows))
+ .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn()))
+ .apply(
+ "Write to SQS",
+ SqsIO.writeBatches((b, row) -> b.messageBody(row.name())).to(sqsQueue.url));
+
+ // Read test dataset from SQS.
+ PCollection output =
+ pipelineRead
+ .apply("Read from SQS", SqsIO.read().withQueueUrl(sqsQueue.url).withMaxNumRecords(rows))
+ .apply("Extract body", MapElements.into(strings()).via(SqsMessage::getBody));
+
+ PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows);
+
+ PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults()))
+ .containsInAnyOrder(getExpectedHashForRowCount(rows));
+
+ pipelineWrite.run();
+ pipelineRead.run();
+ }
+
private static class SqsQueue extends ExternalResource implements Serializable {
private transient SqsClient client = env.buildClient(SqsClient.builder());
private String url;
@@ -113,7 +141,8 @@ SendMessageRequest messageRequest(TestRow r) {
@Override
protected void before() throws Throwable {
- url = client.createQueue(b -> b.queueName("beam-sqsio-it")).queueUrl();
+ url =
+ client.createQueue(b -> b.queueName("beam-sqsio-it-" + randomAlphanumeric(4))).queueUrl();
}
@Override