diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/BaseDeltaTaskWriter.java b/flink/src/main/java/org/apache/iceberg/flink/sink/BaseDeltaTaskWriter.java index 10dab416091c..3696b446a435 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/BaseDeltaTaskWriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/BaseDeltaTaskWriter.java @@ -41,6 +41,7 @@ abstract class BaseDeltaTaskWriter extends BaseTaskWriter { private final Schema schema; private final Schema deleteSchema; private final RowDataWrapper wrapper; + private final boolean upsert; BaseDeltaTaskWriter(PartitionSpec spec, FileFormat format, @@ -50,11 +51,13 @@ abstract class BaseDeltaTaskWriter extends BaseTaskWriter { long targetFileSize, Schema schema, RowType flinkSchema, - List equalityFieldIds) { + List equalityFieldIds, + boolean upsert) { super(spec, format, appenderFactory, fileFactory, io, targetFileSize); this.schema = schema; this.deleteSchema = TypeUtil.select(schema, Sets.newHashSet(equalityFieldIds)); this.wrapper = new RowDataWrapper(flinkSchema, schema.asStruct()); + this.upsert = upsert; } abstract RowDataDeltaWriter route(RowData row); @@ -70,6 +73,9 @@ public void write(RowData row) throws IOException { switch (row.getRowKind()) { case INSERT: case UPDATE_AFTER: + if (upsert) { + writer.delete(row); + } writer.write(row); break; diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java b/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java index 8c4486aa28c6..138cfa7bdae9 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/FlinkSink.java @@ -39,6 +39,7 @@ import org.apache.flink.types.Row; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; +import org.apache.iceberg.PartitionField; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.flink.FlinkSchemaUtil; @@ -115,6 +116,7 @@ public static class Builder { private TableSchema tableSchema; private boolean overwrite = false; private Integer writeParallelism = null; + private boolean upsert = false; private List equalityFieldColumns = null; private Builder() { @@ -172,6 +174,20 @@ public Builder writeParallelism(int newWriteParallelism) { return this; } + /** + * All INSERT/UPDATE_AFTER events from input stream will be transformed to UPSERT events, which means it will + * DELETE the old records and then INSERT the new records. In partitioned table, the partition fields should be + * a subset of equality fields, otherwise the old row that located in partition-A could not be deleted by the + * new row that located in partition-B. + * + * @param enable indicate whether it should transform all INSERT/UPDATE_AFTER events to UPSERT. + * @return {@link Builder} to connect the iceberg table. + */ + public Builder upsert(boolean enable) { + this.upsert = enable; + return this; + } + /** * Configuring the equality field columns for iceberg table that accept CDC or UPSERT events. * @@ -209,7 +225,22 @@ public DataStreamSink build() { } } - IcebergStreamWriter streamWriter = createStreamWriter(table, tableSchema, equalityFieldIds); + // Convert the iceberg schema to flink's RowType. + RowType flinkSchema = convertToRowType(table, tableSchema); + + // Convert the INSERT stream to be an UPSERT stream if needed. + if (upsert) { + Preconditions.checkState(!equalityFieldIds.isEmpty(), + "Equality field columns shouldn't be empty when configuring to use UPSERT data stream."); + if (!table.spec().isUnpartitioned()) { + for (PartitionField partitionField : table.spec().fields()) { + Preconditions.checkState(equalityFieldIds.contains(partitionField.sourceId()), + "Partition field '%s' is not included in equality fields: '%s'", partitionField, equalityFieldColumns); + } + } + } + + IcebergStreamWriter streamWriter = createStreamWriter(table, flinkSchema, equalityFieldIds, upsert); IcebergFilesCommitter filesCommitter = new IcebergFilesCommitter(tableLoader, overwrite); this.writeParallelism = writeParallelism == null ? rowDataInput.getParallelism() : writeParallelism; @@ -227,8 +258,7 @@ public DataStreamSink build() { } } - static IcebergStreamWriter createStreamWriter(Table table, TableSchema requestedSchema, - List equalityFieldIds) { + private static RowType convertToRowType(Table table, TableSchema requestedSchema) { Preconditions.checkArgument(table != null, "Iceberg table should't be null"); RowType flinkSchema; @@ -246,13 +276,22 @@ static IcebergStreamWriter createStreamWriter(Table table, TableSchema flinkSchema = FlinkSchemaUtil.convert(table.schema()); } + return flinkSchema; + } + + static IcebergStreamWriter createStreamWriter(Table table, + RowType flinkSchema, + List equalityFieldIds, + boolean upsert) { + Preconditions.checkArgument(table != null, "Iceberg table should't be null"); + Map props = table.properties(); long targetFileSize = getTargetFileSizeBytes(props); FileFormat fileFormat = getFileFormat(props); TaskWriterFactory taskWriterFactory = new RowDataTaskWriterFactory(table.schema(), flinkSchema, table.spec(), table.locationProvider(), table.io(), table.encryption(), targetFileSize, fileFormat, props, - equalityFieldIds); + equalityFieldIds, upsert); return new IcebergStreamWriter<>(table.name(), taskWriterFactory); } diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/PartitionedDeltaWriter.java b/flink/src/main/java/org/apache/iceberg/flink/sink/PartitionedDeltaWriter.java index b2f8ceece9f8..1eee6298e933 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/PartitionedDeltaWriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/PartitionedDeltaWriter.java @@ -49,8 +49,10 @@ class PartitionedDeltaWriter extends BaseDeltaTaskWriter { long targetFileSize, Schema schema, RowType flinkSchema, - List equalityFieldIds) { - super(spec, format, appenderFactory, fileFactory, io, targetFileSize, schema, flinkSchema, equalityFieldIds); + List equalityFieldIds, + boolean upsert) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize, schema, flinkSchema, equalityFieldIds, + upsert); this.partitionKey = new PartitionKey(spec, schema); } diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/RowDataTaskWriterFactory.java b/flink/src/main/java/org/apache/iceberg/flink/sink/RowDataTaskWriterFactory.java index b0776f49d190..be7268da6670 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/RowDataTaskWriterFactory.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/RowDataTaskWriterFactory.java @@ -49,6 +49,7 @@ public class RowDataTaskWriterFactory implements TaskWriterFactory { private final long targetFileSizeBytes; private final FileFormat format; private final List equalityFieldIds; + private final boolean upsert; private final FileAppenderFactory appenderFactory; private transient OutputFileFactory outputFileFactory; @@ -62,7 +63,8 @@ public RowDataTaskWriterFactory(Schema schema, long targetFileSizeBytes, FileFormat format, Map tableProperties, - List equalityFieldIds) { + List equalityFieldIds, + boolean upsert) { this.schema = schema; this.flinkSchema = flinkSchema; this.spec = spec; @@ -72,6 +74,7 @@ public RowDataTaskWriterFactory(Schema schema, this.targetFileSizeBytes = targetFileSizeBytes; this.format = format; this.equalityFieldIds = equalityFieldIds; + this.upsert = upsert; if (equalityFieldIds == null || equalityFieldIds.isEmpty()) { this.appenderFactory = new FlinkAppenderFactory(schema, flinkSchema, tableProperties, spec); @@ -104,10 +107,10 @@ public TaskWriter create() { // Initialize a task writer to write both INSERT and equality DELETE. if (spec.isUnpartitioned()) { return new UnpartitionedDeltaWriter(spec, format, appenderFactory, outputFileFactory, io, - targetFileSizeBytes, schema, flinkSchema, equalityFieldIds); + targetFileSizeBytes, schema, flinkSchema, equalityFieldIds, upsert); } else { return new PartitionedDeltaWriter(spec, format, appenderFactory, outputFileFactory, io, - targetFileSizeBytes, schema, flinkSchema, equalityFieldIds); + targetFileSizeBytes, schema, flinkSchema, equalityFieldIds, upsert); } } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/sink/UnpartitionedDeltaWriter.java b/flink/src/main/java/org/apache/iceberg/flink/sink/UnpartitionedDeltaWriter.java index 341e634df713..331ed7c78192 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/sink/UnpartitionedDeltaWriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/sink/UnpartitionedDeltaWriter.java @@ -41,8 +41,10 @@ class UnpartitionedDeltaWriter extends BaseDeltaTaskWriter { long targetFileSize, Schema schema, RowType flinkSchema, - List equalityFieldIds) { - super(spec, format, appenderFactory, fileFactory, io, targetFileSize, schema, flinkSchema, equalityFieldIds); + List equalityFieldIds, + boolean upsert) { + super(spec, format, appenderFactory, fileFactory, io, targetFileSize, schema, flinkSchema, equalityFieldIds, + upsert); this.writer = new RowDataDeltaWriter(null); } diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java index 8b4986dcd67b..7691366e9b9b 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java @@ -81,7 +81,8 @@ public RowDataRewriter(Table table, boolean caseSensitive, FileIO io, Encryption Long.MAX_VALUE, format, table.properties(), - null); + null, + false); } public List rewriteDataForTasks(DataStream dataStream, int parallelism) { diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java index 603562bc70a3..1b157c9d6efb 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestDeltaTaskWriter.java @@ -333,6 +333,6 @@ private StructLikeSet actualRowSet(String... columns) throws IOException { private TaskWriterFactory createTaskWriterFactory(List equalityFieldIds) { return new RowDataTaskWriterFactory(table.schema(), FlinkSchemaUtil.convert(table.schema()), table.spec(), table.locationProvider(), table.io(), table.encryption(), 128 * 1024 * 1024, - format, table.properties(), equalityFieldIds); + format, table.properties(), equalityFieldIds, false); } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java index 93222ddc4535..6d68d022f8f7 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestFlinkIcebergSinkV2.java @@ -31,6 +31,7 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; +import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Snapshot; @@ -132,6 +133,7 @@ private List findValidSnapshots(Table table) { private void testChangeLogs(List equalityFieldColumns, KeySelector keySelector, + boolean insertAsUpsert, List> elementsPerCheckpoint, List> expectedRecordsPerCheckpoint) throws Exception { DataStream dataStream = env.addSource(new BoundedTestSource<>(elementsPerCheckpoint), ROW_TYPE_INFO); @@ -145,6 +147,7 @@ private void testChangeLogs(List equalityFieldColumns, .tableSchema(SimpleDataUtil.FLINK_SCHEMA) .writeParallelism(parallelism) .equalityFieldColumns(equalityFieldColumns) + .upsert(insertAsUpsert) .build(); // Execute the program. @@ -207,7 +210,8 @@ public void testChangeLogOnIdKey() throws Exception { ImmutableList.of(record(1, "ddd"), record(2, "ddd")) ); - testChangeLogs(ImmutableList.of("id"), row -> row.getField(ROW_ID_POS), elementsPerCheckpoint, expectedRecords); + testChangeLogs(ImmutableList.of("id"), row -> row.getField(ROW_ID_POS), false, + elementsPerCheckpoint, expectedRecords); } @Test @@ -238,7 +242,8 @@ public void testChangeLogOnDataKey() throws Exception { ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "ccc")) ); - testChangeLogs(ImmutableList.of("data"), row -> row.getField(ROW_DATA_POS), elementsPerCheckpoint, expectedRecords); + testChangeLogs(ImmutableList.of("data"), row -> row.getField(ROW_DATA_POS), false, + elementsPerCheckpoint, expectedRecords); } @Test @@ -269,7 +274,7 @@ public void testChangeLogOnIdDataKey() throws Exception { ); testChangeLogs(ImmutableList.of("data", "id"), row -> Row.of(row.getField(ROW_ID_POS), row.getField(ROW_DATA_POS)), - elementsPerCheckpoint, expectedRecords); + false, elementsPerCheckpoint, expectedRecords); } @Test @@ -307,9 +312,103 @@ public void testChangeLogOnSameKey() throws Exception { ); testChangeLogs(ImmutableList.of("id", "data"), row -> Row.of(row.getField(ROW_ID_POS), row.getField(ROW_DATA_POS)), + false, elementsPerCheckpoint, expectedRecords); + } + + @Test + public void testUpsertOnIdKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("+U", 1, "bbb") + ), + ImmutableList.of( + row("+I", 1, "ccc") + ), + ImmutableList.of( + row("+U", 1, "ddd"), + row("+I", 1, "eee") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "bbb")), + ImmutableList.of(record(1, "ccc")), + ImmutableList.of(record(1, "eee")) + ); + + if (!partitioned) { + testChangeLogs(ImmutableList.of("id"), row -> row.getField(ROW_ID_POS), true, + elementsPerCheckpoint, expectedRecords); + } else { + AssertHelpers.assertThrows("Should be error because equality field columns don't include all partition keys", + IllegalStateException.class, "not included in equality fields", + () -> { + testChangeLogs(ImmutableList.of("id"), row -> row.getField(ROW_ID_POS), true, elementsPerCheckpoint, + expectedRecords); + return null; + }); + } + } + + @Test + public void testUpsertOnDataKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("+I", 2, "aaa"), + row("+I", 3, "bbb") + ), + ImmutableList.of( + row("+U", 4, "aaa"), + row("-U", 3, "bbb"), + row("+U", 5, "bbb") + ), + ImmutableList.of( + row("+I", 6, "aaa"), + row("+U", 7, "bbb") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(2, "aaa"), record(3, "bbb")), + ImmutableList.of(record(4, "aaa"), record(5, "bbb")), + ImmutableList.of(record(6, "aaa"), record(7, "bbb")) + ); + + testChangeLogs(ImmutableList.of("data"), row -> row.getField(ROW_DATA_POS), true, elementsPerCheckpoint, expectedRecords); } + @Test + public void testUpsertOnIdDataKey() throws Exception { + List> elementsPerCheckpoint = ImmutableList.of( + ImmutableList.of( + row("+I", 1, "aaa"), + row("+U", 1, "aaa"), + row("+I", 2, "bbb") + ), + ImmutableList.of( + row("+I", 1, "aaa"), + row("-D", 2, "bbb"), + row("+I", 2, "ccc") + ), + ImmutableList.of( + row("-U", 1, "aaa"), + row("+U", 1, "bbb") + ) + ); + + List> expectedRecords = ImmutableList.of( + ImmutableList.of(record(1, "aaa"), record(2, "bbb")), + ImmutableList.of(record(1, "aaa"), record(2, "ccc")), + ImmutableList.of(record(1, "bbb"), record(2, "ccc")) + ); + + testChangeLogs(ImmutableList.of("id", "data"), row -> Row.of(row.getField(ROW_ID_POS), row.getField(ROW_DATA_POS)), + true, elementsPerCheckpoint, expectedRecords); + } + private StructLikeSet expectedRowSet(Record... records) { return SimpleDataUtil.expectedRowSet(table, records); } diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java index c6c20e0624fb..a6b46ac61ed6 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestIcebergStreamWriter.java @@ -32,6 +32,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.logical.RowType; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.LocatedFileStatus; @@ -337,7 +338,8 @@ private OneInputStreamOperatorTestHarness createIcebergStr private OneInputStreamOperatorTestHarness createIcebergStreamWriter( Table icebergTable, TableSchema flinkSchema) throws Exception { - IcebergStreamWriter streamWriter = FlinkSink.createStreamWriter(icebergTable, flinkSchema, null); + RowType rowType = (RowType) flinkSchema.toRowDataType().getLogicalType(); + IcebergStreamWriter streamWriter = FlinkSink.createStreamWriter(icebergTable, rowType, null, false); OneInputStreamOperatorTestHarness harness = new OneInputStreamOperatorTestHarness<>( streamWriter, 1, 1, 0); diff --git a/flink/src/test/java/org/apache/iceberg/flink/sink/TestTaskWriters.java b/flink/src/test/java/org/apache/iceberg/flink/sink/TestTaskWriters.java index 8439f7d80c41..84160773e26e 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/sink/TestTaskWriters.java +++ b/flink/src/test/java/org/apache/iceberg/flink/sink/TestTaskWriters.java @@ -239,7 +239,7 @@ private TaskWriter createTaskWriter(long targetFileSize) { TaskWriterFactory taskWriterFactory = new RowDataTaskWriterFactory(table.schema(), (RowType) SimpleDataUtil.FLINK_SCHEMA.toRowDataType().getLogicalType(), table.spec(), table.locationProvider(), table.io(), table.encryption(), - targetFileSize, format, table.properties(), null); + targetFileSize, format, table.properties(), null, false); taskWriterFactory.initialize(1, 1); return taskWriterFactory.create(); }