From 97f48a2d94e97351032f4233fd1221a72a081072 Mon Sep 17 00:00:00 2001 From: Adrian Carpente Recouso Date: Mon, 12 Jan 2026 12:23:11 +0100 Subject: [PATCH 1/3] Add SQL Support for MERGE INTO In Presto #20578 (iceberg) Support SQL MERGE in the Iceberg connector Cherry-pick of trinodb/trino@6cb188b Co-authored-by: David Phillips --- .../iceberg/IcebergAbstractMetadata.java | 83 +++++++ .../presto/iceberg/IcebergColumnHandle.java | 7 + .../presto/iceberg/IcebergHandleResolver.java | 6 + .../presto/iceberg/IcebergMergeSink.java | 219 ++++++++++++++++++ .../iceberg/IcebergMergeTableHandle.java | 68 ++++++ .../presto/iceberg/IcebergMetadataColumn.java | 10 +- .../iceberg/IcebergPageSinkProvider.java | 30 +++ .../iceberg/IcebergPageSourceProvider.java | 57 ++++- .../iceberg/IcebergUpdateablePageSource.java | 38 +-- .../iceberg/delete/IcebergDeletePageSink.java | 9 +- 10 files changed, 494 insertions(+), 33 deletions(-) create mode 100644 presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java create mode 100644 presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index d1e8f670a9594..913c501b5795d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; @@ -68,6 +69,7 @@ import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionOperator; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; +import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.procedure.BaseProcedure; @@ -98,6 +100,7 @@ import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileMetadata; import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.MetricsModes.None; import org.apache.iceberg.PartitionSpec; @@ -119,6 +122,7 @@ import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StringType; import org.apache.iceberg.util.CharSequenceSet; import org.apache.iceberg.view.View; @@ -159,6 +163,7 @@ import static com.facebook.presto.iceberg.IcebergColumnHandle.PATH_COLUMN_HANDLE; import static com.facebook.presto.iceberg.IcebergColumnHandle.PATH_COLUMN_METADATA; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_COMMIT_ERROR; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_FORMAT_VERSION; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_MATERIALIZED_VIEW; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_SNAPSHOT_ID; import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getRefreshType; @@ -170,6 +175,8 @@ import static com.facebook.presto.iceberg.IcebergMetadataColumn.DELETE_FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.IS_DELETED; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_PARTITION_DATA; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_TARGET_ROW_ID_DATA; import static com.facebook.presto.iceberg.IcebergMetadataColumn.UPDATE_ROW_DATA; import static com.facebook.presto.iceberg.IcebergPartitionType.ALL; import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; @@ -225,6 +232,7 @@ import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; @@ -232,11 +240,13 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Maps.transformValues; import static java.lang.Long.parseLong; import static java.lang.String.format; import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; import static org.apache.iceberg.MetadataColumns.ROW_POSITION; +import static org.apache.iceberg.MetadataColumns.SPEC_ID; import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; import static org.apache.iceberg.SnapshotSummary.DELETED_RECORDS_PROP; import static org.apache.iceberg.SnapshotSummary.REMOVED_EQ_DELETES_PROP; @@ -795,6 +805,78 @@ public Optional getDeleteRowIdColumn(ConnectorSession session, Con return Optional.of(IcebergColumnHandle.create(ROW_POSITION, typeManager, REGULAR)); } + /** + * Return the row change paradigm supported by the connector on the table. + */ + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return DELETE_ROW_AND_INSERT_ROW; + } + + @Override + public ColumnHandle getMergeTargetTableRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Types.StructType type = Types.StructType.of(ImmutableList.builder() + .add(MetadataColumns.FILE_PATH) + .add(ROW_POSITION) + .add(SPEC_ID) + .add(NestedField.required(MERGE_PARTITION_DATA.getId(), MERGE_PARTITION_DATA.getColumnName(), StringType.get())) + .build()); + + NestedField field = NestedField.required(MERGE_TARGET_ROW_ID_DATA.getId(), MERGE_TARGET_ROW_ID_DATA.getColumnName(), type); + return IcebergColumnHandle.create(field, typeManager, SYNTHESIZED); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + verify(icebergTableHandle.getIcebergTableName().getTableType() == DATA, "only the data table can have data merged"); + Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); + int formatVersion = ((BaseTable) icebergTable).operations().current().formatVersion(); + + if (formatVersion < MIN_FORMAT_VERSION_FOR_DELETE || + !Optional.ofNullable(icebergTable.properties().get(TableProperties.UPDATE_MODE)) + .map(mode -> mode.equals(MERGE_ON_READ.modeName())) + .orElse(false)) { + throw new PrestoException(ICEBERG_INVALID_FORMAT_VERSION, + "Iceberg table updates require at least format version 2 and update mode must be merge-on-read"); + } + validateTableMode(session, icebergTable); + transaction = icebergTable.newTransaction(); + + IcebergInsertTableHandle insertHandle = new IcebergInsertTableHandle( + icebergTableHandle.getSchemaName(), + icebergTableHandle.getIcebergTableName(), + toPrestoSchema(icebergTable.schema(), typeManager), + toPrestoPartitionSpec(icebergTable.spec(), typeManager), + getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + icebergTable.location(), + getFileFormat(icebergTable), + getCompressionCodec(session), + icebergTable.properties(), + getSupportedSortFields(icebergTable.schema(), icebergTable.sortOrder()), + Optional.empty()); + + Map partitionSpecs = transformValues(icebergTable.specs(), partitionSpec -> toPrestoPartitionSpec(partitionSpec, typeManager)); + + return new IcebergMergeTableHandle(icebergTableHandle, insertHandle, partitionSpecs); + } + + @Override + public void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle tableHandle, + Collection fragments, + Collection computedStatistics) + { + IcebergWritableTableHandle insertTableHandle = + ((IcebergMergeTableHandle) tableHandle).getInsertTableHandle(); + + finishWrite(session, insertTableHandle, fragments, UPDATE_AFTER); + } + @Override public boolean isLegacyGetLayoutSupported(ConnectorSession session, ConnectorTableHandle tableHandle) { @@ -814,6 +896,7 @@ protected List getColumnMetadata(ConnectorSession session, Table .setExtraInfo(partitionFields.containsKey(column.name()) ? columnExtraInfo(partitionFields.get(column.name())) : null) + .setNullable(column.isOptional()) .build()) .collect(toImmutableList()); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java index 62fafea0d48d8..3afa99d710d72 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java @@ -37,6 +37,7 @@ import static com.facebook.presto.iceberg.IcebergMetadataColumn.DELETE_FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.IS_DELETED; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_TARGET_ROW_ID_DATA; import static com.facebook.presto.iceberg.IcebergMetadataColumn.UPDATE_ROW_DATA; import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; import static com.google.common.base.Preconditions.checkArgument; @@ -109,6 +110,12 @@ public boolean isUpdateRowIdColumn() return columnIdentity.getId() == UPDATE_ROW_DATA.getId(); } + @JsonIgnore + public boolean isMergeTargetTableRowIdColumn() + { + return columnIdentity.getId() == MERGE_TARGET_ROW_ID_DATA.getId(); + } + @Override public ColumnHandle withRequiredSubfields(List subfields) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java index 92d3d0e9fdeec..fbb24b55b577d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -64,6 +65,11 @@ public Class getInsertTableHandleClass() return IcebergInsertTableHandle.class; } + public Class getMergeTableHandleClass() + { + return IcebergMergeTableHandle.class; + } + @Override public Class getDeleteTableHandleClass() { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java new file mode 100644 index 0000000000000..579aca18ce44b --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.ColumnarRow; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.hive.HdfsContext; +import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.iceberg.delete.IcebergDeletePageSink; +import com.facebook.presto.spi.ConnectorMergeSink; +import com.facebook.presto.spi.ConnectorPageSink; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.connector.MergePage; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.io.LocationProvider; +import org.roaringbitmap.longlong.ImmutableLongBitmapDataProvider; +import org.roaringbitmap.longlong.LongBitmapDataProvider; +import org.roaringbitmap.longlong.Roaring64Bitmap; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.plugin.base.util.Closables.closeAllSuppress; +import static com.facebook.presto.spi.connector.MergePage.createDeleteAndInsertPages; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class IcebergMergeSink + implements ConnectorMergeSink +{ + private final LocationProvider locationProvider; + private final IcebergFileWriterFactory fileWriterFactory; + private final HdfsEnvironment hdfsEnvironment; + private final JsonCodec jsonCodec; + private final ConnectorSession session; + private final FileFormat fileFormat; + private final Map partitionsSpecs; + private final ConnectorPageSink insertPageSink; + private final int columnCount; + private final Map fileDeletions = new HashMap<>(); + + public IcebergMergeSink( + LocationProvider locationProvider, + IcebergFileWriterFactory fileWriterFactory, + HdfsEnvironment hdfsEnvironment, + JsonCodec jsonCodec, + ConnectorSession session, + FileFormat fileFormat, + Map partitionsSpecs, + ConnectorPageSink insertPageSink, + int columnCount) + { + this.locationProvider = requireNonNull(locationProvider, "locationProvider is null"); + this.fileWriterFactory = requireNonNull(fileWriterFactory, "fileWriterFactory is null"); + this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); + this.session = requireNonNull(session, "session is null"); + this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); + this.partitionsSpecs = requireNonNull(partitionsSpecs, "partitionsSpecs is null"); + this.insertPageSink = requireNonNull(insertPageSink, "insertPageSink is null"); + this.columnCount = columnCount; + } + + /** + * @param page It has N + 2 channels/blocks, where N is the number of columns in the source table.
+ * 1: Source table column 1.
+ * 2: Source table column 2.
+ * N: Source table column N.
+ * N + 1: Operation: INSERT(1), DELETE(2), UPDATE(3). More info: {@link ConnectorMergeSink}
+ * N + 2: Target Table Row ID (_file:varchar, _pos:bigint, partition_spec_id:integer, partition_data:varchar). + */ + @Override + public void storeMergedRows(Page page) + { + MergePage mergePage = createDeleteAndInsertPages(page, columnCount); + + mergePage.getInsertionsPage().ifPresent(insertPageSink::appendPage); + + mergePage.getDeletionsPage().ifPresent(deletions -> { + ColumnarRow rowIdRow = toColumnarRow(deletions.getBlock(deletions.getChannelCount() - 1)); + + for (int position = 0; position < rowIdRow.getPositionCount(); position++) { + Slice filePath = VarcharType.VARCHAR.getSlice(rowIdRow.getField(0), position); + long rowPosition = BIGINT.getLong(rowIdRow.getField(1), position); + + int index = position; + FileDeletion deletion = fileDeletions.computeIfAbsent(filePath, ignored -> { + int partitionSpecId = toIntExact(INTEGER.getLong(rowIdRow.getField(2), index)); + String partitionData = VarcharType.VARCHAR.getSlice(rowIdRow.getField(3), index).toStringUtf8(); + return new FileDeletion(partitionSpecId, partitionData); + }); + + deletion.rowsToDelete().addLong(rowPosition); + } + }); + } + + @Override + public CompletableFuture> finish() + { + List fragments = new ArrayList<>(insertPageSink.finish().join()); + + fileDeletions.forEach((dataFilePath, deletion) -> { + ConnectorPageSink sink = createPositionDeletePageSink( + dataFilePath.toStringUtf8(), + partitionsSpecs.get(deletion.partitionSpecId()), + deletion.partitionDataJson()); + + fragments.addAll(writePositionDeletes(sink, deletion.rowsToDelete())); + }); + + return completedFuture(fragments); + } + + @Override + public void abort() + { + insertPageSink.abort(); + } + + private ConnectorPageSink createPositionDeletePageSink(String dataFilePath, PartitionSpec partitionSpec, String partitionDataJson) + { + return new IcebergDeletePageSink( + partitionSpec, + Optional.of(partitionDataJson), + locationProvider, + fileWriterFactory, + hdfsEnvironment, + new HdfsContext(session), + jsonCodec, + session, + dataFilePath, + fileFormat); + } + + private static Collection writePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + try { + return doWritePositionDeletes(sink, rowsToDelete); + } + catch (Throwable t) { + closeAllSuppress(t, sink::abort); + throw t; + } + } + + private static Collection doWritePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + + rowsToDelete.forEach(rowPosition -> { + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), rowPosition); + pageBuilder.declarePosition(); + if (pageBuilder.isFull()) { + sink.appendPage(pageBuilder.build()); + pageBuilder.reset(); + } + }); + + if (!pageBuilder.isEmpty()) { + sink.appendPage(pageBuilder.build()); + } + + return sink.finish().join(); + } + + private static class FileDeletion + { + private final int partitionSpecId; + private final String partitionDataJson; + private final LongBitmapDataProvider rowsToDelete = new Roaring64Bitmap(); + + public FileDeletion(int partitionSpecId, String partitionDataJson) + { + this.partitionSpecId = partitionSpecId; + this.partitionDataJson = requireNonNull(partitionDataJson, "partitionDataJson is null"); + } + + public int partitionSpecId() + { + return partitionSpecId; + } + + public String partitionDataJson() + { + return partitionDataJson; + } + + public LongBitmapDataProvider rowsToDelete() + { + return rowsToDelete; + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java new file mode 100644 index 0000000000000..7d706cb1e2d40 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public class IcebergMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final IcebergTableHandle tableHandle; + private final IcebergInsertTableHandle insertTableHandle; + private final Map partitionSpecs; + + @JsonCreator + @ThriftConstructor + public IcebergMergeTableHandle( + @JsonProperty("tableHandle") IcebergTableHandle tableHandle, + @JsonProperty("insertTableHandle") IcebergInsertTableHandle insertTableHandle, + @JsonProperty("partitionSpecs") Map partitionSpecs) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertTableHandle = requireNonNull(insertTableHandle, "insertTableHandle is null"); + this.partitionSpecs = requireNonNull(partitionSpecs, "partitionSpecs is null"); + } + + @Override + @JsonProperty + @ThriftField(1) + public IcebergTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + @ThriftField(2) + public IcebergInsertTableHandle getInsertTableHandle() + { + return insertTableHandle; + } + + @JsonProperty + @ThriftField(3) + public Map getPartitionSpecs() + { + return partitionSpecs; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java index 5862fba4975f8..f89488911f8cd 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java @@ -39,12 +39,14 @@ public enum IcebergMetadataColumn * Iceberg reserved row ids begin at INTEGER.MAX_VALUE and count down. Starting with MIN_VALUE here to avoid conflicts. * Inner type for row is not known until runtime. */ - UPDATE_ROW_DATA(Integer.MIN_VALUE, "$row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT) + UPDATE_ROW_DATA(Integer.MIN_VALUE, "$row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT), + MERGE_TARGET_ROW_ID_DATA(Integer.MIN_VALUE + 1, "$target_table_row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT), + MERGE_PARTITION_DATA(Integer.MIN_VALUE + 2, "partition_data", VARCHAR, PRIMITIVE) /**/; - private static final Set COLUMN_IDS = Stream.of(values()) - .map(IcebergMetadataColumn::getId) - .collect(toImmutableSet()); + private static final Set COLUMN_IDS = Stream.concat( + Stream.of(values()).map(IcebergMetadataColumn::getId), + Stream.of(MetadataColumns.SPEC_ID.fieldId())).collect(toImmutableSet()); private final int id; private final String columnName; private final Type type; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java index e14d0178b153d..45e8a7164df8f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java @@ -18,6 +18,8 @@ import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeSink; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; @@ -32,12 +34,14 @@ import org.apache.iceberg.Table; import org.apache.iceberg.io.LocationProvider; +import java.util.Map; import java.util.Optional; import static com.facebook.presto.iceberg.IcebergUtil.getLocationProvider; import static com.facebook.presto.iceberg.IcebergUtil.getShallowWrappedIcebergTable; import static com.facebook.presto.iceberg.PartitionSpecConverter.toIcebergPartitionSpec; import static com.facebook.presto.iceberg.SchemaConverter.toIcebergSchema; +import static com.google.common.collect.Maps.transformValues; import static java.util.Objects.requireNonNull; public class IcebergPageSinkProvider @@ -109,4 +113,30 @@ private ConnectorPageSink createPageSink(ConnectorSession session, IcebergWritab tableHandle.getSortOrder(), sortParameters); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + IcebergMergeTableHandle merge = (IcebergMergeTableHandle) mergeHandle; + IcebergWritableTableHandle tableHandle = merge.getInsertTableHandle(); + SchemaTableName schemaTableName = new SchemaTableName(tableHandle.getSchemaName(), tableHandle.getTableName().getTableName()); + LocationProvider locationProvider = getLocationProvider(schemaTableName, tableHandle.getOutputPath(), tableHandle.getStorageProperties()); + + Schema schema = toIcebergSchema(tableHandle.getSchema()); + Map partitionSpecs = transformValues(merge.getPartitionSpecs(), + prestoIcebergPartitionSpec -> toIcebergPartitionSpec(prestoIcebergPartitionSpec).toUnbound().bind(schema)); + + ConnectorPageSink pageSink = createPageSink(session, tableHandle); + + return new IcebergMergeSink( + locationProvider, + fileWriterFactory, + hdfsEnvironment, + jsonCodec, + session, + tableHandle.getFileFormat(), + partitionSpecs, + pageSink, + tableHandle.getInputColumns().size()); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java index c99e5b8034c3a..2b6bc0b6e9f2f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java @@ -97,6 +97,7 @@ import org.apache.iceberg.Table; import org.apache.iceberg.io.LocationProvider; import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.NestedField; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.crypto.InternalFileDecryptor; @@ -150,6 +151,7 @@ import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_CANNOT_OPEN_SPLIT; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_MISSING_COLUMN; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_MISSING_DATA; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_PARTITION_DATA; import static com.facebook.presto.iceberg.IcebergOrcColumn.ROOT_COLUMN_ID; import static com.facebook.presto.iceberg.IcebergUtil.getColumns; import static com.facebook.presto.iceberg.IcebergUtil.getLocationProvider; @@ -181,6 +183,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Maps.uniqueIndex; +import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.String.format; import static java.time.ZoneOffset.UTC; @@ -188,7 +191,9 @@ import static java.util.Objects.requireNonNull; import static org.apache.iceberg.MetadataColumns.DELETE_FILE_PATH; import static org.apache.iceberg.MetadataColumns.DELETE_FILE_POS; +import static org.apache.iceberg.MetadataColumns.FILE_PATH; import static org.apache.iceberg.MetadataColumns.ROW_POSITION; +import static org.apache.iceberg.MetadataColumns.SPEC_ID; import static org.apache.parquet.io.ColumnIOConverter.constructField; import static org.apache.parquet.io.ColumnIOConverter.findNestedColumnIO; @@ -356,7 +361,8 @@ private static ConnectorPageSourceWithRowPositions createParquetPageSource( Type prestoType = column.getType(); prestoTypes.add(prestoType); - if (column.getColumnType() == IcebergColumnHandle.ColumnType.SYNTHESIZED && !column.isUpdateRowIdColumn()) { + if (column.getColumnType() == IcebergColumnHandle.ColumnType.SYNTHESIZED && + !column.isUpdateRowIdColumn() && !column.isMergeTargetTableRowIdColumn()) { Subfield pushedDownSubfield = getPushedDownSubfield(column); List nestedColumnPath = nestedColumnPath(pushedDownSubfield); Optional columnIO = findNestedColumnIO(lookupColumnByName(messageColumnIO, pushedDownSubfield.getRootName()), nestedColumnPath); @@ -751,10 +757,10 @@ public ConnectorPageSource createPageSource( Map partitionKeys = split.getPartitionKeys(); - // the update row isn't a valid column that can be read from storage. + // The update row id and merge target table row id aren't valid columns that can be read from storage. // Filter it out from columns passed to the storage page source. Set columnsToReadFromStorage = icebergColumns.stream() - .filter(not(IcebergColumnHandle::isUpdateRowIdColumn)) + .filter(not(column -> column.isUpdateRowIdColumn() || column.isMergeTargetTableRowIdColumn())) .collect(Collectors.toSet()); // add any additional columns which may need to be read from storage @@ -765,22 +771,36 @@ public ConnectorPageSource createPageSource( .filter(not(icebergColumns::contains)) .forEach(columnsToReadFromStorage::add); - // finally, add the fields that the update column requires. - Optional updateRow = icebergColumns.stream() - .filter(IcebergColumnHandle::isUpdateRowIdColumn) + // finally, add the fields that the UPDATE and MERGE column requires. + Optional rowIdColumnHandle = icebergColumns.stream() + .filter(column -> column.isUpdateRowIdColumn() || column.isMergeTargetTableRowIdColumn()) .findFirst(); - updateRow.ifPresent(updateRowIdColumn -> { + rowIdColumnHandle.ifPresent(rowIdColumn -> { Set alreadyRequiredColumnIds = columnsToReadFromStorage.stream() .map(IcebergColumnHandle::getId) .collect(toImmutableSet()); - updateRowIdColumn.getColumnIdentity().getChildren() + rowIdColumn.getColumnIdentity().getChildren() .stream() .filter(colId -> !alreadyRequiredColumnIds.contains(colId.getId())) .forEach(colId -> { - if (colId.getId() == ROW_POSITION.fieldId()) { + if (colId.getId() == FILE_PATH.fieldId()) { + IcebergColumnHandle handle = IcebergColumnHandle.create(FILE_PATH, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } + else if (colId.getId() == ROW_POSITION.fieldId()) { IcebergColumnHandle handle = IcebergColumnHandle.create(ROW_POSITION, typeManager, REGULAR); columnsToReadFromStorage.add(handle); } + else if (colId.getId() == SPEC_ID.fieldId()) { + IcebergColumnHandle handle = IcebergColumnHandle.create(SPEC_ID, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } + else if (colId.getId() == MERGE_PARTITION_DATA.getId()) { + NestedField mergePartitionData = NestedField.required(MERGE_PARTITION_DATA.getId(), + MERGE_PARTITION_DATA.getColumnName(), Types.StringType.get()); + IcebergColumnHandle handle = IcebergColumnHandle.create(mergePartitionData, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } else { NestedField column = tableSchema.findField(colId.getId()); if (column == null) { @@ -814,6 +834,20 @@ public ConnectorPageSource createPageSource( else if (icebergColumn.isDataSequenceNumberColumn()) { metadataValues.put(icebergColumn.getColumnIdentity().getId(), split.getDataSequenceNumber()); } + else if (icebergColumn.isMergeTargetTableRowIdColumn()) { + for (ColumnIdentity subColumn : icebergColumn.getColumnIdentity().getChildren()) { + if (subColumn.getId() == FILE_PATH.fieldId()) { + metadataValues.put(subColumn.getId(), utf8Slice(split.getPath())); + } + else if (subColumn.getId() == SPEC_ID.fieldId()) { + metadataValues.put(subColumn.getId(), (long) partitionSpec.specId()); + } + else if (subColumn.getId() == MERGE_PARTITION_DATA.getId()) { + Optional partitionDataJson = split.getPartitionDataJson(); + metadataValues.put(subColumn.getId(), partitionDataJson.isPresent() ? utf8Slice(partitionDataJson.get()) : EMPTY_SLICE); + } + } + } } List delegateColumns = columnsToReadFromStorage.stream().collect(toImmutableList()); @@ -830,8 +864,7 @@ else if (icebergColumn.isDataSequenceNumberColumn()) { LocationProvider locationProvider = getLocationProvider(table.getSchemaTableName(), outputPath.get(), storageProperties.get()); Supplier deleteSinkSupplier = () -> new IcebergDeletePageSink( - tableSchema, - split.getPartitionSpecAsJson(), + partitionSpec, split.getPartitionDataJson(), locationProvider, fileWriterFactory, @@ -893,7 +926,7 @@ else if (icebergColumn.isDataSequenceNumberColumn()) { deleteFilters, updatedRowPageSinkSupplier, table.getUpdatedColumns(), - updateRow); + rowIdColumnHandle); if (split.getChangelogSplitInfo().isPresent()) { dataSource = new ChangelogPageSource(dataSource, split.getChangelogSplitInfo().get(), (List) (List) desiredColumns, icebergColumns); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java index 7d8d4bb1500fd..8a0bbdd1b16e8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java @@ -113,7 +113,7 @@ public IcebergUpdateablePageSource( Supplier updatedRowPageSinkSupplier, // the columns that this page source is supposed to update List updatedColumns, - Optional updateRowIdColumn) + Optional rowIdColumn) { requireNonNull(partitionKeys, "partitionKeys is null"); this.tableSchema = requireNonNull(tableSchema, "tableSchema is null"); @@ -128,14 +128,14 @@ public IcebergUpdateablePageSource( this.updatedRowPageSinkSupplier = requireNonNull(updatedRowPageSinkSupplier, "updatedRowPageSinkSupplier is null"); this.updatedColumns = requireNonNull(updatedColumns, "updatedColumns is null"); this.outputColumnToDelegateMapping = new int[columns.size()]; - this.updateRowIdColumnIndex = updateRowIdColumn.map(columns::indexOf).orElse(-1); - this.updateRowIdChildColumnIndexes = updateRowIdColumn + this.updateRowIdColumnIndex = rowIdColumn.map(columns::indexOf).orElse(-1); + this.updateRowIdChildColumnIndexes = rowIdColumn .map(column -> new int[column.getColumnIdentity().getChildren().size()]) .orElse(new int[0]); Map columnToIndex = IntStream.range(0, delegateColumns.size()) .boxed() .collect(toImmutableMap(index -> delegateColumns.get(index).getColumnIdentity(), identity())); - updateRowIdColumn.ifPresent(column -> { + rowIdColumn.ifPresent(column -> { List rowIdFields = column.getColumnIdentity().getChildren(); for (int i = 0; i < rowIdFields.size(); i++) { ColumnIdentity columnIdentity = rowIdFields.get(i); @@ -151,15 +151,16 @@ public IcebergUpdateablePageSource( } } for (int i = 0; i < outputColumnToDelegateMapping.length; i++) { - if (outputColumns.get(i).isUpdateRowIdColumn()) { + IcebergColumnHandle outputColumn = outputColumns.get(i); + if (outputColumn.isUpdateRowIdColumn() || outputColumn.isMergeTargetTableRowIdColumn()) { continue; } - if (!columnToIndex.containsKey(outputColumns.get(i).getColumnIdentity())) { - throw new PrestoException(ICEBERG_MISSING_COLUMN, format("Column %s not found in delegate column map", outputColumns.get(i))); + if (!columnToIndex.containsKey(outputColumn.getColumnIdentity())) { + throw new PrestoException(ICEBERG_MISSING_COLUMN, format("Column %s not found in delegate column map", outputColumn)); } else { - outputColumnToDelegateMapping[i] = columnToIndex.get(outputColumns.get(i).getColumnIdentity()); + outputColumnToDelegateMapping[i] = columnToIndex.get(outputColumn.getColumnIdentity()); } } this.isDeletedColumnId = getDelegateColumnId(IcebergColumnHandle::isDeletedColumn); @@ -198,7 +199,7 @@ public boolean isFinished() * {@link IcebergPartitionInsertingPageSource}. * 2. Using the newly retrieved page, apply any necessary delete filters. * 3. Finally, take the necessary channels from the page with the delete filters applied and - * nest them into the updateRowId channel in {@link #setUpdateRowIdBlock(Page)} + * nest them into the updateRowId channel in {@link #setRowIdBlock(Page)} */ @Override public Page getNextPage() @@ -229,7 +230,7 @@ else if (deleteFilterPredicate.isPresent()) { dataPage = deleteFilterPredicate.get().filterPage(dataPage); } - return setUpdateRowIdBlock(dataPage); + return setRowIdBlock(dataPage); } catch (RuntimeException e) { closeWithSuppression(e); @@ -247,6 +248,14 @@ public void deleteRows(Block rowIds) positionDeleteSink.appendPage(new Page(rowIds)); } + /** + * @param page This page contains the following channels: + *
    + *
  • One channel for the row ID, which includes the position number of this row within the file and the values of the unmodified columns.
  • + *
  • One additional channel for each updated column. These channels contain the new values for the updated columns.
  • + *
+ * @param columnValueAndRowIdChannels Channel numbers of the column values and the row ID's channel number at the end of the list. + */ @Override public void updateRows(Page page, List columnValueAndRowIdChannels) { @@ -268,6 +277,7 @@ public void updateRows(Page page, List columnValueAndRowIdChannels) Set updatedColumnFieldIds = columnIdentityToUpdatedColumnIndex.keySet(); List tableColumns = tableSchema.columns(); Block[] fullPage = new Block[tableColumns.size()]; + // Build a page that will contain the values of the updated rows. The rows stored in the "fullPage" include both updated and non-updated field values. for (int targetChannel = 0; targetChannel < tableColumns.size(); targetChannel++) { Types.NestedField column = tableColumns.get(targetChannel); ColumnIdentity columnIdentity = ColumnIdentity.createColumnIdentity(column); @@ -309,18 +319,20 @@ public void abort() } /** - * The $row_id column used for updates is a composite column of at least one other column in the Page. + * The $row_id column used for updates and merge is a composite column of at least one other column in the Page. * The indexes of the columns needed for the $row_id are in the updateRowIdChildColumnIndexes array. * * @param page The raw Page from the Parquet/ORC reader. * @return A Page where the $row_id channel has been populated. */ - private Page setUpdateRowIdBlock(Page page) + private Page setRowIdBlock(Page page) { Block[] fullPage = new Block[columns.size()]; Block[] rowIdFields; Consumer loopFunc; - if (updateRowIdColumnIndex == -1 || updatedColumns.isEmpty()) { + boolean isMergeTargetTable = columns.stream().anyMatch(IcebergColumnHandle::isMergeTargetTableRowIdColumn); + + if ((updateRowIdColumnIndex == -1 || updatedColumns.isEmpty()) && !isMergeTargetTable) { loopFunc = (channel) -> fullPage[channel] = page.getBlock(outputColumnToDelegateMapping[channel]); } else { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java index ac7ee1aa9bbcf..9d735f40d409a 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java @@ -36,7 +36,6 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.PartitionSpecParser; import org.apache.iceberg.Schema; import org.apache.iceberg.io.LocationProvider; @@ -81,8 +80,7 @@ public class IcebergDeletePageSink private static final MetricsConfig FULL_METRICS_CONFIG = MetricsConfig.fromProperties(ImmutableMap.of(DEFAULT_WRITE_METRICS_MODE, "full")); public IcebergDeletePageSink( - Schema outputSchema, - String partitionSpecAsJson, + PartitionSpec partitionSpec, Optional partitionDataAsJson, LocationProvider locationProvider, IcebergFileWriterFactory fileWriterFactory, @@ -101,7 +99,7 @@ public IcebergDeletePageSink( this.session = requireNonNull(session, "session is null"); this.dataFile = requireNonNull(dataFile, "dataFile is null"); this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); - this.partitionSpec = PartitionSpecParser.fromJson(outputSchema, partitionSpecAsJson); + this.partitionSpec = requireNonNull(partitionSpec, "partitionSpec is null"); this.partitionData = partitionDataFromJson(partitionSpec, partitionDataAsJson); String fileName = fileFormat.addExtension(String.format("delete_file_%s", randomUUID().toString())); this.outputPath = partitionData.map(partition -> new Path(locationProvider.newDataLocation(partitionSpec, partition, fileName))) @@ -182,6 +180,9 @@ public IcebergPositionDeleteWriter() this.writer = createWriter(); } + /** + * @param page Only one channel. It contains the list of row positions to delete. + */ public void appendPage(Page page) { if (page.getChannelCount() == 1) { From dd9beda57b0ff5b76f0bb2a581c3754581021713 Mon Sep 17 00:00:00 2001 From: "Adrian Carpente (Denodo)" Date: Thu, 8 Jan 2026 19:33:40 +0100 Subject: [PATCH 2/3] Add SQL Support for MERGE INTO In Presto #20578 (iceberg-tests) SQL MERGE automated tests for Iceberg connector Cherry-pick of https://github.com/trinodb/trino/pull/7933/commits/6cb188b32716e225e446b749cc5de63588929abb Co-authored-by: David Phillips --- .../iceberg/IcebergDistributedTestBase.java | 932 ++++++++++++++++++ 1 file changed, 932 insertions(+) diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java index 53eb1a93cc467..a30bef26a46c3 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java @@ -173,7 +173,9 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.UPDATE_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.privilege; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; import static com.facebook.presto.testing.assertions.Assert.assertEquals; @@ -182,6 +184,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.nio.file.Files.createTempDirectory; +import static java.util.Locale.ENGLISH; import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; @@ -2929,6 +2932,935 @@ public void testUpdateOnPartitionTable() assertQuery("SELECT a, b FROM " + tableName, "VALUES (3,'first'), (4,'4th'), (3,'third')"); } + @DataProvider + public Object[][] partitionedProvider() + { + return new Object[][] { + {""}, // Without partitions. + {"WITH (partitioning = ARRAY['address'])"} + }; + } + + @Test(dataProvider = "partitionedProvider") + public void testMergeSimpleQuery(String partitioning) + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", targetTable, partitioning)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeSimpleQueryPartitioned() + { + String targetTable = "merge_simple_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithoutTablesAliases() + { + String targetTable = "test_without_aliases_target_" + randomTableSuffix(); + String sourceTable = "test_without_aliases_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s USING %s ", targetTable, sourceTable) + + format("ON (%s.customer = %s.customer) ", targetTable, sourceTable) + + format("WHEN MATCHED THEN" + + " UPDATE SET purchases = %s.purchases + %s.purchases, address = %s.address ", sourceTable, targetTable, sourceTable) + + format("WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(%s.customer, %s.purchases, %s.address)", sourceTable, sourceTable, sourceTable); + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeUsingUpdateAndInsert() + { + String targetTable = "merge_simple_target_" + randomTableSuffix(); + String sourceTable = "merge_simple_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeOnlyInsertNewRows() + { + String targetTable = "merge_inserts_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 2); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeOnlyUpdateExistingRows() + { + String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); + String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave_updated', 22, 'Darbyshire'), ('Aaron_updated', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol_updated', 12, 'Centreville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] partitionedAndBucketedProvider() + { + return new Object[][] { + {""}, // Without partitions. + {"WITH (partitioning = ARRAY['customer'])"}, + {"WITH (partitioning = ARRAY['purchases'])"}, + {"WITH (partitioning = ARRAY['bucket(customer, 3)'])"}, + {"WITH (partitioning = ARRAY['bucket(purchases, 4)'])"}, + }; + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMergeUsingSelectQuery(String partitioning) + { + String targetTable = "merge_various_target_" + randomTableSuffix(); + String sourceTable = "merge_various_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases VARCHAR) %s", targetTable, partitioning)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + assertUpdate(format("INSERT INTO %s (customer, purchases) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (SELECT customer, purchases FROM %s) s ", targetTable, sourceTable) + + "ON (t.purchases = s.purchases) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_', s.customer) " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases) VALUES(s.customer, s.purchases)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Lou_Len', 'limes'), ('Joe', 'jellybeans')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMultipleMergeCommands(String partitioning) + { + int targetCustomerCount = 32; + String targetTable = "merge_multiple_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR) %s", targetTable, partitioning)); + + // joe_1, 1000, 91000, jan_1, 1 Poe Ct + // ... + // joe_15, 1000, 91000, jan_15, 15 Poe Ct + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // joe_16, 2000, 92000, jan_16, 16 Poe Ct + // ... + // joe_32, 2000, 92000, jan_32, 32 Poe Ct + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) " + + "VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + // joe_16, 3000, 83000, jan_16, 16 Eop Ct + // ... + // joe_32, 3000, 83000, jan_32, 32 Eop Ct + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + "ON t.customer = s.customer " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address"; + + assertUpdate(sqlMergeCommand, targetCustomerCount / 2); + + assertQuery( + format("SELECT customer, purchases, zipcode, spouse, address FROM %s", targetTable), + format("VALUES %s, %s", originalInsertFirstHalf, firstMergeSource)); + + // jack_32, 4000, 74000, jan_32, 32 Poe Ct + // ... + // jack_48, 4000, 74000, jan_48, 48 Poe Ct + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + // joe_1, 5000, 85000, jen_32, 32 Poe Ct + // ... + // joe_48, 5000, 85000, jen_48, 48 Poe Ct + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // Note that the following MERGE INTO does not update the "purchases" column. + sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + "ON t.customer = s.customer " + + "WHEN MATCHED THEN" + + " UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)"; + + assertUpdate(sqlMergeCommand, targetCustomerCount * 3 / 2 - 1); + + // joe_1, 1000, 85000, jen_1, 1 Poe Ct + // ... + // joe_15, 1000, 85000, jen_15, 15 Poe Ct + String updatedFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 1000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // joe_16, 3000, 85000, jen_16, 16 Poe Ct + // ... + // joe_32, 3000, 85000, jen_32, 32 Poe Ct + String updatedSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 3000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // jack_32, 4000, 74000, jan_32, 32 Poe Ct + // ... + // jack_48, 4000, 74000, jan_48, 48 Poe Ct + String nonUpdatedRows = nextInsert; + + // joe_32, 5000, 85000, jen_32, 32 Poe Ct + // ... + // joe_48, 5000, 85000, jen_48, 48 Poe Ct + String insertedRows = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertQuery( + format("SELECT customer, purchases, zipcode, spouse, address FROM %s", targetTable), + format("VALUES %s, %s, %s, %s", updatedFirstHalf, updatedSecondHalf, nonUpdatedRows, insertedRows)); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeMillionRows() + { + String tableName = "test_merge_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (orderkey BIGINT, custkey BIGINT, totalprice DOUBLE)", tableName)); + + // Initialize the merge target table with data: + // When "mod(orderkey, 3) = 0" -> copy rows, when "mod(orderkey, 3) = 1" -> double price, when "mod(orderkey, 3) = 2" -> rows with new orderkey + assertUpdate( + format("INSERT INTO %s " + + "SELECT orderkey, custkey, totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0 " + // rows copied + "UNION ALL " + + "SELECT orderkey, custkey, 2*totalprice as totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1 " + // rows with updated price + "UNION ALL " + + "SELECT orderkey + 100000002 as orderkey, custkey, totalprice as totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2", // rows with new orderkey + tableName), + (long) computeActual("SELECT count(*) FROM tpch.sf1.orders").getOnlyValue()); + + // verify copied rows: same total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 0", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0"); + + // verify rows will be updated: double total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 1", + "SELECT count(*), round(2*sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1"); + + // verify rows will be inserted: same total price and different orderkey. + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + + // MERGE INTO command to update the price of the existing orders and insert new orders, multiplying the original price by 3. + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (SELECT * FROM tpch.sf1.orders) s ", tableName) + + "ON (t.orderkey = s.orderkey) " + + "WHEN MATCHED THEN" + + " UPDATE SET totalprice = s.totalprice " + + "WHEN NOT MATCHED THEN" + + " INSERT (orderkey, custkey, totalprice) VALUES (s.orderkey, s.custkey, 3*s.totalprice)"; + + assertUpdate(sqlMergeCommand, 1_500_000); + + // verify unmodified rows: same total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 0", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0"); + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2 AND orderkey > 100000002", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + + // verify updated rows: same total price (these rows originally had double total price in the target table) + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 1", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1"); + + // verify inserted rows: triple original price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2 AND orderkey < 100000002", + "SELECT count(*), round(3*sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + public void testMergeQueryWithWeirdColumnsCapitalization() + { + String targetTable = "merge_weird_capitalization_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable.toUpperCase(ENGLISH)) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purCHases = s.PurchaseS + t.pUrchases, aDDress = s.addrESs " + + "WHEN NOT MATCHED THEN" + + " INSERT (CUSTOMER, purchases, addRESS) VALUES(s.custoMer, s.Purchases, s.ADDress)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithMultipleConditions() + { + String targetTable = "merge_predicates_target_" + randomTableSuffix(); + String sourceTable = "merge_predicates_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Dave', 10, 'Devon'), (2, 'Dave', 20, 'Darbyshire')", targetTable), 2); + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (3, 'Dave', 2, 'Madrid'), (4, 'Dave', 15, 'Barcelona')", sourceTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON t.customer = s.customer AND s.purchases < 6 " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address) " + + "WHEN NOT MATCHED THEN" + + " INSERT (id, customer, purchases, address) VALUES (s.id, s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, 'Dave', 12, 'Devon/Madrid'), (2, 'Dave', 22, 'Darbyshire/Madrid'), (4, 'Dave', 15, 'Barcelona')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeCasts() + { + String targetTable = "merge_cast_target_" + randomTableSuffix(); + String sourceTable = "merge_cast_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (col1 INT, col2 BIGINT, col3 REAL, col4 DOUBLE, col5 DOUBLE)", targetTable)); + assertUpdate(format("CREATE TABLE %s (col1 INT, col2 INT, col3 INT, col4 INT, col5 REAL)", sourceTable)); + + assertUpdate(format("INSERT INTO %s VALUES (1, 2, 3, 4, 5)", targetTable), 1); + assertUpdate(format("INSERT INTO %s VALUES (2, 3, 4, 5, 6)", sourceTable), 1); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.col1 + 1 = s.col1) " + // Note that the merge condition contains a sum. + "WHEN MATCHED THEN" + + " UPDATE SET col1 = s.col1, col2 = s.col2, col3 = s.col3, col4 = s.col4, col5 = s.col5"; + + assertUpdate(sqlMergeCommand, 1); + + assertQuery("SELECT * FROM " + targetTable, "VALUES (2, 3, 4.0, 5.0, 6.0)"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeSubqueries() + { + String targetTable = "merge_nation_target_" + randomTableSuffix(); + String sourceTable = "merge_nation_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('GERMANY', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE')", targetTable), 3); + assertUpdate(format("INSERT INTO %s VALUES ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE'), ('EGYPT', 'MIDDLE EAST'), ('RUSSIA', 'EUROPE')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.nation_name = s.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = (SELECT CONCAT(name, '_UPDATED') FROM tpch.tiny.region WHERE name = t.region_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT VALUES(s.nation_name, (SELECT CONCAT(name, '_INSERTED') FROM tpch.tiny.region WHERE name = s.region_name))"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('GERMANY', 'EUROPE'), " + + "('ALGERIA', 'AFRICA_UPDATED'), ('FRANCE', 'EUROPE_UPDATED'), " + + "('EGYPT', 'MIDDLE EAST_INSERTED'), ('RUSSIA', 'EUROPE_INSERTED')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])"}, + {"CREATE TABLE %s (customer VARCHAR, address VARCHAR, purchases INT) WITH (partitioning = ARRAY['address'])"}, + {"CREATE TABLE %s (purchases INT, customer VARCHAR, address VARCHAR) WITH (partitioning = ARRAY['customer', 'address'])"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['bucket(customer, 3)'])"} + }; + } + + @Test(dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchMustFails(String createTableSql) + { + String targetTable = "merge_multiple_rows_match_target_" + randomTableSuffix(); + String sourceTable = "merge_multiple_rows_match_source_" + randomTableSuffix(); + + try { + assertUpdate(format(createTableSql, targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET address = s.address"; + + assertQueryFails(sqlMergeCommand, ".*The MERGE INTO command requires each target row to match at most one source row.*"); + + assertUpdate(format("DELETE FROM %s WHERE purchases = 8", sourceTable), 1); + + assertUpdate(sqlMergeCommand, 1); + + assertQuery("SELECT customer, purchases, address FROM " + targetTable, + "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + private void createNationRegionTable(String targetTable) + { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR NOT NULL)", targetTable)); + } + + @Test + public void testMergeNonNullableColumns() + { + String targetTable = "merge_non_nullable_target_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable), 3); + + List sqlMergeCommands = Arrays.asList( + // Command to check that updating using a null value fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name)\n" + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = NULL", + + // Command to check that inserting using a null value fails. + format("MERGE INTO %s t ", targetTable) + + " USING (VALUES ('ANGOLA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, region_name) VALUES (s.nation_name, NULL)", + + // Command to check that inserting using an implicit null value fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ANGOLA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name) VALUES ('CANADA')", + + // Command to check that if the updated value is provided by a function unpredictably computing null, the merge fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = CAST(TRY(5/0) AS VARCHAR)"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, "NULL value not allowed for NOT NULL column. Table: merge_non_nullable_target_.* Column: region_name"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentPartitioning() + { + return new Object[][] { + { + "target_flat_source_flat", + "", + "" + }, + { + "target_partitioned_source_flat", + "WITH (partitioning = ARRAY['customer'])", + "" + }, + { + "target_bucketed_source_flat", + "WITH (partitioning = ARRAY['bucket(customer, 3)'])", + "" + }, + { + "target_partitioned_and_bucketed_source_flat", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "" + }, + { + "target_partitioned_and_bucketed_source_partitioned", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "WITH (partitioning = ARRAY['customer'])" + }, + { + "target_and_source_partitioned_and_bucketed", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])" + } + }; + } + + @Test(dataProvider = "targetAndSourceWithDifferentPartitioning") + public void testMergeWithDifferentPartitioning(String testDescription, String targetTablePartitioning, String sourceTablePartitioning) + { + String targetTable = format("%s_target_%s", testDescription, randomTableSuffix()); + String sourceTable = format("%s_source_%s", testDescription, randomTableSuffix()); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", targetTable, targetTablePartitioning)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", sourceTable, sourceTablePartitioning)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeAccessControl() + { + String catalogName = getSession().getCatalog().get(); + String schemaName = getSession().getSchema().get(); + + String targetTable = "merge_nation_target_" + randomTableSuffix(); + String targetName = format("%s.%s.%s", catalogName, schemaName, targetTable); + + String sourceTable = "merge_nation_source_" + randomTableSuffix(); + String sourceName = format("%s.%s.%s", catalogName, schemaName, sourceTable); + + try { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable)); + + String baseMergeSql = format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.nation_name = s.nation_name) "; + String updateCase = + "WHEN MATCHED THEN" + + " UPDATE SET nation_name = concat(s.nation_name, '_foo')"; + String insertCase = + "WHEN NOT MATCHED THEN" + + " INSERT VALUES(s.nation_name, (SELECT 'EUROPE'))"; + + ImmutableList mergeCases = ImmutableList.of(updateCase, insertCase); + for (String mergeCase : mergeCases) { + // Show that without SELECT privilege on the source table, the MERGE fails regardless of which case is included + assertAccessDenied(baseMergeSql + mergeCase, "Cannot select from columns .* in table or view " + sourceName, privilege(sourceTable, SELECT_COLUMN)); + + // Show that without SELECT privilege on the target table, the MERGE fails regardless of which case is included + assertAccessDenied(baseMergeSql + mergeCase, "Cannot select from columns .* in table or view " + targetName, privilege(targetTable, SELECT_COLUMN)); + } + + // Show that without INSERT privilege on the target table, the MERGE fails + assertAccessDenied(baseMergeSql + insertCase, "Cannot insert into table " + targetName, privilege(targetTable, INSERT_TABLE)); + + // Show that without UPDATE privilege on the target table, the MERGE fails + assertAccessDenied(baseMergeSql + updateCase, "Cannot update columns \\[\\[nation_name\\]\\] in table " + targetName, privilege(targetTable, UPDATE_TABLE)); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testInvalidMergePredicate() + { + String targetTable = "merge_invalid_predicate_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name"; + + assertQueryFails(sqlMergeCommand, ".*The MERGE predicate must evaluate to a boolean: actual type varchar"); + + sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES (1, 'ALGERIA', 'AFRICA')) s(nation_id, nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_id) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name"; + + assertQueryFails(sqlMergeCommand, ".*'=' cannot be applied to varchar, integer"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeUnknownColumnName() + { + String targetTable = "merge_unknown_column_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Unknown column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET unknown_column = s.region_name", + + // Unknown column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, unknown_column) VALUES(s.nation_name, (SELECT 'EUROPE'))"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, ".*Merge column name does not exist in target table: unknown_column"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeDuplicateColumnName() + { + String targetTable = "merge_duplicate_column_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Duplicate column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name, region_name = 'AFRICA'", + + // Duplicate column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, region_name, region_name) VALUES(s.nation_name, (SELECT 'EUROPE'), 'AFRICA')"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, ".*Merge column name is specified more than once: region_name"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeMismatchedColumnDataTypes() + { + String targetTable = "merge_mismatched_column_data_types_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Mismatched column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = 1", + + // Mismatched column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (region_name) VALUES(1)"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, + ".*MERGE table column types don't match for MERGE case 0, SET expressions: Table: \\[varchar\\], Expressions: \\[integer\\]"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithPartitionSpecEvolutionAddPartitionedField() + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (a int, b varchar)", targetTable)); + + assertUpdate(format("INSERT INTO %s VALUES (1, '1001'), (2, '1002')", targetTable), 2); + assertUpdate(format("INSERT INTO %s VALUES (3, '1003'), (4, '1004')", targetTable), 2); + + // Add a partition field to the target iceberg table. + assertUpdate(format("ALTER TABLE %s ADD COLUMN c int WITH(partitioning = 'identity')", targetTable)); + + assertUpdate(format("INSERT INTO %s VALUES (5, '1005', 5), (6, '1006', 6)", targetTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES (1, 11), (3, 33), (5, 55), (7, 77)) AS s(a, c) " + + "ON (t.a = s.a) " + + "WHEN MATCHED THEN" + + " UPDATE SET c = s.c " + + "WHEN NOT MATCHED THEN" + + " INSERT (a, b, c) VALUES(s.a, 'NEW_LINE', s.c)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, '1001', 11), (2, '1002', NULL), (3, '1003', 33), (4, '1004', NULL), (5, '1005', 55), (6, '1006', 6), (7, 'NEW_LINE', 77)"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithPartitionSpecEvolutionRemovePartitionedField() + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (a int, b varchar, c int) with(partitioning = ARRAY['a', 'c'])", targetTable)); + assertUpdate(format("INSERT INTO %s VALUES (1, '1001', 11), (2, '1002', 12)", targetTable), 2); + + // Remove a partitioned field from the target iceberg table. + Table icebergTable = loadTable(targetTable); + String partitionFieldName = icebergTable.spec().fields().get(0).name(); + icebergTable.updateSpec().removeField(partitionFieldName).commit(); + + assertUpdate(format("INSERT INTO %s VALUES (3, '1003', 13), (4, '1004', 14)", targetTable), 2); + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES (1, 111), (3, 333), (5, 555)) AS s(a, c) " + + "ON (t.a = s.a) " + + "WHEN MATCHED THEN" + + " UPDATE SET c = s.c " + + "WHEN NOT MATCHED THEN" + + " INSERT (a, b, c) VALUES(s.a, 'NEW_LINE', s.c)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, '1001', 111), (2, '1002', 12), (3, '1003', 333), (4, '1004', 14), (5, 'NEW_LINE', 555)"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + private void testCheckDeleteFiles(Table icebergTable, int expectedSize, List expectedFileContent) { // check delete file list From 92c97a06406cf54787f4719212b8fd3533db7db1 Mon Sep 17 00:00:00 2001 From: Adrian Carpente Recouso Date: Mon, 12 Jan 2026 17:32:06 +0100 Subject: [PATCH 3/3] Add SQL Support for MERGE INTO In Presto #20578 (iceberg) - Improved CompletableFuture management. - Added new automted tests --- .../presto/iceberg/IcebergMergeSink.java | 28 +++--- .../iceberg/IcebergDistributedTestBase.java | 89 +++++++++++++++++-- 2 files changed, 100 insertions(+), 17 deletions(-) diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java index 579aca18ce44b..a6447da093c08 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java @@ -49,6 +49,7 @@ import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.CompletableFuture.failedFuture; public class IcebergMergeSink implements ConnectorMergeSink @@ -123,18 +124,23 @@ public void storeMergedRows(Page page) @Override public CompletableFuture> finish() { - List fragments = new ArrayList<>(insertPageSink.finish().join()); - - fileDeletions.forEach((dataFilePath, deletion) -> { - ConnectorPageSink sink = createPositionDeletePageSink( - dataFilePath.toStringUtf8(), - partitionsSpecs.get(deletion.partitionSpecId()), - deletion.partitionDataJson()); - - fragments.addAll(writePositionDeletes(sink, deletion.rowsToDelete())); + return insertPageSink.finish().thenCompose(insertFragments -> { + List fragments = new ArrayList<>(insertFragments); + + try { + fileDeletions.forEach((dataFilePath, deletion) -> { + ConnectorPageSink sink = createPositionDeletePageSink( + dataFilePath.toStringUtf8(), + partitionsSpecs.get(deletion.partitionSpecId()), + deletion.partitionDataJson()); + fragments.addAll(writePositionDeletes(sink, deletion.rowsToDelete())); + }); + return completedFuture(fragments); + } + catch (Exception e) { + return failedFuture(e); + } }); - - return completedFuture(fragments); } @Override diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java index a30bef26a46c3..e047e05178597 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java @@ -3060,9 +3060,20 @@ public void testMergeUsingUpdateAndInsert() } } - @Test - public void testMergeOnlyInsertNewRows() + @DataProvider + public Object[][] mergeIncludeWhenAndWhenNotMatchedProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + + @Test(dataProvider = "mergeIncludeWhenAndWhenNotMatchedProvider") + public void testMergeOnlyInsertNewRows(boolean includeWhenMatched) { + // This test verifies that the MERGE command works correctly when no rows in the source table meet the MERGE condition. + // It means that the MERGE command will behave as an INSERT command. String targetTable = "merge_inserts_" + randomTableSuffix(); try { @@ -3073,6 +3084,9 @@ public void testMergeOnlyInsertNewRows() format("MERGE INTO %s t USING ", targetTable) + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + "ON (t.customer = s.customer)" + + (includeWhenMatched ? + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " : "") + "WHEN NOT MATCHED THEN" + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; @@ -3086,9 +3100,11 @@ public void testMergeOnlyInsertNewRows() } } - @Test - public void testMergeOnlyUpdateExistingRows() + @Test(dataProvider = "mergeIncludeWhenAndWhenNotMatchedProvider") + public void testMergeOnlyUpdateExistingRows(boolean includeWhenNotMatched) { + // This test verifies that the MERGE command works correctly when all rows in the source table meet the MERGE condition. + // It means that the MERGE command will behave as an UPDATE command. String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); @@ -3097,13 +3113,16 @@ public void testMergeOnlyUpdateExistingRows() assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); - assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville')", sourceTable), 3); @Language("SQL") String sqlMergeCommand = format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + "ON (t.customer = s.customer) " + "WHEN MATCHED THEN" + - " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address"; + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + (includeWhenNotMatched ? + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)" : ""); assertUpdate(sqlMergeCommand, 3); @@ -3116,6 +3135,64 @@ public void testMergeOnlyUpdateExistingRows() } } + @Test + public void testMergeEmptyTargetTable() + { + String targetTable = "merge_inserts_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 2); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeEmptySourceTable() + { + String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); + String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 0); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + @DataProvider public Object[][] partitionedAndBucketedProvider() {