diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java b/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java index 5ca68c4713b2..8c4486aa28c6 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; import java.util.Locale; import java.util.Map; import org.apache.flink.api.common.functions.MapFunction; @@ -44,6 +45,7 @@ import org.apache.iceberg.flink.TableLoader; import org.apache.iceberg.io.WriteResult; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.util.PropertyUtil; @@ -113,6 +115,7 @@ public static class Builder { private TableSchema tableSchema; private boolean overwrite = false; private Integer writeParallelism = null; + private List equalityFieldColumns = null; private Builder() { } @@ -169,6 +172,17 @@ public Builder writeParallelism(int newWriteParallelism) { return this; } + /** + * Configuring the equality field columns for iceberg table that accept CDC or UPSERT events. + * + * @param columns defines the iceberg table's key. + * @return {@link Builder} to connect the iceberg table. + */ + public Builder equalityFieldColumns(List columns) { + this.equalityFieldColumns = columns; + return this; + } + @SuppressWarnings("unchecked") public DataStreamSink build() { Preconditions.checkArgument(rowDataInput != null, @@ -184,7 +198,18 @@ public DataStreamSink build() { } } - IcebergStreamWriter streamWriter = createStreamWriter(table, tableSchema); + // Find out the equality field id list based on the user-provided equality field column names. + List equalityFieldIds = Lists.newArrayList(); + if (equalityFieldColumns != null && equalityFieldColumns.size() > 0) { + for (String column : equalityFieldColumns) { + org.apache.iceberg.types.Types.NestedField field = table.schema().findField(column); + Preconditions.checkNotNull(field, "Missing required equality field column '%s' in table schema %s", + column, table.schema()); + equalityFieldIds.add(field.fieldId()); + } + } + + IcebergStreamWriter streamWriter = createStreamWriter(table, tableSchema, equalityFieldIds); IcebergFilesCommitter filesCommitter = new IcebergFilesCommitter(tableLoader, overwrite); this.writeParallelism = writeParallelism == null ? rowDataInput.getParallelism() : writeParallelism; @@ -202,7 +227,8 @@ public DataStreamSink build() { } } - static IcebergStreamWriter createStreamWriter(Table table, TableSchema requestedSchema) { + static IcebergStreamWriter createStreamWriter(Table table, TableSchema requestedSchema, + List equalityFieldIds) { Preconditions.checkArgument(table != null, "Iceberg table should't be null"); RowType flinkSchema; @@ -226,7 +252,7 @@ static IcebergStreamWriter createStreamWriter(Table table, TableSchema TaskWriterFactory taskWriterFactory = new RowDataTaskWriterFactory(table.schema(), flinkSchema, table.spec(), table.locationProvider(), table.io(), table.encryption(), targetFileSize, fileFormat, props, - null); + equalityFieldIds); return new IcebergStreamWriter<>(table.name(), taskWriterFactory); } diff --git a/flink/src/test/java/org/apache/iceberg/flink/SimpleDataUtil.java b/flink/src/test/java/org/apache/iceberg/flink/SimpleDataUtil.java index da064eb057b5..ff80d9da7e02 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/SimpleDataUtil.java +++ b/flink/src/test/java/org/apache/iceberg/flink/SimpleDataUtil.java @@ -20,6 +20,7 @@ package org.apache.iceberg.flink; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.flink.table.api.DataTypes; @@ -56,6 +57,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.StructLikeSet; import org.junit.Assert; import static org.apache.iceberg.hadoop.HadoopOutputFile.fromPath; @@ -197,4 +199,19 @@ public static void assertTableRecords(String tablePath, List expected) t Preconditions.checkArgument(expected != null, "expected records shouldn't be null"); assertTableRecords(new HadoopTables().load(tablePath), expected); } + + public static StructLikeSet expectedRowSet(Table table, Record... records) { + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + Collections.addAll(set, records); + return set; + } + + public static StructLikeSet actualRowSet(Table table, String... columns) throws IOException { + table.refresh(); + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + try (CloseableIterable reader = IcebergGenerics.read(table).select(columns).build()) { + reader.forEach(set::add); + } + return set; + } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/TestTableLoader.java b/flink/src/test/java/org/apache/iceberg/flink/TestTableLoader.java new file mode 100644 index 000000000000..f3df4283f781 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/TestTableLoader.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink; + +import java.io.File; +import org.apache.iceberg.Table; +import org.apache.iceberg.TestTables; + +public class TestTableLoader implements TableLoader { + private File dir; + + public TestTableLoader(String dir) { + this.dir = new File(dir); + } + + @Override + public void open() { + + } + + @Override + public Table loadTable() { + return TestTables.load(dir, "test"); + } + + @Override + public void close() { + + } +} diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java index ed56753f6243..603562bc70a3 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java @@ -25,7 +25,6 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.stream.Collectors; @@ -35,10 +34,9 @@ import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.RowDelta; import org.apache.iceberg.TableTestBase; -import org.apache.iceberg.data.IcebergGenerics; import org.apache.iceberg.data.Record; import org.apache.iceberg.flink.FlinkSchemaUtil; -import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.flink.SimpleDataUtil; import org.apache.iceberg.io.TaskWriter; import org.apache.iceberg.io.WriteResult; import org.apache.iceberg.relocated.com.google.common.collect.Lists; @@ -325,17 +323,11 @@ private void commitTransaction(WriteResult result) { } private StructLikeSet expectedRowSet(Record... records) { - StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); - Collections.addAll(set, records); - return set; + return SimpleDataUtil.expectedRowSet(table, records); } private StructLikeSet actualRowSet(String... columns) throws IOException { - StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); - try (CloseableIterable reader = IcebergGenerics.read(table).select(columns).build()) { - reader.forEach(set::add); - } - return set; + return SimpleDataUtil.actualRowSet(table, columns); } private TaskWriterFactory createTaskWriterFactory(List equalityFieldIds) { diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java new file mode 100644 index 000000000000..93222ddc4535 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.sink; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.types.RowKind; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.TableTestBase; +import org.apache.iceberg.data.IcebergGenerics; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.flink.SimpleDataUtil; +import org.apache.iceberg.flink.TestTableLoader; +import org.apache.iceberg.flink.source.BoundedTestSource; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.util.StructLikeSet; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class TestFlinkIcebergSinkV2 extends TableTestBase { + private static final int FORMAT_V2 = 2; + private static final TypeInformation ROW_TYPE_INFO = + new RowTypeInfo(SimpleDataUtil.FLINK_SCHEMA.getFieldTypes()); + + private static final Map ROW_KIND_MAP = ImmutableMap.of( + "+I", RowKind.INSERT, + "-D", RowKind.DELETE, + "-U", RowKind.UPDATE_BEFORE, + "+U", RowKind.UPDATE_AFTER); + + private static final int ROW_ID_POS = 0; + private static final int ROW_DATA_POS = 1; + + private final FileFormat format; + private final int parallelism; + private final boolean partitioned; + + private StreamExecutionEnvironment env; + private TestTableLoader tableLoader; + + @Parameterized.Parameters(name = "FileFormat = {0}, Parallelism = {1}, Partitioned={2}") + public static Object[][] parameters() { + return new Object[][] { + new Object[] {"avro", 1, true}, + new Object[] {"avro", 1, false}, + new Object[] {"avro", 2, true}, + new Object[] {"avro", 2, false}, + new Object[] {"parquet", 1, true}, + new Object[] {"parquet", 1, false}, + new Object[] {"parquet", 2, true}, + new Object[] {"parquet", 2, false} + }; + } + + public TestFlinkIcebergSinkV2(String format, int parallelism, boolean partitioned) { + super(FORMAT_V2); + this.format = FileFormat.valueOf(format.toUpperCase(Locale.ENGLISH)); + this.parallelism = parallelism; + this.partitioned = partitioned; + } + + @Before + public void setupTable() throws IOException { + this.tableDir = temp.newFolder(); + this.metadataDir = new File(tableDir, "metadata"); + Assert.assertTrue(tableDir.delete()); + + if (!partitioned) { + table = create(SimpleDataUtil.SCHEMA, PartitionSpec.unpartitioned()); + } else { + table = create(SimpleDataUtil.SCHEMA, PartitionSpec.builderFor(SimpleDataUtil.SCHEMA).identity("data").build()); + } + + table.updateProperties() + .set(TableProperties.DEFAULT_FILE_FORMAT, format.name()) + .commit(); + + env = StreamExecutionEnvironment.getExecutionEnvironment() + .enableCheckpointing(100L) + .setParallelism(parallelism) + .setMaxParallelism(parallelism); + + tableLoader = new TestTableLoader(tableDir.getAbsolutePath()); + } + + private List findValidSnapshots(Table table) { + List validSnapshots = Lists.newArrayList(); + for (Snapshot snapshot : table.snapshots()) { + if (snapshot.allManifests().stream().anyMatch(m -> snapshot.snapshotId() == m.snapshotId())) { + validSnapshots.add(snapshot); + } + } + return validSnapshots; + } + + private void testChangeLogs(List equalityFieldColumns, + KeySelector keySelector, + List> elementsPerCheckpoint, + List> expectedRecordsPerCheckpoint) throws Exception { + DataStream dataStream = env.addSource(new BoundedTestSource<>(elementsPerCheckpoint), ROW_TYPE_INFO); + + // Shuffle by the equality key, so that different operations of the same key could be wrote in order when + // executing tasks in parallel. + dataStream = dataStream.keyBy(keySelector); + + FlinkSink.forRow(dataStream, SimpleDataUtil.FLINK_SCHEMA) + .tableLoader(tableLoader) + .tableSchema(SimpleDataUtil.FLINK_SCHEMA) + .writeParallelism(parallelism) + .equalityFieldColumns(equalityFieldColumns) + .build(); + + // Execute the program. + env.execute("Test Iceberg Change-Log DataStream."); + + table.refresh(); + List snapshots = findValidSnapshots(table); + int expectedSnapshotNum = expectedRecordsPerCheckpoint.size(); + Assert.assertEquals("Should have the expected snapshot number", expectedSnapshotNum, snapshots.size()); + + for (int i = 0; i < expectedSnapshotNum; i++) { + long snapshotId = snapshots.get(i).snapshotId(); + List expectedRecords = expectedRecordsPerCheckpoint.get(i); + Assert.assertEquals("Should have the expected records for the checkpoint#" + i, + expectedRowSet(expectedRecords.toArray(new Record[0])), actualRowSet(snapshotId, "*")); + } + } + + private Row row(String rowKind, int id, String data) { + RowKind kind = ROW_KIND_MAP.get(rowKind); + if (kind == null) { + throw new IllegalArgumentException("Unknown row kind: " + rowKind); + } + + return Row.ofKind(kind, id, data); + } + + private Record record(int id, String data) { + return SimpleDataUtil.createRecord(id, data); + } + + @Test + public void testChangeLogOnIdKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("-D", 1, "aaa"), + row("+I", 1, "bbb"), + row("+I", 2, "aaa"), + row("-D", 2, "aaa"), + row("+I", 2, "bbb") + ), + ImmutableList.of( + row("-U", 2, "bbb"), + row("+U", 2, "ccc"), + row("-D", 2, "ccc"), + row("+I", 2, "ddd") + ), + ImmutableList.of( + row("-D", 1, "bbb"), + row("+I", 1, "ccc"), + row("-D", 1, "ccc"), + row("+I", 1, "ddd") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "bbb"), record(2, "bbb")), + ImmutableList.of(record(1, "bbb"), record(2, "ddd")), + ImmutableList.of(record(1, "ddd"), record(2, "ddd")) + ); + + testChangeLogs(ImmutableList.of("id"), row -> row.getField(ROW_ID_POS), elementsPerCheckpoint, expectedRecords); + } + + @Test + public void testChangeLogOnDataKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("-D", 1, "aaa"), + row("+I", 2, "bbb"), + row("+I", 1, "bbb"), + row("+I", 2, "aaa") + ), + ImmutableList.of( + row("-U", 2, "aaa"), + row("+U", 1, "ccc"), + row("+I", 1, "aaa") + ), + ImmutableList.of( + row("-D", 1, "bbb"), + row("+I", 2, "aaa"), + row("+I", 2, "ccc") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "bbb"), record(2, "aaa")), + ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc")), + ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "ccc")) + ); + + testChangeLogs(ImmutableList.of("data"), row -> row.getField(ROW_DATA_POS), elementsPerCheckpoint, expectedRecords); + } + + @Test + public void testChangeLogOnIdDataKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("-D", 1, "aaa"), + row("+I", 2, "bbb"), + row("+I", 1, "bbb"), + row("+I", 2, "aaa") + ), + ImmutableList.of( + row("-U", 2, "aaa"), + row("+U", 1, "ccc"), + row("+I", 1, "aaa") + ), + ImmutableList.of( + row("-D", 1, "bbb"), + row("+I", 2, "aaa") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "bbb"), record(2, "aaa"), record(2, "bbb")), + ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc"), record(2, "bbb")), + ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "bbb")) + ); + + testChangeLogs(ImmutableList.of("data", "id"), row -> Row.of(row.getField(ROW_ID_POS), row.getField(ROW_DATA_POS)), + elementsPerCheckpoint, expectedRecords); + } + + @Test + public void testChangeLogOnSameKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + // Checkpoint #1 + ImmutableList.of( + row("+I", 1, "aaa"), + row("-D", 1, "aaa"), + row("+I", 1, "aaa") + ), + // Checkpoint #2 + ImmutableList.of( + row("-U", 1, "aaa"), + row("+U", 1, "aaa") + ), + // Checkpoint #3 + ImmutableList.of( + row("-D", 1, "aaa"), + row("+I", 1, "aaa") + ), + // Checkpoint #4 + ImmutableList.of( + row("-U", 1, "aaa"), + row("+U", 1, "aaa"), + row("+I", 1, "aaa") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "aaa")), + ImmutableList.of(record(1, "aaa")), + ImmutableList.of(record(1, "aaa")), + ImmutableList.of(record(1, "aaa"), record(1, "aaa")) + ); + + testChangeLogs(ImmutableList.of("id", "data"), row -> Row.of(row.getField(ROW_ID_POS), row.getField(ROW_DATA_POS)), + elementsPerCheckpoint, expectedRecords); + } + + private StructLikeSet expectedRowSet(Record... records) { + return SimpleDataUtil.expectedRowSet(table, records); + } + + private StructLikeSet actualRowSet(long snapshotId, String... columns) throws IOException { + table.refresh(); + StructLikeSet set = StructLikeSet.create(table.schema().asStruct()); + try (CloseableIterable reader = IcebergGenerics.read(table) + .useSnapshot(snapshotId) + .select(columns) + .build()) { + reader.forEach(set::add); + } + return set; + } +} diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergFilesCommitter.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergFilesCommitter.java index 5ec246b7af07..418e480c8722 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergFilesCommitter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergFilesCommitter.java @@ -48,13 +48,11 @@ import org.apache.iceberg.ManifestContent; import org.apache.iceberg.ManifestFile; import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.Table; import org.apache.iceberg.TableTestBase; -import org.apache.iceberg.TestTables; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.flink.SimpleDataUtil; -import org.apache.iceberg.flink.TableLoader; +import org.apache.iceberg.flink.TestTableLoader; import org.apache.iceberg.io.FileAppenderFactory; import org.apache.iceberg.io.WriteResult; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; @@ -861,27 +859,4 @@ public Class getStreamOperatorClass(ClassLoader classL return IcebergFilesCommitter.class; } } - - private static class TestTableLoader implements TableLoader { - private File dir = null; - - TestTableLoader(String dir) { - this.dir = new File(dir); - } - - @Override - public void open() { - - } - - @Override - public Table loadTable() { - return TestTables.load(dir, "test"); - } - - @Override - public void close() { - - } - } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java index 28db89456f7e..c6c20e0624fb 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java @@ -337,7 +337,7 @@ private OneInputStreamOperatorTestHarness createIcebergStr private OneInputStreamOperatorTestHarness createIcebergStreamWriter( Table icebergTable, TableSchema flinkSchema) throws Exception { - IcebergStreamWriter streamWriter = FlinkSink.createStreamWriter(icebergTable, flinkSchema); + IcebergStreamWriter streamWriter = FlinkSink.createStreamWriter(icebergTable, flinkSchema, null); OneInputStreamOperatorTestHarness harness = new OneInputStreamOperatorTestHarness<>( streamWriter, 1, 1, 0); diff --git a/flink/src/test/java/org/apache/iceberg/flink/source/BoundedTestSource.java b/flink/src/test/java/org/apache/iceberg/flink/source/BoundedTestSource.java new file mode 100644 index 000000000000..1ae04ab6d741 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/source/BoundedTestSource.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.source; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +/** + * A stream source that: + * 1) emits the elements from elementsPerCheckpoint.get(0) without allowing checkpoints. + * 2) then waits for the checkpoint to complete. + * 3) emits the elements from elementsPerCheckpoint.get(1) without allowing checkpoints. + * 4) then waits for the checkpoint to complete. + * 5) ... + * + *

Util all the list from elementsPerCheckpoint are exhausted. + */ +public final class BoundedTestSource implements SourceFunction, CheckpointListener { + + private final List> elementsPerCheckpoint; + private volatile boolean running = true; + + private final AtomicInteger numCheckpointsComplete = new AtomicInteger(0); + + /** + * Emits all those elements in several checkpoints. + */ + public BoundedTestSource(List> elementsPerCheckpoint) { + this.elementsPerCheckpoint = elementsPerCheckpoint; + } + + /** + * Emits all those elements in a single checkpoint. + */ + public BoundedTestSource(T... elements) { + this(Collections.singletonList(Arrays.asList(elements))); + } + + @Override + public void run(SourceContext ctx) throws Exception { + for (int checkpoint = 0; checkpoint < elementsPerCheckpoint.size(); checkpoint++) { + + final int checkpointToAwait; + synchronized (ctx.getCheckpointLock()) { + checkpointToAwait = numCheckpointsComplete.get() + 2; + for (T element : elementsPerCheckpoint.get(checkpoint)) { + ctx.collect(element); + } + } + + synchronized (ctx.getCheckpointLock()) { + while (running && numCheckpointsComplete.get() < checkpointToAwait) { + ctx.getCheckpointLock().wait(1); + } + } + } + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + numCheckpointsComplete.incrementAndGet(); + } + + @Override + public void cancel() { + running = false; + } +}