diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java index 8ffd3f94a8e8..c9674675e591 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java @@ -20,6 +20,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorIndexHandle; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; @@ -72,6 +73,12 @@ public static com.fasterxml.jackson.databind.Module insertTableHandleModule(Hand return new AbstractTypedJacksonModule<>(ConnectorInsertTableHandle.class, resolver::getId, resolver::getInsertTableHandleClass) {}; } + @ProvidesIntoSet + public static com.fasterxml.jackson.databind.Module mergeTableHandleModule(HandleResolver resolver) + { + return new AbstractTypedJacksonModule<>(ConnectorMergeTableHandle.class, resolver::getId, resolver::getMergeTableHandleClass) {}; + } + @ProvidesIntoSet public static com.fasterxml.jackson.databind.Module indexHandleModule(HandleResolver resolver) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java index b2b61a372192..3eb3d9d8a4f4 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.ConnectorHandleResolver; import io.trino.spi.connector.ConnectorIndexHandle; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; @@ -103,6 +104,11 @@ public String getId(ConnectorInsertTableHandle insertHandle) return getId(insertHandle, MaterializedHandleResolver::getInsertTableHandleClass); } + public String getId(ConnectorMergeTableHandle mergeHandle) + { + return getId(mergeHandle, MaterializedHandleResolver::getMergeTableHandleClass); + } + public String getId(ConnectorPartitioningHandle partitioningHandle) { return getId(partitioningHandle, MaterializedHandleResolver::getPartitioningHandleClass); @@ -148,6 +154,11 @@ public Class getInsertTableHandleClass(Str return resolverFor(id).getInsertTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getMergeTableHandleClass(String id) + { + return resolverFor(id).getMergeTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + } + public Class getPartitioningHandleClass(String id) { return resolverFor(id).getPartitioningHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -188,6 +199,7 @@ private static class MaterializedHandleResolver private final Optional> indexHandle; private final Optional> outputTableHandle; private final Optional> insertTableHandle; + private final Optional> mergeTableHandle; private final Optional> partitioningHandle; private final Optional> transactionHandle; @@ -200,6 +212,7 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) indexHandle = getHandleClass(resolver::getIndexHandleClass); outputTableHandle = getHandleClass(resolver::getOutputTableHandleClass); insertTableHandle = getHandleClass(resolver::getInsertTableHandleClass); + mergeTableHandle = getHandleClass(resolver::getMergeTableHandleClass); partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); } @@ -249,6 +262,11 @@ public Optional> getInsertTableHandl return insertTableHandle; } + public Optional> getMergeTableHandleClass() + { + return mergeTableHandle; + } + public Optional> getPartitioningHandleClass() { return partitioningHandle; diff --git a/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java b/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java new file mode 100644 index 000000000000..39a598a59cd2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public final class MergeHandle +{ + private final TableHandle tableHandle; + private final ConnectorMergeTableHandle connectorMergeHandle; + + @JsonCreator + public MergeHandle( + @JsonProperty("tableHandle") TableHandle tableHandle, + @JsonProperty("connectorMergeHandle") ConnectorMergeTableHandle connectorMergeHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.connectorMergeHandle = requireNonNull(connectorMergeHandle, "connectorMergeHandle is null"); + } + + @JsonProperty + public TableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public ConnectorMergeTableHandle getConnectorMergeHandle() + { + return connectorMergeHandle; + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index bb76cd1e54a8..c956728bcf66 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -13,6 +13,7 @@ */ package io.trino.metadata; +import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.Session; import io.trino.connector.CatalogName; @@ -39,7 +40,9 @@ import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -112,6 +115,16 @@ public interface Metadata */ Optional getCommonPartitioning(Session session, PartitioningHandle left, PartitioningHandle right); + /** + * Return the column handles for the columns that must be present in order + * to perform the partitioning and/or bucketing required. By default, the table + * has no such columns. + */ + default List getWriteRedistributionColumns(Session session, TableHandle table) + { + return ImmutableList.of(); + } + Optional getInfo(Session session, TableHandle handle); /** @@ -348,6 +361,29 @@ Optional finishRefreshMaterializedView( */ void finishUpdate(Session session, TableHandle tableHandle, Collection fragments); + /** + * Return the row update paradigm supported by the connector on the table or throw + * an exception if row change is not supported. + */ + RowChangeParadigm getRowChangeParadigm(Session session, TableHandle tableHandle); + + /** + * Get the column handle that will generate row IDs for the merge operation. + * These IDs will be passed to the {@code storeMergedRows()} method of the + * {@link io.trino.spi.connector.ConnectorMergeSink} that created them. + */ + ColumnHandle getMergeRowIdColumnHandle(Session session, TableHandle tableHandle, MergeDetails mergeDetails); + + /** + * Begin merge query + */ + MergeHandle beginMerge(Session session, TableHandle tableHandle, MergeDetails mergeDetails); + + /** + * Finish merge query + */ + void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics); + /** * Returns a connector id for the specified catalog name. */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 78377e0b9483..30c8d5c2c386 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -57,6 +57,7 @@ import io.trino.spi.connector.ConnectorCapabilities; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -80,7 +81,9 @@ import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -489,6 +492,15 @@ public Optional getCommonPartitioning(Session session, Parti return commonHandle.map(handle -> new PartitioningHandle(Optional.of(catalogName), left.getTransactionHandle(), handle)); } + @Override + public List getWriteRedistributionColumns(Session session, TableHandle table) + { + CatalogName catalogName = table.getCatalogName(); + CatalogMetadata catalogMetadata = getCatalogMetadata(session, catalogName); + ConnectorMetadata metadata = catalogMetadata.getMetadataFor(catalogName); + return metadata.getWriteRedistributionColumns(session.toConnectorSession(catalogName), table.getConnectorHandle()); + } + @Override public Optional getInfo(Session session, TableHandle handle) { @@ -936,6 +948,14 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl return metadata.getUpdateRowIdColumnHandle(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), updatedColumns); } + @Override + public ColumnHandle getMergeRowIdColumnHandle(Session session, TableHandle tableHandle, MergeDetails mergeDetails) + { + CatalogName catalogName = tableHandle.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + return metadata.getMergeRowIdColumnHandle(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), mergeDetails); + } + @Override public boolean supportsMetadataDelete(Session session, TableHandle tableHandle) { @@ -1017,6 +1037,31 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + CatalogName catalogName = mergeHandle.getTableHandle().getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + metadata.finishMerge(session.toConnectorSession(catalogName), mergeHandle.getConnectorMergeHandle(), fragments, computedStatistics); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java index d6918c2e7c79..1f79ac42e82b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java @@ -78,6 +78,11 @@ public String toString() return catalogName + ":" + connectorHandle; } + public TableHandle withConnectorHandle(ConnectorTableHandle connectorHandle) + { + return new TableHandle(catalogName, connectorHandle, transaction, layout); + } + @Override public boolean equals(Object o) { diff --git a/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java index 33cb5f24d25e..e1e1f6a0c0da 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java @@ -49,7 +49,7 @@ protected enum State protected State state = State.RUNNING; protected long rowCount; private boolean closed; - private ListenableFuture> finishFuture; + protected ListenableFuture> finishFuture; private Supplier> pageSource = Optional::empty; public AbstractRowChangeOperator(OperatorContext operatorContext) @@ -149,7 +149,7 @@ public void setPageSource(Supplier> pageSource) protected UpdatablePageSource pageSource() { Optional source = pageSource.get(); - checkState(source.isPresent(), "UpdatablePageSource not set"); + checkState(source.isPresent(), "pageSource not set"); return source.get(); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java new file mode 100644 index 000000000000..473001ed7d15 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.RowChangeParadigm; +import io.trino.spi.type.Type; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ChangeOnlyUpdatedColumnsMergeProcessor + implements MergeRowChangeProcessor +{ + private final List dataColumns; + private final List dataColumnTypes; + private final List writeRedistributionColumns; + private final Type rowIdType; + private final List dataColumnChannels; + private final DuplicateRowFinder duplicateRowFinder; + + @JsonCreator + public ChangeOnlyUpdatedColumnsMergeProcessor( + @JsonProperty("dataColumns") List dataColumns, + @JsonProperty("dataColumnTypes") List dataColumnTypes, + @JsonProperty("writeRedistributionColumns") List writeRedistributionColumns, + @JsonProperty("rowIdType") Type rowIdType) + { + this.dataColumns = dataColumns; + this.dataColumnTypes = dataColumnTypes; + this.writeRedistributionColumns = writeRedistributionColumns; + this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); + int dataColumnChannel = 0; + List dataColumnChannelsBuilder = new ArrayList<>(); + for (ColumnHandle handle : dataColumns) { + dataColumnChannelsBuilder.add(new HandleAndChannel(handle, dataColumnChannel)); + dataColumnChannel++; + } + this.dataColumnChannels = Collections.unmodifiableList(dataColumnChannelsBuilder); + this.duplicateRowFinder = new DuplicateRowFinder(dataColumns, dataColumnTypes, writeRedistributionColumns, rowIdType); + } + + @Override + public RowChangeParadigm getRowChangeParadigm() + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @JsonProperty + public List getDataColumns() + { + return dataColumns; + } + + @JsonProperty + public List getDataColumnTypes() + { + return dataColumnTypes; + } + + @JsonProperty + public List getWriteRedistributionColumns() + { + return writeRedistributionColumns; + } + + @JsonProperty + public Type getRowIdType() + { + return rowIdType; + } + + /** + * Transform the input page containing the target table's write redistribution column + * blocks; the rowId block; and the merge case RowBlock. Each row in the output Page + * starts with all the data column blocks, including the partition columns blocks, + * table column order; followed by the "operation" block from the merge case RowBlock, + * whose values are {@link io.trino.spi.connector.MergeDetails#INSERT_OPERATION_NUMBER}, + * {@link io.trino.spi.connector.MergeDetails#DELETE_OPERATION_NUMBER}, or + * {@link io.trino.spi.connector.MergeDetails#UPDATE_OPERATION_NUMBER} + * @param inputPage The page to be transformed. + * @return A page containing all data column blocks, followed by the operation block. + */ + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + int inputChannelCount = inputPage.getChannelCount(); + if (inputChannelCount != 2 + writeRedistributionColumns.size()) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) should be = 2 + %s", inputChannelCount, writeRedistributionColumns.size())); + } + + int positionCount = inputPage.getPositionCount(); + if (positionCount <= 0) { + throw new IllegalArgumentException("positionCount should be > 0, but is " + positionCount); + } + + Block mergeCaseBlock = inputPage.getBlock(writeRedistributionColumns.size() + 1); + + List mergeCaseBlocks = mergeCaseBlock.getChildren(); + int mergeBlocksSize = mergeCaseBlocks.size(); + Block operationChannelBlock = mergeCaseBlocks.get(mergeBlocksSize - 1); + // The rowId block is the last block of the resulting page + Block rowIdBlock = inputPage.getBlock(writeRedistributionColumns.size()); + duplicateRowFinder.checkForDuplicateTargetRows(inputPage, operationChannelBlock); + + // Add the data columns + List builder = new ArrayList<>(); + + dataColumnChannels.forEach(handleAndChannel -> builder.add(mergeCaseBlocks.get(handleAndChannel.getChannel()))); + + builder.add(operationChannelBlock); + + builder.add(rowIdBlock); + return new Page(builder.toArray(new Block[]{})); + } + + private static class HandleAndChannel + { + private final ColumnHandle handle; + private final int channel; + + public HandleAndChannel(ColumnHandle handle, int channel) + { + this.handle = handle; + this.channel = channel; + } + + public ColumnHandle getHandle() + { + return handle; + } + + public int getChannel() + { + return channel; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java new file mode 100644 index 000000000000..5fd2927294c8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java @@ -0,0 +1,342 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RowBlock; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.MergeDetails; +import io.trino.spi.connector.MergeProcessorUtilities; +import io.trino.spi.connector.RowChangeParadigm; +import io.trino.spi.type.Type; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.trino.spi.connector.MergeDetails.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.UPDATE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeProcessorUtilities.getPositionsForPredicate; +import static io.trino.spi.connector.MergeProcessorUtilities.getUnderlyingBlock; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class DeleteAndInsertMergeProcessor + implements MergeRowChangeProcessor +{ + private final MergeDetails mergeDetails; + private final List dataColumns; + private final List dataColumnTypes; + private final List writeRedistributionColumns; + private final Type rowIdType; + private final List dataColumnChannels; + private final List writeRedistributionColumnChannels; + + private final DuplicateRowFinder duplicateRowFinder; + + @JsonCreator + public DeleteAndInsertMergeProcessor( + @JsonProperty("mergeDetails") MergeDetails mergeDetails, + @JsonProperty("dataColumns") List dataColumns, + @JsonProperty("dataColumnTypes") List dataColumnTypes, + @JsonProperty("writeRedistributionColumns") List writeRedistributionColumns, + @JsonProperty("rowIdType") Type rowIdType) + { + this.mergeDetails = requireNonNull(mergeDetails, "mergeDetails is null"); + this.dataColumns = requireNonNull(dataColumns, "dataColumns is null"); + this.dataColumnTypes = requireNonNull(dataColumnTypes, "dataColumnTypes is null"); + this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); + this.writeRedistributionColumns = requireNonNull(writeRedistributionColumns, "writeRedistributionColumns is null"); + int dataColumnChannel = 0; + ImmutableList.Builder dataColumnChannelsBuilder = ImmutableList.builder(); + ImmutableList.Builder writeRedistributionChannelsBuilder = ImmutableList.builder(); + for (ColumnHandle handle : dataColumns) { + dataColumnChannelsBuilder.add(new HandleAndChannel(handle, dataColumnChannel)); + dataColumnChannel++; + if (writeRedistributionColumns.contains(handle)) { + writeRedistributionChannelsBuilder.add(new HandleAndChannel(handle, dataColumnChannel)); + } + } + this.dataColumnChannels = dataColumnChannelsBuilder.build(); + this.writeRedistributionColumnChannels = ImmutableList.copyOf(writeRedistributionChannelsBuilder.build()); + this.duplicateRowFinder = new DuplicateRowFinder(dataColumns, dataColumnTypes, writeRedistributionColumns, rowIdType); + } + + @Override + public RowChangeParadigm getRowChangeParadigm() + { + return DELETE_ROW_AND_INSERT_ROW; + } + + @JsonProperty + public MergeDetails getMergeDetails() + { + return mergeDetails; + } + + @JsonProperty + public List getDataColumns() + { + return dataColumns; + } + + @JsonProperty + public List getDataColumnTypes() + { + return dataColumnTypes; + } + + @JsonProperty + public List getWriteRedistributionColumns() + { + return writeRedistributionColumns; + } + + @JsonProperty + public Type getRowIdType() + { + return rowIdType; + } + + /** + * Transform the input page containing the target table's partition column blocks; the + * rowId block; and the merge case RowBlock into a page that duplicates rows for + * UPDATE_OPERATION_NUMBER operations. Each row in the outputPage starts with all the + * data column blocks, including the partition columns blocks, in declared column order; + * followed by an operation block with DELETE_OPERATION_NUMBER for delete rows or + * INSERT_OPERATION_NUMBER for insert rows; followed by the rowId block. + * {@link Block#getPositions(int[], int, int)} is used to do the duplication, so no value + * copying is required to project the columns into the page. + * @param inputPage A page containing write redistribution column blocks for the target table; the + * rowId block; the merge case RowBlock; and for partitioned or bucketed tables, a hash column block. + * @return A page containing all data columns, the operation block and the rowId block, + * with UPDATE rows expanded into delete and insert rows, so they can be routed separately. + * The delete rows contain the partition key blocks from the target table, whereas the insert + * rows have the partition key blocks from the merge case RowBlock. + */ + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + int inputChannelCount = inputPage.getChannelCount(); + if (inputChannelCount < 2 + writeRedistributionColumnChannels.size()) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) should be >= 2 + partition columns size (%s)", inputChannelCount, writeRedistributionColumnChannels.size())); + } + + int originalPositionCount = inputPage.getPositionCount(); + if (originalPositionCount <= 0) { + throw new IllegalArgumentException("originalPositionCount should be > 0, but is " + originalPositionCount); + } + Block mergeCaseBlock = inputPage.getBlock(writeRedistributionColumnChannels.size() + 1); + int[] operationPositions = getPositionsExcludingNotMatchedCases(mergeCaseBlock); + int positionCount = operationPositions.length; + + List mergeCaseBlocks = mergeCaseBlock.getChildren(); + int mergeBlocksSize = mergeCaseBlocks.size(); + Block operationChannelBlock = mergeCaseBlocks.get(mergeBlocksSize - 1); + Block rowIdBlock = inputPage.getBlock(writeRedistributionColumnChannels.size()); + duplicateRowFinder.checkForDuplicateTargetRows(inputPage, operationChannelBlock); + + Block restrictedOperationChannelBlock = operationChannelBlock.getPositions(operationPositions, 0, positionCount); + Block caseNumberBlock = mergeCaseBlocks.get(mergeBlocksSize - 2); + int[] deletePositions = getDeletePositions(positionCount, restrictedOperationChannelBlock); + int deleteCount = deletePositions.length; + int[] insertPositions = getInsertPositions(positionCount, restrictedOperationChannelBlock); + int insertCount = insertPositions.length; + int totalPositions = deleteCount + insertCount; + + // Create a position array with both the delete and insert positions, + // so that it duplicates the UPDATE rows + int[] deleteAndInsertPositions = new int[totalPositions]; + int[] operationBlockArray = new int[totalPositions]; + for (int offset = 0; offset < deleteCount; offset++) { + deleteAndInsertPositions[offset] = deletePositions[offset]; + operationBlockArray[offset] = DELETE_OPERATION_NUMBER; + } + for (int offset = 0; offset < insertCount; offset++) { + deleteAndInsertPositions[deleteCount + offset] = insertPositions[offset]; + operationBlockArray[deleteCount + offset] = INSERT_OPERATION_NUMBER; + } + + // Add the data columns + List builder = new ArrayList<>(); + + // Partition column blocks for the target table, required to route deletes, + // are at the top level of the inputPage, starting with channel 2 + int writeRedistributionColumnIndex = 0; + for (int index = 0; index < dataColumnChannels.size(); index++) { + HandleAndChannel handleAndChannel = dataColumnChannels.get(index); + Type columnType = dataColumnTypes.get(index); + ColumnHandle column = handleAndChannel.getHandle(); + int channel = handleAndChannel.getChannel(); + Block mergeBlock = mergeCaseBlocks.get(channel); + if (writeRedistributionColumns.contains(column)) { + Block block = inputPage.getBlock(writeRedistributionColumnIndex); + builder.add(buildWriteRedistributionColumnValues(columnType, deleteAndInsertPositions, deleteCount, mergeBlock, block)); + writeRedistributionColumnIndex++; + } + else { + DictionaryBlock cleanedBlock = new DictionaryBlock(mergeBlock, operationPositions); + Block deleteAndInsertBlock = ((DictionaryBlock) cleanedBlock.getPositions(deleteAndInsertPositions, 0, totalPositions)).compact(); + builder.add(deleteAndInsertBlock); + } + } + + // Add the operations block + builder.add(new IntArrayBlock(totalPositions, Optional.empty(), operationBlockArray)); + + Block newRowIdBlock; + Block underlyingBlock = getUnderlyingBlock(rowIdBlock); + if (rowIdBlock.allPositionsAreNull()) { + newRowIdBlock = MergeProcessorUtilities.getAllNullsRowIdBlock(rowIdBlock, underlyingBlock, totalPositions); + } + else { + int[] rowIdPositions = getRowIdPositions(operationChannelBlock, caseNumberBlock, totalPositions); + if (underlyingBlock instanceof RowBlock) { + List newRowIdChildrenBuilder = new ArrayList<>(); + rowIdBlock.getChildren().stream() + .map(block -> block.getPositions(rowIdPositions, 0, totalPositions)) + .forEach(newRowIdChildrenBuilder::add); + newRowIdBlock = RowBlock.fromFieldBlocks( + totalPositions, + Optional.empty(), + newRowIdChildrenBuilder.toArray(new Block[] {})); + } + else { + newRowIdBlock = rowIdBlock.getPositions(rowIdPositions, 0, totalPositions); + } + } + builder.add(newRowIdBlock); + + return new Page(builder.toArray(new Block[]{})); + } + + private Block buildWriteRedistributionColumnValues(Type columnType, int[] deleteAndInsertPositions, int deleteCount, Block mergeBlock, Block writeRedistributionColumnBlock) + { + int positionCount = deleteAndInsertPositions.length; + BlockBuilder builder = columnType.createBlockBuilder(null, positionCount, 0); + for (int position = 0; position < positionCount; position++) { + int dictionaryPosition = deleteAndInsertPositions[position]; + if (position < deleteCount) { + // The value comes from the target table's partition block + columnType.appendTo(writeRedistributionColumnBlock, dictionaryPosition, builder); + } + else { + // The value comes from the mergeCaseBlock + columnType.appendTo(mergeBlock, dictionaryPosition, builder); + } + } + return builder.build(); + } + + private static int[] getRowIdPositions(Block operationBlock, Block caseNumberBlock, int finalPositionCount) + { + int inputPositions = caseNumberBlock.getPositionCount(); + int[] positions = new int[finalPositionCount]; + int rowIdCursor = 0; + int positionCursor = 0; + for (int position = 0; position < inputPositions; position++) { + if (caseNumberBlock.getInt(position, 0) != -1) { + int operation = operationBlock.getInt(position, 0); + if (operation != INSERT_OPERATION_NUMBER) { + positions[positionCursor] = rowIdCursor; + rowIdCursor++; + positionCursor++; + } + } + } + for (int position = 0; position < caseNumberBlock.getPositionCount(); position++) { + if (caseNumberBlock.getInt(position, 0) != -1) { + int operation = operationBlock.getInt(position, 0); + if (operation != DELETE_OPERATION_NUMBER) { + positions[positionCursor] = 0; + positionCursor++; + } + } + } + if (positionCursor != finalPositionCount) { + throw new IllegalArgumentException(format("positionCursor (%s) is not equal to finalPositionCount (%s)", positionCursor, finalPositionCount)); + } + return positions; + } + + private static int[] getPositionsExcludingNotMatchedCases(Block mergeCaseBlock) + { + List mergeCaseBlocks = mergeCaseBlock.getChildren(); + int mergeBlocksSize = mergeCaseBlocks.size(); + Block caseNumberBlock = mergeCaseBlocks.get(mergeBlocksSize - 2); + int counter = 0; + for (int position = 0; position < caseNumberBlock.getPositionCount(); position++) { + if (caseNumberBlock.getInt(position, 0) != -1) { + counter++; + } + } + int cursor = 0; + int[] positions = new int[counter]; + for (int position = 0; position < caseNumberBlock.getPositionCount(); position++) { + if (caseNumberBlock.getInt(position, 0) != -1) { + positions[cursor] = position; + cursor++; + } + } + return positions; + } + + private static int[] getDeletePositions(int positionCount, Block operationBlock) + { + return getPositionsForPredicate(positionCount, position -> { + int operation = operationBlock.getInt(position, 0); + return operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER; + }); + } + + private static int[] getInsertPositions(int positionCount, Block operationBlock) + { + return getPositionsForPredicate(positionCount, position -> { + int operation = operationBlock.getInt(position, 0); + return operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER; + }); + } + + private static class HandleAndChannel + { + private final ColumnHandle handle; + private final int channel; + + public HandleAndChannel(ColumnHandle handle, int channel) + { + this.handle = handle; + this.channel = channel; + } + + public ColumnHandle getHandle() + { + return handle; + } + + public int getChannel() + { + return channel; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertOperator.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertOperator.java new file mode 100644 index 000000000000..dbec62fd1bc6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertOperator.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; +import io.trino.sql.planner.plan.PlanNodeId; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory; +import static io.trino.operator.WorkProcessor.TransformationState.finished; +import static io.trino.operator.WorkProcessor.TransformationState.ofResult; +import static java.util.Objects.requireNonNull; + +/** + * This operator is used by operations like SQL MERGE and UPDATE to support connectors + * that represent a modification of a row as a DELETE plus an INSERT, and support a + * partition and/or bucket paradigm. NOTE: Not all + * {@link io.trino.spi.connector.RowChangeParadigm}s require + * separation of UPDATEs in to DELETEs and INSERTs. That is determined by the + * {@link RowChangeProcessor} + */ +public class DeleteAndInsertOperator + implements WorkProcessorOperator +{ + public static OperatorFactory createOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + RowChangeProcessor rowChangeProcessor) + { + return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor)); + } + + public static class Factory + implements BasicWorkProcessorOperatorAdapter.BasicAdapterWorkProcessorOperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final RowChangeProcessor rowChangeProcessor; + private boolean closed; + + public Factory(int operatorId, PlanNodeId planNodeId, RowChangeProcessor rowChangeProcessor) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + @Override + public WorkProcessorOperator create(ProcessorContext processorContext, WorkProcessor sourcePages) + { + checkState(!closed, "Factory is already closed"); + return new DeleteAndInsertOperator(sourcePages, rowChangeProcessor); + } + + @Override + public int getOperatorId() + { + return operatorId; + } + + @Override + public PlanNodeId getPlanNodeId() + { + return planNodeId; + } + + @Override + public String getOperatorType() + { + return DeleteAndInsertOperator.class.getSimpleName(); + } + + @Override + public void close() + { + closed = true; + } + + @Override + public Factory duplicate() + { + return new Factory(operatorId, planNodeId, rowChangeProcessor); + } + } + + private final WorkProcessor pages; + + private DeleteAndInsertOperator( + WorkProcessor sourcePages, + RowChangeProcessor rowChangeProcessor) + { + pages = sourcePages + .transform(page -> { + if (page == null) { + return finished(); + } + return ofResult(rowChangeProcessor.transformPage(page)); + }); + } + + @Override + public WorkProcessor getOutputPages() + { + return pages; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/DuplicateRowFinder.java b/core/trino-main/src/main/java/io/trino/operator/DuplicateRowFinder.java new file mode 100644 index 000000000000..6ec157180652 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/DuplicateRowFinder.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.type.Type; +import io.trino.type.BlockTypeOperators; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.spi.StandardErrorCode.MERGE_TARGET_ROW_MULTIPLE_MATCHES; +import static io.trino.spi.connector.MergeDetails.DEFAULT_CASE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.INSERT_OPERATION_NUMBER; + +public class DuplicateRowFinder +{ + private final List channels; + private final List channelPositionIsDistinctFrom; + private Page lastPage; + private int lastRowIdPosition = -1; + + public DuplicateRowFinder(List dataColumns, List dataColumnTypes, List writeRedistributionColumns, Type rowIdType) + { + // TODO: David said I ought to be able to inject this, but I don't see how + BlockTypeOperators blockTypeOperators = new BlockTypeOperators(); + Map typeMap = IntStream.range(0, dataColumns.size()) + .boxed() + .collect(toImmutableMap(dataColumns::get, dataColumnTypes::get)); + ImmutableList.Builder channelsBuilder = ImmutableList.builder(); + ImmutableList.Builder channelPositionIsDistinctFromBuilder = ImmutableList.builder(); + for (int channel = 0; channel < writeRedistributionColumns.size(); channel++) { + channelsBuilder.add(channel); + ColumnHandle handle = writeRedistributionColumns.get(channel); + channelPositionIsDistinctFromBuilder.add(blockTypeOperators.getDistinctFromOperator(typeMap.get(handle))); + } + channelsBuilder.add(writeRedistributionColumns.size()); + channelPositionIsDistinctFromBuilder.add(blockTypeOperators.getDistinctFromOperator(rowIdType)); + channels = channelsBuilder.build(); + channelPositionIsDistinctFrom = channelPositionIsDistinctFromBuilder.build(); + } + + /** + * This method looks for sequential duplicates in the target rowId block and redistribution column blocks. + * A sequential duplicate signals that multiple target table rows matched a source row, which violates the + * SQL MERGE spec. + * @param page The page + * @param operationBlock The operation block extracted from the MERGE case RowBlock. + */ + public void checkForDuplicateTargetRows(Page page, Block operationBlock) + { + int positionCount = page.getPositionCount(); + for (int position = 0; position < positionCount; position++) { + int operation = operationBlock.getInt(position, 0); + if (operation != INSERT_OPERATION_NUMBER && operation != DEFAULT_CASE_OPERATION_NUMBER) { + if (lastPage != null) { + boolean isDistinct = false; + for (int channel = 0; channel < channels.size(); channel++) { + if (channelPositionIsDistinctFrom.get(channel).isDistinctFrom(lastPage.getBlock(channel), lastRowIdPosition, page.getBlock(channel), position)) { + isDistinct = true; + break; + } + } + if (!isDistinct) { + throw new TrinoException(MERGE_TARGET_ROW_MULTIPLE_MATCHES, "One MERGE target table row matched more than one source row"); + } + } + lastPage = page; + lastRowIdPosition = position; + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java new file mode 100644 index 000000000000..d9ac248737ef --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; + +public interface MergeRowChangeProcessor + extends RowChangeProcessor +{ + /** + * Transform a page generated by an SQL MERGE operation into page of data columns and + * operations. The SQL MERGE inputPage consists of the write redistribution + * column values; followed by the rowId block; followed by the merge case RowBlock; + * followed by a hash column block for partitioned or bucketed tables in some connectors. + * The result is a page starting with all non-hidden columns computed by the SQL MERGE + * operation, in table column order, followed by rowId column, followed by the "operation" + * block whose values are {@link io.trino.spi.connector.MergeDetails#INSERT_OPERATION_NUMBER}, + * {@link io.trino.spi.connector.MergeDetails#DELETE_OPERATION_NUMBER}, or + * {@link io.trino.spi.connector.MergeDetails#UPDATE_OPERATION_NUMBER}. In some implementations, + * a row in the inputPage can give rise to multiple rows in the transformed page. + */ + @Override + Page transformPage(Page inputPage); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/RowChangeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/RowChangeProcessor.java new file mode 100644 index 000000000000..5c01cb2b0171 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/RowChangeProcessor.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.trino.spi.Page; +import io.trino.spi.connector.RowChangeParadigm; + +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = DeleteAndInsertMergeProcessor.class, name = "deleteAndInsert"), + @JsonSubTypes.Type(value = ChangeOnlyUpdatedColumnsMergeProcessor.class, name = "changeOnlyUpdated"), +}) +public interface RowChangeProcessor +{ + /** + * @return Return the RowChangeParadigm used by the RowChangeProcessor. + */ + RowChangeParadigm getRowChangeParadigm(); + + /** + * Transform the inputPage into an output page. The format of the inputPage depends + * on the operation. + */ + Page transformPage(Page inputPage); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/SqlMergeOperator.java b/core/trino-main/src/main/java/io/trino/operator/SqlMergeOperator.java new file mode 100644 index 000000000000..debd20771251 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/SqlMergeOperator.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.Session; +import io.trino.spi.Page; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.split.PageSinkManager; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; + +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static java.util.Objects.requireNonNull; + +public class SqlMergeOperator + extends AbstractRowChangeOperator +{ + public static class SqlMergeOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final PageSinkManager pageSinkManager; + private final MergeTarget target; + private final Session session; + private boolean closed; + + public SqlMergeOperatorFactory(int operatorId, PlanNodeId planNodeId, PageSinkManager pageSinkManager, MergeTarget target, Session session) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); + this.target = requireNonNull(target, "target is null"); + this.session = requireNonNull(session, "session is null"); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, SqlMergeOperator.class.getSimpleName()); + ConnectorMergeSink mergeSink = pageSinkManager.createMergeSink(session, target.getMergeHandle().get()); + return new SqlMergeOperator(context, mergeSink); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new SqlMergeOperatorFactory(operatorId, planNodeId, pageSinkManager, target, session); + } + } + + private final ConnectorMergeSink mergeSink; + + public SqlMergeOperator(OperatorContext operatorContext, ConnectorMergeSink mergeSink) + { + super(operatorContext); + this.mergeSink = requireNonNull(mergeSink, "mergeSink is null"); + } + + @Override + public void addInput(Page page) + { + requireNonNull(page, "page is null"); + checkState(state == State.RUNNING, "Operator is %s", state); + + mergeSink.storeMergedRows(page); + rowCount += page.getPositionCount(); + } + + @Override + public void finish() + { + if (state == State.RUNNING) { + state = State.FINISHING; + finishFuture = toListenableFuture(mergeSink.finish()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java index 1c9f7d62f4e8..70a2ba0c1369 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java @@ -16,7 +16,10 @@ import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.OutputTableHandle; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorSession; @@ -61,6 +64,15 @@ public ConnectorPageSink createPageSink(Session session, InsertTableHandle table return providerFor(tableHandle.getCatalogName()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle()); } + @Override + public ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle) + { + // assumes connectorId and catalog are the same + TableHandle tableHandle = mergeHandle.getTableHandle(); + ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getCatalogName()); + return providerFor(tableHandle.getCatalogName()).createMergeSink(tableHandle.getTransaction(), connectorSession, mergeHandle.getConnectorMergeHandle()); + } + private ConnectorPageSinkProvider providerFor(CatalogName catalogName) { ConnectorPageSinkProvider provider = pageSinkProviders.get(catalogName); diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java index 8cf84625635a..12b21eb4dc1b 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java @@ -15,12 +15,25 @@ import io.trino.Session; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.OutputTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; public interface PageSinkProvider { + /* + * Used for CTAS + */ ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle); + /* + * Used to insert into an existing table + */ ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle); + + /* + * Used to write the result of SQL MERGE to an existing table + */ + ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 4b4253bf5c34..5dd5bafe6ac6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -33,6 +33,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.eventlistener.ColumnInfo; import io.trino.spi.eventlistener.RoutineInfo; import io.trino.spi.eventlistener.TableInfo; @@ -186,6 +187,7 @@ public class Analysis private Optional refreshMaterializedView = Optional.empty(); private Optional analyzeTarget = Optional.empty(); private Optional> updatedColumns = Optional.empty(); + private Optional mergeAnalysis = Optional.empty(); // for describe input and describe output private final boolean isDescribe; @@ -235,7 +237,7 @@ public void resetUpdateType() public boolean isUpdateTarget(Table table) { - return ("DELETE".equals(updateType) || "UPDATE".equals(updateType)) && + return ("DELETE".equals(updateType) || "UPDATE".equals(updateType) || "MERGE".equals(updateType)) && target.orElseThrow(() -> new IllegalStateException("Update target not set")) .getTable().orElseThrow(() -> new IllegalStateException("Table reference not set in update target")) == table; // intentional comparison by reference } @@ -712,6 +714,16 @@ public void setUpdatedColumns(List updatedColumns) this.updatedColumns = Optional.of(updatedColumns); } + public Optional getMergeAnalysis() + { + return mergeAnalysis; + } + + public void setMergeAnalysis(MergeAnalysis mergeAnalysis) + { + this.mergeAnalysis = Optional.of(mergeAnalysis); + } + public Optional> getUpdatedColumns() { return updatedColumns; @@ -1308,6 +1320,73 @@ public boolean isFrameInherited() } } + // All string column names have been translated from the user's spelling of + // column names to the column names from the table column metadata, so identification + // of columns by string column name is consistent. + public static class MergeAnalysis + { + private final Table targetTable; + private final MergeDetails mergeDetails; + private final Map allColumnTypes; + private final Map allUpdatedColumnTypes; + private final List writeRedistributionColumnNames; + private final Optional newTableLayout; + private final Optional finalQuery; + + public MergeAnalysis( + Table targetTable, + MergeDetails mergeDetails, + Map allColumnTypes, + Map allUpdatedColumnTypes, + List writeRedistributionColumnNames, + Optional newTableLayout, + Optional finalQuery) + { + this.targetTable = requireNonNull(targetTable, "targetTable is null"); + this.mergeDetails = requireNonNull(mergeDetails, "mergeDetails is null"); + this.allColumnTypes = requireNonNull(allColumnTypes, "allColumnTypes is null"); + this.allUpdatedColumnTypes = requireNonNull(allUpdatedColumnTypes, "allUpdatedColumnTypes is null"); + this.writeRedistributionColumnNames = requireNonNull(writeRedistributionColumnNames, "writeRedistributionColumnNames is null"); + this.newTableLayout = requireNonNull(newTableLayout, "newTableLayout is null"); + this.finalQuery = requireNonNull(finalQuery, "finalQuery is null"); + } + + public Table getTargetTable() + { + return targetTable; + } + + public MergeDetails getMergeDetails() + { + return mergeDetails; + } + + public Map getAllColumnTypes() + { + return allColumnTypes; + } + + public Map getAllUpdatedColumnTypes() + { + return allUpdatedColumnTypes; + } + + public List getWriteRedistributionColumnNames() + { + return writeRedistributionColumnNames; + } + + public Optional getNewTableLayout() + { + return newTableLayout; + } + + public Optional getFinalQuery() + { + return finalQuery; + } + } + public static final class AccessControlInfo { private final AccessControl accessControl; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java index ba45c046120c..5dd4c545ae6d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/CanonicalizationAware.java @@ -99,4 +99,9 @@ public static String canonicalize(Identifier identifier) return identifier.getValue().toUpperCase(ENGLISH); } + + public static String canonicalize(String name) + { + return name.toUpperCase(ENGLISH); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 07fc30b832b6..2057f08e7f91 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -47,6 +47,9 @@ import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.ConnectorViewDefinition.ViewColumn; +import io.trino.spi.connector.MergeCaseDetails; +import io.trino.spi.connector.MergeCaseKind; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.function.OperatorType; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.GroupProvider; @@ -62,6 +65,7 @@ import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.SqlPath; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; +import io.trino.sql.analyzer.Analysis.MergeAnalysis; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.Analysis.UnnestAnalysis; @@ -79,7 +83,9 @@ import io.trino.sql.tree.AllRows; import io.trino.sql.tree.Analyze; import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.BooleanLiteral; import io.trino.sql.tree.Call; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; import io.trino.sql.tree.CreateMaterializedView; @@ -113,17 +119,25 @@ import io.trino.sql.tree.Identifier; import io.trino.sql.tree.Insert; import io.trino.sql.tree.Intersect; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.Join; import io.trino.sql.tree.JoinCriteria; import io.trino.sql.tree.JoinOn; import io.trino.sql.tree.JoinUsing; import io.trino.sql.tree.Lateral; import io.trino.sql.tree.Limit; +import io.trino.sql.tree.LogicalBinaryExpression; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.Merge; +import io.trino.sql.tree.MergeCase; +import io.trino.sql.tree.MergeDelete; +import io.trino.sql.tree.MergeInsert; +import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.Parameter; @@ -144,6 +158,7 @@ import io.trino.sql.tree.Rollup; import io.trino.sql.tree.Row; import io.trino.sql.tree.SampledRelation; +import io.trino.sql.tree.SearchedCaseExpression; import io.trino.sql.tree.Select; import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.SetOperation; @@ -166,6 +181,7 @@ import io.trino.sql.tree.UpdateAssignment; import io.trino.sql.tree.Use; import io.trino.sql.tree.Values; +import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.Window; import io.trino.sql.tree.WindowDefinition; import io.trino.sql.tree.WindowFrame; @@ -184,12 +200,10 @@ import java.util.Optional; import java.util.OptionalLong; import java.util.Set; -import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -241,10 +255,12 @@ import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; import static io.trino.spi.StandardErrorCode.VIEW_IS_RECURSIVE; import static io.trino.spi.StandardErrorCode.VIEW_IS_STALE; +import static io.trino.spi.connector.MergeDetails.DEFAULT_CASE_OPERATION_NUMBER; import static io.trino.spi.connector.StandardWarningCode.REDUNDANT_ORDER_BY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; import static io.trino.sql.NodeUtils.mapFromProperties; @@ -264,6 +280,7 @@ import static io.trino.sql.analyzer.ScopeReferenceExtractor.getReferencesToScope; import static io.trino.sql.analyzer.SemanticExceptions.semanticException; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.ExplainType.Type.DISTRIBUTED; import static io.trino.sql.tree.Join.Type.FULL; @@ -278,6 +295,7 @@ import static java.util.Collections.emptyList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; class StatementAnalyzer { @@ -325,16 +343,25 @@ public Scope analyze(Node node, Optional outerQueryScope) .process(node, Optional.empty()); } - public Scope analyzeForUpdate(Table table, Optional outerQueryScope, UpdateKind updateKind) + public Scope analyzeForUpdate(Relation relation, Optional outerQueryScope, UpdateKind updateKind) { return new Visitor(outerQueryScope, warningCollector, Optional.of(updateKind)) - .process(table, Optional.empty()); + .process(relation, Optional.empty()); } private enum UpdateKind { DELETE, - UPDATE; + UPDATE, + MERGE, + /**/; + } + + private static void checkArgument(boolean test, String message, Object... args) + { + if (!test) { + throw new IllegalArgumentException(format(message, args)); + } } /** @@ -1292,7 +1319,22 @@ protected Scope visitTable(Table table, Optional scope) analysis.setColumn(field, columnHandle); } + boolean addRowIdColumn = false; if (updateKind.isPresent()) { + UpdateKind kind = updateKind.get(); + if (kind == UpdateKind.MERGE) { + checkArgument(analysis.getMergeAnalysis().isPresent(), "analysis.getMergeAnalysis() isn't present"); + boolean isMergeTarget = table.shallowEquals(analysis.getMergeAnalysis().get().getTargetTable()); + if (isMergeTarget) { + addRowIdColumn = true; + } + } + else { + addRowIdColumn = true; + } + } + + if (addRowIdColumn) { // Add the row id field ColumnHandle rowIdColumnHandle; switch (updateKind.get()) { @@ -1309,6 +1351,11 @@ protected Scope visitTable(Table table, Optional scope) .collect(toImmutableList()); rowIdColumnHandle = metadata.getUpdateRowIdColumnHandle(session, tableHandle.get(), updatedColumns); break; + case MERGE: + Optional mergeAnalysis = analysis.getMergeAnalysis(); + checkArgument(mergeAnalysis.isPresent(), "mergeAnalysis isn't present"); + rowIdColumnHandle = metadata.getMergeRowIdColumnHandle(session, tableHandle.get(), mergeAnalysis.get().getMergeDetails()); + break; default: throw new UnsupportedOperationException("Unknown UpdateKind " + updateKind.get()); } @@ -1325,7 +1372,7 @@ protected Scope visitTable(Table table, Optional scope) Scope tableScope = createAndAssignScope(table, scope, outputFields); - if (updateKind.isPresent()) { + if (addRowIdColumn) { FieldReference reference = new FieldReference(outputFields.size() - 1); analyzeExpression(reference, tableScope); analysis.setRowIdField(table, reference); @@ -1823,6 +1870,7 @@ else if (node.getType() == FULL) { if (node.getType() == Join.Type.CROSS || node.getType() == Join.Type.IMPLICIT) { return output; } + checkArgument(criteria instanceof JoinOn, "criteria isn't an instance of JoinOn, but instead %s", criteria); if (criteria instanceof JoinOn) { Expression expression = ((JoinOn) criteria).getExpression(); verifyNoAggregateWindowOrGroupingFunctions(metadata, expression, "JOIN clause"); @@ -1835,7 +1883,7 @@ else if (node.getType() == FULL) { if (!clauseType.equals(UNKNOWN)) { throw semanticException(TYPE_MISMATCH, expression, "JOIN ON clause must evaluate to a boolean: actual type %s", clauseType); } - // coerce null to boolean + // coerce expression to boolean analysis.addCoercion(expression, BOOLEAN, false); } @@ -1865,7 +1913,7 @@ protected Scope visitUpdate(Update update, Optional scope) List allColumns = tableMetadata.getColumns(); Map columns = allColumns.stream() - .collect(toImmutableMap(ColumnMetadata::getName, Function.identity())); + .collect(toImmutableMap(ColumnMetadata::getName, identity())); for (UpdateAssignment assignment : update.getAssignments()) { String columnName = assignment.getName().getValue(); @@ -1958,7 +2006,325 @@ protected Scope visitUpdate(Update update, Optional scope) @Override protected Scope visitMerge(Merge merge, Optional scope) { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support merge"); + Table table = merge.getTable(); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + if (metadata.getView(session, tableName).isPresent()) { + throw semanticException(NOT_SUPPORTED, merge, "Merging through views is not supported"); + } + + TableHandle handle = metadata.getTableHandle(session, tableName) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, table, "Table '%s' does not exist", tableName)); + + // Strip out the "hidden" columns + TableMetadata tableMetadata = metadata.getTableMetadata(session, handle); + List allColumns = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) + .collect(toImmutableList()); + + // All identifiers for columns in the target table are mapped to the table column names, + // through this canonicalization map + Map canonicalNameToTableColumnName = allColumns.stream() + .map(ColumnMetadata::getName) + .collect(toImmutableMap(column -> canonicalize(column), identity())); + + // Create MergeDetails + ImmutableList.Builder caseDetailsBuilder = ImmutableList.builder(); + ImmutableMap.Builder> mergeCaseColumnsListsBuilder = ImmutableMap.builder(); + ImmutableSet.Builder allColumnNamesBuilder = ImmutableSet.builder(); + int caseCounter = 0; + for (MergeCase operation : merge.getMergeCases()) { + List mergeColumnNames = translateToTableColumnNames(operation.getSetColumns(), canonicalNameToTableColumnName); + allColumnNamesBuilder.addAll(mergeColumnNames); + mergeCaseColumnsListsBuilder.put(caseCounter, mergeColumnNames); + if (operation instanceof MergeInsert) { + caseDetailsBuilder.add(new MergeCaseDetails( + caseCounter, + MergeCaseKind.INSERT, + ImmutableSet.copyOf(mergeColumnNames))); + } + else if (operation instanceof MergeUpdate) { + caseDetailsBuilder.add(new MergeCaseDetails(caseCounter, MergeCaseKind.UPDATE, ImmutableSet.copyOf(mergeColumnNames))); + } + else if (operation instanceof MergeDelete) { + caseDetailsBuilder.add(new MergeCaseDetails(caseCounter, MergeCaseKind.DELETE, ImmutableSet.of())); + } + else { + throw new IllegalArgumentException(format("Unknown MergeOperation %s of class %s", operation, operation.getClass().getName())); + } + caseCounter++; + } + List mergeCases = caseDetailsBuilder.build(); + Map> mergeCaseColumnsLists = mergeCaseColumnsListsBuilder.build(); + Set allColumnsUpdated = allColumnNamesBuilder.build(); + MergeDetails mergeDetails = new MergeDetails(mergeCases); + + // Build the required mappings between table column name and column type + List allColumnNames = allColumns.stream().map(ColumnMetadata::getName).collect(toImmutableList()); + + List columnTypes = allColumns.stream() + .map(metadata -> new NameAndType(metadata.getName(), metadata.getType())) + .collect(toImmutableList()); + Map allUpdatedColumnTypes = createColumnTypeMap(columnTypes, allColumnsUpdated); + Map allColumnTypes = createColumnTypeMap(columnTypes, allColumnNames); + + // Create the RowType that holds all column values + ImmutableList.Builder fieldsBuilder = ImmutableList.builder(); + allColumnTypes.forEach((columnName, type) -> fieldsBuilder.add(new RowType.Field(Optional.of(columnName), type))); + // Add the case number and the operation number + fieldsBuilder.add(new RowType.Field(Optional.empty(), INTEGER)); + fieldsBuilder.add(new RowType.Field(Optional.empty(), INTEGER)); + List updatedColumnFields = fieldsBuilder.build(); + RowType rowType = RowType.from(updatedColumnFields); + + // Perform legality and access control checks + for (MergeCaseDetails mergeCase : mergeDetails.getCases()) { + switch (mergeCase.getCaseKind()) { + case INSERT: + accessControl.checkCanInsertIntoTable(session.toSecurityContext(), tableName); + break; + case DELETE: + accessControl.checkCanDeleteFromTable(session.toSecurityContext(), tableName); + break; + case UPDATE: + accessControl.checkCanUpdateTableColumns(session.toSecurityContext(), tableName, mergeCase.getUpdatedColumns()); + break; + default: + throw new IllegalStateException("Unknown MergeCaseKind " + mergeCase.getCaseKind()); + } + } + + List updatedColumns = allColumns.stream() + .filter(column -> allColumnsUpdated.contains(column.getName())) + .collect(toImmutableList()); + + if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) { + throw semanticException(NOT_SUPPORTED, merge, "Merge table with row filter"); + } + + for (ColumnMetadata tableColumn : allColumns) { + if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + throw semanticException(NOT_SUPPORTED, merge, "Merge table with column mask"); + } + } + + Map columnHandles = metadata.getColumnHandles(session, handle); + List redistributionColumns = metadata.getWriteRedistributionColumns(session, handle); + List writeRedistributionColumnNames = columnHandles.entrySet().stream() + .filter(entry -> redistributionColumns.contains(entry.getValue())) + .map(entry -> entry.getKey()) + .collect(toImmutableList()); + + analysis.setUpdateType("MERGE", tableName, Optional.of(table), Optional.empty()); + analysis.setUpdatedColumns(updatedColumns); + Optional newTableLayout = metadata.getInsertLayout(session, handle); + // Save MergeAnalysis before the query is generated, because visitTable() needs the data. + analysis.setMergeAnalysis(new MergeAnalysis(table, mergeDetails, allColumnTypes, allUpdatedColumnTypes, writeRedistributionColumnNames, newTableLayout, Optional.empty())); + + // The identifier for the "matched" boolean. + // TODO: Can this collide with user names? How do I make sure it doesn't? + Identifier matchedIdentifier = new Identifier("$row_matched"); + + // TODO: Can this collide with user names? How do I make sure it doesn't? + Identifier targetAlias = merge.getTargetAlias().orElse(new Identifier("$target_alias")); + + // Build the rows of the SearchedCaseExpression + ImmutableList.Builder rowsBuilder = ImmutableList.builder(); + for (int caseNumber = 0; caseNumber < mergeCases.size(); caseNumber++) { + MergeCase mergeCase = merge.getMergeCases().get(caseNumber); + MergeCaseKind mergeKind = getMergeCaseKind(mergeCase); + List caseExpressions = mergeCase.getSetExpressions(); + ImmutableList.Builder rowExpressions = ImmutableList.builder(); + + // Add the updated columns, filling in nulls where this MERGE case didn't assign a value to the column + for (Map.Entry entry : allColumnTypes.entrySet()) { + String column = entry.getKey(); + int index = mergeCaseColumnsLists.get(caseNumber).indexOf(column); + if (index >= 0) { + rowExpressions.add(caseExpressions.get(index)); + } + else { + rowExpressions.add(new DereferenceExpression(targetAlias, new Identifier(column))); + } + } + + // Build the match condition for the MERGE case + boolean matched = mergeKind != MergeCaseKind.INSERT; + Expression caseMatchedExpression = matched ? new IsNotNullPredicate(matchedIdentifier) : new IsNullPredicate(matchedIdentifier); + Expression casePredicate = mergeCase.getExpression().isPresent() ? + new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, caseMatchedExpression, mergeCase.getExpression().get()) : + caseMatchedExpression; + + // Add the caseNumber and the operation number + rowExpressions.add(new LongLiteral(String.valueOf(caseNumber))); + rowExpressions.add(new LongLiteral(String.valueOf(mergeKind.getOperationNumber()))); + + rowsBuilder.add(new WhenClause(casePredicate, new Row(rowExpressions.build()))); + } + + List rows = rowsBuilder.build(); + ImmutableList.Builder caseDefaultRow = ImmutableList.builder(); + + // The default value if no WHEN clause matched - - Connectors must ignore default value rows. + allColumnTypes.values().forEach(type -> caseDefaultRow.add(new Cast(new NullLiteral(), toSqlType(type)))); + caseDefaultRow.add(new LongLiteral(String.valueOf(DEFAULT_CASE_OPERATION_NUMBER))); + caseDefaultRow.add(new LongLiteral(String.valueOf(-1))); + + SearchedCaseExpression searchedCaseExpression = new SearchedCaseExpression(rows, Optional.of(new Row(caseDefaultRow.build()))); + analysis.addCoercion(searchedCaseExpression, rowType, true); + + // Analyzer checks for select permissions but MERGE has a separate permission, so disable access checks + StatementAnalyzer analyzer = new StatementAnalyzer( + analysis, + metadata, + sqlParser, + groupProvider, + new AllowAllAccessControl(), + session, + warningCollector, + CorrelationSupport.ALLOWED); + + // Analyze the table to get the rowId column + new Visitor(outerQueryScope, warningCollector, Optional.of(UpdateKind.MERGE)) + .process(table, Optional.empty()); + + // Create the list of SelectItems containing all columns from the table, and the "matched" flag + Select select = new Select( + false, + ImmutableList.of( + new AllColumns(), + new SingleColumn(new BooleanLiteral("TRUE"), Optional.of(matchedIdentifier)))); + + QuerySpecification query = new QuerySpecification(select, Optional.of(merge.getTable()), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty()); + + AliasedRelation aliasedRelation = new AliasedRelation(query, targetAlias, null); + + Join join = new Join(merge.getLocation(), Join.Type.RIGHT, aliasedRelation, merge.getRelation(), Optional.of(new JoinOn(merge.getExpression()))); + + // Build the outer select - - the merge case RowBlock; the rowId RowBlock, and the partition key columns from the target table if they exist + ImmutableList.Builder itemsBuilder = ImmutableList.builder(); + itemsBuilder.add(new SingleColumn(searchedCaseExpression), new SingleColumn(new FieldReference(allColumnNames.size()))); + + writeRedistributionColumnNames.forEach(column -> itemsBuilder.add(new SingleColumn(new DereferenceExpression(targetAlias, new Identifier(column))))); + Select outerSelect = new Select(false, itemsBuilder.build()); + QuerySpecification finalQuery = new QuerySpecification(outerSelect, Optional.of(join), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty()); + + analyzer.analyzeForUpdate(finalQuery, scope, UpdateKind.MERGE); + + // In WHEN xxx AND condition phrases, cast the condition to BOOLEAN if it isn't already + + for (MergeCase operation : merge.getMergeCases()) { + if (operation.getExpression().isPresent()) { + Expression predicate = operation.getExpression().get(); + Type predicateType = analysis.getType(predicate); + if (!predicateType.equals(BOOLEAN)) { + if (!predicateType.equals(UNKNOWN)) { + throw semanticException(TYPE_MISMATCH, predicate, "WHERE clause must evaluate to a boolean: actual type %s", predicateType); + } + // Coerce the predicate to boolean + analysis.addCoercion(predicate, BOOLEAN, false); + } + } + } + + // Add any necessary casts of column expressions to column type, and throw an exception + // if the expression type cannot be cast to the column type + for (int caseIndex = 0; caseIndex < merge.getMergeCases().size(); caseIndex++) { + MergeCase mergeCase = merge.getMergeCases().get(caseIndex); + if (mergeCase instanceof MergeDelete) { + continue; + } + ImmutableList.Builder setColumnTypesBuilder = ImmutableList.builder(); + ImmutableList.Builder setExpressionTypesBuilder = ImmutableList.builder(); + List columnList = mergeCaseColumnsLists.get(caseIndex); + allColumnTypes.forEach((name, type) -> { + int index = columnList.indexOf(name); + if (index >= 0) { + setColumnTypesBuilder.add(type); + Expression expression = requireNonNull(mergeCase.getSetExpressions().get(index), "merge set expression is null"); + Type expressionType = analysis.getType(expression); + setExpressionTypesBuilder.add(expressionType); + } + }); + List setColumnTypes = setColumnTypesBuilder.build(); + List setExpressionTypes = setExpressionTypesBuilder.build(); + if (!typesMatchForInsert(setColumnTypes, setExpressionTypes)) { + throw semanticException(TYPE_MISMATCH, + mergeCase, + "MERGE table column types don't match for MERGE case %s, SET expressions: Table: [%s], Expressions: [%s]", + caseIndex, + Joiner.on(", ").join(setColumnTypes), + Joiner.on(", ").join(setExpressionTypes)); + } + for (int index = 0; index < setColumnTypes.size(); index++) { + Expression expression = mergeCase.getSetExpressions().get(index); + Type targetType = setColumnTypes.get(index); + Type expressionType = setExpressionTypes.get(index); + if (!targetType.equals(expressionType)) { + analysis.addCoercion(expression, targetType, typeCoercion.isTypeOnlyCoercion(expressionType, targetType)); + } + } + } + + // The final version of MergeAnalysis, with the finalQuery + MergeAnalysis mergeAnalysis = new MergeAnalysis(table, mergeDetails, allColumnTypes, allUpdatedColumnTypes, writeRedistributionColumnNames, newTableLayout, Optional.of(finalQuery)); + analysis.setMergeAnalysis(mergeAnalysis); + return createAndAssignScope(merge, scope, Field.newUnqualified("rows", BIGINT)); + } + + private List translateToTableColumnNames(Collection identifiers, Map canonicalNameToTableColumnName) + { + return identifiers.stream() + .map(identifier -> translateToTableColumnName(identifier, canonicalNameToTableColumnName)) + .collect(toImmutableList()); + } + + private String translateToTableColumnName(Identifier identifier, Map canonicalNameToTableColumnName) + { + return requireNonNull(canonicalNameToTableColumnName.get(canonicalize(identifier.getValue())), "tableColumnName for identifier is null"); + } + + private MergeCaseKind getMergeCaseKind(MergeCase mergeCase) + { + requireNonNull(mergeCase, "mergeCase is null"); + if (mergeCase instanceof MergeInsert) { + return MergeCaseKind.INSERT; + } + if (mergeCase instanceof MergeUpdate) { + return MergeCaseKind.UPDATE; + } + if (mergeCase instanceof MergeDelete) { + return MergeCaseKind.DELETE; + } + throw new IllegalArgumentException("Unrecognized MergeCase " + mergeCase.getClass()); + } + + private class NameAndType + { + private final String name; + private final Type type; + + public NameAndType(String name, Type type) + { + this.name = name; + this.type = type; + } + + public String getName() + { + return name; + } + + public Type getType() + { + return type; + } + } + + private Map createColumnTypeMap(List allColumnTypes, Collection selectedColumns) + { + return allColumnTypes.stream() + .filter(nameAndType -> selectedColumns.contains(nameAndType.getName())) + .collect(toImmutableMap(NameAndType::getName, NameAndType::getType)); } private Scope analyzeJoinUsing(Join node, List columns, Optional scope, Scope left, Scope right) @@ -2792,6 +3158,14 @@ private void analyzeWhere(Node node, Scope scope, Expression predicate) analysis.addCoercion(predicate, BOOLEAN, false); } + if (!predicateType.equals(BOOLEAN)) { + if (!predicateType.equals(UNKNOWN)) { + throw semanticException(TYPE_MISMATCH, predicate, "WHERE clause must evaluate to a boolean: actual type %s", predicateType); + } + // coerce null to boolean + analysis.addCoercion(predicate, BOOLEAN, false); + } + analysis.setWhere(node, predicate); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java index 0ca6bc3844d0..6619e87ab961 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java @@ -30,6 +30,7 @@ import io.trino.sql.DynamicFilters; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AssignUniqueId; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -41,6 +42,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; @@ -416,6 +418,18 @@ public Map visitUpdate(UpdateNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Map visitMerge(MergeNode node, Void context) + { + return node.getSource().accept(this, context); + } + + @Override + public Map visitDeleteAndInsert(DeleteAndInsertNode node, Void context) + { + return node.getSource().accept(this, context); + } + @Override public Map visitTableDelete(TableDeleteNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 767b7407c971..46add9374999 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -39,11 +39,13 @@ import io.trino.execution.buffer.OutputBuffer; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.index.IndexManager; +import io.trino.metadata.MergeHandle; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableHandle; import io.trino.operator.AggregationOperator.AggregationOperatorFactory; import io.trino.operator.AssignUniqueIdOperator; +import io.trino.operator.DeleteAndInsertOperator; import io.trino.operator.DeleteOperator.DeleteOperatorFactory; import io.trino.operator.DevNullOperator.DevNullOperatorFactory; import io.trino.operator.DriverFactory; @@ -87,6 +89,7 @@ import io.trino.operator.SpatialIndexBuilderOperator.SpatialIndexBuilderOperatorFactory; import io.trino.operator.SpatialIndexBuilderOperator.SpatialPredicate; import io.trino.operator.SpatialJoinOperator.SpatialJoinOperatorFactory; +import io.trino.operator.SqlMergeOperator.SqlMergeOperatorFactory; import io.trino.operator.StageExecutionDescriptor; import io.trino.operator.StatisticsWriterOperator.StatisticsWriterOperatorFactory; import io.trino.operator.StreamingAggregationOperator; @@ -153,6 +156,7 @@ import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -166,6 +170,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; @@ -184,6 +189,7 @@ import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; @@ -2481,7 +2487,6 @@ private Set getCoordinatorDynamicFilters(Set d @Override public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context) { - // Set table writer count context.setDriverInstanceCount(getTaskWriterCount(session)); // serialize writes by forcing data through a single writer @@ -2674,6 +2679,38 @@ private List createColumnValueAndRowIdChannels(List outputSymbo return Arrays.asList(columnValueAndRowIdChannels); } + @Override + public PhysicalOperation visitMerge(MergeNode node, LocalExecutionPlanContext context) + { + context.setDriverInstanceCount(getTaskWriterCount(session)); + + PhysicalOperation source = node.getSource().accept(this, context); + OperatorFactory operatorFactory = new SqlMergeOperatorFactory(context.getNextOperatorId(), node.getId(), pageSinkManager, node.getTarget(), session); + + Map layout = ImmutableMap.builder() + .put(node.getOutputSymbols().get(0), 0) + .put(node.getOutputSymbols().get(1), 1) + .build(); + + return new PhysicalOperation(operatorFactory, layout, context, source); + } + + @Override + public PhysicalOperation visitDeleteAndInsert(DeleteAndInsertNode node, LocalExecutionPlanContext context) + { + PhysicalOperation source = node.getSource().accept(this, context); + OperatorFactory operatorFactory = DeleteAndInsertOperator.createOperatorFactory(context.getNextOperatorId(), node.getId(), node.getTarget().getRowChangeProcessor()); + + ImmutableMap.Builder layoutBuilder = ImmutableMap.builder(); + int index = 0; + for (Symbol symbol : node.getOutputSymbols()) { + layoutBuilder.put(symbol, index); + index++; + } + + return new PhysicalOperation(operatorFactory, layoutBuilder.build(), context, source); + } + @Override public PhysicalOperation visitTableDelete(TableDeleteNode node, LocalExecutionPlanContext context) { @@ -3217,6 +3254,12 @@ else if (target instanceof UpdateTarget) { metadata.finishUpdate(session, ((UpdateTarget) target).getHandleOrElseThrow(), fragments); return Optional.empty(); } + else if (target instanceof MergeTarget) { + MergeTarget mergeTarget = (MergeTarget) target; + MergeHandle mergeHandle = mergeTarget.getMergeHandle().orElseThrow(() -> new IllegalArgumentException("mergeHandle isn't present")); + metadata.finishMerge(session, mergeHandle, fragments, statistics); + return Optional.empty(); + } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index cf782eed104a..0e18a22f9707 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -51,6 +51,7 @@ import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; @@ -75,6 +76,7 @@ import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.Insert; import io.trino.sql.tree.LambdaArgumentDeclaration; +import io.trino.sql.tree.Merge; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; @@ -262,6 +264,9 @@ private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement sta if (statement instanceof Update) { return createUpdatePlan(analysis, (Update) statement); } + if (statement instanceof Merge) { + return createMergePlan(analysis, (Merge) statement); + } if (statement instanceof Query) { return createRelationPlan(analysis, (Query) statement); } @@ -665,6 +670,49 @@ private RelationPlan createUpdatePlan(Analysis analysis, Update node) return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputSymbols(), Optional.empty()); } + private RelationPlan createMergePlan(Analysis analysis, Merge node) + { + MergeNode mergeNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session, ImmutableMap.of()) + .plan(node); + + TableFinishNode commitNode = new TableFinishNode( + idAllocator.getNextId(), + mergeNode, + mergeNode.getTarget(), + symbolAllocator.newSymbol("rows", BIGINT), + Optional.empty(), + Optional.empty()); + + return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputSymbols(), Optional.empty()); + } + + public static Optional createPartitioningScheme(Optional writeTableLayout, List symbols, List columnNames) + { + if (writeTableLayout.isPresent()) { + List partitionFunctionArguments = new ArrayList<>(); + writeTableLayout.get().getPartitionColumns().stream() + .mapToInt(columnNames::indexOf) + .mapToObj(symbols::get) + .forEach(partitionFunctionArguments::add); + + List outputLayout = new ArrayList<>(symbols); + + Optional partitioningHandle = writeTableLayout.get().getPartitioning(); + if (partitioningHandle.isPresent()) { + return Optional.of(new PartitioningScheme( + Partitioning.create(partitioningHandle.get(), partitionFunctionArguments), + outputLayout)); + } + else { + // empty connector partitioning handle means evenly partitioning on partitioning columns + return Optional.of(new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionFunctionArguments), + outputLayout)); + } + } + return Optional.empty(); + } + private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) { ImmutableList.Builder outputs = ImmutableList.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index d8a573d108f7..fd1835344eff 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -31,9 +31,11 @@ import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; @@ -324,6 +326,24 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + { + if (node.getPartitioningScheme().isPresent()) { + context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitDeleteAndInsert(DeleteAndInsertNode node, RewriteContext context) + { +// if (node.getPartitioningScheme().isPresent()) { +// context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); +// } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitValues(ValuesNode node, RewriteContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 88b6c93a88a3..9d6acfe035ab 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -100,6 +100,7 @@ import io.trino.sql.planner.iterative.rule.PruneJoinColumns; import io.trino.sql.planner.iterative.rule.PruneLimitColumns; import io.trino.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import io.trino.sql.planner.iterative.rule.PruneMergeSourceColumns; import io.trino.sql.planner.iterative.rule.PruneOffsetColumns; import io.trino.sql.planner.iterative.rule.PruneOrderByInAggregation; import io.trino.sql.planner.iterative.rule.PruneOutputSourceColumns; @@ -324,6 +325,7 @@ public PlanOptimizers( new PruneCorrelatedJoinCorrelation(), new PruneDeleteSourceColumns(), new PruneUpdateSourceColumns(), + new PruneMergeSourceColumns(), new PruneDistinctLimitSourceColumns(), new PruneEnforceSingleRowColumns(), new PruneExceptSourceColumns(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index aa45804b111c..a1257b47a2d2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -24,14 +24,20 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableHandle; import io.trino.metadata.TableMetadata; +import io.trino.operator.ChangeOnlyUpdatedColumnsMergeProcessor; +import io.trino.operator.DeleteAndInsertMergeProcessor; +import io.trino.operator.RowChangeProcessor; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.MergeDetails; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SortOrder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Type; import io.trino.sql.NodeUtils; import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; +import io.trino.sql.analyzer.Analysis.MergeAnalysis; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.FieldId; @@ -39,10 +45,12 @@ import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.GroupIdNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; @@ -51,6 +59,7 @@ import io.trino.sql.planner.plan.SortNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UpdateNode; @@ -71,6 +80,7 @@ import io.trino.sql.tree.LambdaArgumentDeclaration; import io.trino.sql.tree.LambdaExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.Merge; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.Offset; @@ -81,6 +91,8 @@ import io.trino.sql.tree.Relation; import io.trino.sql.tree.SortItem; import io.trino.sql.tree.StringLiteral; +import io.trino.sql.tree.SubscriptExpression; +import io.trino.sql.tree.SymbolReference; import io.trino.sql.tree.Table; import io.trino.sql.tree.Union; import io.trino.sql.tree.Update; @@ -90,6 +102,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -109,6 +122,7 @@ import static io.trino.SystemSessionProperties.isSkipRedundantSort; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; @@ -116,6 +130,7 @@ import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.GroupingOperationRewriter.rewriteGroupingOperation; +import static io.trino.sql.planner.LogicalPlanner.createPartitioningScheme; import static io.trino.sql.planner.OrderingScheme.sortItemToSortOrder; import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; @@ -560,6 +575,145 @@ public UpdateNode plan(Update node) outputs); } + public MergeNode plan(Merge node) + { + Table table = node.getTable(); + TableHandle handle = analysis.getTableHandle(table); + TableMetadata tableMetadata = metadata.getTableMetadata(session, handle); + MergeAnalysis mergeAnalysis = analysis.getMergeAnalysis().orElseThrow(() -> new IllegalArgumentException("analysis.getMergeAnalysis() isn't present")); + QuerySpecification query = mergeAnalysis.getFinalQuery().orElseThrow(() -> new IllegalArgumentException("mergeAnalysis.getFinalQuery() not present")); + + // create table scan + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) + .process(query, null); + + List relationProjectedSymbols = ImmutableList.copyOf(relationPlan.getFieldMappings()); + List writeRedistributionColumns = mergeAnalysis.getWriteRedistributionColumnNames(); + int expectedSize = 2 + writeRedistributionColumns.size(); + checkArgument(relationProjectedSymbols.size() == expectedSize, "projectedSymbols should have size %s, but is %s", expectedSize, relationProjectedSymbols); + + Symbol mergeRow = relationProjectedSymbols.get(0); + Symbol rowId = relationProjectedSymbols.get(1); + List outputs = ImmutableList.of( + symbolAllocator.newSymbol("partialrows", BIGINT), + symbolAllocator.newSymbol("fragment", VARBINARY)); + MergeDetails mergeDetails = mergeAnalysis.getMergeDetails(); + + RowChangeParadigm paradigm = metadata.getRowChangeParadigm(session, handle); + Map columnMap = metadata.getColumnHandles(session, handle); + Type rowIdType = analysis.getType(analysis.getRowIdField(table)); + RowChangeProcessor rowChangeProcessor = createMergeProcessor(paradigm, tableMetadata, mergeDetails, columnMap, writeRedistributionColumns, rowIdType); + + Map updatedColumnTypes = mergeAnalysis.getAllUpdatedColumnTypes(); + Set columnNamesSet = new HashSet<>(updatedColumnTypes.keySet()); + columnNamesSet.addAll(mergeAnalysis.getWriteRedistributionColumnNames()); + + List columnNames = tableMetadata.getColumns().stream() + .map(column -> column.getName()) + .filter(name -> columnNamesSet.contains(name)) + .collect(toImmutableList()); + List columnSymbols = relationPlan.getRoot().getOutputSymbols().stream() + .filter(symbol -> columnNamesSet.contains(symbol.getName())) + .collect(toImmutableList()); + checkState(columnNames.size() == columnSymbols.size(), "Didn't find symbols for all the columns, columns %s, symbols %s", columnNames, columnSymbols); + + Assignments.Builder assignmentsBuilder = Assignments.builder(); + assignmentsBuilder.put(mergeRow, new SymbolReference(mergeRow.getName())); + assignmentsBuilder.put(rowId, new SymbolReference(rowId.getName())); + for (String column : mergeAnalysis.getWriteRedistributionColumnNames()) { + Type type = requireNonNull(mergeAnalysis.getAllColumnTypes().get(column), "column type is null"); + Symbol symbol = symbolAllocator.newSymbol(column, type); + assignmentsBuilder.put(symbol, new SymbolReference(column)); + } + Assignments projectedAssignments = assignmentsBuilder.build(); + ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), relationPlan.getRoot(), projectedAssignments); + + ImmutableList.Builder projectedSymbolsBuilder = ImmutableList.builder(); + int subscriptIndex = 1; + for (ColumnMetadata columnMetadata : tableMetadata.getColumns()) { + if (!columnMetadata.isHidden()) { + SubscriptExpression subscriptExpression = new SubscriptExpression(new SymbolReference(mergeRow.getName()), new LongLiteral(String.valueOf(subscriptIndex))); + Symbol symbol = new Symbol(columnMetadata.getName()); + projectedSymbolsBuilder.add(symbol); + analysis.addTypes(ImmutableMap.of(NodeRef.of(subscriptExpression), columnMetadata.getType())); + subscriptIndex++; + } + } + + Symbol operationSymbol = symbolAllocator.newSymbol("$operation", INTEGER); + projectedSymbolsBuilder.add(operationSymbol); + projectedSymbolsBuilder.add(rowId); + + List finalProjectedSymbols = projectedSymbolsBuilder.build(); + MergeTarget target = new MergeTarget(handle, Optional.empty(), tableMetadata.getTable(), mergeDetails, rowChangeProcessor); + + Optional partitioningScheme = createPartitioningScheme(mergeAnalysis.getNewTableLayout(), columnSymbols, columnNames); + DeleteAndInsertNode deleteAndInsertNode = new DeleteAndInsertNode( + idAllocator.getNextId(), + projectNode, + target, + finalProjectedSymbols); + + Optional tableScanId = getIdForLeftTableScan(relationPlan.getRoot()); + checkArgument(tableScanId.isPresent(), "tableScanId not present"); + return new MergeNode( + idAllocator.getNextId(), + deleteAndInsertNode, + target, + tableScanId, + finalProjectedSymbols, + partitioningScheme, + outputs); + } + + private RowChangeProcessor createMergeProcessor( + RowChangeParadigm paradigm, + TableMetadata tableMetadata, + MergeDetails mergeDetails, + Map columnHandles, + List writeRedistributionColumnNames, + Type rowIdType) + { + switch (paradigm) { + case DELETE_ROW_AND_INSERT_ROW: + return createDeleteAndInsertMergeProcessor(tableMetadata, mergeDetails, columnHandles, writeRedistributionColumnNames, rowIdType); + case CHANGE_ONLY_UPDATED_COLUMNS: + return createChangeOnlyUpdatedColumnsMergeProcessor(tableMetadata, columnHandles, writeRedistributionColumnNames, rowIdType); + default: + throw new IllegalArgumentException("Unsupported RowChangeParadigm " + paradigm); + } + } + + private RowChangeProcessor createDeleteAndInsertMergeProcessor(TableMetadata tableMetadata, MergeDetails mergeDetails, Map columnHandles, List writeRedistributionColumnNames, Type rowIdType) + { + List dataColumnMetadata = tableMetadata.getMetadata().getColumns().stream() + .filter(column -> !column.isHidden()) + .collect(toImmutableList()); + List dataColumnTypes = dataColumnMetadata.stream().map(ColumnMetadata::getType).collect(toImmutableList()); + List dataColumns = dataColumnMetadata.stream() + .map(column -> columnHandles.get(column.getName())) + .collect(toImmutableList()); + List writeRedistributionColumns = writeRedistributionColumnNames.stream() + .map(columnHandles::get) + .collect(toImmutableList()); + return new DeleteAndInsertMergeProcessor(mergeDetails, dataColumns, dataColumnTypes, writeRedistributionColumns, rowIdType); + } + + private RowChangeProcessor createChangeOnlyUpdatedColumnsMergeProcessor(TableMetadata tableMetadata, Map columnHandles, List writeRedistributionColumnNames, Type rowIdType) + { + List dataColumnMetadata = tableMetadata.getMetadata().getColumns().stream() + .filter(column -> !column.isHidden()) + .collect(toImmutableList()); + List dataColumnTypes = dataColumnMetadata.stream().map(ColumnMetadata::getType).collect(toImmutableList()); + List dataColumns = dataColumnMetadata.stream() + .map(column -> columnHandles.get(column.getName())) + .collect(toImmutableList()); + List writeRedistributionColumns = writeRedistributionColumnNames.stream() + .map(columnHandles::get) + .collect(toImmutableList()); + return new ChangeOnlyUpdatedColumnsMergeProcessor(dataColumns, dataColumnTypes, writeRedistributionColumns, rowIdType); + } + private Optional getIdForLeftTableScan(PlanNode node) { if (node instanceof TableScanNode) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java new file mode 100644 index 000000000000..5f24fe3a2a21 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableSet; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.MergeNode; + +import static io.trino.sql.planner.iterative.rule.Util.restrictChildOutputs; +import static io.trino.sql.planner.plan.Patterns.merge; + +public class PruneMergeSourceColumns + implements Rule +{ + private static final Pattern PATTERN = merge(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(MergeNode mergeNode, Captures captures, Context context) + { + ImmutableSet.Builder builder = ImmutableSet.builder(); + builder.addAll(mergeNode.getProjectedSymbols()); + return restrictChildOutputs(context.getIdAllocator(), mergeNode, builder.build()) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 4842bde4544c..859df6c9f205 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -52,6 +52,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanVisitor; @@ -531,8 +532,26 @@ public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProper PlanWithProperties source = node.getSource().accept(this, preferredProperties); Optional partitioningScheme = node.getPartitioningScheme(); + PlanWithProperties partitionedSource = getWriterPlanWithProperties(partitioningScheme, source, true); + + return rebaseAndDeriveProperties(node, partitionedSource); + } + + @Override + public PlanWithProperties visitMerge(MergeNode node, PreferredProperties preferredProperties) + { + PlanWithProperties source = node.getSource().accept(this, preferredProperties); + + Optional partitioningScheme = node.getPartitioningScheme(); + PlanWithProperties partitionedSource = getWriterPlanWithProperties(partitioningScheme, source, false); + + return rebaseAndDeriveProperties(node, partitionedSource); + } + + private PlanWithProperties getWriterPlanWithProperties(Optional partitioningScheme, PlanWithProperties source, boolean allowScaleWriters) + { if (partitioningScheme.isEmpty()) { - if (scaleWriters) { + if (scaleWriters && allowScaleWriters) { partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputSymbols())); } else if (redistributeWrites) { @@ -549,7 +568,7 @@ else if (redistributeWrites) { partitioningScheme.get()), source.getProperties()); } - return rebaseAndDeriveProperties(node, source); + return source; } private Optional planTableScan(TableScanNode node, Expression predicate) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index c4f270836904..605588f5a3d1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -44,6 +44,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanVisitor; @@ -73,6 +74,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.SystemSessionProperties.getTaskConcurrency; import static io.trino.SystemSessionProperties.getTaskWriterCount; import static io.trino.SystemSessionProperties.isDistributedSortEnabled; @@ -534,15 +536,30 @@ public PlanWithProperties visitTopNRanking(TopNRankingNode node, StreamPreferred @Override public PlanWithProperties visitTableWriter(TableWriterNode node, StreamPreferredProperties parentPreferences) + { + return visitPartitionedWriter(node, node.getPartitioningScheme(), parentPreferences); + } + + // + // Merge + // + + @Override + public PlanWithProperties visitMerge(MergeNode node, StreamPreferredProperties parentPreferences) + { + return visitPartitionedWriter(node, node.getPartitioningScheme(), parentPreferences); + } + + private PlanWithProperties visitPartitionedWriter(PlanNode node, Optional optionalPartitioning, StreamPreferredProperties parentPreferences) { if (getTaskWriterCount(session) == 1) { return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); } - if (node.getPartitioningScheme().isEmpty()) { + if (optionalPartitioning.isEmpty()) { return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); } - PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); + PartitioningScheme partitioningScheme = optionalPartitioning.get(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION)) { // arbitrary hash function on predefined set of partition columns StreamPreferredProperties preference = partitionedOn(partitioningScheme.getPartitioning().getColumns()); @@ -554,13 +571,13 @@ public PlanWithProperties visitTableWriter(TableWriterNode node, StreamPreferred verify( partitioningScheme.getPartitioning().getArguments().stream().noneMatch(Partitioning.ArgumentBinding::isConstant), "Table writer partitioning has constant arguments"); - PlanWithProperties source = node.getSource().accept(this, parentPreferences); + PlanWithProperties source = getOnlyElement(node.getSources()).accept(this, parentPreferences); PlanWithProperties exchange = deriveProperties( partitionedExchange( idAllocator.getNextId(), LOCAL, source.getNode(), - node.getPartitioningScheme().get()), + partitioningScheme), source.getProperties()); return rebaseAndDeriveProperties(node, ImmutableList.of(exchange)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index caec54d1ef30..f2e76bbf444d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -16,16 +16,20 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.MergeHandle; import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SemiJoinNode; import io.trino.sql.planner.plan.SimplePlanRewriter; @@ -39,6 +43,7 @@ import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; import io.trino.sql.planner.plan.TableWriterNode.InsertReference; import io.trino.sql.planner.plan.TableWriterNode.InsertTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TableWriterNode.WriterTarget; import io.trino.sql.planner.plan.UnionNode; @@ -47,6 +52,7 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; @@ -132,6 +138,60 @@ public PlanNode visitUpdate(UpdateNode node, RewriteContext> context) + { + MergeTarget mergeTarget = (MergeTarget) getContextTarget(context); + PlanNodeId tableScanId = mergeNode.getTableScanId().orElseThrow(() -> new IllegalArgumentException("tableScanId not present")); + return new MergeNode( + mergeNode.getId(), + rewriteTableScanWithId(mergeNode.getSource(), tableScanId, mergeTarget.getHandle(), context), + mergeTarget, + mergeNode.getTableScanId(), + mergeNode.getProjectedSymbols(), + mergeNode.getPartitioningScheme(), + mergeNode.getOutputSymbols()); + } + + @Override + public PlanNode visitDeleteAndInsert(DeleteAndInsertNode node, RewriteContext> context) + { + MergeTarget mergeTarget = (MergeTarget) getContextTarget(context); + return new DeleteAndInsertNode( + node.getId(), + node.getSource(), + mergeTarget, + node.getOutputSymbols()); + } + + private PlanNode rewriteTableScanWithId(PlanNode node, PlanNodeId tableScanNodeId, TableHandle tableHandle, RewriteContext> context) + { + if (node.getId().equals(tableScanNodeId)) { + TableScanNode scan = (TableScanNode) node; + return new TableScanNode( + scan.getId(), + tableHandle, + scan.getOutputSymbols(), + scan.getAssignments(), + scan.getEnforcedConstraint(), + scan.isUpdateTarget(), + scan.getUseConnectorNodePartitioning()); + } + if (node instanceof DeleteAndInsertNode) { + DeleteAndInsertNode deleter = (DeleteAndInsertNode) node; + MergeTarget mergeTarget = (MergeTarget) getContextTarget(context); + return new DeleteAndInsertNode( + deleter.getId(), + rewriteTableScanWithId(deleter.getSource(), tableScanNodeId, mergeTarget.getHandle(), context), + mergeTarget, + node.getOutputSymbols()); + } + + return node.replaceChildren(node.getSources().stream() + .map(child -> rewriteTableScanWithId(child, tableScanNodeId, tableHandle, context)) + .collect(toImmutableList())); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context) { @@ -190,6 +250,9 @@ public WriterTarget getWriterTarget(PlanNode node) update.getUpdatedColumns(), update.getUpdatedColumnHandles()); } + if (node instanceof MergeNode) { + return ((MergeNode) node).getTarget(); + } if (node instanceof ExchangeNode || node instanceof UnionNode) { Set writerTargets = node.getSources().stream() .map(this::getWriterTarget) @@ -225,6 +288,16 @@ private WriterTarget createWriterTarget(WriterTarget target) update.getUpdatedColumns(), update.getUpdatedColumnHandles()); } + if (target instanceof MergeTarget) { + MergeTarget merge = (MergeTarget) target; + MergeHandle mergeHandle = metadata.beginMerge(session, merge.getHandle(), merge.getMergeDetails()); + return new MergeTarget( + mergeHandle.getTableHandle(), + Optional.of(mergeHandle), + merge.getSchemaTableName(), + merge.getMergeDetails(), + merge.getRowChangeProcessor()); + } if (target instanceof TableWriterNode.RefreshMaterializedViewReference) { TableWriterNode.RefreshMaterializedViewReference refreshMV = (TableWriterNode.RefreshMaterializedViewReference) target; return new TableWriterNode.RefreshMaterializedViewTarget( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 5e5de457db80..5ff6840b1f7f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -44,6 +44,7 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -56,6 +57,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanVisitor; @@ -424,6 +426,18 @@ public ActualProperties visitUpdate(UpdateNode node, List inpu return Iterables.getOnlyElement(inputProperties).translate(symbol -> Optional.empty()); } + @Override + public ActualProperties visitMerge(MergeNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + @Override + public ActualProperties visitDeleteAndInsert(DeleteAndInsertNode node, List inputProperties) + { + return Iterables.getOnlyElement(inputProperties).translate(symbol -> Optional.empty()); + } + @Override public ActualProperties visitJoin(JoinNode node, List inputProperties) { @@ -685,6 +699,11 @@ else if (!(value instanceof Expression)) { @Override public ActualProperties visitTableWriter(TableWriterNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + private ActualProperties visitPartitionedWriter(List inputProperties) { ActualProperties properties = Iterables.getOnlyElement(inputProperties); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java index 26b73ead520b..1468600a8a34 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -38,6 +38,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.ExceptNode; @@ -51,6 +52,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; @@ -726,6 +728,28 @@ public PlanNode visitUpdate(UpdateNode node, RewriteContext> context return new UpdateNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getColumnValueAndRowIdSymbols(), node.getOutputSymbols()); } + @Override + public PlanNode visitMerge(MergeNode node, RewriteContext> context) + { + ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); + if (node.getPartitioningScheme().isPresent()) { + PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); + partitioningScheme.getPartitioning().getColumns().forEach(expectedInputs::add); + partitioningScheme.getHashColumn().ifPresent(expectedInputs::add); + } + expectedInputs.addAll(node.getProjectedSymbols()); + PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); + return new MergeNode(node.getId(), source, node.getTarget(), node.getTableScanId(), node.getProjectedSymbols(), node.getPartitioningScheme(), node.getOutputSymbols()); + } + + @Override + public PlanNode visitDeleteAndInsert(DeleteAndInsertNode node, RewriteContext> context) + { + ImmutableSet.Builder expectedInputsBuilder = ImmutableSet.builder(); + PlanNode source = context.rewrite(node.getSource(), expectedInputsBuilder.build()); + return new DeleteAndInsertNode(node.getId(), source, node.getTarget(), node.getOutputSymbols()); + } + @Override public PlanNode visitUnion(UnionNode node, RewriteContext> context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 1ed3b150698a..94307800c3df 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -32,6 +32,7 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -44,6 +45,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanVisitor; @@ -442,6 +444,20 @@ public StreamProperties visitUpdate(UpdateNode node, List inpu return properties.withUnspecifiedPartitioning(); } + @Override + public StreamProperties visitMerge(MergeNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + + @Override + public StreamProperties visitDeleteAndInsert(DeleteAndInsertNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + @Override public StreamProperties visitTableWriter(TableWriterNode node, List inputProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index b32f63d73268..b430cdee013b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -22,9 +22,11 @@ import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.GroupIdNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RowNumberNode; @@ -306,6 +308,42 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map))); } + public MergeNode map(MergeNode node, PlanNode source) + { + return map(node, source, node.getId()); + } + + public MergeNode map(MergeNode node, PlanNode source, PlanNodeId newId) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + List newOutputs = map(node.getOutputSymbols()); + + return new MergeNode( + node.getId(), + source, + node.getTarget(), + node.getTableScanId(), + map(node.getProjectedSymbols()), + node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())), + newOutputs); + } + + public DeleteAndInsertNode map(DeleteAndInsertNode node, PlanNode source) + { + return map(node, source, node.getId()); + } + + public DeleteAndInsertNode map(DeleteAndInsertNode node, PlanNode source, PlanNodeId newId) + { + List newOutputs = map(node.getOutputSymbols()); + + return new DeleteAndInsertNode( + node.getId(), + source, + node.getTarget(), + newOutputs); + } + public PartitioningScheme map(PartitioningScheme scheme, List sourceLayout) { return new PartitioningScheme( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 819896463742..2c9956704754 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -36,6 +36,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -51,6 +52,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; @@ -580,6 +582,30 @@ public PlanAndMappings visitUpdate(UpdateNode node, UnaliasContext context) mapping); } + @Override + public PlanAndMappings visitMerge(MergeNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + MergeNode rewrittenMerge = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenMerge, mapping); + } + + @Override + public PlanAndMappings visitDeleteAndInsert(DeleteAndInsertNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + DeleteAndInsertNode rewrittenMerge = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenMerge, mapping); + } + @Override public PlanAndMappings visitStatisticsWriterNode(StatisticsWriterNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/DeleteAndInsertNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/DeleteAndInsertNode.java new file mode 100644 index 000000000000..fc8ae2abab6e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/DeleteAndInsertNode.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.sql.planner.Symbol; + +import java.util.List; + +import static io.trino.sql.planner.plan.TableWriterNode.MergeTarget; +import static java.util.Objects.requireNonNull; + +/** + * A node that supports connectors that morph update operations into delete + * and insert operations, if required by the connector's + * {@link io.trino.spi.connector.RowChangeParadigm}. + */ +public class DeleteAndInsertNode + extends PlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final List outputs; + + @JsonCreator + public DeleteAndInsertNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("outputs") List outputs) + { + super(id); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty("outputs") + @Override + public List getOutputSymbols() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitDeleteAndInsert(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new DeleteAndInsertNode(getId(), Iterables.getOnlyElement(newChildren), target, outputs); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeNode.java new file mode 100644 index 000000000000..b1b0b35c26ff --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeNode.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class MergeNode + extends PlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final Optional tableScanId; + private final List projectedSymbols; + private final Optional partitioningScheme; + private final List outputs; + + @JsonCreator + public MergeNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("tableScanId") Optional tableScanId, + @JsonProperty("projectedSymbols") List projectedSymbols, + @JsonProperty("partitioningScheme") Optional partitioningScheme, + @JsonProperty("outputs") List outputs) + { + super(id); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.tableScanId = requireNonNull(tableScanId, "tableScanId is null"); + this.projectedSymbols = requireNonNull(projectedSymbols, "projectedSymbols is null"); + this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty + public Optional getTableScanId() + { + return tableScanId; + } + + @JsonProperty + public List getProjectedSymbols() + { + return projectedSymbols; + } + + @JsonProperty + public Optional getPartitioningScheme() + { + return partitioningScheme; + } + + /** + * Aggregate information about updated data + */ + @JsonProperty("outputs") + @Override + public List getOutputSymbols() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitMerge(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new MergeNode(getId(), Iterables.getOnlyElement(newChildren), target, tableScanId, projectedSymbols, partitioningScheme, outputs); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index f8f1ad46f4c2..2fdb09bcd7cb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -63,6 +63,11 @@ public static Pattern update() return typeOf(UpdateNode.class); } + public static Pattern merge() + { + return typeOf(MergeNode.class); + } + public static Pattern exchange() { return typeOf(ExchangeNode.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index a3fd5ced3024..5de7f913bb9e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -50,6 +50,8 @@ @JsonSubTypes.Type(value = TableWriterNode.class, name = "tablewriter"), @JsonSubTypes.Type(value = DeleteNode.class, name = "delete"), @JsonSubTypes.Type(value = UpdateNode.class, name = "update"), + @JsonSubTypes.Type(value = MergeNode.class, name = "merge"), + @JsonSubTypes.Type(value = DeleteAndInsertNode.class, name = "deleteAndInsert"), @JsonSubTypes.Type(value = TableDeleteNode.class, name = "tableDelete"), @JsonSubTypes.Type(value = TableFinishNode.class, name = "tablecommit"), @JsonSubTypes.Type(value = UnnestNode.class, name = "unnest"), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 0e313b9d0412..a3243a015d6e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -134,6 +134,16 @@ public R visitUpdate(UpdateNode node, C context) return visitPlan(node, context); } + public R visitMerge(MergeNode node, C context) + { + return visitPlan(node, context); + } + + public R visitDeleteAndInsert(DeleteAndInsertNode node, C context) + { + return visitPlan(node, context); + } + public R visitTableDelete(TableDeleteNode node, C context) { return visitPlan(node, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java index 5c3d0255d6c2..e59393bd59ed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java @@ -22,12 +22,15 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.NewTableLayout; import io.trino.metadata.OutputTableHandle; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.TableHandle; +import io.trino.operator.RowChangeProcessor; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.SchemaTableName; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.Symbol; @@ -199,6 +202,7 @@ public PlanNode replaceChildren(List newChildren) @JsonSubTypes.Type(value = InsertTarget.class, name = "InsertTarget"), @JsonSubTypes.Type(value = DeleteTarget.class, name = "DeleteTarget"), @JsonSubTypes.Type(value = UpdateTarget.class, name = "UpdateTarget"), + @JsonSubTypes.Type(value = MergeTarget.class, name = "MergeTarget"), @JsonSubTypes.Type(value = RefreshMaterializedViewTarget.class, name = "RefreshMaterializedViewTarget")}) @SuppressWarnings({"EmptyClass", "ClassMayBeInterface"}) public abstract static class WriterTarget @@ -528,4 +532,65 @@ public String toString() return handle.map(Object::toString).orElse("[]"); } } + + public static class MergeTarget + extends WriterTarget + { + private final TableHandle handle; + private final Optional mergeHandle; + private final SchemaTableName schemaTableName; + private final MergeDetails mergeDetails; + private final RowChangeProcessor rowChangeProcessor; + + @JsonCreator + public MergeTarget( + @JsonProperty("handle") TableHandle handle, + @JsonProperty("mergeHandle") Optional mergeHandle, + @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + @JsonProperty("mergeDetails") MergeDetails mergeDetails, + @JsonProperty("rowChangeProcessor") RowChangeProcessor rowChangeProcessor) + { + this.handle = requireNonNull(handle, "handle is null"); + this.mergeHandle = requireNonNull(mergeHandle, "mergeHandle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.mergeDetails = requireNonNull(mergeDetails, "mergeDetails is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + @JsonProperty + public TableHandle getHandle() + { + return handle; + } + + @JsonProperty + public Optional getMergeHandle() + { + return mergeHandle; + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @JsonProperty + public MergeDetails getMergeDetails() + { + return mergeDetails; + } + + @JsonProperty + public RowChangeProcessor getRowChangeProcessor() + { + return rowChangeProcessor; + } + + @Override + public String toString() + { + return handle.toString(); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java index 4d094ea036dc..c75160d8ff4a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java @@ -44,6 +44,7 @@ import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; import io.trino.sql.planner.plan.TableWriterNode.InsertReference; import io.trino.sql.planner.plan.TableWriterNode.InsertTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TableWriterNode.WriterTarget; import io.trino.sql.planner.planprinter.IoPlanPrinter.FormattedMarker.Bound; @@ -672,6 +673,13 @@ else if (writerTarget instanceof UpdateTarget) { target.getSchemaTableName().getSchemaName(), target.getSchemaTableName().getTableName())); } + else if (writerTarget instanceof MergeTarget) { + MergeTarget target = (MergeTarget) writerTarget; + context.setOutputTable(new CatalogSchemaTableName( + target.getHandle().getCatalogName().getCatalogName(), + target.getSchemaTableName().getSchemaName(), + target.getSchemaTableName().getTableName())); + } else if (writerTarget instanceof TableWriterNode.RefreshMaterializedViewTarget) { TableWriterNode.RefreshMaterializedViewTarget target = (TableWriterNode.RefreshMaterializedViewTarget) writerTarget; context.setOutputTable(new CatalogSchemaTableName( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index db3ed7da51d1..d3ff93a68036 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -53,6 +53,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -69,6 +70,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanFragmentId; @@ -1129,6 +1131,22 @@ public Void visitUpdate(UpdateNode node, Void context) return processChildren(node, context); } + @Override + public Void visitMerge(MergeNode node, Void context) + { + addNode(node, "Merge", format("[%s]", node.getTarget())); + + return processChildren(node, context); + } + + @Override + public Void visitDeleteAndInsert(DeleteAndInsertNode node, Void context) + { + addNode(node, "DeleteAndInsert"); + + return processChildren(node, context); + } + @Override public Void visitTableDelete(TableDeleteNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 7c1894ad891a..0a9c2118f202 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -29,6 +29,7 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DeleteAndInsertNode; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -43,6 +44,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; @@ -586,6 +588,23 @@ public Void visitUpdate(UpdateNode node, Set boundSymbols) return null; } + @Override + public Void visitMerge(MergeNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + return null; + } + + @Override + public Void visitDeleteAndInsert(DeleteAndInsertNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + + return null; + } + @Override public Void visitTableDelete(TableDeleteNode node, Set boundSymbols) { diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 5c0dfa4edb36..2e30da057d35 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -602,7 +602,7 @@ public enum TestingPrivilegeType EXECUTE_QUERY, VIEW_QUERY, KILL_QUERY, EXECUTE_FUNCTION, CREATE_SCHEMA, DROP_SCHEMA, RENAME_SCHEMA, - SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, COMMENT_TABLE, COMMENT_COLUMN, INSERT_TABLE, DELETE_TABLE, UPDATE_TABLE, SHOW_COLUMNS, + SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, COMMENT_TABLE, COMMENT_COLUMN, INSERT_TABLE, DELETE_TABLE, UPDATE_TABLE, MERGE_TABLE, SHOW_COLUMNS, ADD_COLUMN, DROP_COLUMN, RENAME_COLUMN, SELECT_COLUMN, CREATE_VIEW, RENAME_VIEW, DROP_VIEW, CREATE_VIEW_WITH_SELECT_COLUMNS, GRANT_EXECUTE_FUNCTION, diff --git a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java index 747faf2318f4..ee64ab23af57 100644 --- a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java @@ -40,6 +40,7 @@ import io.trino.sql.tree.Grant; import io.trino.sql.tree.GrantRoles; import io.trino.sql.tree.Insert; +import io.trino.sql.tree.Merge; import io.trino.sql.tree.Prepare; import io.trino.sql.tree.Query; import io.trino.sql.tree.RefreshMaterializedView; @@ -96,6 +97,8 @@ private StatementUtils() {} builder.put(Update.class, QueryType.UPDATE); + builder.put(Merge.class, QueryType.MERGE); + builder.put(ShowCatalogs.class, QueryType.DESCRIBE); builder.put(ShowCreate.class, QueryType.DESCRIBE); builder.put(ShowFunctions.class, QueryType.DESCRIBE); diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index d1df73ff6e36..c146b153759e 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -44,7 +44,9 @@ import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -427,6 +429,30 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new UnsupportedOperationException(); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java index 1be838e34cc1..ab9edaeb67a2 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java @@ -36,7 +36,7 @@ public Join(NodeLocation location, Type type, Relation left, Relation right, Opt this(Optional.of(location), type, left, right, criteria); } - private Join(Optional location, Type type, Relation left, Relation right, Optional criteria) + public Join(Optional location, Type type, Relation left, Relation right, Optional criteria) { super(location); requireNonNull(left, "left is null"); diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 001975266f20..d97d435400e8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -115,6 +115,7 @@ public enum StandardErrorCode DUPLICATE_WINDOW_NAME(92, USER_ERROR), INVALID_WINDOW_REFERENCE(93, USER_ERROR), INVALID_PARTITION_BY(94, USER_ERROR), + MERGE_TARGET_ROW_MULTIPLE_MATCHES(95, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java index 7c111523d0ee..0a28b98fda76 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java @@ -317,4 +317,15 @@ default List getChildren() { return Collections.emptyList(); } + + default boolean allPositionsAreNull() + { + int positionCount = getPositionCount(); + for (int position = 0; position < positionCount; position++) { + if (!isNull(position)) { + return false; + } + } + return true; + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java index 8e65748008e2..cc3182a0f3d2 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java @@ -50,6 +50,11 @@ default Class getInsertTableHandleClass() throw new UnsupportedOperationException(); } + default Class getMergeTableHandleClass() + { + throw new UnsupportedOperationException(); + } + default Class getPartitioningHandleClass() { throw new UnsupportedOperationException(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java new file mode 100644 index 000000000000..e55f8627e291 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +public interface ConnectorMergeSink +{ + /** + * Store the page(s) resulting from a merge. The page consists n blocks, numbered 0..n-1: + *
    + *
  • Blocks 0..n-3 in page are the data column blocks.
  • + *
  • Block n-2: Is the "operation" IntArrayBlock, whose values are {@link MergeDetails#INSERT_OPERATION_NUMBER}, + * {@link MergeDetails#DELETE_OPERATION_NUMBER} or {@link MergeDetails#UPDATE_OPERATION_NUMBER} + *
  • + *
  • Block n-1 is a connector-specific rowId block, previously returned by + * {@link ConnectorMetadata#getMergeRowIdColumnHandle(ConnectorSession, ConnectorTableHandle, MergeDetails)} + *
  • + *
+ * @param page The page to store. + */ + default void storeMergedRows(Page page) + { + throw new UnsupportedOperationException("This connector does not support row merge"); + } + + CompletableFuture> finish(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java new file mode 100644 index 000000000000..6bf97e7fbe2a --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +public interface ConnectorMergeTableHandle +{ + ConnectorTableHandle getTableHandle(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index f6bcd3b0c43c..fe08a989b8f8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -172,6 +172,17 @@ default ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorT return getTableMetadata(session, table).getTableSchema(); } + /** + * Return the column handles for the columns that must be present to in order + * to perform the partitioning and/or bucketing required. By default, the table + * has no such columns. + * @return Return the write redistribution columns for the table + */ + default List getWriteRedistributionColumns(ConnectorSession session, ConnectorTableHandle table) + { + return List.of(); + } + /** * Return the metadata for the specified table handle. * @@ -556,6 +567,45 @@ default void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHa throw new TrinoException(NOT_SUPPORTED, "This connector does not support updates"); } + /** + * Return the row change paradigm supported by the connector on the table. + */ + default RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Get the column handle that will generate row IDs for the merge operation. + * These IDs will be passed to the {@link ConnectorMergeSink#storeMergedRows} + * method of the {@link ConnectorMergeSink} that created them. + */ + default ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Do whatever is necessary to start an MERGE query, returning the {@link ConnectorMergeTableHandle} + * instance that will be passed to the PageSink, and to the {@link #finishMerge} method. + */ + default ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Finish a merge query + * @param session The session + * @param tableHandle A ConnectorMergeTableHandle for the table that is the target of the merge + * @param fragments All fragments returned by {@link UpdatablePageSource#finish()} + * @param computedStatistics Statistics for the table, meaningful only to the connector that produced them. + */ + default void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata beginMerge() is implemented without finishMerge()"); + } + /** * Create the specified view. The view definition is intended to * be serialized by the connector for permanent storage. diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java index 345b3d72cdeb..a19900cd9ee8 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java @@ -13,9 +13,18 @@ */ package io.trino.spi.connector; +import io.trino.spi.TrinoException; + +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; + public interface ConnectorPageSinkProvider { ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle outputTableHandle); ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle insertTableHandle); + + default ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support SQL MERGE operations"); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseDetails.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseDetails.java new file mode 100644 index 000000000000..652c194b45aa --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseDetails.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Set; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class MergeCaseDetails +{ + private final int caseNumber; + private final MergeCaseKind caseKind; + private final Set updatedColumns; + + @JsonCreator + public MergeCaseDetails( + @JsonProperty("caseNumber") int caseNumber, + @JsonProperty("caseKind") MergeCaseKind caseKind, + @JsonProperty("updatedColumns") Set updatedColumns) + { + this.caseNumber = caseNumber; + this.caseKind = requireNonNull(caseKind, "caseKind is null"); + this.updatedColumns = requireNonNull(updatedColumns, "updatedColumns is null"); + switch (caseKind) { + case DELETE: + checkArgument(updatedColumns.isEmpty(), "For DELETE operations, updatedColumns must be empty, but is %s", updatedColumns); + break; + case INSERT: + case UPDATE: + checkArgument(!updatedColumns.isEmpty(), "For INSERT and UPDATE operations, updatedColumns must be non-empty"); + break; + } + } + + @JsonProperty + public int getCaseNumber() + { + return caseNumber; + } + + @JsonProperty + public MergeCaseKind getCaseKind() + { + return caseKind; + } + + @JsonProperty + public Set getUpdatedColumns() + { + return updatedColumns; + } + + private static void checkArgument(boolean test, String message, Object... args) + { + if (!test) { + throw new IllegalArgumentException(format(message, args)); + } + } + + @Override + public String toString() + { + return String.format("Case{caseNumber=%s, caseKind=%s, updatedColumns=%s}", caseNumber, caseKind, updatedColumns); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseKind.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseKind.java new file mode 100644 index 000000000000..73039f2e11bb --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeCaseKind.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import static io.trino.spi.connector.MergeDetails.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.UPDATE_OPERATION_NUMBER; + +public enum MergeCaseKind +{ + INSERT(INSERT_OPERATION_NUMBER), + DELETE(DELETE_OPERATION_NUMBER), + UPDATE(UPDATE_OPERATION_NUMBER); + + private final int operationNumber; + + MergeCaseKind(int operationNumber) + { + this.operationNumber = operationNumber; + } + + public int getOperationNumber() + { + return operationNumber; + } + + public boolean matchedKind() + { + return this != INSERT; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergeDetails.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeDetails.java new file mode 100644 index 000000000000..60c2d5fe1a48 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeDetails.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class MergeDetails +{ + public static final int DEFAULT_CASE_OPERATION_NUMBER = -1; + public static final int INSERT_OPERATION_NUMBER = 1; + public static final int DELETE_OPERATION_NUMBER = 2; + public static final int UPDATE_OPERATION_NUMBER = 3; + + private final List cases; + + @JsonCreator + public MergeDetails(@JsonProperty("cases") List cases) + { + this.cases = requireNonNull(cases, "cases is null"); + } + + @JsonProperty + public List getCases() + { + return cases; + } + + @Override + public String toString() + { + return format("MergeDetails{cases=[%s]}", cases.stream().map(MergeCaseDetails::toString).collect(joining(", "))); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergeProcessorUtilities.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeProcessorUtilities.java new file mode 100644 index 000000000000..2aa11d6b62d7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergeProcessorUtilities.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlock; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RowBlock; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static io.trino.spi.connector.MergeDetails.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.INSERT_OPERATION_NUMBER; +import static java.lang.String.format; + +public final class MergeProcessorUtilities +{ + private MergeProcessorUtilities() {} + + public static PagePair createMergedDeleteAndInsertPages(Page inputPage, int dataColumnCount) + { + // Blocks 0..n-3 in inputPage are the data column blocks. The last block in the + // inputPage is is the rowId block. The next-to-last block in the inputPage is the + // IntArrayBlock operation block, with values DELETE_OPERATION_NUMBER + // and INSERT_OPERATION_NUMBER + int inputChannelCount = inputPage.getChannelCount(); + if (inputChannelCount != dataColumnCount + 2) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 2", inputChannelCount, dataColumnCount)); + } + + int positionCount = inputPage.getPositionCount(); + if (positionCount <= 0) { + throw new IllegalArgumentException("positionCount should be > 0, but is " + positionCount); + } + Block operationBlock = inputPage.getBlock(inputChannelCount - 2); + + Optional deletePage = Optional.empty(); + int[] deletePositions = getPositionsForPredicate(positionCount, position -> operationBlock.getInt(position, 0) == DELETE_OPERATION_NUMBER); + if (deletePositions.length > 0) { + deletePage = Optional.of(getPositionsHandlingRowId(inputPage, deletePositions)); + } + Optional insertPage = Optional.empty(); + int[] insertPositions = getPositionsForPredicate(positionCount, position -> operationBlock.getInt(position, 0) == INSERT_OPERATION_NUMBER); + if (insertPositions.length > 0) { + insertPage = Optional.of(getPositionsHandlingRowId(inputPage, insertPositions)); + } + return new PagePair(deletePage, insertPage); + } + + public static Block getUnderlyingBlock(Block block) + { + while (block instanceof DictionaryBlock) { + block = ((DictionaryBlock) block).getDictionary(); + } + return block; + } + + public static Block getAllNullsRowIdBlock(Block rowIdBlock, Block underlyingBlock, int positionCount) + { + boolean[] nulls = new boolean[positionCount]; + Arrays.fill(nulls, true); + if (underlyingBlock instanceof RowBlock) { + return RowBlock.fromFieldBlocks(positionCount, Optional.of(nulls), rowIdBlock.getChildren().toArray(new Block[]{})); + } + else { + return ArrayBlock.fromElementBlock(positionCount, Optional.of(nulls), new int[positionCount], underlyingBlock); + } + } + + public static int[] getPositionsForPredicate(int positionCount, Predicate positionPicker) + { + int counter = 0; + for (int position = 0; position < positionCount; position++) { + if (positionPicker.test(position)) { + counter++; + } + } + int[] positions = new int[counter]; + int cursor = 0; + for (int position = 0; position < positionCount; position++) { + if (positionPicker.test(position)) { + positions[cursor] = position; + cursor++; + } + } + return positions; + } + + private static Page getPositionsHandlingRowId(Page page, int[] positions) + { + int positionCount = positions.length; + int channelCount = page.getChannelCount(); + Block[] newPageBlocks = new Block[channelCount]; + for (int channel = 0; channel < channelCount - 1; channel++) { + newPageBlocks[channel] = page.getBlock(channel).getPositions(positions, 0, positionCount); + } + Block rowIdBlock = page.getBlock(channelCount - 1); + newPageBlocks[channelCount - 1] = extractRowIdBlockPositions(rowIdBlock, positions); + return new Page(newPageBlocks); + } + + private static Block extractRowIdBlockPositions(Block rowIdBlock, int[] positions) + { + int positionCount = positions.length; + Block underlyingBlock = MergeProcessorUtilities.getUnderlyingBlock(rowIdBlock); + if (underlyingBlock.allPositionsAreNull()) { + return getAllNullsRowIdBlock(rowIdBlock, underlyingBlock, positionCount); + } + else { + List rowIdChildren = rowIdBlock.getChildren(); + int rowIdChildCount = rowIdChildren.size(); + Block[] rowIdBlocks = new Block[rowIdChildCount]; + for (int channel = 0; channel < rowIdChildCount; channel++) { + rowIdBlocks[channel] = rowIdChildren.get(channel).getPositions(positions, 0, positionCount); + } + return RowBlock.fromFieldBlocks(positionCount, Optional.empty(), rowIdBlocks); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/PagePair.java b/core/trino-spi/src/main/java/io/trino/spi/connector/PagePair.java new file mode 100644 index 000000000000..97dcc40c9c9f --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/PagePair.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; + +import java.util.Optional; + +public class PagePair +{ + private final Optional deletionsPage; + private final Optional insertionsPage; + + public PagePair(Optional deletionsPage, Optional insertionsPage) + { + this.deletionsPage = deletionsPage; + this.insertionsPage = insertionsPage; + } + + public Optional getDeletionsPage() + { + return deletionsPage; + } + + public Optional getInsertionsPage() + { + return insertionsPage; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java b/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java new file mode 100644 index 000000000000..93c09cad1555 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +public enum RowChangeParadigm +{ + /** + * The JDBC paradigm - - requires just the rowId and the changed columns + */ + CHANGE_ONLY_UPDATED_COLUMNS, + + /** + * The Hive and Iceberg paradigm - - translates a changed row into a delete + * by rowId, and an insert of a new record, which will get a new rowId + */ + DELETE_ROW_AND_INSERT_ROW, + + /** + * The Delta Lake paradigm - - to change any field of any row, the entire file + * must be replaced + */ + COPY_FILE_ON_CHANGE, +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java index 8962cae8d9d5..770c5ad307ef 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java @@ -22,5 +22,6 @@ public enum QueryType ANALYZE, INSERT, UPDATE, + MERGE, SELECT } diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/Privilege.java b/core/trino-spi/src/main/java/io/trino/spi/security/Privilege.java index 7f90c196ecf0..5c4067012637 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/Privilege.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/Privilege.java @@ -15,5 +15,5 @@ public enum Privilege { - SELECT, DELETE, INSERT, UPDATE + SELECT, DELETE, INSERT, UPDATE, MERGE } diff --git a/docs/src/main/sphinx/develop.rst b/docs/src/main/sphinx/develop.rst index 73e3dc0de5fe..bad20ecf0493 100644 --- a/docs/src/main/sphinx/develop.rst +++ b/docs/src/main/sphinx/develop.rst @@ -11,6 +11,7 @@ This guide is intended for Trino contributors and plugin developers. develop/connectors develop/example-http develop/delete-and-update + develop/supporting-merge develop/types develop/functions develop/system-access-control diff --git a/docs/src/main/sphinx/develop/supporting-merge.rst b/docs/src/main/sphinx/develop/supporting-merge.rst new file mode 100644 index 000000000000..9eb1f54fb4f4 --- /dev/null +++ b/docs/src/main/sphinx/develop/supporting-merge.rst @@ -0,0 +1,373 @@ +==================== +Supporting ``MERGE`` +==================== + +The Trino engine provides APIs to support row-level SQL ``MERGE``. +To implement ``MERGE``, a connector must provide an implementation +of ``ConnectorMergeSink``, which is typically layered on top of a +``ConnectorPageSink``, and define ``ConnectorMetadata`` +methods to get a "rowId" column handle; get the row change paradigm; +and to start and complete the ``MERGE`` operation. + +Standard SQL ``MERGE`` +---------------------- + +Different query engines support varying definitions of SQL ``MERGE``. +Trino supports the strict SQL specification ``ISO/IEC 9075``, published +in 2016. As a simple example, given tables ``target_table`` and +``source_table`` defined as:: + + CREATE TABLE target_table ( + customer VARCHAR, + purchases DECIMAL, + address VARCHAR); + INSERT INTO target_table (customer, purchases, address) VALUES ...; + CREATE TABLE source_table ( + customer VARCHAR, + purchases DECIMAL, + address VARCHAR); + INSERT INTO source_table (customer, purchases, address) VALUES ...; + +Here is a possible ``MERGE`` operation, from ``source_table`` to +``target_table``:: + + MERGE INTO target_table t USING source_table s + ON (t.customer = s.customer) + WHEN MATCHED AND s.address = 'Berkeley' THEN + DELETE + WHEN MATCHED AND s.customer = 'Joe Shmoe' THEN + UPDATE SET purchases = purchases + 100.0 + WHEN MATCHED THEN + UPDATE + SET purchases = s.purchases + t.purchases, address = s.address + WHEN NOT MATCHED THEN + INSERT (customer, purchases, address) + VALUES (s.customer, s.purchases, s.address); + +SQL ``MERGE`` supports two distinct operations on the target table and source +when a row from the source query matches a row in the target table: + +* ``UPDATE``, in which the columns in the target row are updated. +* ``DELETE``, in which the target row is deleted. + +In the ``NOT MATCHED`` case, SQL ``MERGE`` supports only ``INSERT`` of the +unmatched row from the source query into the target table. + +``RowChangeParadigm`` +--------------------- + +Different connectors have different ways of representing row updates, +imposed by the underlying storage systems. The Trino engine classifies +these different paradigms as elements of the ``RowChangeParadigm`` +enumeration, returned by ``ConnectorMetadata`` method +``getRowChangeParadigm(...)``. + +The ``RowChangeParadigm`` enumeration values are: + +* ``CHANGE_ONLY_UPDATED_COLUMNS``, intended for connectors like + ``trino-jdbc`` and ``trino-kudu`` that can update individual columns of + rows identified by a ``rowId``. The corresponding merge processor class is + ``ChangeOnlyUpdatedColumnsMergeProcessor``. +* ``DELETE_ROW_AND_INSERT_ROW``, intended for connectors like ``trino-hive`` + and ``trino-iceberg`` that represent a row change as a row deletion paired + with a row insertion. The corresponding merge processor class is + ``DeleteAndInsertMergeProcessor``. +* ``COPY_FILE_ON_CHANGE``, intended for connectors that must copy an entire + file to update one or more rows in the file. + +Overview of ``MERGE`` processing +-------------------------------- + +A ``MERGE`` statement is processed by creating a ``RIGHT JOIN`` between the +target table and the source query, on the ``MERGE`` criteria, returning a +``ROW`` containing the column values from the ``UPDATE`` or ``INSERT`` +cases, and ``NULL`` for the ``DELETE`` cases; an integer identifying the +``MERGE`` case matched; and whether the merge case operation is +``UPDATE``, ``DELETE`` or ``INSERT``. The example above is executed as +if it were written as:: + + SELECT + FROM (SELECT *, true AS present FROM target_table) t + RIGHT JOIN source_table s ON s.customer = t.customer + CASE + WHEN present AND s.address = 'Berkeley' THEN + // Null values for delete; case #0; operation DELETE=2 + row(null, null, null, 0, 2) + WHEN present AND s.customer = 'Joe Shmoe' THEN + // Update column values; case #1; operation UPDATE=3 + row(null, t.purchases + 100.0, null, 1, 3) + WHEN present THEN + // Update column values; case #2; operation UPDATE=3 + row(null, s.purchases + t.purchases, s.address, 2, 3) + WHEN (NOT present) THEN + // Insert column values; case #3; operation INSERT=1 + row(s.customer, s.purchases, s.address, 4, 1) + ELSE + // Null values for no case matched; case #-1; operation=-1 + row(null, null, null, -1, -1) + END; + +The Trino engine executes the ``RIGHT JOIN`` and ``SELECT CASE``, +creating a sequence of pages to be routed to the node that runs the +``ConnectorMergeSink`` ``storeMergedRows`` method. To ensure that all +``NOT MATCHED`` rows for a given partition end up on a single node, a +redistribution hash on the partition key/bucket column(s) is applied +to the page partition key(s). As a result of the hash, all rows for a +specific partition/bucket will hash together, whether they were +``MATCHED`` rows or ``NOT MATCHED`` rows. + +Like ``DELETE`` and ``UPDATE``, ``MERGE`` target table rows are identified by +a connector-specific ``rowId`` column handle, returned by +``ConnectorMetadata.getMergeRowIdColumnHandle(...)``. For ``MERGE``, +that ``rowId`` column contains whatever the connector needs in order +to process all the possible ``MERGE`` cases for that row. For example, +in the case of the Hive connector, the merge handle contains the +``InsertTableHandle`` for the target table. + +Representation of ``MERGE`` cases +--------------------------------- + +The Trino engine provide a ``MergeDetails`` instance to describe the ``MERGE`` +cases to the connector. ``MergeDetails`` contains a list of +``MergeCaseDetails`` instances. ``MergeCaseDetails`` has these members: + +* The ``int`` ``caseNumber``, starting from 0, in syntactic order. +* The ``MergeCaseKind`` ``caseKind``: One of ``UPDATE``, ``DELETE`` or + ``INSERT``. +* The ``Set`` ``updatedColumns``: The columns updated by the case. + For an ``INSERT``, all data columns targeted in the insert are included. + +``MERGE`` Redistribution +------------------------ + +The Trino ``MERGE`` implementation allows ``UPDATE`` and ``INSERT`` to change +the values of columns that determine partitioning and/or bucketing, and so +the Trino engine must "redistribute" rows from the ``MERGE`` operation to +the worker nodes responsible for writing rows with the merged partitioning +and/or bucketing columns. + +Connector support for ``MERGE`` +=============================== + +To start start ``MERGE`` processing, the Trino engine calls: + +* ``ConnectorMetadata.getMergeRowIdColumnHandle(...)`` to get the + ``rowId`` column handle. +* ``ConnectorMetadata.getRowChangeParadigm(...)`` to get the paradigm + supported by the connectoor for changing existing table rows. +* ``ConnectorMetadata.beginMerge(...)`` to get the a + ``ConnectorMergeTableHandle`` for the merge operation. That + ``ConnectorMergeTableHandle`` object contains whatever information the + connector needs to specify the ``MERGE`` operation. +* ``ConnectorMetadata.getWriteRedistributionColumns(...)`` to get the list + of partition or table columns that impact write redistribution. + +On nodes that are targets of the hash, the Trino engine calls +``ConnectorPageSinkProvider.createMergeSink(...)`` to create a +``ConnectorMergeSink``. + +To write out each page of merged rows, the Trino engine calls +``ConnectorMergeSink.storeMergedRows(Page)``. The ``storeMergedRows(Page)`` +method iterates over the rows in the page, performing updates and deletes +in the ``MATCHED`` cases, and inserts in the ``NOT MATCHED`` cases. + +For some ``RowChangeParadigms``, ``UPDATE`` operations may have been +translated into the corresponding ``DELETE`` and ``INSERT`` operations +before ``storeMergedRows(Page)`` was called. + +To complete the ``MERGE`` operation, the Trino engine calls +``ConnectorMetadata.finishMerge(...)``, passing the table +and that collection of ``Slice`` fragments with information about what +was changed, and the connector takes appropriate actions, if +any. + +``RowChangeProcessor`` implementation for ``MERGE`` +--------------------------------------------------- + +In SQL ``MERGE``, each supported ``RowChangeParadigm`` +corresponds to an internal Trino engine class that implements interface +``RowChangeProcessor``. ``RowChangeProcessor`` has one interesting method: +``Page transformPage(Page)``. The format of the input and output page depend +on the the SQL operation being implemented. + +The connector has no access to the ``RowChangeProcessor`` instance - - it +is used inside the Trino engine to transform the merge page rows into rows +to be stored, based on the connector's choice of ``RowChangeParadigm``. + +For SQL MERGE, the page supplied to ``transformPage()`` consists of: + +* The write redistribution columns if any +* The ``rowId`` column for the row from the target table if matched, or + null if not matched +* The merge case ``RowBlock`` +* For partitioned or bucketed tables, a hash value column. + +The merge case ``RowBlock`` has this layout: + +* Blocks for each column in the table, including partition columns, in + table column order. +* A block containing the merge case number of the matching merge case for + the row starting at 0, or -1 no merge case matched for the row. +* A block containing the merge case operation, encoded as ``INSERT`` = 1, + ``DELETE`` = 2, ``UPDATE`` = 3 and if no merge case matched, -1. + +The page returned from ``transformPage`` consists of all table columns, +in table column order, followed by the rowId column, followed by the +operation column from the merge case ``RowBlock``. + +Interface ``RowChangeProcessor`` now supports SQL MERGE, and will later be +used to upgrade the SQL UPDATE engine implementation. + +Detecting duplicate matching target rows +---------------------------------------- + +The SQL ``MERGE`` specification requires that in each ``MERGE`` case, +a single target table row must match at most one source row, after +applying the ``MERGE`` case condition expression. This is done in +class ``DuplicateRowFinder.checkForDuplicateTargetRows()``. That method +examines successive rows produced by the ``SELECT CASE`` expression, +excluding ``INSERT``ed rows, and if the ``writeRedistributionColumns`` +and the ``rowId`` column are the identical in successive rows. +If so, it raises the ``MERGE_TARGET_ROW_MULTIPLE_MATCHES`` +exception. + +``ConnectorMergeTableHandle`` API +--------------------------------- + +Interface ``ConnectorMergeTableHandle`` defines one method, +``getTableHandle()`` to retrieve the ``ConnectorTableHandle`` +returned originally created by ``ConnectorMetadata.beginMerge()``. + +``ConnectorPageSinkProvider`` API +--------------------------------- + +To support ``SQL MERGE``, ``ConnectorPageSinkProvider`` must implement +the method that creates the ``ConnectorMergeSink``: + +* ``createMergeSink``:: + + ConnectorMergeSink createMergeSink( + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + ConnectorMergeTableHandle mergeHandle) + + Create a ``ConnectorMergeSink`` for the ``transactionHandle``, + ``session`` and ``mergeHandle`` + ``Session`` and + +``ConnectorMergeSink`` API +-------------------------- + +As mentioned above, to support ``MERGE``, the connector must define an +implementation of ``ConnectorMergeSink``, usually layered over the +connector's ``ConnectorPageSink``. + +The ``ConnectorMergeSink`` is created by a call to +``ConnectorPageSinkProvider.createMergeSink()``. + +The only interesting methods are: + +* ``storeMergedRows``:: + + void storeMergedRows(Page page) + + The Trino engine calls the + ``storeMergedRows(Page)`` method of the ``ConnectorMergeSink`` + instance returned by ``ConnectorPageSinkProvider.createMergeSink()``, + passing the page generated by the ``RowChangeProcessor.transformPage() + method. That page consists of all table columns, in table column + order, followed by the rowId column, followed by the operation column + from the merge case ``RowBlock``. + + The job of ``storeMergedRows()`` is iterate over the rows in the page, + and based on the value of the operation column, ``INSERT``, ``DELETE``, + ``UPDATE``, or ignore the page row. For some connnectors, the ``UPDATE`` + operations is transformed into ``DELETE`` and ``INSERT`` operations. + +* ``finish``:: + + ``CompletableFuture> finish()`` + + The Trino engine + calls ``finish()`` when all the data has been processed by a specific + ``ConnectorMergeSink`` instance. The connector returns a future containing + a collection of ``Slice``, representing connector-specific information + about the rows processed. Usually this includes the row count, and + might include information like the files or partitions created or changed. + +``ConnectorMetadata`` ``MERGE`` API +----------------------------------- + +A connector implementing ``MERGE`` must implement these ``ConnectorMetadata`` +methods. + +* ``getRowChangeParadigm()``:: + + RowChangeParadigm getRowChangeParadigm( + ConnectorSession session, + ConnectorTableHandle tableHandle) + + This method is called as the engine starts processing a ``MERGE`` statement. + The connector must return a ``RowChangeParadigm`` enum instance. If the + connector does not support ``MERGE`` it should throw a ``NOT_SUPPORTED, meaning + that ``SQL MERGE`` is not supported by the connector. + +* ``getMergeRowIdColumnHandle()``:: + + ColumnHandle getMergeRowIdColumnHandle( + ConnectorSession session, + ConnectorTableHandle tableHandle, + MergeDetails mergeDetails) + + This method is called in the early stages of query planning for ``MERGE`` + statements. The ColumnHandle returned provides the ``rowId`` used by the + connector to identify rows to be merged, as well as any other fields of + the row that the connector needs to complete the ``MERGE`` operation. + +* ``getWriteRedistributionColumns()``:: + + List getWriteRedistributionColumns( + ConnectorSession session, + ConnectorTableHandle table) + + This method returns a list of ``ColumnHandles`` for table columns that + impact write redistribution, e.g., columns that impact partitioning or + bucketing. By default, this method returns an empty list. + +* ``beginMerge()``:: + + ConnectorMergeTableHandle beginDelete( + ConnectorSession session, + ConnectorTableHandle tableHandle, + MergeDetails mergeDetails) + + As the last step in creating the ``MERGE`` execution plan, the connector's + ``beginMerge()`` method is called, passing the ``session``, the + ``tableHandle`` and the ``MergeDetails`` object. + + ``beginMerge()`` performs any orchestration needed in the connector to + start processing the ``MERGE``. This orchestration varies from connector + to connector. In the Hive ACID connector, for example, ``beginMerge()`` + checks that the table is transactional and that all updated columns are + writable, and starts a Hive Metastore transaction. + + ``beginMerge()`` returns a ``ConnectorMergeTableHandle`` with any added + information the connector needs when the handle is passed back to + ``finishMerge()`` and the split generation machinery. For most + connectors, the returned table handle contains at least a flag identifying + the table handle as a table handle for a ``MERGE`` operation. + +* ``finishMerge()``:: + + void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle tableHandle, + Collection fragments) + + During ``MERGE`` processing, the Trino engine accumulates the ``Slice`` + collections returned by ``ConnectorMergeSink.finish()``. The engine calls + ``finishMerge()``, passing the table handle and that collection of + ``Slice`` fragments. In response, the connector takes appropriate actions + to complete the ``MERGE`` operation. Those actions might include + committing the transaction, assuming the connector supports a transaction + paradigm. diff --git a/docs/src/main/sphinx/sql.rst b/docs/src/main/sphinx/sql.rst index 2f23244fce4e..b24cbc41493d 100644 --- a/docs/src/main/sphinx/sql.rst +++ b/docs/src/main/sphinx/sql.rst @@ -38,6 +38,7 @@ Trino also provides :doc:`numerous SQL functions and operators`. sql/grant sql/grant-roles sql/insert + sql/merge sql/prepare sql/reset-session sql/revoke diff --git a/docs/src/main/sphinx/sql/merge.rst b/docs/src/main/sphinx/sql/merge.rst new file mode 100644 index 000000000000..671a3cd7ddcc --- /dev/null +++ b/docs/src/main/sphinx/sql/merge.rst @@ -0,0 +1,94 @@ +===== +MERGE +===== + +Synopsis +-------- + +.. code-block:: text + + MERGE INTO target_table [ [ AS ] target_alias ] + USING { source_table | query } [ [ AS ] source_alias ] + ON search_condition + when_clause [...] + +where ``when_clause`` is one of + +.. code-block:: text + + WHEN MATCHED [ AND condition ] + THEN DELETE + +.. code-block:: text + + WHEN MATCHED [ AND condition ] + THEN UPDATE SET ( column = expression [, ...] ) + +.. code-block:: text + + WHEN NOT MATCHED [ AND condition ] + THEN INSERT [ column_list ] VALUES (expression, ...) + +Description +----------- + +Conditionally update and/or delete rows of a table and/or insert new +rows into a table. + +``MERGE`` supports an arbitrary number of ``WHEN`` clauses with different +``MATCHED`` conditions, executing the ``DELETE``, ``UPDATE`` or ``INSERT`` +operation in the first ``WHEN`` clause selected by the ``MATCHED`` +state and the match condition. + +In ``WHEN`` clauses with ``UPDATE`` operations, the column value expressions +can depend on any field of the target or the source. In the ``NOT MATCHED`` +case, the ``INSERT`` expressions can depend only on the source. + +Each row in the target table must matched by at most one row in the source. +If more than one source row is matched by a target table row, a +``MERGE_TARGET_ROW_MULTIPLE_MATCHES`` exception is raised. + + +Examples +-------- + +Delete all customers mentioned in the source table:: + + MERGE INTO target_table t USING source_table s + ON (t.customer = s.customer) + WHEN MATCHED + THEN DELETE + +For matching customer rows, increment the purchases, and if there is no +match, insert the row from the source table:: + + MERGE INTO target_table t USING source_table s + ON (t.customer = s.customer) + WHEN MATCHED + THEN UPDATE SET purchases = s.purchases + t.purchases + WHEN NOT MATCHED + THEN INSERT (customer, purchases, address) + VALUES(s.customer, s.purchases, s.address) + +``MERGE`` into the target table from the source table, deleting any matching +target row for which the source address is Centreville. For all other +matching rows, add the source purchases and set the address to the source +address, if there is no match in the target table, insert the source +table row:: + + MERGE INTO target_table t USING source_table s + ON (t.customer = s.customer) + WHEN MATCHED AND s.address = 'Centreville' + THEN DELETE + WHEN MATCHED + THEN UPDATE + SET purchases = s.purchases + t.purchases, address = s.address + WHEN NOT MATCHED + THEN INSERT (customer, purchases, address) + VALUES(s.customer, s.purchases, s.address) + +Limitations +----------- + +Some connectors have limited or no support for ``MERGE``. +See connector documentation for more details. diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java new file mode 100644 index 000000000000..12684af8b9e0 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.classloader; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorMergeSink; + +import javax.inject.Inject; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +public class ClassLoaderSafeConnectorMergeSink + implements ConnectorMergeSink +{ + private final ConnectorMergeSink delegate; + private final ClassLoader classLoader; + + @Inject + public ClassLoaderSafeConnectorMergeSink(@ForClassLoaderSafe ConnectorMergeSink delegate, ClassLoader classLoader) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public void storeMergedRows(Page page) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.storeMergedRows(page); + } + } + + @Override + public CompletableFuture> finish() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.finish(); + } + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 1bb31955d5dc..3ab020f1b791 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -22,6 +22,7 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorNewTableLayout; import io.trino.spi.connector.ConnectorOutputMetadata; @@ -45,7 +46,9 @@ import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -115,6 +118,14 @@ public Optional getCommonPartitioningHandle(Connect } } + @Override + public List getWriteRedistributionColumns(ConnectorSession session, ConnectorTableHandle table) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getWriteRedistributionColumns(session, table); + } + } + @Override public ConnectorTableLayoutHandle makeCompatiblePartitioning(ConnectorSession session, ConnectorTableLayoutHandle tableLayoutHandle, ConnectorPartitioningHandle partitioningHandle) { @@ -569,6 +580,14 @@ public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, Connect } } + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getMergeRowIdColumnHandle(session, tableHandle, mergeDetails); + } + } + @Override public ConnectorTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle) { @@ -890,4 +909,35 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan delegate.finishUpdate(session, tableHandle, fragments); } } + + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getRowChangeParadigm(session, tableHandle); + } + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.beginMerge(session, tableHandle, mergeDetails); + } + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.finishMerge(session, tableHandle, fragments, computedStatistics); + } + } + + @Override + protected Object clone() + throws CloneNotSupportedException + { + return super.clone(); + } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java index 4898cf8da4f1..8708a84ba4e6 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java @@ -15,6 +15,8 @@ import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -53,4 +55,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, insertTableHandle), classLoader); } } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return new ClassLoaderSafeConnectorMergeSink(delegate.createMergeSink(transactionHandle, session, mergeHandle), classLoader); + } + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java index 4668355e2ea0..e35c7e64c1f1 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java @@ -15,12 +15,18 @@ import io.trino.plugin.hive.acid.AcidOperation; import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.orc.OrcFileWriter; import io.trino.plugin.hive.orc.OrcFileWriterFactory; +import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.TypeManager; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; +import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.Properties; @@ -31,11 +37,15 @@ import static io.trino.orc.OrcWriter.OrcOperation.DELETE; import static io.trino.orc.OrcWriter.OrcOperation.INSERT; import static io.trino.plugin.hive.HiveStorageFormat.ORC; +import static io.trino.plugin.hive.HiveUpdatablePageSource.BUCKET_CHANNEL; +import static io.trino.plugin.hive.HiveUpdatablePageSource.ORIGINAL_TRANSACTION_CHANNEL; +import static io.trino.plugin.hive.HiveUpdatablePageSource.ROW_ID_CHANNEL; import static io.trino.plugin.hive.acid.AcidSchema.ACID_COLUMN_NAMES; import static io.trino.plugin.hive.acid.AcidSchema.createAcidSchema; import static io.trino.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat; import static io.trino.plugin.hive.util.ConfigurationUtils.toJobConf; import static io.trino.spi.predicate.Utils.nativeValueToBlock; +import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -57,17 +67,21 @@ public abstract class AbstractHiveAcidWriters protected final AcidTransaction transaction; protected final OptionalInt bucketNumber; protected final int statementId; + protected final Block bucketValueBlock; + private final OrcFileWriterFactory orcFileWriterFactory; private final Configuration configuration; protected final ConnectorSession session; protected final HiveType hiveRowType; private final Properties hiveAcidSchema; + protected final Block hiveRowTypeNullsBlock; + protected Path deltaDirectory; protected final Path deleteDeltaDirectory; private final String bucketFilename; - protected Optional deltaDirectory; protected Optional deleteFileWriter = Optional.empty(); protected Optional insertFileWriter = Optional.empty(); + private int insertRowCounter; public AbstractHiveAcidWriters( AcidTransaction transaction, @@ -78,18 +92,21 @@ public AbstractHiveAcidWriters( OrcFileWriterFactory orcFileWriterFactory, Configuration configuration, ConnectorSession session, + TypeManager typeManager, HiveType hiveRowType, AcidOperation updateKind) { this.transaction = requireNonNull(transaction, "transaction is null"); this.statementId = statementId; this.bucketNumber = requireNonNull(bucketNumber, "bucketNumber is null"); + this.bucketValueBlock = nativeValueToBlock(INTEGER, Long.valueOf(OrcFileWriter.computeBucketValue(bucketNumber.orElse(0), statementId))); this.orcFileWriterFactory = requireNonNull(orcFileWriterFactory, "orcFileWriterFactory is null"); this.configuration = requireNonNull(configuration, "configuration is null"); this.session = requireNonNull(session, "session is null"); checkArgument(transaction.isTransactional(), "Not in a transaction: %s", transaction); this.hiveRowType = requireNonNull(hiveRowType, "hiveRowType is null"); this.hiveAcidSchema = createAcidSchema(hiveRowType); + this.hiveRowTypeNullsBlock = nativeValueToBlock(hiveRowType.getType(typeManager), null); requireNonNull(bucketPath, "bucketPath is null"); Matcher matcher; if (originalFile) { @@ -109,13 +126,33 @@ public AbstractHiveAcidWriters( } } long writeId = transaction.getWriteId(); + this.deltaDirectory = new Path(format("%s/%s", matcher.group("rootDir"), deltaSubdir(writeId, writeId, statementId))); this.deleteDeltaDirectory = new Path(format("%s/%s", matcher.group("rootDir"), deleteDeltaSubdir(writeId, writeId, statementId))); - if (updateKind == AcidOperation.UPDATE) { - this.deltaDirectory = Optional.of(new Path(format("%s/%s", matcher.group("rootDir"), deltaSubdir(writeId, writeId, statementId)))); - } - else { - this.deltaDirectory = Optional.empty(); + } + + protected Page buildDeletePage(Block rowIds, long writeId) + { + int positionCount = rowIds.getPositionCount(); + List blocks = rowIds.getChildren(); + Block[] blockArray = { + new RunLengthEncodedBlock(DELETE_OPERATION_BLOCK, positionCount), + blocks.get(ORIGINAL_TRANSACTION_CHANNEL), + blocks.get(BUCKET_CHANNEL), + blocks.get(ROW_ID_CHANNEL), + RunLengthEncodedBlock.create(BIGINT, writeId, positionCount), + new RunLengthEncodedBlock(hiveRowTypeNullsBlock, positionCount), + }; + return new Page(blockArray); + } + + protected Block createRowIdBlock(int positionCount) + { + long[] rowIds = new long[positionCount]; + for (int index = 0; index < positionCount; index++) { + rowIds[index] = insertRowCounter; + insertRowCounter++; } + return new LongArrayBlock(positionCount, Optional.empty(), rowIds); } protected void lazyInitializeDeleteFileWriter() @@ -142,9 +179,8 @@ protected void lazyInitializeInsertFileWriter() if (insertFileWriter.isEmpty()) { Properties schemaCopy = new Properties(); schemaCopy.putAll(hiveAcidSchema); - Path deltaDir = deltaDirectory.orElseThrow(() -> new IllegalArgumentException("deltaDirectory not present")); insertFileWriter = orcFileWriterFactory.createFileWriter( - new Path(format("%s/%s", deltaDir, bucketFilename)), + new Path(format("%s/%s", deltaDirectory, bucketFilename)), ACID_COLUMN_NAMES, fromHiveStorageFormat(ORC), schemaCopy, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java index b9ca91d2c683..3ca118d6a7f5 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java @@ -15,24 +15,31 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.Sets; import io.trino.plugin.hive.metastore.Column; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.MergeCaseDetails; +import io.trino.spi.connector.MergeCaseKind; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.type.Type; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.SYNTHESIZED; import static io.trino.plugin.hive.HiveType.HIVE_INT; import static io.trino.plugin.hive.HiveType.HIVE_LONG; import static io.trino.plugin.hive.HiveType.HIVE_STRING; import static io.trino.plugin.hive.HiveType.toHiveType; -import static io.trino.plugin.hive.HiveUpdateProcessor.getUpdateRowIdColumnHandle; +import static io.trino.plugin.hive.HiveUpdateProcessor.getRowIdColumnHandleForNonUpdatedColumns; import static io.trino.plugin.hive.acid.AcidSchema.ACID_ROW_ID_ROW_TYPE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; @@ -265,7 +272,34 @@ public static HiveColumnHandle updateRowIdColumnHandle(List co List nonUpdatedColumnHandles = columnHandles.stream() .filter(column -> !column.isPartitionKey() && !column.isHidden() && !updatedColumns.contains(column)) .collect(toImmutableList()); - return getUpdateRowIdColumnHandle(nonUpdatedColumnHandles); + return getRowIdColumnHandleForNonUpdatedColumns(nonUpdatedColumnHandles); + } + + public static HiveColumnHandle mergeRowIdColumnHandle() + { + return createBaseColumn(UPDATE_ROW_ID_COLUMN_NAME, UPDATE_ROW_ID_COLUMN_INDEX, toHiveType(ACID_ROW_ID_ROW_TYPE), ACID_ROW_ID_ROW_TYPE, SYNTHESIZED, Optional.empty()); + } + + /** + * Return union of all columns not updated in some merge case + * @param allColumns All table data columns, in table column order + * @param mergeDetails A MergeDetails instance + * @return the list of all columns updated in _some_ merge case but not + * updated in other(s), in table column order. + */ + public static List computeNonUpdatedColumns(List allColumns, MergeDetails mergeDetails) + { + Set allColumnsNotUpdated = new HashSet<>(); + Set allColumnNames = allColumns.stream().map(HiveColumnHandle::getName).collect(toImmutableSet()); + for (MergeCaseDetails mergeCase : mergeDetails.getCases()) { + if (mergeCase.getCaseKind() != MergeCaseKind.DELETE) { + allColumnsNotUpdated.addAll(Sets.difference(allColumnNames, mergeCase.getUpdatedColumns())); + } + } + + return allColumns.stream() + .filter(column -> allColumnsNotUpdated.contains(column.getName())) + .collect(toImmutableList()); } public static HiveColumnHandle pathColumnHandle() diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveHandleResolver.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveHandleResolver.java index f02a499af864..b4f1a875027b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveHandleResolver.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveHandleResolver.java @@ -16,6 +16,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorHandleResolver; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; @@ -55,6 +56,12 @@ public Class getInsertTableHandleClass() return HiveInsertTableHandle.class; } + @Override + public Class getMergeTableHandleClass() + { + return HiveMergeTableHandle.class; + } + @Override public Class getTransactionHandleClass() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java new file mode 100644 index 000000000000..15837d9d861a --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public class HiveMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final HiveTableHandle tableHandle; + private final HiveInsertTableHandle insertHandle; + + @JsonCreator + public HiveMergeTableHandle( + @JsonProperty("tableHandle") HiveTableHandle tableHandle, + @JsonProperty("insertHandle") HiveInsertTableHandle insertHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertHandle = requireNonNull(insertHandle, "insertHandle is null"); + } + + @Override + @JsonProperty + public HiveTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public HiveInsertTableHandle getInsertHandle() + { + return insertHandle; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index 7af0c762e746..b6ae0b3f832e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -60,6 +60,7 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorNewTableLayout; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -76,7 +77,9 @@ import io.trino.spi.connector.DiscretePredicates; import io.trino.spi.connector.InMemoryRecordSet; import io.trino.spi.connector.MaterializedViewFreshness; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -163,6 +166,7 @@ import static io.trino.plugin.hive.HiveColumnHandle.PARTITION_COLUMN_NAME; import static io.trino.plugin.hive.HiveColumnHandle.PATH_COLUMN_NAME; import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; +import static io.trino.plugin.hive.HiveColumnHandle.mergeRowIdColumnHandle; import static io.trino.plugin.hive.HiveColumnHandle.updateRowIdColumnHandle; import static io.trino.plugin.hive.HiveErrorCode.HIVE_COLUMN_ORDER_MISMATCH; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CONCURRENT_MODIFICATION_DETECTED; @@ -262,6 +266,7 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_EMPTY; import static io.trino.spi.StandardErrorCode.TABLE_NOT_FOUND; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.predicate.TupleDomain.withColumnDomains; import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; @@ -1647,7 +1652,8 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan HdfsContext context = new HdfsContext(session); for (PartitionAndStatementId ps : partitionAndStatementIds) { - createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory())); + createOrcAcidVersionFile(context, new Path(ps.getDeltaDirectory().get())); + createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory().get())); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); @@ -1655,8 +1661,92 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan metastore.finishUpdate(session, table.getDatabaseName(), table.getTableName(), writeInfo.getWritePath(), partitionAndStatementIds); } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + HiveTableHandle handle = (HiveTableHandle) tableHandle; + Optional> properties = handle.getTableParameters(); + if (isTransactionalTable(properties.get())) { + return DELETE_ROW_AND_INSERT_ROW; + } + throw new TrinoException(NOT_SUPPORTED, "Hive merge is only supported for transactional tables"); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + HiveIdentity identity = new HiveIdentity(session); + SchemaTableName tableName = hiveTableHandle.getSchemaTableName(); + Table table = metastore.getTable(identity, tableName.getSchemaName(), tableName.getTableName()) + .orElseThrow(() -> new TableNotFoundException(tableName)); + + if (!isTransactionalTable(table.getParameters())) { + throw new TrinoException(NOT_SUPPORTED, "Hive merge is only supported for transactional tables"); + } + + checkTableIsWritable(table, writesToNonManagedTablesEnabled); + + for (Column column : table.getDataColumns()) { + if (!isWritableType(column.getType())) { + throw new TrinoException(NOT_SUPPORTED, format("Updating a Hive table with column type %s not supported", column.getType())); + } + } + + if (table.getParameters().containsKey(SKIP_HEADER_COUNT_KEY)) { + throw new TrinoException(NOT_SUPPORTED, format("Updating a Hive table with %s property not supported", SKIP_HEADER_COUNT_KEY)); + } + if (table.getParameters().containsKey(SKIP_FOOTER_COUNT_KEY)) { + throw new TrinoException(NOT_SUPPORTED, format("Updating a Hive table with %s property not supported", SKIP_FOOTER_COUNT_KEY)); + } + + HiveInsertTableHandle insertHandle = beginInsertOrMerge(session, tableHandle, "Merging", true); + return new HiveMergeTableHandle(hiveTableHandle.withTransaction(insertHandle.getTransaction()), insertHandle); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + HiveMergeTableHandle mergeHandle = (HiveMergeTableHandle) tableHandle; + HiveInsertTableHandle insertHandle = mergeHandle.getInsertHandle(); + HiveTableHandle handle = mergeHandle.getTableHandle(); + checkArgument(handle.isAcidMerge(), "handle should be a merge handle, but is %s", handle); + + requireNonNull(fragments, "fragments is null"); + List partitionMergeResults = fragments.stream() + .map(Slice::getBytes) + .map(PartitionUpdateAndMergeResults.CODEC::fromJson) + .collect(toList()); + + List partitionUpdates = partitionMergeResults.stream() + .map(PartitionUpdateAndMergeResults::getPartitionUpdate) + .collect(toImmutableList()); + + Table table = finishChangingTable(AcidOperation.MERGE, "merge", session, insertHandle, partitionUpdates, computedStatistics); + + HdfsContext context = new HdfsContext(session); + for (PartitionUpdateAndMergeResults ps : partitionMergeResults) { + ps.getDeltaDirectory().ifPresent(deltaDirectory -> createOrcAcidVersionFile(context, new Path(deltaDirectory))); + ps.getDeleteDeltaDirectory().ifPresent(deleteDeltadDirectory -> createOrcAcidVersionFile(context, new Path(deleteDeltadDirectory))); + } + + List partitions = partitionUpdates.stream() + .filter(update -> !update.getName().isEmpty()) + .map(update -> buildPartitionObject(session, table, update)) + .collect(toImmutableList()); + + LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); + WriteInfo writeInfo = locationService.getQueryWriteInfo(locationHandle); + metastore.finishMerge(session, table.getDatabaseName(), table.getTableName(), writeInfo.getWritePath(), partitionMergeResults, partitions); + } + @Override public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return beginInsertOrMerge(session, tableHandle, "Inserting", false); + } + + private HiveInsertTableHandle beginInsertOrMerge(ConnectorSession session, ConnectorTableHandle tableHandle, String description, boolean isForMerge) { HiveIdentity identity = new HiveIdentity(session); SchemaTableName tableName = ((HiveTableHandle) tableHandle).getSchemaTableName(); @@ -1667,7 +1757,7 @@ public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTabl for (Column column : table.getDataColumns()) { if (!isWritableType(column.getType())) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table %s with column type %s not supported", tableName, column.getType())); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table %s with column type %s not supported", description, tableName, column.getType())); } } @@ -1677,15 +1767,24 @@ public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTabl HiveStorageFormat tableStorageFormat = extractHiveStorageFormat(table); if (table.getParameters().containsKey(SKIP_HEADER_COUNT_KEY)) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table with %s property not supported", SKIP_HEADER_COUNT_KEY)); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table with %s property not supported", description, SKIP_HEADER_COUNT_KEY)); } if (table.getParameters().containsKey(SKIP_FOOTER_COUNT_KEY)) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table with %s property not supported", SKIP_FOOTER_COUNT_KEY)); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table with %s property not supported", description, SKIP_FOOTER_COUNT_KEY)); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); - AcidTransaction transaction = isTransactionalTable(table.getParameters()) ? metastore.beginInsert(session, table) : NO_ACID_TRANSACTION; - + AcidTransaction transaction = NO_ACID_TRANSACTION; + boolean isTransactional = isTransactionalTable(table.getParameters()); + if (isForMerge) { + checkArgument(isTransactional, "It's a merge, but isTransactional is false"); + transaction = metastore.beginMerge(session, table); + } + else { + if (isTransactional) { + transaction = metastore.beginInsert(session, table); + } + } HiveInsertTableHandle result = new HiveInsertTableHandle( tableName.getSchemaName(), tableName.getTableName(), @@ -1712,13 +1811,32 @@ public Optional finishInsert(ConnectorSession session, .map(partitionUpdateCodec::fromJson) .collect(toList()); + Table table = finishChangingTable(AcidOperation.INSERT, "insert", session, handle, partitionUpdates, computedStatistics); + + if (isFullAcidTable(table.getParameters())) { + HdfsContext context = new HdfsContext(session); + for (PartitionUpdate update : partitionUpdates) { + long writeId = handle.getTransaction().getWriteId(); + Path deltaDirectory = new Path(format("%s/%s/%s", table.getStorage().getLocation(), update.getName(), deltaSubdir(writeId, writeId, 0))); + createOrcAcidVersionFile(context, deltaDirectory); + } + } + + return Optional.of(new HiveWrittenPartitions( + partitionUpdates.stream() + .map(PartitionUpdate::getName) + .collect(toImmutableList()))); + } + + private Table finishChangingTable(AcidOperation acidOperation, String changeDescription, ConnectorSession session, HiveInsertTableHandle handle, List partitionUpdates, Collection computedStatistics) + { HiveStorageFormat tableStorageFormat = handle.getTableStorageFormat(); partitionUpdates = PartitionUpdate.mergePartitionUpdates(partitionUpdates); Table table = metastore.getTable(new HiveIdentity(session), handle.getSchemaName(), handle.getTableName()) .orElseThrow(() -> new TableNotFoundException(handle.getSchemaTableName())); if (!table.getStorage().getStorageFormat().getInputFormat().equals(tableStorageFormat.getInputFormat()) && isRespectTableFormat(session)) { - throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during insert"); + throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during " + changeDescription); } if (handle.getBucketProperty().isPresent() && isCreateEmptyBucketFiles(session)) { @@ -1746,7 +1864,7 @@ public Optional finishInsert(ConnectorSession session, if (partitionUpdate.getName().isEmpty()) { // insert into unpartitioned table if (!table.getStorage().getStorageFormat().getInputFormat().equals(handle.getPartitionStorageFormat().getInputFormat()) && isRespectTableFormat(session)) { - throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during insert"); + throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during " + changeDescription); } PartitionStatistics partitionStatistics = createPartitionStatistics( @@ -1766,7 +1884,8 @@ public Optional finishInsert(ConnectorSession session, } else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode() == APPEND) { // insert into unpartitioned table - metastore.finishInsertIntoExistingTable( + metastore.finishChangingExistingTable( + acidOperation, session, handle.getSchemaName(), handle.getTableName(), @@ -1813,20 +1932,7 @@ else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode throw new IllegalArgumentException(format("Unsupported update mode: %s", partitionUpdate.getUpdateMode())); } } - - if (isFullAcidTable(table.getParameters())) { - HdfsContext context = new HdfsContext(session); - for (PartitionUpdate update : partitionUpdates) { - long writeId = handle.getTransaction().getWriteId(); - Path deltaDirectory = new Path(format("%s/%s/%s", table.getStorage().getLocation(), update.getName(), deltaSubdir(writeId, writeId, 0))); - createOrcAcidVersionFile(context, deltaDirectory); - } - } - - return Optional.of(new HiveWrittenPartitions( - partitionUpdates.stream() - .map(PartitionUpdate::getName) - .collect(toImmutableList()))); + return table; } private void createOrcAcidVersionFile(HdfsContext context, Path deltaDirectory) @@ -2115,7 +2221,6 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan HiveIdentity identity = new HiveIdentity(session); Table table = metastore.getTable(identity, tableName.getSchemaName(), tableName.getTableName()) .orElseThrow(() -> new TableNotFoundException(tableName)); - ensureTableSupportsDelete(table); List partitionAndStatementIds = fragments.stream() .map(Slice::getBytes) @@ -2124,7 +2229,7 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan HdfsContext context = new HdfsContext(session); for (PartitionAndStatementId ps : partitionAndStatementIds) { - createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory())); + createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory().get())); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); @@ -2152,6 +2257,12 @@ public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, Connect return updateRowIdColumnHandle(table.getDataColumns(), updatedColumns); } + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + return mergeRowIdColumnHandle(); + } + @Override public Optional applyDelete(ConnectorSession session, ConnectorTableHandle handle) { @@ -2447,6 +2558,19 @@ public Optional getCommonPartitioningHandle(Connect false)); } + @Override + public List getWriteRedistributionColumns(ConnectorSession session, ConnectorTableHandle tableHandle) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + Map columnNames = new HashMap<>(); + hiveTableHandle.getPartitionColumns().stream().forEach(column -> columnNames.put(column.getBaseColumnName(), column)); + hiveTableHandle.getBucketHandle().ifPresent(handle -> handle.getColumns().forEach(column -> columnNames.put(column.getBaseColumnName(), column))); + return getTableMetadata(session, hiveTableHandle.getSchemaTableName()).getColumns().stream() + .map(column -> columnNames.get(column.getName())) + .filter(column -> column != null) + .collect(toImmutableList()); + } + private static OptionalInt min(OptionalInt left, OptionalInt right) { if (left.isEmpty()) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java index d275b2aa9c80..a31218bc55dd 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java @@ -29,6 +29,8 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlockBuilder; +import io.trino.spi.block.RowBlock; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; @@ -46,6 +48,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.slice.Slices.wrappedBuffer; @@ -57,7 +60,7 @@ import static java.util.stream.Collectors.toList; public class HivePageSink - implements ConnectorPageSink + implements ConnectorPageSink, ConnectorMergeSink { private static final Logger log = Logger.get(HivePageSink.class); @@ -83,11 +86,13 @@ public class HivePageSink private final ConnectorSession session; + private final boolean isMergeSink; private long writtenBytes; private long systemMemoryUsage; private long validationCpuNanos; public HivePageSink( + HiveWritableTableHandle tableHandle, HiveWriterFactory writerFactory, List inputColumns, Optional bucketProperty, @@ -109,6 +114,7 @@ public HivePageSink( this.writeVerificationExecutor = requireNonNull(writeVerificationExecutor, "writeVerificationExecutor is null"); this.partitionUpdateCodec = requireNonNull(partitionUpdateCodec, "partitionUpdateCodec is null"); + this.isMergeSink = tableHandle.getTransaction().isMerge(); requireNonNull(bucketProperty, "bucketProperty is null"); this.pagePartitioner = new HiveWriterPagePartitioner( inputColumns, @@ -178,10 +184,28 @@ public CompletableFuture> finish() { // Must be wrapped in doAs entirely // Implicit FileSystem initializations are possible in HiveRecordWriter#commit -> RecordWriter#close - ListenableFuture> result = hdfsEnvironment.doAs(session.getUser(), this::doFinish); + ListenableFuture> result = hdfsEnvironment.doAs( + session.getUser(), + isMergeSink ? this::doMergeSinkFinish : this::doFinish); return MoreFutures.toCompletableFuture(result); } + private ListenableFuture> doMergeSinkFinish() + { + ImmutableList.Builder resultSlices = ImmutableList.builder(); + for (HiveWriter writer : writers) { + writer.commit(); + MergeFileWriter mergeFileWriter = (MergeFileWriter) writer.getFileWriter(); + PartitionUpdateAndMergeResults results = mergeFileWriter.getPartitionUpdateAndMergeResults(writer.getPartitionUpdate()); + resultSlices.add(wrappedBuffer(PartitionUpdateAndMergeResults.CODEC.toJsonBytes(results))); + } + List result = resultSlices.build(); + writtenBytes = writers.stream() + .mapToLong(HiveWriter::getWrittenBytes) + .sum(); + return Futures.immediateFuture(result); + } + private ListenableFuture> doFinish() { ImmutableList.Builder partitionUpdates = ImmutableList.builder(); @@ -307,7 +331,7 @@ private void writePage(Page page) Page pageForWriter = dataPage; if (positions.length != dataPage.getPositionCount()) { verify(positions.length == counts[index]); - pageForWriter = pageForWriter.getPositions(positions, 0, positions.length); + pageForWriter = makePageForWriter(pageForWriter, positions); } HiveWriter writer = writers.get(index); @@ -322,11 +346,36 @@ private void writePage(Page page) } } + private Page makePageForWriter(Page page, int[] positions) + { + if (!isMergeSink) { + return page.getPositions(positions, 0, positions.length); + } + int positionCount = positions.length; + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < page.getChannelCount(); channel++) { + Block block = page.getBlock(channel); + if (block instanceof RowBlock) { + RowBlock rowBlock = (RowBlock) block; + List children = rowBlock.getChildren(); + blocks[channel] = RowBlock.fromFieldBlocks( + positionCount, + Optional.empty(), + children.stream() + .map(child -> child.getPositions(positions, 0, positionCount)) + .toArray(length -> new Block[length])); + } + else { + blocks[channel] = block.getPositions(positions, 0, positionCount); + } + } + return new Page(positionCount, blocks); + } + private int[] getWriterIndexes(Page page) { - Page partitionColumns = extractColumns(page, partitionColumnsInputIndex); Block bucketBlock = buildBucketBlock(page); - int[] writerIndexes = pagePartitioner.partitionPage(partitionColumns, bucketBlock); + int[] writerIndexes = pagePartitioner.partitionPage(extractColumns(page, partitionColumnsInputIndex), bucketBlock); if (pagePartitioner.getMaxIndex() >= maxOpenWriters) { throw new TrinoException(HIVE_TOO_MANY_OPEN_PARTITIONS, format("Exceeded limit of %s open writers for partitions/buckets", maxOpenWriters)); } @@ -347,7 +396,7 @@ private int[] getWriterIndexes(Page page) if (bucketBlock != null) { bucketNumber = OptionalInt.of(bucketBlock.getInt(position, 0)); } - HiveWriter writer = writerFactory.createWriter(partitionColumns, position, bucketNumber); + HiveWriter writer = writerFactory.createWriter(extractColumns(page, partitionColumnsInputIndex), position, bucketNumber); writers.set(writerIndex, writer); } verify(writers.size() == pagePartitioner.getMaxIndex() + 1); @@ -358,6 +407,9 @@ private int[] getWriterIndexes(Page page) private Page getDataPage(Page page) { + if (isMergeSink) { + return page; + } Block[] blocks = new Block[dataColumnInputIndex.length]; for (int i = 0; i < dataColumnInputIndex.length; i++) { int dataColumn = dataColumnInputIndex[i]; @@ -391,6 +443,13 @@ private static Page extractColumns(Page page, int[] columns) return new Page(page.getPositionCount(), blocks); } + @Override + public void storeMergedRows(Page page) + { + checkArgument(isMergeSink, "isMergeSink is false"); + appendPage(page); + } + private static class HiveWriterPagePartitioner { private final PageIndexer pageIndexer; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java index 2588be4b8898..369f0f0685bd 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java @@ -28,6 +28,8 @@ import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -43,6 +45,7 @@ import java.util.OptionalInt; import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.plugin.hive.metastore.cache.CachingHiveMetastore.memoizeMetastore; @@ -121,7 +124,16 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transaction, return createPageSink(handle, false, session, ImmutableMap.of() /* for insert properties are taken from metastore */); } - private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, Map additionalTableParameters) + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + HiveMergeTableHandle hiveMergeHandle = (HiveMergeTableHandle) mergeHandle; + HiveInsertTableHandle insertHandle = hiveMergeHandle.getInsertHandle(); + checkArgument(insertHandle.getTransaction().isMerge(), "handle isn't an ACID MERGE"); + return createPageSink(insertHandle, false, session, ImmutableMap.of()); + } + + private HivePageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, Map additionalTableParameters) { OptionalInt bucketCount = OptionalInt.empty(); List sortedBy = ImmutableList.of(); @@ -163,6 +175,7 @@ private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean hiveWriterStats); return new HivePageSink( + handle, writerFactory, handle.getInputColumns(), handle.getBucketProperty(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java index b99895a2f35e..d65316627693 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java @@ -359,6 +359,12 @@ public boolean isAcidUpdate() return transaction.isUpdate(); } + @JsonIgnore + public boolean isAcidMerge() + { + return transaction.isMerge(); + } + @JsonIgnore public Optional getUpdateProcessor() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java index fa61a3d92eb9..2c5e654105b6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java @@ -22,7 +22,6 @@ import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; @@ -42,7 +41,6 @@ import static com.google.common.base.Verify.verify; import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static io.trino.plugin.hive.PartitionAndStatementId.CODEC; -import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.completedFuture; @@ -62,13 +60,11 @@ public class HiveUpdatablePageSource private final String partitionName; private final ConnectorPageSource hivePageSource; private final AcidOperation updateKind; - private final Block hiveRowTypeNullsBlock; private final long writeId; private final Optional> dependencyChannels; private long maxWriteId; private long rowCount; - private long insertRowCounter; private boolean closed; @@ -88,11 +84,10 @@ public HiveUpdatablePageSource( List dependencyColumns, AcidOperation updateKind) { - super(hiveTableHandle.getTransaction(), statementId, bucketNumber, bucketPath, originalFile, orcFileWriterFactory, configuration, session, hiveRowType, updateKind); + super(hiveTableHandle.getTransaction(), statementId, bucketNumber, bucketPath, originalFile, orcFileWriterFactory, configuration, session, typeManager, hiveRowType, updateKind); this.partitionName = requireNonNull(partitionName, "partitionName is null"); this.hivePageSource = requireNonNull(hivePageSource, "hivePageSource is null"); this.updateKind = requireNonNull(updateKind, "updateKind is null"); - this.hiveRowTypeNullsBlock = nativeValueToBlock(hiveRowType.getType(typeManager), null); checkArgument(hiveTableHandle.isInAcidTransaction(), "Not in a transaction; hiveTableHandle: %s", hiveTableHandle); this.writeId = hiveTableHandle.getWriteId(); if (updateKind == AcidOperation.UPDATE) { @@ -117,15 +112,7 @@ private void deleteRowsInternal(Block rowIds) { int positionCount = rowIds.getPositionCount(); List blocks = rowIds.getChildren(); - Block[] blockArray = { - new RunLengthEncodedBlock(DELETE_OPERATION_BLOCK, positionCount), - blocks.get(ORIGINAL_TRANSACTION_CHANNEL), - blocks.get(BUCKET_CHANNEL), - blocks.get(ROW_ID_CHANNEL), - RunLengthEncodedBlock.create(BIGINT, writeId, positionCount), - new RunLengthEncodedBlock(hiveRowTypeNullsBlock, positionCount), - }; - Page deletePage = new Page(blockArray); + Page deletePage = buildDeletePage(rowIds, writeId); Block block = blocks.get(ORIGINAL_TRANSACTION_CHANNEL); for (int index = 0; index < positionCount; index++) { @@ -167,15 +154,6 @@ public void updateRows(Page page, List columnValueAndRowIdChannels) insertFileWriter.orElseThrow(() -> new IllegalArgumentException("insertFileWriter not present")).appendRows(insertPage); } - Block createRowIdBlock(int positionCount) - { - long[] rowIds = new long[positionCount]; - for (int index = 0; index < positionCount; index++) { - rowIds[index] = insertRowCounter++; - } - return new LongArrayBlock(positionCount, Optional.empty(), rowIds); - } - @Override public CompletableFuture> finish() { @@ -196,8 +174,7 @@ public CompletableFuture> finish() OrcFileWriter insertWriter = (OrcFileWriter) insertFileWriter.get(); insertWriter.setMaxWriteId(maxWriteId); insertWriter.commit(); - checkArgument(deltaDirectory.isPresent(), "deltaDirectory not present"); - deltaDirectoryString = Optional.of(deltaDirectory.get().toString()); + deltaDirectoryString = Optional.of(deltaDirectory.toString()); break; default: @@ -207,7 +184,7 @@ public CompletableFuture> finish() partitionName, statementId, rowCount, - deleteDeltaDirectory.toString(), + Optional.of(deleteDeltaDirectory.toString()), deltaDirectoryString))); return completedFuture(ImmutableList.of(fragment)); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java index 4598324c1ff2..fa21251d5c26 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java @@ -162,9 +162,9 @@ public Page removeNonDependencyColumns(Page page, List dependencyChanne } /** - * Return the column UPDATE column handle, which depends on the 3 ACID columns as well as the non-updated columns. + * Return the column rowId for UPDATE or MERGED column handle, which depends on the 3 ACID columns as well as the non-updated columns. */ - public static HiveColumnHandle getUpdateRowIdColumnHandle(List nonUpdatedColumnHandles) + public static HiveColumnHandle getRowIdColumnHandleForNonUpdatedColumns(List nonUpdatedColumnHandles) { List allAcidFields = new ArrayList<>(ACID_READ_FIELDS); if (!nonUpdatedColumnHandles.isEmpty()) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java index 531eb4065d26..55b8000791ea 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java @@ -57,6 +57,11 @@ public HiveWriter( this.hiveWriterStats = requireNonNull(hiveWriterStats, "hiveWriterStats is null"); } + public FileWriter getFileWriter() + { + return fileWriter; + } + public long getWrittenBytes() { return fileWriter.getWrittenBytes(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java index 83d752b712dd..bd2f53ed9b67 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java @@ -40,6 +40,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SortOrder; import io.trino.spi.session.PropertyMetadata; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import org.apache.hadoop.conf.Configuration; @@ -86,7 +87,9 @@ import static io.trino.plugin.hive.HiveSessionProperties.getTemporaryStagingDirectoryPath; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; import static io.trino.plugin.hive.HiveSessionProperties.isTemporaryStagingDirectoryEnabled; +import static io.trino.plugin.hive.HiveType.toHiveType; import static io.trino.plugin.hive.LocationHandle.WriteMode.DIRECT_TO_TARGET_EXISTING_DIRECTORY; +import static io.trino.plugin.hive.acid.AcidOperation.CREATE_TABLE; import static io.trino.plugin.hive.metastore.MetastoreUtil.getHiveSchema; import static io.trino.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat; import static io.trino.plugin.hive.util.CompressionConfigUtil.configureCompression; @@ -118,6 +121,7 @@ public class HiveWriterFactory private final String schemaName; private final String tableName; private final AcidTransaction transaction; + private final List inputColumns; private final List dataColumns; @@ -155,6 +159,7 @@ public class HiveWriterFactory private final Map sessionProperties; private final HiveWriterStats hiveWriterStats; + private final Optional rowType; public HiveWriterFactory( Set fileWriterFactories, @@ -188,6 +193,7 @@ public HiveWriterFactory( this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.tableName = requireNonNull(tableName, "tableName is null"); this.transaction = requireNonNull(transaction, "transaction is null"); + this.inputColumns = requireNonNull(inputColumns, "inputColumns is null"); this.tableStorageFormat = requireNonNull(tableStorageFormat, "tableStorageFormat is null"); this.partitionStorageFormat = requireNonNull(partitionStorageFormat, "partitionStorageFormat is null"); this.additionalTableParameters = ImmutableMap.copyOf(requireNonNull(additionalTableParameters, "additionalTableParameters is null")); @@ -209,7 +215,6 @@ public HiveWriterFactory( this.parquetTimeZone = requireNonNull(parquetTimeZone, "parquetTimeZone is null"); // divide input columns into partition and data columns - requireNonNull(inputColumns, "inputColumns is null"); ImmutableList.Builder partitionColumnNames = ImmutableList.builder(); ImmutableList.Builder partitionColumnTypes = ImmutableList.builder(); ImmutableList.Builder dataColumns = ImmutableList.builder(); @@ -223,6 +228,15 @@ public HiveWriterFactory( dataColumns.add(new DataColumn(column.getName(), hiveType)); } } + if (transaction.isMerge()) { + this.rowType = Optional.of(toHiveType(RowType.from(inputColumns.stream() + .filter(column -> !column.isPartitionKey()) + .map(column -> new RowType.Field(Optional.of(column.getName()), column.getType())) + .collect(toImmutableList())))); + } + else { + this.rowType = Optional.empty(); + } this.partitionColumnNames = partitionColumnNames.build(); this.partitionColumnTypes = partitionColumnTypes.build(); this.dataColumns = dataColumns.build(); @@ -443,10 +457,11 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt Path path; String fileNameWithExtension; - if (transaction.isAcidTransactionRunning()) { + if (transaction.isAcidTransactionRunning() && transaction.getOperation() != CREATE_TABLE) { String subdir = computeAcidSubdir(transaction); Path subdirPath = new Path(writeInfo.getWritePath(), subdir); - path = createHiveBucketPath(subdirPath, bucketToUse, table.getParameters()); + String nameFormat = table != null && isInsertOnlyTable(table.getParameters()) ? "%05d_0" : "bucket_%05d"; + path = new Path(subdirPath, format(nameFormat, bucketToUse)); fileNameWithExtension = path.getName(); } else { @@ -458,24 +473,34 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt boolean useAcidSchema = isCreateTransactionalTable || (table != null && isFullAcidTable(table.getParameters())); FileWriter hiveFileWriter = null; - for (HiveFileWriterFactory fileWriterFactory : fileWriterFactories) { - Optional fileWriter = fileWriterFactory.createFileWriter( - path, - dataColumns.stream() - .map(DataColumn::getName) - .collect(toList()), - outputStorageFormat, - schema, - conf, - session, - bucketNumber, - transaction, - useAcidSchema, - WriterKind.INSERT); - - if (fileWriter.isPresent()) { - hiveFileWriter = fileWriter.get(); - break; + if (transaction.isMerge()) { + OrcFileWriterFactory orcFileWriterFactory = (OrcFileWriterFactory) fileWriterFactories.stream() + .filter(factory -> factory instanceof OrcFileWriterFactory) + .findFirst() + .get(); + checkArgument(rowType.isPresent(), "rowTypes not present"); + hiveFileWriter = new MergeFileWriter(transaction, 0, bucketNumber, path, partitionName, orcFileWriterFactory, inputColumns, conf, session, typeManager, rowType.get()); + } + if (hiveFileWriter == null) { + for (HiveFileWriterFactory fileWriterFactory : fileWriterFactories) { + Optional fileWriter = fileWriterFactory.createFileWriter( + path, + dataColumns.stream() + .map(DataColumn::getName) + .collect(toList()), + outputStorageFormat, + schema, + conf, + session, + bucketNumber, + transaction, + useAcidSchema, + WriterKind.INSERT); + + if (fileWriter.isPresent()) { + hiveFileWriter = fileWriter.get(); + break; + } } } @@ -590,12 +615,6 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt hiveWriterStats); } - private static Path createHiveBucketPath(Path subdirPath, int bucketToUse, Map tableParameters) - { - String nameFormat = isInsertOnlyTable(tableParameters) ? "%05d_0" : "bucket_%05d"; - return new Path(subdirPath, format(nameFormat, bucketToUse)); - } - private void validateSchema(Optional partitionName, Properties schema) { // existing tables may have columns in a different order @@ -647,9 +666,12 @@ private String computeAcidSubdir(AcidTransaction transaction) long writeId = transaction.getWriteId(); switch (transaction.getOperation()) { case INSERT: + case CREATE_TABLE: return deltaSubdir(writeId, writeId, 0); case DELETE: return deleteDeltaSubdir(writeId, writeId, 0); + case MERGE: + return deltaSubdir(writeId, writeId, 0); default: throw new UnsupportedOperationException("transaction operation is " + transaction.getOperation()); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java new file mode 100644 index 000000000000..c3118c11918c --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.plugin.hive.acid.AcidOperation; +import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.orc.OrcFileWriterFactory; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.PagePair; +import io.trino.spi.type.TypeManager; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.connector.MergeProcessorUtilities.createMergedDeleteAndInsertPages; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.util.Objects.requireNonNull; + +public class MergeFileWriter + extends AbstractHiveAcidWriters + implements FileWriter +{ + private final String partitionName; + private final List inputColumns; + + private int deleteRowCount; + private int insertRowCount; + + public MergeFileWriter( + AcidTransaction transaction, + int statementId, + OptionalInt bucketNumber, + Path bucketPath, + Optional partitionName, + OrcFileWriterFactory orcFileWriterFactory, + List inputColumns, + Configuration configuration, + ConnectorSession session, + TypeManager typeManager, + HiveType hiveRowType) + { + super(transaction, statementId, bucketNumber, bucketPath, false, orcFileWriterFactory, configuration, session, typeManager, hiveRowType, AcidOperation.MERGE); + this.partitionName = requireNonNull(partitionName, "partitionName is null").orElse(""); + this.inputColumns = requireNonNull(inputColumns, "inputColumns is null"); + } + + @Override + public void appendRows(Page page) + { + int positionCount = page.getPositionCount(); + if (positionCount == 0) { + return; + } + + PagePair pagePair = createMergedDeleteAndInsertPages(page, inputColumns.size()); + pagePair.getDeletionsPage().ifPresent(deletePage -> { + if (deletePage.getPositionCount() > 0) { + Block acidBlock = deletePage.getBlock(deletePage.getChannelCount() - 1); + Page orcDeletePage = buildDeletePage(acidBlock, transaction.getWriteId()); + lazyInitializeDeleteFileWriter(); + checkArgument(deleteFileWriter.isPresent(), "deleteFileWriter not present"); + deleteFileWriter.get().appendRows(orcDeletePage); + deleteRowCount += deletePage.getPositionCount(); + } + }); + pagePair.getInsertionsPage().ifPresent(insertPage -> { + if (insertPage.getPositionCount() > 0) { + Page orcInsertPage = buildInsertPage(insertPage, transaction.getWriteId()); + lazyInitializeInsertFileWriter(); + checkArgument(insertFileWriter.isPresent(), "insertFileWriter not present"); + insertFileWriter.get().appendRows(orcInsertPage); + insertRowCount += insertPage.getPositionCount(); + } + }); + } + + private Page buildInsertPage(Page insertPage, long writeId) + { + int positionCount = insertPage.getPositionCount(); + List dataColumns = inputColumns.stream() + .filter(column -> !column.isPartitionKey() && !column.isHidden()) + .map(column -> insertPage.getBlock(column.getBaseHiveColumnIndex())) + .collect(toImmutableList()); + Block mergedColumnsBlock = RowBlock.fromFieldBlocks(positionCount, Optional.empty(), dataColumns.toArray(new Block[]{})); + Block currentTransactionBlock = RunLengthEncodedBlock.create(BIGINT, writeId, positionCount); + Block[] blockArray = { + new RunLengthEncodedBlock(INSERT_OPERATION_BLOCK, positionCount), + currentTransactionBlock, + new RunLengthEncodedBlock(bucketValueBlock, positionCount), + createRowIdBlock(positionCount), + currentTransactionBlock, + mergedColumnsBlock + }; + + return new Page(blockArray); + } + + public String getPartitionName() + { + return partitionName; + } + + @Override + public long getWrittenBytes() + { + return (deleteFileWriter.map(FileWriter::getWrittenBytes).orElse(0L)) + + (insertFileWriter.map(FileWriter::getWrittenBytes).orElse(0L)); + } + + @Override + public long getSystemMemoryUsage() + { + return (deleteFileWriter.map(FileWriter::getSystemMemoryUsage).orElse(0L)) + + (insertFileWriter.map(FileWriter::getSystemMemoryUsage).orElse(0L)); + } + + @Override + public void commit() + { + deleteFileWriter.ifPresent(FileWriter::commit); + insertFileWriter.ifPresent(FileWriter::commit); + } + + @Override + public void rollback() + { + deleteFileWriter.ifPresent(FileWriter::rollback); + insertFileWriter.ifPresent(FileWriter::rollback); + } + + @Override + public long getValidationCpuNanos() + { + return (deleteFileWriter.map(FileWriter::getValidationCpuNanos).orElse(0L)) + + (insertFileWriter.map(FileWriter::getValidationCpuNanos).orElse(0L)); + } + + public PartitionUpdateAndMergeResults getPartitionUpdateAndMergeResults(PartitionUpdate partitionUpdate) + { + return new PartitionUpdateAndMergeResults( + partitionUpdate, + insertRowCount, + insertFileWriter.isPresent() ? Optional.of(deltaDirectory.toString()) : Optional.empty(), + deleteRowCount, + deleteFileWriter.isPresent() ? Optional.of(deleteDeltaDirectory.toString()) : Optional.empty()); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java index dcfda26b8a7c..8a9e332ad760 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java @@ -32,7 +32,7 @@ public class PartitionAndStatementId private final String partitionName; private final int statementId; private final long rowCount; - private final String deleteDeltaDirectory; + private final Optional deleteDeltaDirectory; private final Optional deltaDirectory; @JsonCreator @@ -40,7 +40,7 @@ public PartitionAndStatementId( @JsonProperty("partitionName") String partitionName, @JsonProperty("statementId") int statementId, @JsonProperty("rowCount") long rowCount, - @JsonProperty("deleteDeltaDirectory") String deleteDeltaDirectory, + @JsonProperty("deleteDeltaDirectory") Optional deleteDeltaDirectory, @JsonProperty("deltaDirectory") Optional deltaDirectory) { this.partitionName = requireNonNull(partitionName, "partitionName is null"); @@ -69,7 +69,7 @@ public long getRowCount() } @JsonProperty - public String getDeleteDeltaDirectory() + public Optional getDeleteDeltaDirectory() { return deleteDeltaDirectory; } @@ -83,9 +83,10 @@ public Optional getDeltaDirectory() @JsonIgnore public List getAllDirectories() { - return deltaDirectory - .map(directory -> ImmutableList.of(deleteDeltaDirectory, directory)) - .orElseGet(() -> ImmutableList.of(deleteDeltaDirectory)); + ImmutableList.Builder builder = ImmutableList.builder(); + deltaDirectory.ifPresent(builder::add); + deleteDeltaDirectory.ifPresent(builder::add); + return builder.build(); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java new file mode 100644 index 000000000000..f1e91c960829 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.json.JsonCodec; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class PartitionUpdateAndMergeResults +{ + public static final JsonCodec CODEC = JsonCodec.jsonCodec(PartitionUpdateAndMergeResults.class); + + private final PartitionUpdate partitionUpdate; + private final long insertRowCount; + private final Optional deltaDirectory; + private final long deleteRowCount; + private final Optional deleteDeltaDirectory; + + @JsonCreator + public PartitionUpdateAndMergeResults( + @JsonProperty("partitionUpdate") PartitionUpdate partitionUpdate, + @JsonProperty("insertRowCount") long insertRowCount, + @JsonProperty("deleteDirectory") Optional deltaDirectory, + @JsonProperty("deleteRowCount") long deleteRowCount, + @JsonProperty("deleteDeltaDirectory") Optional deleteDeltaDirectory) + { + this.partitionUpdate = requireNonNull(partitionUpdate, "partitionUpdate is null"); + this.insertRowCount = insertRowCount; + this.deltaDirectory = requireNonNull(deltaDirectory, "deltaDirectory is null"); + this.deleteRowCount = deleteRowCount; + this.deleteDeltaDirectory = requireNonNull(deleteDeltaDirectory, "deleteDeltaDirectory is null"); + } + + @JsonProperty + public PartitionUpdate getPartitionUpdate() + { + return partitionUpdate; + } + + @JsonProperty + public long getInsertRowCount() + { + return insertRowCount; + } + + @JsonProperty + public Optional getDeltaDirectory() + { + return deltaDirectory; + } + + @JsonProperty + public long getDeleteRowCount() + { + return deleteRowCount; + } + + @JsonProperty + public Optional getDeleteDeltaDirectory() + { + return deleteDeltaDirectory; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java index 5e00160681e3..8f8aaacdec11 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java @@ -22,16 +22,18 @@ public enum AcidOperation { - // UPDATE and MERGE will be added when they are implemented NONE, CREATE_TABLE, DELETE, INSERT, - UPDATE; + UPDATE, + MERGE, + /**/; private static final Map DATA_OPERATION_TYPES = ImmutableMap.of( DELETE, DataOperationType.DELETE, - INSERT, DataOperationType.INSERT); + INSERT, DataOperationType.INSERT, + MERGE, DataOperationType.UPDATE); private static final Map ORC_OPERATIONS = ImmutableMap.of( DELETE, OrcOperation.DELETE, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java index cdcfd17c1bd2..b67f92824a66 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java @@ -27,6 +27,7 @@ import static io.trino.plugin.hive.acid.AcidOperation.CREATE_TABLE; import static io.trino.plugin.hive.acid.AcidOperation.DELETE; import static io.trino.plugin.hive.acid.AcidOperation.INSERT; +import static io.trino.plugin.hive.acid.AcidOperation.MERGE; import static io.trino.plugin.hive.acid.AcidOperation.NONE; import static io.trino.plugin.hive.acid.AcidOperation.UPDATE; import static java.util.Objects.requireNonNull; @@ -50,7 +51,7 @@ public AcidTransaction( this.operation = requireNonNull(operation, "operation is null"); this.transactionId = transactionId; this.writeId = writeId; - this.updateProcessor = updateProcessor; + this.updateProcessor = requireNonNull(updateProcessor, "updateProcessor is null"); } @JsonProperty("operation") @@ -71,7 +72,7 @@ public long getWriteIdForSerialization() return writeId; } - @JsonProperty + @JsonProperty("updateProcessor") public Optional getUpdateProcessor() { return updateProcessor; @@ -80,7 +81,7 @@ public Optional getUpdateProcessor() @JsonIgnore public boolean isAcidTransactionRunning() { - return operation == INSERT || operation == DELETE || operation == UPDATE; + return operation == INSERT || operation == CREATE_TABLE || operation == DELETE || operation == UPDATE || operation == MERGE; } @JsonIgnore @@ -133,6 +134,12 @@ public boolean isUpdate() return operation == UPDATE; } + @JsonIgnore + public boolean isMerge() + { + return operation == MERGE; + } + public boolean isAcidInsertOperation(WriterKind writerKind) { return isInsert() || (isUpdate() && writerKind == WriterKind.INSERT); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePrivilegeInfo.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePrivilegeInfo.java index 8e26dc148584..679522367975 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePrivilegeInfo.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/HivePrivilegeInfo.java @@ -27,6 +27,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.DELETE; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.INSERT; +import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.MERGE; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.SELECT; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.UPDATE; import static java.util.Objects.requireNonNull; @@ -36,7 +37,13 @@ public class HivePrivilegeInfo { public enum HivePrivilege { - SELECT, INSERT, UPDATE, DELETE, OWNERSHIP + SELECT, + INSERT, + UPDATE, + DELETE, + MERGE, + OWNERSHIP, + /**/; } private final HivePrivilege hivePrivilege; @@ -92,8 +99,11 @@ public static HivePrivilege toHivePrivilege(Privilege privilege) return DELETE; case UPDATE: return UPDATE; + case MERGE: + return MERGE; + default: + throw new IllegalArgumentException("Unexpected privilege: " + privilege); } - throw new IllegalArgumentException("Unexpected privilege: " + privilege); } public boolean isContainedIn(HivePrivilegeInfo hivePrivilegeInfo) @@ -114,6 +124,8 @@ public Set toPrivilegeInfo() return ImmutableSet.of(new PrivilegeInfo(Privilege.DELETE, isGrantOption())); case UPDATE: return ImmutableSet.of(new PrivilegeInfo(Privilege.UPDATE, isGrantOption())); + case MERGE: + return ImmutableSet.of(new PrivilegeInfo(Privilege.MERGE, isGrantOption())); case OWNERSHIP: return ImmutableSet.of(); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java index e8fe14dba0f7..b80ca3a7df3d 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java @@ -33,6 +33,7 @@ import io.trino.plugin.hive.PartitionAndStatementId; import io.trino.plugin.hive.PartitionNotFoundException; import io.trino.plugin.hive.PartitionStatistics; +import io.trino.plugin.hive.PartitionUpdateAndMergeResults; import io.trino.plugin.hive.TableAlreadyExistsException; import io.trino.plugin.hive.acid.AcidOperation; import io.trino.plugin.hive.acid.AcidTransaction; @@ -122,6 +123,10 @@ public class SemiTransactionalHiveMetastore private static final int PARTITION_COMMIT_BATCH_SIZE = 8; private static final Pattern DELTA_DIRECTORY_MATCHER = Pattern.compile("(delete_)?delta_[\\d]+_[\\d]+_[\\d]+$"); + private static final Map ACID_OPERATION_ACTION_TYPES = ImmutableMap.of( + AcidOperation.INSERT, ActionType.INSERT_EXISTING, + AcidOperation.MERGE, ActionType.MERGE); + private final HiveMetastoreClosure delegate; private final HdfsEnvironment hdfsEnvironment; private final Executor renameExecutor; @@ -211,6 +216,7 @@ public synchronized Optional getTable(HiveIdentity identity, String datab case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return Optional.of(tableAction.getData().getTable()); case DROP: return Optional.empty(); @@ -239,6 +245,7 @@ public synchronized PartitionStatistics getTableStatistics(HiveIdentity identity case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return tableAction.getData().getStatistics(); case DROP: return PartitionStatistics.empty(); @@ -312,6 +319,7 @@ private TableSource getTableSource(String databaseName, String tableName) case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return TableSource.PRE_EXISTING_TABLE; case DROP_PRESERVE_DATA: // TODO @@ -459,6 +467,7 @@ public synchronized void createTable( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new TableAlreadyExistsException(table.getSchemaTableName()); case DROP_PRESERVE_DATA: // TODO @@ -488,6 +497,7 @@ public synchronized void dropTable(ConnectorSession session, String databaseName case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("dropping a table added/modified in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -536,7 +546,8 @@ public synchronized void dropColumn(HiveIdentity identity, String databaseName, setExclusive((delegate, hdfsEnvironment) -> delegate.dropColumn(identity, databaseName, tableName, columnName)); } - public synchronized void finishInsertIntoExistingTable( + public synchronized void finishChangingExistingTable( + AcidOperation acidOperation, ConnectorSession session, String databaseName, String tableName, @@ -549,6 +560,7 @@ public synchronized void finishInsertIntoExistingTable( setShared(); HiveIdentity identity = new HiveIdentity(session); SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + ActionType actionType = requireNonNull(ACID_OPERATION_ACTION_TYPES.get(acidOperation), "ACID_OPERATION_ACTION_TYPES doesn't contain the acidOperation"); Action oldTableAction = tableActions.get(schemaTableName); if (oldTableAction == null) { Table table = getExistingTable(identity, schemaTableName.getSchemaName(), schemaTableName.getTableName()); @@ -560,7 +572,7 @@ public synchronized void finishInsertIntoExistingTable( tableActions.put( schemaTableName, new Action<>( - ActionType.INSERT_EXISTING, + actionType, new TableAndMore( table, identity, @@ -584,6 +596,7 @@ public synchronized void finishInsertIntoExistingTable( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting into an unpartitioned table that were added, altered, or inserted into in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -667,6 +680,7 @@ public synchronized void finishRowLevelDelete( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting or deleting in an unpartitioned table that were added, altered, or inserted into in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -717,6 +731,63 @@ public synchronized void finishUpdate( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: + throw new UnsupportedOperationException("Inserting, updating or deleting in a table that was added, altered, inserted into, updated or deleted from in the same transaction is not supported"); + default: + throw new IllegalStateException("Unknown action type"); + } + } + + public synchronized void finishMerge( + ConnectorSession session, + String databaseName, + String tableName, + Path currentLocation, + List partitionUpdateAndMergeResults, + List partitions) + { + if (partitionUpdateAndMergeResults.isEmpty()) { + return; + } + checkArgument(partitionUpdateAndMergeResults.size() >= partitions.size(), "partitionUpdateAndMergeResults.size() (%s) < partitions.size() (%s)", partitionUpdateAndMergeResults.size(), partitions.size()); + setShared(); + if (partitions.isEmpty()) { + return; + } + HiveIdentity identity = new HiveIdentity(session); + + SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + Action oldTableAction = tableActions.get(schemaTableName); + if (oldTableAction == null) { + Table table = getExistingTable(identity, schemaTableName.getSchemaName(), schemaTableName.getTableName()); + HdfsContext hdfsContext = new HdfsContext(session); + PrincipalPrivileges principalPrivileges = buildInitialPrivilegeSet(table.getOwner()); + tableActions.put( + schemaTableName, + new Action<>( + ActionType.MERGE, + new TableAndMergeResults( + table, + identity, + Optional.of(principalPrivileges), + Optional.of(currentLocation), + partitionUpdateAndMergeResults, + partitions), + hdfsContext, + identity, + session.getQueryId())); + return; + } + + switch (oldTableAction.getType()) { + case DROP: + throw new TableNotFoundException(schemaTableName); + case ADD: + case ALTER: + case INSERT_EXISTING: + case DELETE_ROWS: + case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting, updating or deleting in a table that was added, altered, inserted into, updated or deleted from in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -807,6 +878,7 @@ private Optional> doGetPartitionNames( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: resultBuilder.add(partitionName); break; default: @@ -873,6 +945,7 @@ private static Optional getPartitionFromPartitionAction(Action listTablePrivileges(HiveIdentity iden case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return delegate.listTablePrivileges(databaseName, tableName, getTableOwner(identity, databaseName, tableName), principal); case DROP: throw new TableNotFoundException(schemaTableName); @@ -1226,6 +1303,11 @@ public AcidTransaction beginUpdate(ConnectorSession session, Table table, HiveUp return beginOperation(session, table, AcidOperation.UPDATE, DataOperationType.UPDATE, Optional.of(updateProcessor)); } + public AcidTransaction beginMerge(ConnectorSession session, Table table) + { + return beginOperation(session, table, AcidOperation.MERGE, DataOperationType.UPDATE, Optional.empty()); + } + private AcidTransaction beginOperation(ConnectorSession session, Table table, AcidOperation operation, DataOperationType hiveOperation, Optional updateProcessor) { String queryId = session.getQueryId(); @@ -1361,6 +1443,9 @@ private void commitShared() case UPDATE: committer.prepareUpdateExistingTable(action.getHdfsContext(), action.getData()); break; + case MERGE: + committer.prepareMergeExistingTable(action.getHdfsContext(), action.getData()); + break; default: throw new IllegalStateException("Unknown action type"); } @@ -1386,6 +1471,9 @@ private void commitShared() case INSERT_EXISTING: committer.prepareInsertExistingPartition(action.getHdfsContext(), action.getIdentity(), action.getData()); break; + case MERGE: + committer.prepareInsertExistingPartition(action.getHdfsContext(), action.getIdentity(), action.getData()); + break; case UPDATE: case DELETE_ROWS: break; @@ -1647,9 +1735,12 @@ private void prepareDeleteRowsFromExistingTable(HdfsContext context, TableAndMor AcidTransaction transaction = currentHiveTransaction.get().getTransaction(); checkArgument(transaction.isDelete(), "transaction should be delete, but is %s", transaction); - cleanUpTasksForAbort.addAll(deletionState.getPartitionAndStatementIds().stream() - .map(ps -> new DirectoryCleanUpTask(context, new Path(ps.getDeleteDeltaDirectory()), true)) - .collect(toImmutableList())); + deletionState.getPartitionAndStatementIds().stream().forEach(ps -> { + ps.getDeleteDeltaDirectory().ifPresent(dir -> + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, new Path(dir), true))); + ps.getDeltaDirectory().ifPresent(dir -> + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, new Path(dir), true))); + }); Map partitionRowCounts = new HashMap<>(partitionAndStatementIds.size()); int totalRowsDeleted = 0; @@ -1747,6 +1838,30 @@ private void prepareUpdateExistingTable(HdfsContext context, TableAndMore tableA updateTableWriteId(identity, databaseName, tableName, transactionId, writeId, OptionalLong.empty()); } + private void prepareMergeExistingTable(HdfsContext context, TableAndMore tableAndMore) + { + checkArgument(currentHiveTransaction.isPresent(), "currentHiveTransaction isn't present"); + AcidTransaction transaction = currentHiveTransaction.get().getTransaction(); + checkArgument(transaction.isMerge(), "transaction should be merge, but is %s", transaction); + + deleteOnly = false; + Table table = tableAndMore.getTable(); + Path targetPath = new Path(table.getStorage().getLocation()); + Path currentPath = tableAndMore.getCurrentLocation().get(); + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, targetPath, false)); + if (!targetPath.equals(currentPath)) { + asyncRename(hdfsEnvironment, renameExecutor, fileRenameCancelled, fileRenameFutures, context, currentPath, targetPath, tableAndMore.getFileNames().get()); + } + updateStatisticsOperations.add(new UpdateStatisticsOperation( + tableAndMore.getIdentity(), + table.getSchemaTableName(), + Optional.empty(), + tableAndMore.getStatisticsUpdate(), + true)); + + updateTableWriteId(tableAndMore.getIdentity(), table.getDatabaseName(), table.getTableName(), transaction.getAcidTransactionId(), transaction.getWriteId(), OptionalLong.empty()); + } + private void prepareDropPartition(HiveIdentity identity, SchemaTableName schemaTableName, List partitionValues, boolean deleteData) { metastoreDeleteOperations.add(new IrreversibleMetastoreOperation( @@ -2585,6 +2700,7 @@ private enum ActionType INSERT_EXISTING, DELETE_ROWS, UPDATE, + MERGE, } private enum TableSource @@ -2772,6 +2888,42 @@ public String toString() } } + private static class TableAndMergeResults + extends TableAndMore + { + private final List partitionMergeResults; + private final List partitions; + + public TableAndMergeResults(Table table, HiveIdentity identity, Optional principalPrivileges, Optional currentLocation, List partitionMergeResults, List partitions) + { + super(table, identity, principalPrivileges, currentLocation, Optional.empty(), false, PartitionStatistics.empty(), PartitionStatistics.empty()); + this.partitionMergeResults = requireNonNull(partitionMergeResults, "partitionMergeResults is null"); + this.partitions = requireNonNull(partitions, "partitions is nul"); + } + + public List getPartitionMergeResults() + { + return partitionMergeResults; + } + + public List getPartitions() + { + return partitions; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("table", getTable()) + .add("partitionMergeResults", partitionMergeResults) + .add("partitions", partitions) + .add("principalPrivileges", getPrincipalPrivileges()) + .add("currentLocation", getCurrentLocation()) + .toString(); + } + } + private static class PartitionAndMore { private final HiveIdentity identity; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java index 7dde95c74e2c..6c6e2d95b5f8 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/thrift/ThriftMetastoreUtil.java @@ -107,6 +107,7 @@ import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createStringColumnStatistics; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.DELETE; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.INSERT; +import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.MERGE; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.OWNERSHIP; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.SELECT; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.UPDATE; @@ -745,6 +746,8 @@ public static Set parsePrivilege(PrivilegeGrantInfo userGrant return ImmutableSet.of(new HivePrivilegeInfo(UPDATE, grantOption, grantor, grantee.orElse(grantor))); case "DELETE": return ImmutableSet.of(new HivePrivilegeInfo(DELETE, grantOption, grantor, grantee.orElse(grantor))); + case "MERGE": + return ImmutableSet.of(new HivePrivilegeInfo(MERGE, grantOption, grantor, grantee.orElse(grantor))); case "OWNERSHIP": return ImmutableSet.of(new HivePrivilegeInfo(OWNERSHIP, grantOption, grantor, grantee.orElse(grantor))); default: diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java index ef95c83f7310..d5aec7b5e0fc 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java @@ -15,6 +15,7 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; import io.trino.orc.OrcDataSink; import io.trino.orc.OrcDataSource; import io.trino.orc.OrcWriteValidation.OrcWriteValidationMode; @@ -63,6 +64,7 @@ public class OrcFileWriter implements FileWriter { + private static final Logger log = Logger.get(OrcFileWriter.class); private static final int INSTANCE_SIZE = ClassLayout.parseClass(OrcFileWriter.class).instanceSize(); private static final ThreadMXBean THREAD_MX_BEAN = ManagementFactory.getThreadMXBean(); @@ -128,6 +130,9 @@ public OrcFileWriter( validationInputFactory.isPresent(), validationMode, stats); + if (transaction.isTransactional()) { + this.setMaxWriteId(transaction.getWriteId()); + } } @Override @@ -189,6 +194,7 @@ public void commit() } catch (Exception ignored) { // ignore + log.error(ignored, "Exception when committing file"); } throw new TrinoException(HIVE_WRITER_CLOSE_ERROR, "Error committing write to Hive", e); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java index 79ec9163ed8f..43e671be6216 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java @@ -26,6 +26,7 @@ import io.trino.plugin.hive.FileWriter; import io.trino.plugin.hive.HdfsEnvironment; import io.trino.plugin.hive.HiveFileWriterFactory; +import io.trino.plugin.hive.HiveType; import io.trino.plugin.hive.NodeVersion; import io.trino.plugin.hive.WriterKind; import io.trino.plugin.hive.acid.AcidTransaction; @@ -66,6 +67,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.getOrcStringStatisticsLimit; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; import static io.trino.plugin.hive.HiveSessionProperties.isOrcOptimizedWriterValidate; +import static io.trino.plugin.hive.HiveType.toHiveType; import static io.trino.plugin.hive.acid.AcidSchema.ACID_COLUMN_NAMES; import static io.trino.plugin.hive.acid.AcidSchema.createAcidColumnPrestoTypes; import static io.trino.plugin.hive.acid.AcidSchema.createRowType; @@ -230,6 +232,17 @@ public Optional createFileWriter( } } + public static HiveType createHiveRowType(Properties schema, TypeManager typeManager, ConnectorSession session) + { + List dataColumnNames = getColumnNames(schema); + List dataColumnTypes = getColumnTypes(schema).stream() + .map(hiveType -> hiveType.getType(typeManager, getTimestampPrecision(session))) + .collect(toList()); + Type dataRowType = createRowType(dataColumnNames, dataColumnTypes); + Type acidRowType = createRowType(ACID_COLUMN_NAMES, createAcidColumnPrestoTypes(dataRowType)); + return toHiveType(acidRowType); + } + public static OrcDataSink createOrcDataSink(FileSystem fileSystem, Path path) throws IOException { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java index 9e2c11480e5f..878b76178ded 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java @@ -233,6 +233,11 @@ static ColumnAdaptation updatedRowColumns(HiveUpdateProcessor updateProcessor, L { return new UpdatedRowAdaptation(updateProcessor, dependencyColumns); } + + static ColumnAdaptation mergedRowColumns() + { + return new MergedRowAdaptation(); + } } private static class NullColumn @@ -359,6 +364,33 @@ public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunct } } + /** + * This ColumnAdaptation creates a RowBlock column containing the three + * ACID columms - - originalTransaction, rowId, bucket - - and + * then all the partition columns + */ + private static final class MergedRowAdaptation + implements ColumnAdaptation + { + @Override + public Block block(Page page, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition) + { + requireNonNull(page, "page is null"); + int acidBlocks = 3; + + Block[] blocks = new Block[acidBlocks]; + blocks[ORIGINAL_TRANSACTION_CHANNEL] = page.getBlock(ORIGINAL_TRANSACTION_CHANNEL); + blocks[ROW_ID_CHANNEL] = page.getBlock(ROW_ID_CHANNEL); + blocks[BUCKET_CHANNEL] = page.getBlock(BUCKET_CHANNEL); + + Block block = maskDeletedRowsFunction.apply(fromFieldBlocks( + page.getPositionCount(), + Optional.empty(), + blocks)); + return block; + } + } + private static class OriginalFileRowIdAdaptation implements ColumnAdaptation { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java index e69ac8afb82e..6a17f5e2aac6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java @@ -96,6 +96,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.isOrcNestedLazy; import static io.trino.plugin.hive.HiveSessionProperties.isUseOrcColumnNames; import static io.trino.plugin.hive.ReaderPageSource.noProjectionAdaptation; +import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.mergedRowColumns; import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.updatedRowColumns; import static io.trino.plugin.hive.orc.OrcPageSource.handleException; import static io.trino.plugin.hive.util.HiveUtil.isDeserializerClass; @@ -412,6 +413,9 @@ else if (transaction.isUpdate()) { .collect(toImmutableList()); columnAdaptations.add(updatedRowColumns(updateProcessor, dependencyColumns)); } + else if (transaction.isMerge()) { + columnAdaptations.add(mergedRowColumns()); + } return new OrcPageSource( recordReader, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java index cb9bff7fcf3f..86b55ef74d45 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java @@ -19,6 +19,7 @@ import io.trino.plugin.hive.HiveMetastoreClosure; import io.trino.plugin.hive.HiveType; import io.trino.plugin.hive.PartitionStatistics; +import io.trino.plugin.hive.acid.AcidOperation; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.plugin.hive.authentication.HiveIdentity; import org.apache.hadoop.fs.Path; @@ -103,7 +104,7 @@ public void testParallelUpdateStatisticsOperations() else { semiTransactionalHiveMetastore = getSemiTransactionalHiveMetastoreWithUpdateExecutor(newFixedThreadPool(updateThreads)); } - IntStream.range(0, tablesToUpdate).forEach(i -> semiTransactionalHiveMetastore.finishInsertIntoExistingTable(SESSION, + IntStream.range(0, tablesToUpdate).forEach(i -> semiTransactionalHiveMetastore.finishChangingExistingTable(AcidOperation.INSERT, SESSION, "database", "table_" + i, new Path("location"), diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduHandleResolver.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduHandleResolver.java index d0a040a4fa48..f2ce5d0c0307 100755 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduHandleResolver.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduHandleResolver.java @@ -16,6 +16,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorHandleResolver; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; @@ -55,6 +56,12 @@ public Class getInsertTableHandleClass() return KuduInsertTableHandle.class; } + @Override + public Class getMergeTableHandleClass() + { + return KuduMergeTableHandle.class; + } + @Override public Class getOutputTableHandleClass() { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java new file mode 100644 index 000000000000..d4798e8fa5bc --- /dev/null +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kudu; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; +import io.trino.spi.type.Type; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class KuduMergeTableHandle + implements ConnectorMergeTableHandle, KuduTableMapping +{ + private final KuduTableHandle tableHandle; + private final KuduOutputTableHandle outputTableHandle; + + @JsonCreator + public KuduMergeTableHandle( + @JsonProperty("tableHandle") KuduTableHandle tableHandle, + @JsonProperty("outputTableHandle") KuduOutputTableHandle outputTableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.outputTableHandle = requireNonNull(outputTableHandle, "outputTableHandle is null"); + } + + @JsonProperty + @Override + public KuduTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public KuduOutputTableHandle getOutputTableHandle() + { + return outputTableHandle; + } + + @Override + public boolean isGenerateUUID() + { + return outputTableHandle.isGenerateUUID(); + } + + @Override + public List getColumnTypes() + { + return outputTableHandle.getColumnTypes(); + } + + @Override + public List getOriginalColumnTypes() + { + return outputTableHandle.getOriginalColumnTypes(); + } +} diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java index 97969f1dd64b..87f0b5b2c6a9 100755 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java @@ -25,6 +25,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorNewTableLayout; import io.trino.spi.connector.ConnectorOutputMetadata; @@ -38,8 +39,10 @@ import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.LocalProperty; +import io.trino.spi.connector.MergeDetails; import io.trino.spi.connector.NotFoundException; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.expression.ConnectorExpression; @@ -67,11 +70,13 @@ import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; +import java.util.function.Consumer; import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.kudu.KuduSessionProperties.isKuduGroupedExecutionEnabled; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; import static java.util.Objects.requireNonNull; public class KuduMetadata @@ -174,18 +179,21 @@ private ConnectorTableMetadata getTableMetadata(KuduTableHandle tableHandle) public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle connectorTableHandle) { KuduTableHandle tableHandle = (KuduTableHandle) connectorTableHandle; + ImmutableMap.Builder columnHandles = ImmutableMap.builder(); Schema schema = clientSession.getTableSchema(tableHandle); + forAllColumnHandles(schema, column -> columnHandles.put(column.getName(), column)); + return columnHandles.build(); + } - ImmutableMap.Builder columnHandles = ImmutableMap.builder(); + private void forAllColumnHandles(Schema schema, Consumer handleEater) + { for (int ordinal = 0; ordinal < schema.getColumnCount(); ordinal++) { ColumnSchema col = schema.getColumnByIndex(ordinal); String name = col.getName(); Type type = TypeHelper.fromKuduColumn(col); KuduColumnHandle columnHandle = new KuduColumnHandle(name, ordinal, type); - columnHandles.put(name, columnHandle); + handleEater.accept(columnHandle); } - - return columnHandles.build(); } @Override @@ -432,6 +440,64 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan { } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + return KuduColumnHandle.ROW_ID_HANDLE; + } + + @Override + public List getWriteRedistributionColumns(ConnectorSession session, ConnectorTableHandle tableHandle) + { + KuduTableHandle kuduTableHandle = (KuduTableHandle) tableHandle; + ImmutableList.Builder handlesListBuilder = ImmutableList.builder(); + Schema schema = clientSession.getTableSchema(kuduTableHandle); + forAllColumnHandles(schema, handlesListBuilder::add); + List handlesList = handlesListBuilder.build(); + + KuduTable table = kuduTableHandle.getTable(clientSession); + ImmutableSet.Builder redistributionColumnIdsBuilder = ImmutableSet.builder(); + redistributionColumnIdsBuilder.addAll(table.getPartitionSchema().getRangeSchema().getColumnIds()); + table.getPartitionSchema().getHashBucketSchemas().forEach(partitionSchema -> + redistributionColumnIdsBuilder.addAll(partitionSchema.getColumnIds())); + Set redistributionColumnIds = redistributionColumnIdsBuilder.build(); + + return handlesList.stream() + .filter(column -> redistributionColumnIds.contains(column.getOrdinalPosition())) + .collect(toImmutableList()); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, MergeDetails mergeDetails) + { + KuduTableHandle kuduTableHandle = (KuduTableHandle) tableHandle; + KuduTable table = kuduTableHandle.getTable(clientSession); + Schema schema = table.getSchema(); + List columns = schema.getColumns(); + List columnTypes = columns.stream() + .map(TypeHelper::fromKuduColumn).collect(toImmutableList()); + ConnectorTableMetadata tableMetadata = getTableMetadata(kuduTableHandle); + List columnOriginalTypes = tableMetadata.getColumns().stream() + .map(ColumnMetadata::getType).collect(toImmutableList()); + PartitionDesign design = KuduTableProperties.getPartitionDesign(tableMetadata.getProperties()); + boolean generateUUID = !design.hasPartitions(); + return new KuduMergeTableHandle( + kuduTableHandle, + new KuduOutputTableHandle(tableMetadata.getTable(), columnOriginalTypes, columnTypes, generateUUID, table)); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + // For Kudu, nothing needs to be done finish the merge. + } + @Override public boolean usesLegacyTableLayouts() { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java index 47c4f39839a8..7e74e0899c50 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java @@ -19,6 +19,7 @@ import io.airlift.slice.Slice; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.DecimalType; @@ -26,6 +27,10 @@ import io.trino.spi.type.SqlDecimal; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import org.apache.kudu.Schema; +import org.apache.kudu.client.Delete; +import org.apache.kudu.client.Insert; +import org.apache.kudu.client.KeyEncoderAccessor; import org.apache.kudu.client.KuduException; import org.apache.kudu.client.KuduSession; import org.apache.kudu.client.KuduTable; @@ -43,6 +48,11 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.connector.MergeDetails.DEFAULT_CASE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.MergeDetails.UPDATE_OPERATION_NUMBER; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; @@ -61,7 +71,7 @@ import static java.util.concurrent.CompletableFuture.completedFuture; public class KuduPageSink - implements ConnectorPageSink + implements ConnectorPageSink, ConnectorMergeSink { private final ConnectorSession connectorSession; private final KuduSession session; @@ -89,6 +99,14 @@ public KuduPageSink( this(connectorSession, clientSession, tableHandle.getTable(clientSession), tableHandle); } + public KuduPageSink( + ConnectorSession connectorSession, + KuduClientSession clientSession, + KuduMergeTableHandle tableHandle) + { + this(connectorSession, clientSession, tableHandle.getOutputTableHandle().getTable(clientSession), tableHandle); + } + private KuduPageSink( ConnectorSession connectorSession, KuduClientSession clientSession, @@ -189,6 +207,57 @@ else if (type instanceof DecimalType) { } } + @Override + public void storeMergedRows(Page page) + { + // The last channel in the page is the rowId block, the next-to-last is the operation block + int columnCount = columnTypes.size(); + checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + Block operationBlock = page.getBlock(columnCount); + Block rowIds = page.getBlock(columnCount + 1); + + Schema schema = table.getSchema(); + for (int position = 0; position < page.getPositionCount(); position++) { + int operation = operationBlock.getInt(position, 0); + + if (operation == DEFAULT_CASE_OPERATION_NUMBER) { + continue; + } + + if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + Delete delete = table.newDelete(); + Slice deleteRowId = rowIds.getSlice(position, 0, rowIds.getSliceLength(position)); + RowHelper.copyPrimaryKey(schema, KeyEncoderAccessor.decodePrimaryKey(schema, deleteRowId.getBytes()), delete.getRow()); + try { + session.apply(delete); + } + catch (KuduException e) { + throw new RuntimeException(e); + } + } + + if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + Insert insert = table.newInsert(); + PartialRow insertRow = insert.getRow(); + int insertStart = 0; + if (generateUUID) { + String id = format("%s-%08x", uuid, nextSubId++); + insertRow.addString(0, id); + insertStart = 1; + } + for (int channel = 0; channel < columnCount; channel++) { + appendColumn(insertRow, page, position, channel, channel + insertStart); + } + try { + session.apply(insert); + } + catch (KuduException e) { + throw new RuntimeException(e); + } + } + } + } + @Override public CompletableFuture> finish() { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java index eff964c67a78..fd97c9ba1754 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java @@ -14,6 +14,8 @@ package io.trino.plugin.kudu; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -55,4 +57,14 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new KuduPageSink(session, clientSession, handle); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + requireNonNull(mergeHandle, "mergeHandle is null"); + checkArgument(mergeHandle instanceof KuduMergeTableHandle, "mergeHandle is not an instance of KuduMergeTableHandle"); + KuduMergeTableHandle handle = (KuduMergeTableHandle) mergeHandle; + + return new KuduPageSink(session, clientSession, handle); + } } diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java index 5bcb783bd286..ea09ee60688d 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java @@ -173,7 +173,7 @@ public static PartitionDesign getPartitionDesign(Map tableProper @SuppressWarnings("unchecked") List hashColumns = (List) tableProperties.get(PARTITION_BY_HASH_COLUMNS); @SuppressWarnings("unchecked") - List hashColumns2 = (List) tableProperties.get(PARTITION_BY_HASH_COLUMNS_2); + List hashColumns2 = (List) tableProperties.getOrDefault(PARTITION_BY_HASH_COLUMNS_2, ImmutableList.of()); PartitionDesign design = new PartitionDesign(); if (!hashColumns.isEmpty()) { diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/AbstractKuduIntegrationSmokeTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/AbstractKuduIntegrationSmokeTest.java index 11328f8b763a..b69e919529ea 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/AbstractKuduIntegrationSmokeTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/AbstractKuduIntegrationSmokeTest.java @@ -22,6 +22,8 @@ import java.util.Optional; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static io.trino.plugin.kudu.KuduQueryRunnerFactory.createKuduQueryRunnerTpch; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -31,6 +33,7 @@ import static io.trino.tpch.TpchTable.NATION; import static io.trino.tpch.TpchTable.ORDERS; import static io.trino.tpch.TpchTable.REGION; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertTrue; @@ -175,4 +178,245 @@ private void assertTableProperty(String tableProperties, String key, String rege assertTrue(Pattern.compile(key + "\\s*=\\s*" + regexValue + ",?\\s+").matcher(tableProperties).find(), "Not found: " + key + " = " + regexValue + " in " + tableProperties); } + + @Test + public void testMergeSimpleSelect() + { + String targetTable = "simple_select_target"; + String sourceTable = "simple_select_source"; + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR)" + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR)" + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)", + 4); + + assertQuery("SELECT * FROM " + targetTable, "SELECT * FROM (VALUES('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'))"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + assertUpdate("DROP TABLE IF EXISTS " + sourceTable); + } + } + + @Test + public void testMergeMultipleOperations() + { + String targetTable = "merge_multiple"; + try { + int targetCustomerCount = 10; + String originalInsertFirstHalf = IntStream.range(0, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address", + targetCustomerCount / 2); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("SELECT * FROM (VALUES %s, %s) AS v(customer, purchases, zipcode, spouse, address)", originalInsertFirstHalf, firstMergeSource)); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + String secondMergeSource = IntStream.range(0, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)", + targetCustomerCount * 3 / 2); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("SELECT * FROM (VALUES %s, %s, %s) AS v(customer, purchases, zipcode, spouse, address)", updatedBeginning, updatedMiddle, updatedEnd)); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + } + } + + @Test + public void testMergeInsertAll() + { + String targetTable = "merge_update_all"; + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR)" + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 3" + + ")", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); + + assertUpdate(format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire'))) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)", + 2); + + assertQuery( + "SELECT * FROM " + targetTable, + "SELECT * FROM (VALUES('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire'))"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + } + } + + @Test + public void testMergeAllColumnsUpdated() + { + String targetTable = "merge_all_columns_updated_target"; + String sourceTable = "merge_all_columns_updated_source"; + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address", + 4); + + assertQuery( + "SELECT * FROM " + targetTable, + "SELECT * FROM (VALUES ('Dave_updated', 22, 'Darbyshire'), ('Aaron_updated', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol_updated', 12, 'Centreville'))"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + assertUpdate("DROP TABLE IF EXISTS " + sourceTable); + } + } + + @Test + public void testMergeAllMatchesDeleted() + { + String targetTable = "merge_all_matches_deleted_target"; + String sourceTable = "merge_all_matches_deleted_source"; + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchases INT, address VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 2" + + ")", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN DELETE", + 4); + + assertQuery("SELECT * FROM " + targetTable, "SELECT * FROM (VALUES ('Bill', 7, 'Buena'))"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + assertUpdate("DROP TABLE IF EXISTS " + sourceTable); + } + } + + @Test + public void testMergeLimes() + { + String targetTable = "merge_with_various_formats"; + String sourceTable = "merge_simple_source"; + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchase VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 5" + + ")", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + assertQuery("SELECT * FROM " + targetTable, "SELECT * FROM (VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles'))"); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR WITH (primary_key=true), purchase VARCHAR) " + + "WITH (" + + " partition_by_hash_columns = ARRAY['customer'], " + + " partition_by_hash_buckets = 5" + + ")", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + assertUpdate(sql, 3); + + assertQuery("SELECT * FROM " + targetTable, "SELECT * FROM (VALUES ('Dave', 'dates'), row('Carol_Craig', 'candles'), row('Joe', 'jellybeans'))"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + targetTable); + assertUpdate("DROP TABLE IF EXISTS " + sourceTable); + } + } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/hive/TestHiveTransactionalTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/hive/TestHiveTransactionalTable.java index 2b5b96d30f22..15e8a1240a48 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/hive/TestHiveTransactionalTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/hive/TestHiveTransactionalTable.java @@ -19,7 +19,7 @@ import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.plugin.hive.metastore.thrift.ThriftHiveMetastoreClient; -import io.trino.tempto.assertions.QueryAssert; +import io.trino.tempto.assertions.QueryAssert.Row; import io.trino.tempto.hadoop.hdfs.HdfsClient; import io.trino.tempto.query.QueryExecutor; import io.trino.tempto.query.QueryResult; @@ -1290,6 +1290,420 @@ public void testDeleteWholePartition() }); } + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleSelect() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_simple_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleSelectPartitioned() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address'])", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_simple_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(sql); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT, dataProvider = "partitionedAndBucketedProvider") + public void testMergeUpdateWithVariousLayouts(boolean partitioned, String bucketing) + { + log.info("In testMergeUpdateWithVariousLayouts, partitioned %s, bucketing %s", partitioned, bucketing); + BucketingType bucketingType = bucketing.isEmpty() ? NONE : BUCKETED_V2; + withTemporaryTable("merge_with_various_formats", true, true, bucketingType, targetTable -> { + StringBuilder builder = new StringBuilder(); + builder.append("CREATE TABLE ") + .append(targetTable) + .append("(customer STRING"); + builder.append(partitioned ? ") PARTITIONED BY (" : ", ") + .append("purchase STRING) "); + if (!bucketing.isEmpty()) { + builder.append(bucketing); + } + builder.append(" STORED AS ORC TBLPROPERTIES ('transactional' = 'true')"); + onHive().executeQuery(builder.toString()); + + log.info("About to insert"); + query(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable)); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Lou", "limes"), row("Carol", "candles")); + + withTemporaryTable("merge_simple_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(sql); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Carol_Craig", "candles"), row("Joe", "jellybeans")); + }); + }); + } + + /** + * This test demonstrates a failure of Hive to verify the result of a MERGE operation, + * specifically, Hive fails to recognize the delete_delta file written by the MERGE. I + * captured the HDFS delta and delete_delta files and verified that they are correct. + * I used Wireshark to capture the traffic between Trino and the Hive metastore during + * the MERGE, and it was all as expected. I tried to vary the test to understand the + * issue, but almost any change I made to the test caused Hive to correctly verify the + * MERGE. + * TODO: Determine what is causing the Hive verification failure + */ + @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT) + public void testMergeUnBucketedUnPartitionedFailure() + { + log.info("In testMergeUnbucketedUnpartitioned"); + withTemporaryTable("merge_with_various_formats", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("About to insert"); + query(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable)); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Lou", "limes"), row("Carol", "candles")); + + withTemporaryTable("merge_simple_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(sql); + + log.info("Verifying MERGE on Presto"); + assertThat(query("SELECT * FROM " + targetTable)) + .containsOnly(row("Dave", "dates"), row("Carol_Craig", "candles"), row("Joe", "jellybeans")); + + log.info("Verifying MERGE on Hive fails - - and it shouldn't"); + assertThat(onHive().executeQuery("SELECT * FROM " + targetTable)) + .containsOnly(row("Dave", "dates"), row("Lou", "limes"), row("Carol", "candles"), row("Carol_Craig", "candles"), row("Joe", "jellybeans")); + + log.info("Finished"); + }); + }); + } + + @DataProvider + public Object[][] partitionedAndBucketedProvider() + { + return new Object[][] { + {false, "CLUSTERED BY (customer) INTO 3 BUCKETS"}, + {false, "CLUSTERED BY (purchase) INTO 4 BUCKETS"}, + {true, ""}, + {true, "CLUSTERED BY (customer) INTO 3 BUCKETS"}, + }; + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsUnbucketedUnpartitioned() + { + withTemporaryTable("merge_multiple", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR) WITH (transactional = true)", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsUnbucketedPartitioned() + { + withTemporaryTable("merge_multiple", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address', 'customer'])", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsBucketedUnpartitioned() + { + withTemporaryTable("merge_multiple", true, true, NONE, targetTable -> { + onHive().executeQuery(format("CREATE TABLE %s (customer STRING, purchases INT, zipcode INT, spouse STRING, address STRING)" + + " CLUSTERED BY(customer, zipcode, address) INTO 4 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + private void testMergeMultipleOperationsInternal(String targetTable, int targetCustomerCount) + { + log.info("Inserting a bunch into target"); + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + log.info("About to run first insert"); + query(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf)); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + log.info("About to run first merge"); + + query(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address"); + + QueryResult expectedResult = query(format("SELECT * FROM (VALUES %s, %s) AS v(customer, purchases, zipcode, spouse, address)", originalInsertFirstHalf, firstMergeSource)); + verifyOnTrinoAndHiveFromQueryResults("SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, expectedResult); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + log.info("About to run second insert"); + query(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert)); + + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + query(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)"); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + expectedResult = query(format("SELECT * FROM (VALUES %s, %s, %s) AS v(customer, purchases, zipcode, spouse, address)", updatedBeginning, updatedMiddle, updatedEnd)); + verifyOnTrinoAndHiveFromQueryResults("SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, expectedResult); + } + + private void verifyOnTrinoAndHiveFromQueryResults(String sql, QueryResult expectedResult) + { + QueryResult trinoResult = query(sql); + assertThat(trinoResult).contains(getRowsFromQueryResult(expectedResult)); + QueryResult hiveResult = onHive().executeQuery(sql); + assertThat(hiveResult).contains(getRowsFromQueryResult(expectedResult)); + } + + private List getRowsFromQueryResult(QueryResult result) + { + return result.rows().stream().map(values -> new Row(values)).collect(toImmutableList()); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleQuery() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + log.info("About to merge, target table %s", targetTable); + query(format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllInserts() + { + withTemporaryTable("merge_all_inserts", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable)); + + log.info("About to merge, target table %s", targetTable); + query(format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Antioch"), row("Bill", 7, "Buena"), row("Carol", 9, "Centreville"), row("Dave", 22, "Darbyshire")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleQueryPartitioned() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address'])", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + String query = format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + log.info("About to merge, target table %s", targetTable); + query(query); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllColumnsUpdated() + { + withTemporaryTable("merge_all_columns_updated_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable)); + + withTemporaryTable("merge_all_columns_updated_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable)); + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address"); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave_updated", 22, "Darbyshire"), row("Aaron_updated", 11, "Arches"), row("Bill", 7, "Buena"), row("Carol_updated", 12, "Centreville")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllMatchesDeleted() + { + withTemporaryTable("merge_all_matches_deleted_target", true, true, NONE, targetTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_all_matches_deleted_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable)); + + log.info("About to merge, target table %s, source table %s", targetTable, sourceTable); + query(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN DELETE"); + + log.info("Verifying MERGE"); + verifySelectForPrestoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Bill", 7, "Buena")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000, dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchFails(String createTableSql) + { + withTemporaryTable("merge_all_matches_deleted_target", true, true, NONE, targetTable -> { + onHive().executeQuery(format(createTableSql, targetTable)); + + log.info("Inserting into target"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable)); + + withTemporaryTable("merge_all_matches_deleted_source", true, true, NONE, sourceTable -> { + query(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + log.info("Inserting into source"); + query(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable)); + + log.info("About to run failing merge, target table %s, source table %s", targetTable, sourceTable); + assertThat(() -> query(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address")) + .failsWithMessage("One MERGE target table row matched more than one source row"); + + log.info("Merge succeeds if the WHEN clause condition limits to one source row"); + query(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address"); + log.info("Final SELECT"); + verifySelectForPrestoAndHive("SELECT customer, purchases, address FROM " + targetTable, "TRUE", row("Aaron", 5, "Adelphi"), row("Bill", 7, "Antioch")); + }); + }); + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (customer STRING, purchases INT, address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) CLUSTERED BY (address) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"} + }; + } + @DataProvider public Object[][] insertersProvider() { @@ -1607,13 +2021,13 @@ private void ensureSchemaEvolutionSupported() } } - private static void verifySelectForPrestoAndHive(String select, String whereClause, QueryAssert.Row... rows) + private static void verifySelectForPrestoAndHive(String select, String whereClause, Row... rows) { verifySelect("onPresto", onTrino(), select, whereClause, rows); verifySelect("onHive", onHive(), select, whereClause, rows); } - private static void verifySelect(String name, QueryExecutor executor, String select, String whereClause, QueryAssert.Row... rows) + private static void verifySelect(String name, QueryExecutor executor, String select, String whereClause, Row... rows) { String fullQuery = format("%s WHERE %s", select, whereClause);