diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java index 72969cedc99e..e199110a1b87 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java @@ -20,6 +20,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorIndexHandle; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; @@ -75,6 +76,12 @@ public static com.fasterxml.jackson.databind.Module tableExecuteHandleModule(Han return new AbstractTypedJacksonModule<>(ConnectorTableExecuteHandle.class, resolver::getId, resolver::getHandleClass) {}; } + @ProvidesIntoSet + public static com.fasterxml.jackson.databind.Module mergeTableHandleModule(HandleResolver resolver) + { + return new AbstractTypedJacksonModule<>(ConnectorMergeTableHandle.class, resolver::getId, resolver::getHandleClass) {}; + } + @ProvidesIntoSet public static com.fasterxml.jackson.databind.Module indexHandleModule(HandleResolver resolver) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java b/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java new file mode 100644 index 000000000000..39a598a59cd2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/MergeHandle.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public final class MergeHandle +{ + private final TableHandle tableHandle; + private final ConnectorMergeTableHandle connectorMergeHandle; + + @JsonCreator + public MergeHandle( + @JsonProperty("tableHandle") TableHandle tableHandle, + @JsonProperty("connectorMergeHandle") ConnectorMergeTableHandle connectorMergeHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.connectorMergeHandle = requireNonNull(connectorMergeHandle, "connectorMergeHandle is null"); + } + + @JsonProperty + public TableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public ConnectorMergeTableHandle getConnectorMergeHandle() + { + return connectorMergeHandle; + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index 9f2c749ed923..b5c62a8b7a27 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -37,6 +37,7 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleApplicationResult; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SortItem; @@ -375,6 +376,34 @@ Optional finishRefreshMaterializedView( */ void finishUpdate(Session session, TableHandle tableHandle, Collection fragments); + /** + * Return the row update paradigm supported by the connector on the table or throw + * an exception if row change is not supported. + */ + RowChangeParadigm getRowChangeParadigm(Session session, TableHandle tableHandle); + + /** + * Get the column handle that will generate row IDs for the merge operation. + * These IDs will be passed to the {@code storeMergedRows()} method of the + * {@link io.trino.spi.connector.ConnectorMergeSink} that created them. + */ + ColumnHandle getMergeRowIdColumnHandle(Session session, TableHandle tableHandle); + + /** + * Get the physical layout for updated or deleted rows of a MERGE operation. + */ + Optional getUpdateLayout(Session session, TableHandle tableHandle); + + /** + * Begin merge query + */ + MergeHandle beginMerge(Session session, TableHandle tableHandle); + + /** + * Finish merge query + */ + void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics); + /** * Returns a catalog handle for the specified catalog name. */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 81947cfcbd6b..964342c46f22 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -43,6 +43,7 @@ import io.trino.spi.connector.ConnectorCapabilities; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -65,6 +66,7 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleApplicationResult; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SchemaTableName; @@ -959,6 +961,26 @@ public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tabl return metadata.getUpdateRowIdColumnHandle(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle(), updatedColumns); } + @Override + public ColumnHandle getMergeRowIdColumnHandle(Session session, TableHandle tableHandle) + { + CatalogHandle catalogHandle = tableHandle.getCatalogHandle(); + ConnectorMetadata metadata = getMetadata(session, catalogHandle); + return metadata.getMergeRowIdColumnHandle(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle()); + } + + @Override + public Optional getUpdateLayout(Session session, TableHandle tableHandle) + { + CatalogHandle catalogHandle = tableHandle.getCatalogHandle(); + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogHandle); + ConnectorMetadata metadata = catalogMetadata.getMetadata(session); + ConnectorTransactionHandle transactionHandle = catalogMetadata.getTransactionHandleFor(catalogHandle); + + return metadata.getUpdateLayout(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle()) + .map(partitioning -> new PartitioningHandle(Optional.of(catalogHandle), Optional.of(transactionHandle), partitioning)); + } + @Override public Optional applyDelete(Session session, TableHandle table) { @@ -1014,6 +1036,31 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + CatalogHandle catalogHandle = mergeHandle.getTableHandle().getCatalogHandle(); + ConnectorMetadata metadata = getMetadata(session, catalogHandle); + metadata.finishMerge(session.toConnectorSession(catalogHandle), mergeHandle.getConnectorMergeHandle(), fragments, computedStatistics); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { diff --git a/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java index d9cca3e1a83f..3b172ca54d90 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/AbstractRowChangeOperator.java @@ -51,7 +51,7 @@ protected enum State protected State state = State.RUNNING; protected long rowCount; private boolean closed; - private ListenableFuture> finishFuture; + protected ListenableFuture> finishFuture; private ListenableFuture blockedFutureView; private Supplier> pageSource = Optional::empty; @@ -146,6 +146,7 @@ public void close() } else { pageSource.get().ifPresent(UpdatablePageSource::abort); + abort(); } } } @@ -158,7 +159,9 @@ public void setPageSource(Supplier> pageSource) protected UpdatablePageSource pageSource() { Optional source = pageSource.get(); - checkState(source.isPresent(), "UpdatablePageSource not set"); + checkState(source.isPresent(), "pageSource not set"); return source.get(); } + + protected void abort() {} } diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java new file mode 100644 index 000000000000..05396a4f1b3c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.block.RunLengthEncodedBlock; + +import java.util.ArrayList; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.predicate.Utils.nativeValueToBlock; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.util.Objects.requireNonNull; + +/** + * The transformPage() method in this class does two things: + *
    + *
  • Transform the input page into an "update" page format
  • + *
  • Removes all rows whose operation number is DEFAULT_CASE_OPERATION_NUMBER
  • + *
+ */ +public class ChangeOnlyUpdatedColumnsMergeProcessor + implements MergeRowChangeProcessor +{ + private static final Block INSERT_FROM_UPDATE_BLOCK = nativeValueToBlock(TINYINT, 0L); + + private final int rowIdChannel; + private final int mergeRowChannel; + private final List dataColumnChannels; + private final int writeRedistributionColumnCount; + + public ChangeOnlyUpdatedColumnsMergeProcessor( + int rowIdChannel, + int mergeRowChannel, + List dataColumnChannels, + List redistributionColumnChannels) + { + this.rowIdChannel = rowIdChannel; + this.mergeRowChannel = mergeRowChannel; + this.dataColumnChannels = requireNonNull(dataColumnChannels, "dataColumnChannels is null"); + this.writeRedistributionColumnCount = redistributionColumnChannels.size(); + } + + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + int inputChannelCount = inputPage.getChannelCount(); + checkArgument(inputChannelCount >= 2 + writeRedistributionColumnCount, "inputPage channelCount (%s) should be >= 2 + %s", inputChannelCount, writeRedistributionColumnCount); + int positionCount = inputPage.getPositionCount(); + // TODO: Check with Karol to see if we can get empty pages + checkArgument(positionCount > 0, "positionCount should be > 0, but is %s", positionCount); + + ColumnarRow mergeRow = toColumnarRow(inputPage.getBlock(mergeRowChannel)); + checkArgument(!mergeRow.mayHaveNull(), "The mergeRow may not have null rows"); + + // We've verified that the mergeRow block has no null rows, so it's okay to get the field blocks + + List builder = new ArrayList<>(dataColumnChannels.size() + 3); + + for (int channel : dataColumnChannels) { + builder.add(mergeRow.getField(channel)); + } + Block operationChannelBlock = mergeRow.getField(mergeRow.getFieldCount() - 2); + builder.add(operationChannelBlock); + builder.add(inputPage.getBlock(rowIdChannel)); + builder.add(new RunLengthEncodedBlock(INSERT_FROM_UPDATE_BLOCK, positionCount)); + + Page result = new Page(builder.toArray(Block[]::new)); + + int defaultCaseCount = 0; + for (int position = 0; position < positionCount; position++) { + if (TINYINT.getLong(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) { + defaultCaseCount++; + } + } + if (defaultCaseCount == 0) { + return result; + } + + int usedCases = 0; + int[] positions = new int[positionCount - defaultCaseCount]; + for (int position = 0; position < positionCount; position++) { + if (TINYINT.getLong(operationChannelBlock, position) != DEFAULT_CASE_OPERATION_NUMBER) { + positions[usedCases] = position; + usedCases++; + } + } + + checkArgument(usedCases + defaultCaseCount == positionCount, "usedCases (%s) + defaultCaseCount (%s) != positionCount (%s)", usedCases, defaultCaseCount, positionCount); + + return result.getPositions(positions, 0, usedCases); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java new file mode 100644 index 000000000000..0d3b8dea0294 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.type.Type; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class DeleteAndInsertMergeProcessor + implements MergeRowChangeProcessor +{ + private final List dataColumnTypes; + private final Type rowIdType; + private final int rowIdChannel; + private final int mergeRowChannel; + private final List dataColumnChannels; + private final int redistributionColumnCount; + private final List redistributionChannelNumbers; + + public DeleteAndInsertMergeProcessor( + List dataColumnTypes, + Type rowIdType, + int rowIdChannel, + int mergeRowChannel, + List redistributionChannelNumbers, + List dataColumnChannels) + { + this.dataColumnTypes = requireNonNull(dataColumnTypes, "dataColumnTypes is null"); + this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); + this.rowIdChannel = rowIdChannel; + this.mergeRowChannel = mergeRowChannel; + this.redistributionColumnCount = redistributionChannelNumbers.size(); + int redistributionSourceIndex = 0; + this.dataColumnChannels = requireNonNull(dataColumnChannels, "dataColumnChannels is null"); + ImmutableList.Builder redistributionChannelNumbersBuilder = ImmutableList.builder(); + for (int dataColumnChannel : dataColumnChannels) { + if (redistributionChannelNumbers.contains(dataColumnChannel)) { + redistributionChannelNumbersBuilder.add(redistributionSourceIndex); + redistributionSourceIndex++; + } + else { + redistributionChannelNumbersBuilder.add(-1); + } + } + this.redistributionChannelNumbers = redistributionChannelNumbersBuilder.build(); + } + + @JsonProperty + public List getDataColumnTypes() + { + return dataColumnTypes; + } + + @JsonProperty + public Type getRowIdType() + { + return rowIdType; + } + + /** + * Transform UPDATE operations into an INSERT and DELETE operation. + * See {@link MergeRowChangeProcessor#transformPage} for details. + */ + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + int inputChannelCount = inputPage.getChannelCount(); + checkArgument(inputChannelCount >= 2 + redistributionColumnCount, "inputPage channelCount (%s) should be >= 2 + partition columns size (%s)", inputChannelCount, redistributionColumnCount); + + int originalPositionCount = inputPage.getPositionCount(); + checkArgument(originalPositionCount > 0, "originalPositionCount should be > 0, but is %s", originalPositionCount); + + ColumnarRow mergeRow = toColumnarRow(inputPage.getBlock(mergeRowChannel)); + Block operationChannelBlock = mergeRow.getField(mergeRow.getFieldCount() - 2); + + int updatePositions = 0; + int insertPositions = 0; + int deletePositions = 0; + for (int position = 0; position < originalPositionCount; position++) { + int operation = toIntExact(TINYINT.getLong(operationChannelBlock, position)); + switch (operation) { + case DEFAULT_CASE_OPERATION_NUMBER -> { /* ignored */ } + case INSERT_OPERATION_NUMBER -> insertPositions++; + case DELETE_OPERATION_NUMBER -> deletePositions++; + case UPDATE_OPERATION_NUMBER -> updatePositions++; + default -> throw new IllegalArgumentException("Unknown operator number: " + operation); + } + } + + int totalPositions = insertPositions + deletePositions + (2 * updatePositions); + List pageTypes = ImmutableList.builder() + .addAll(dataColumnTypes) + .add(TINYINT) + .add(rowIdType) + .add(TINYINT) + .build(); + + PageBuilder pageBuilder = new PageBuilder(totalPositions, pageTypes); + for (int position = 0; position < originalPositionCount; position++) { + long operation = TINYINT.getLong(operationChannelBlock, position); + if (operation != DEFAULT_CASE_OPERATION_NUMBER) { + // Delete and Update because both create a delete row + if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + addDeleteRow(pageBuilder, inputPage, position); + } + // Insert and update because both create an insert row + if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + addInsertRow(pageBuilder, mergeRow, position, operation == UPDATE_OPERATION_NUMBER); + } + } + } + + Page page = pageBuilder.build(); + verify(page.getPositionCount() == totalPositions, "page positions (%s) is not equal to (%s)", page.getPositionCount(), totalPositions); + return page; + } + + private void addDeleteRow(PageBuilder pageBuilder, Page originalPage, int position) + { + // TODO: There is no need to copy the data columns themselves. Instead, we could + // use a DictionaryBlock to omit columns. + // Copy the write redistribution columns + for (int targetChannel : dataColumnChannels) { + Type columnType = dataColumnTypes.get(targetChannel); + BlockBuilder targetBlock = pageBuilder.getBlockBuilder(targetChannel); + + int redistributionChannelNumber = redistributionChannelNumbers.get(targetChannel); + if (redistributionChannelNumbers.get(targetChannel) >= 0) { + // The value comes from that column of the page + columnType.appendTo(originalPage.getBlock(redistributionChannelNumber), position, targetBlock); + } + else { + // We don't care about the other data columns + targetBlock.appendNull(); + } + } + + // Add the operation column == deleted + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size()), DELETE_OPERATION_NUMBER); + + // Copy row ID column + rowIdType.appendTo(originalPage.getBlock(rowIdChannel), position, pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1)); + + // Write 0, meaning this row is not an insert derived from an update + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2), 0); + + pageBuilder.declarePosition(); + } + + private void addInsertRow(PageBuilder pageBuilder, ColumnarRow mergeCaseBlock, int position, boolean causedByUpdate) + { + // Copy the values from the merge block + for (int targetChannel : dataColumnChannels) { + Type columnType = dataColumnTypes.get(targetChannel); + BlockBuilder targetBlock = pageBuilder.getBlockBuilder(targetChannel); + // The value comes from that column of the page + columnType.appendTo(mergeCaseBlock.getField(targetChannel), position, targetBlock); + } + + // Add the operation column == insert + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size()), INSERT_OPERATION_NUMBER); + + // Add null row ID column + pageBuilder.getBlockBuilder(dataColumnChannels.size() + 1).appendNull(); + + // Write 1 if this row is an insert derived from an update, 0 otherwise + TINYINT.writeLong(pageBuilder.getBlockBuilder(dataColumnChannels.size() + 2), causedByUpdate ? 1 : 0); + + pageBuilder.declarePosition(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java b/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java new file mode 100644 index 000000000000..69843e4d12ed --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableWriterNode.MergeParadigmAndTypes; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory; +import static io.trino.operator.WorkProcessor.TransformationState.finished; +import static io.trino.operator.WorkProcessor.TransformationState.ofResult; +import static java.util.Objects.requireNonNull; + +/** + * This operator is used by operations like SQL MERGE. It is used + * for all {@link io.trino.spi.connector.RowChangeParadigm}s. This operator + * creates the {@link MergeRowChangeProcessor}. + */ +public class MergeProcessorOperator + implements WorkProcessorOperator +{ + public static OperatorFactory createOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + MergeParadigmAndTypes merge, + int rowIdChannel, + int mergeRowChannel, + List redistributionColumns, + List dataColumnChannels) + { + MergeRowChangeProcessor rowChangeProcessor = createRowChangeProcessor(merge, rowIdChannel, mergeRowChannel, redistributionColumns, dataColumnChannels); + return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor)); + } + + private static MergeRowChangeProcessor createRowChangeProcessor(MergeParadigmAndTypes merge, int rowIdChannel, int mergeRowChannel, List redistributionColumnChannels, List dataColumnChannels) + { + return switch (merge.getParadigm()) { + case DELETE_ROW_AND_INSERT_ROW -> new DeleteAndInsertMergeProcessor( + merge.getColumnTypes(), + merge.getRowIdType(), + rowIdChannel, + mergeRowChannel, + redistributionColumnChannels, + dataColumnChannels); + case CHANGE_ONLY_UPDATED_COLUMNS -> new ChangeOnlyUpdatedColumnsMergeProcessor( + rowIdChannel, + mergeRowChannel, + dataColumnChannels, + redistributionColumnChannels); + }; + } + + public static class Factory + implements BasicWorkProcessorOperatorAdapter.BasicAdapterWorkProcessorOperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final MergeRowChangeProcessor rowChangeProcessor; + private boolean closed; + + public Factory(int operatorId, PlanNodeId planNodeId, MergeRowChangeProcessor rowChangeProcessor) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + @Override + public WorkProcessorOperator create(ProcessorContext processorContext, WorkProcessor sourcePages) + { + checkState(!closed, "Factory is already closed"); + return new MergeProcessorOperator(sourcePages, rowChangeProcessor); + } + + @Override + public int getOperatorId() + { + return operatorId; + } + + @Override + public PlanNodeId getPlanNodeId() + { + return planNodeId; + } + + @Override + public String getOperatorType() + { + return MergeProcessorOperator.class.getSimpleName(); + } + + @Override + public void close() + { + closed = true; + } + + @Override + public Factory duplicate() + { + return new Factory(operatorId, planNodeId, rowChangeProcessor); + } + } + + private final WorkProcessor pages; + + private MergeProcessorOperator( + WorkProcessor sourcePages, + MergeRowChangeProcessor rowChangeProcessor) + { + pages = sourcePages + .transform(page -> { + if (page == null) { + return finished(); + } + return ofResult(rowChangeProcessor.transformPage(page)); + }); + } + + @Override + public WorkProcessor getOutputPages() + { + return pages; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java new file mode 100644 index 000000000000..82720c08b8dd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; +import io.trino.spi.connector.ConnectorMergeSink; + +public interface MergeRowChangeProcessor +{ + int DEFAULT_CASE_OPERATION_NUMBER = -1; + + /** + * Transform a page generated by an SQL MERGE operation into page of data columns and + * operations. The SQL MERGE input page consists of the following: + *
    + *
  • The write redistribution columns, if any
  • + *
  • For partitioned or bucketed tables, a hash value column
  • + *
  • The rowId column for the row from the target table if matched, or null if not matched
  • + *
  • The merge case row block
  • + *
+ * The output page consists of the following: + *
    + *
  • All data columns, in table column order
  • + *
  • {@link ConnectorMergeSink#storeMergedRows The operation block}
  • + *
  • The rowId block
  • + *
  • The last column in the resulting page is 1 if the row is an insert + * derived from an update, and zero otherwise.
  • + *
+ *

+ * The {@link DeleteAndInsertMergeProcessor} implementation will transform each UPDATE + * row into multiple rows: an INSERT row and a DELETE row. + */ + Page transformPage(Page inputPage); +} diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeWriterOperator.java b/core/trino-main/src/main/java/io/trino/operator/MergeWriterOperator.java new file mode 100644 index 000000000000..ddc416dba9dc --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/MergeWriterOperator.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.Session; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.split.PageSinkManager; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; + +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.util.Objects.requireNonNull; + +public class MergeWriterOperator + extends AbstractRowChangeOperator +{ + public static class MergeWriterOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final PageSinkManager pageSinkManager; + private final MergeTarget target; + private final Session session; + private boolean closed; + + public MergeWriterOperatorFactory(int operatorId, PlanNodeId planNodeId, PageSinkManager pageSinkManager, MergeTarget target, Session session) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); + this.target = requireNonNull(target, "target is null"); + this.session = requireNonNull(session, "session is null"); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, MergeWriterOperator.class.getSimpleName()); + ConnectorMergeSink mergeSink = pageSinkManager.createMergeSink(session, target.getMergeHandle().get()); + return new MergeWriterOperator(context, mergeSink); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new MergeWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, session); + } + } + + private final ConnectorMergeSink mergeSink; + + public MergeWriterOperator(OperatorContext operatorContext, ConnectorMergeSink mergeSink) + { + super(operatorContext); + this.mergeSink = requireNonNull(mergeSink, "mergeSink is null"); + } + + @Override + public void addInput(Page page) + { + requireNonNull(page, "page is null"); + checkState(state == State.RUNNING, "Operator is %s", state); + + // Copy all but the last block to a new page. + // The last block exists only to get the rowCount right. + int outputChannelCount = page.getChannelCount() - 1; + int[] columns = IntStream.range(0, outputChannelCount).toArray(); + Page newPage = page.getColumns(columns); + + // Store the page + mergeSink.storeMergedRows(newPage); + + // Calculate the amount to increment the rowCount + Block insertFromUpdateColumn = page.getBlock(page.getChannelCount() - 1); + long insertsFromUpdates = 0; + int positionCount = page.getPositionCount(); + for (int position = 0; position < positionCount; position++) { + insertsFromUpdates += TINYINT.getLong(insertFromUpdateColumn, position); + } + rowCount += positionCount - insertsFromUpdates; + } + + @Override + public void finish() + { + if (state == State.RUNNING) { + state = State.FINISHING; + finishFuture = toListenableFuture(mergeSink.finish()); + } + } + + @Override + protected void abort() + { + mergeSink.abort(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index 4d297fc21ea8..1c721e453bae 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -14,6 +14,7 @@ package io.trino.operator.exchange; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.airlift.slice.XxHash64; @@ -214,8 +215,9 @@ private static PartitionFunction createPartitionFunction( // The same bucket function (with the same bucket count) as for node // partitioning must be used. This way rows within a single bucket // will be being processed by single thread. - ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningManager.getConnectorBucketNodeMap(session, partitioning); - int bucketCount = connectorBucketNodeMap.getBucketCount(); + int bucketCount = nodePartitioningManager.getConnectorBucketNodeMap(session, partitioning) + .map(ConnectorBucketNodeMap::getBucketCount) + .orElseThrow(() -> new VerifyException("No bucket node map for partitioning: " + partitioning)); int[] bucketToPartition = new int[bucketCount]; for (int bucket = 0; bucket < bucketCount; bucket++) { // mix the bucket bits so we don't use the same bucket number used to distribute between stages diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java index f391217b7c3e..5a403afda00f 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java @@ -17,8 +17,11 @@ import io.trino.connector.CatalogHandle; import io.trino.connector.CatalogServiceProvider; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.OutputTableHandle; import io.trino.metadata.TableExecuteHandle; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorSession; @@ -41,7 +44,6 @@ public PageSinkManager(CatalogServiceProvider pageSin @Override public ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle) { - // assumes connectorId and catalog are the same ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getCatalogHandle()); return providerFor(tableHandle.getCatalogHandle()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle()); } @@ -62,6 +64,15 @@ public ConnectorPageSink createPageSink(Session session, TableExecuteHandle tabl return providerFor(tableHandle.getCatalogHandle()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle()); } + @Override + public ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle) + { + // assumes connectorId and catalog are the same + TableHandle tableHandle = mergeHandle.getTableHandle(); + ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getCatalogHandle()); + return providerFor(tableHandle.getCatalogHandle()).createMergeSink(tableHandle.getTransaction(), connectorSession, mergeHandle.getConnectorMergeHandle()); + } + private ConnectorPageSinkProvider providerFor(CatalogHandle catalogHandle) { return pageSinkProvider.getService(catalogHandle); diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java index effcfaef23f3..e8551a96ee09 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java @@ -15,15 +15,28 @@ import io.trino.Session; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.OutputTableHandle; import io.trino.metadata.TableExecuteHandle; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; public interface PageSinkProvider { + /* + * Used for CTAS + */ ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle); + /* + * Used to insert into an existing table + */ ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle); ConnectorPageSink createPageSink(Session session, TableExecuteHandle tableHandle); + + /* + * Used to write the result of SQL MERGE to an existing table + */ + ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index fb518888803c..3442b7f24d12 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -47,6 +47,7 @@ import io.trino.spi.type.Type; import io.trino.sql.analyzer.ExpressionAnalyzer.LabelPrefixedReference; import io.trino.sql.analyzer.JsonPathAnalyzer.JsonPathAnalysis; +import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.tree.AllColumns; import io.trino.sql.tree.DereferenceExpression; import io.trino.sql.tree.ExistsPredicate; @@ -220,6 +221,7 @@ public class Analysis private Optional delegatedRefreshMaterializedView = Optional.empty(); private Optional analyzeMetadata = Optional.empty(); private Optional> updatedColumns = Optional.empty(); + private Optional mergeAnalysis = Optional.empty(); private final QueryType queryType; @@ -769,6 +771,16 @@ public Optional> getUpdatedColumns() return updatedColumns; } + public Optional getMergeAnalysis() + { + return mergeAnalysis; + } + + public void setMergeAnalysis(MergeAnalysis mergeAnalysis) + { + this.mergeAnalysis = Optional.of(mergeAnalysis); + } + public void setRefreshMaterializedView(RefreshMaterializedViewAnalysis refreshMaterializedView) { this.refreshMaterializedView = Optional.of(refreshMaterializedView); @@ -1625,6 +1637,102 @@ public boolean isFrameInherited() } } + public static class MergeAnalysis + { + private final Table targetTable; + private final List dataColumnSchemas; + private final List dataColumnHandles; + private final List redistributionColumnHandles; + private final List> mergeCaseColumnHandles; + private final Map columnHandleFieldNumbers; + private final List insertPartitioningArgumentIndexes; + private final Optional insertLayout; + private final Optional updateLayout; + private final Scope targetTableScope; + private final Scope joinScope; + + public MergeAnalysis( + Table targetTable, + List dataColumnSchemas, + List dataColumnHandles, + List redistributionColumnHandles, + List> mergeCaseColumnHandles, + Map columnHandleFieldNumbers, + List insertPartitioningArgumentIndexes, + Optional insertLayout, + Optional updateLayout, + Scope targetTableScope, + Scope joinScope) + { + this.targetTable = requireNonNull(targetTable, "targetTable is null"); + this.dataColumnSchemas = requireNonNull(dataColumnSchemas, "dataColumnSchemas is null"); + this.dataColumnHandles = requireNonNull(dataColumnHandles, "dataColumnHandles is null"); + this.redistributionColumnHandles = requireNonNull(redistributionColumnHandles, "redistributionColumnHandles is null"); + this.mergeCaseColumnHandles = requireNonNull(mergeCaseColumnHandles, "mergeCaseColumnHandles is null"); + this.columnHandleFieldNumbers = requireNonNull(columnHandleFieldNumbers, "columnHandleFieldNumbers is null"); + this.insertLayout = requireNonNull(insertLayout, "insertLayout is null"); + this.updateLayout = requireNonNull(updateLayout, "updateLayout is null"); + this.insertPartitioningArgumentIndexes = (requireNonNull(insertPartitioningArgumentIndexes, "insertPartitioningArgumentIndexes is null")); + this.targetTableScope = requireNonNull(targetTableScope, "targetTableScope is null"); + this.joinScope = requireNonNull(joinScope, "joinScope is null"); + } + + public Table getTargetTable() + { + return targetTable; + } + + public List getDataColumnSchemas() + { + return dataColumnSchemas; + } + + public List getDataColumnHandles() + { + return dataColumnHandles; + } + + public List getRedistributionColumnHandles() + { + return redistributionColumnHandles; + } + + public List> getMergeCaseColumnHandles() + { + return mergeCaseColumnHandles; + } + + public Map getColumnHandleFieldNumbers() + { + return columnHandleFieldNumbers; + } + + public List getInsertPartitioningArgumentIndexes() + { + return insertPartitioningArgumentIndexes; + } + + public Optional getInsertLayout() + { + return insertLayout; + } + + public Optional getUpdateLayout() + { + return updateLayout; + } + + public Scope getJoinScope() + { + return joinScope; + } + + public Scope getTargetTableScope() + { + return targetTableScope; + } + } + public static final class AccessControlInfo { private final AccessControl accessControl; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 730a5366634a..73b9cbe9580e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -100,6 +100,7 @@ import io.trino.sql.PlannerContext; import io.trino.sql.SqlPath; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; +import io.trino.sql.analyzer.Analysis.MergeAnalysis; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.Analysis.SourceColumn; @@ -111,6 +112,7 @@ import io.trino.sql.parser.SqlParser; import io.trino.sql.planner.DeterminismEvaluator; import io.trino.sql.planner.ExpressionInterpreter; +import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.ScopeAware; import io.trino.sql.planner.SymbolsExtractor; import io.trino.sql.planner.TypeProvider; @@ -168,6 +170,10 @@ import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.MeasureDefinition; import io.trino.sql.tree.Merge; +import io.trino.sql.tree.MergeCase; +import io.trino.sql.tree.MergeDelete; +import io.trino.sql.tree.MergeInsert; +import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.NaturalJoin; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; @@ -429,16 +435,17 @@ public Scope analyze(Node node, Optional outerQueryScope) .process(node, Optional.empty()); } - private Scope analyzeForUpdate(Table table, Optional outerQueryScope, UpdateKind updateKind) + public Scope analyzeForUpdate(Relation relation, Optional outerQueryScope, UpdateKind updateKind) { return new Visitor(outerQueryScope, warningCollector, Optional.of(updateKind)) - .process(table, Optional.empty()); + .process(relation, Optional.empty()); } private enum UpdateKind { DELETE, UPDATE, + MERGE, } /** @@ -1829,7 +1836,9 @@ protected Scope visitTable(Table table, Optional scope) ImmutableList.Builder fields = ImmutableList.builder(); fields.addAll(analyzeTableOutputFields(table, targetTableName, tableSchema, columnHandles)); - if (updateKind.isPresent()) { + boolean addRowIdColumn = updateKind.isPresent(); + + if (addRowIdColumn) { // Add the row id field ColumnHandle rowIdColumnHandle; switch (updateKind.get()) { @@ -1848,6 +1857,9 @@ protected Scope visitTable(Table table, Optional scope) .collect(toImmutableList()); rowIdColumnHandle = metadata.getUpdateRowIdColumnHandle(session, tableHandle.get(), updatedColumns); break; + case MERGE: + rowIdColumnHandle = metadata.getMergeRowIdColumnHandle(session, tableHandle.get()); + break; default: throw new UnsupportedOperationException("Unknown UpdateKind " + updateKind.get()); } @@ -1864,7 +1876,7 @@ protected Scope visitTable(Table table, Optional scope) Scope tableScope = createAndAssignScope(table, scope, outputFields); - if (updateKind.isPresent()) { + if (addRowIdColumn) { FieldReference reference = new FieldReference(outputFields.size() - 1); analyzeExpression(reference, tableScope); analysis.setRowIdField(table, reference); @@ -2682,7 +2694,7 @@ else if (node.getType() == FULL) { if (!clauseType.equals(UNKNOWN)) { throw semanticException(TYPE_MISMATCH, expression, "JOIN ON clause must evaluate to a boolean: actual type %s", clauseType); } - // coerce null to boolean + // coerce expression to boolean analysis.addCoercion(expression, BOOLEAN, false); } @@ -2801,7 +2813,239 @@ protected Scope visitUpdate(Update update, Optional scope) @Override protected Scope visitMerge(Merge merge, Optional scope) { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support merge"); + Relation relation = merge.getTarget(); + Table table = getMergeTargetTable(relation); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + if (metadata.getView(session, tableName).isPresent()) { + throw semanticException(NOT_SUPPORTED, merge, "MERGE INTO a view is not supported"); + } + + TableHandle targetTableHandle = metadata.getTableHandle(session, tableName) + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, table, "Table '%s' does not exist", tableName)); + + StatementAnalyzer analyzer = statementAnalyzerFactory + .withSpecializedAccessControl(new AllowAllAccessControl()) + .createStatementAnalyzer(analysis, session, warningCollector, CorrelationSupport.ALLOWED); + + Scope targetTableScope = analyzer.analyzeForUpdate(relation, scope, UpdateKind.MERGE); + Scope sourceTableScope = process(merge.getSource(), scope); + Scope joinScope = createAndAssignScope(merge, scope, targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType())); + + TableSchema tableSchema = metadata.getTableSchema(session, targetTableHandle); + + List dataColumnSchemas = tableSchema.getColumns().stream() + .filter(column -> !column.isHidden()) + .collect(toImmutableList()); + + Optional insertLayout = metadata.getInsertLayout(session, targetTableHandle); + + Map allColumnHandles = metadata.getColumnHandles(session, targetTableHandle); + ImmutableList.Builder dataColumnHandlesBuilder = ImmutableList.builder(); + ImmutableSet.Builder dataColumnNamesBuilder = ImmutableSet.builder(); + ImmutableList.Builder redistributionColumnHandlesBuilder = ImmutableList.builder(); + Set partitioningColumnNames = ImmutableSet.copyOf(insertLayout.map(TableLayout::getPartitionColumns).orElse(ImmutableList.of())); + for (ColumnSchema columnSchema : dataColumnSchemas) { + String name = columnSchema.getName(); + ColumnHandle handle = allColumnHandles.get(name); + dataColumnNamesBuilder.add(name); + dataColumnHandlesBuilder.add(handle); + if (partitioningColumnNames.contains(name)) { + redistributionColumnHandlesBuilder.add(handle); + } + } + List dataColumnHandles = dataColumnHandlesBuilder.build(); + Set dataColumnNames = dataColumnNamesBuilder.build(); + List redistributionColumnHandles = redistributionColumnHandlesBuilder.build(); + + Map dataColumnTypes = dataColumnSchemas.stream().collect(toImmutableMap(ColumnSchema::getName, ColumnSchema::getType)); + + // Analyze all expressions in the Merge node + + Expression mergePredicate = merge.getPredicate(); + ExpressionAnalysis predicateAnalysis = analyzeExpression(mergePredicate, joinScope, CorrelationSupport.DISALLOWED); + Type mergePredicateType = predicateAnalysis.getType(mergePredicate); + if (!typeCoercion.canCoerce(mergePredicateType, BOOLEAN)) { + throw semanticException(TYPE_MISMATCH, mergePredicate, "The MERGE predicate must evaluate to a boolean: actual type %s", mergePredicateType); + } + if (!mergePredicateType.equals(BOOLEAN)) { + analysis.addCoercion(mergePredicate, BOOLEAN, typeCoercion.isTypeOnlyCoercion(mergePredicateType, BOOLEAN)); + } + analysis.recordSubqueries(merge, predicateAnalysis); + + ImmutableSet.Builder allUpdateColumnNamesBuilder = ImmutableSet.builder(); + + for (int caseCounter = 0; caseCounter < merge.getMergeCases().size(); caseCounter++) { + MergeCase operation = merge.getMergeCases().get(caseCounter); + List caseColumnNames = lowercaseIdentifierList(operation.getSetColumns()); + if (operation instanceof MergeUpdate) { + allUpdateColumnNamesBuilder.addAll(caseColumnNames); + } + else if (operation instanceof MergeInsert && caseColumnNames.isEmpty()) { + caseColumnNames = dataColumnSchemas.stream().map(ColumnSchema::getName).collect(toImmutableList()); + } + int columnCount = caseColumnNames.size(); + List setExpressions = operation.getSetExpressions(); + checkArgument(columnCount == setExpressions.size(), "Number of merge columns (%s) isn't equal to number of expressions (%s)", columnCount, setExpressions.size()); + Set columnNameSet = new HashSet<>(columnCount); + caseColumnNames.forEach(mergeColumn -> { + if (!dataColumnNames.contains(mergeColumn)) { + throw semanticException(COLUMN_NOT_FOUND, merge, "Merge column name does not exist in target table: %s", mergeColumn); + } + if (!columnNameSet.add(mergeColumn)) { + throw semanticException(DUPLICATE_COLUMN_NAME, merge, "Merge column name is specified more than once: %s", mergeColumn); + } + }); + + if (operation.getExpression().isPresent()) { + Expression predicate = operation.getExpression().get(); + analysis.recordSubqueries(merge, analyzeExpression(predicate, joinScope)); + Type predicateType = analysis.getType(predicate); + + if (!predicateType.equals(BOOLEAN)) { + if (!typeCoercion.canCoerce(predicateType, BOOLEAN)) { + throw semanticException(TYPE_MISMATCH, predicate, "WHERE clause predicate must evaluate to a boolean: actual type %s", predicateType); + } + // Coerce the predicate to boolean + analysis.addCoercion(predicate, BOOLEAN, typeCoercion.isTypeOnlyCoercion(predicateType, BOOLEAN)); + } + } + + ImmutableList.Builder setColumnTypesBuilder = ImmutableList.builder(); + ImmutableList.Builder setExpressionTypesBuilder = ImmutableList.builder(); + for (int index = 0; index < caseColumnNames.size(); index++) { + String columnName = caseColumnNames.get(index); + Expression expression = setExpressions.get(index); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, joinScope); + analysis.recordSubqueries(merge, expressionAnalysis); + Type targetType = requireNonNull(dataColumnTypes.get(columnName)); + setColumnTypesBuilder.add(targetType); + setExpressionTypesBuilder.add(expressionAnalysis.getType(expression)); + } + List setColumnTypes = setColumnTypesBuilder.build(); + List setExpressionTypes = setExpressionTypesBuilder.build(); + if (!typesMatchForInsert(setColumnTypes, setExpressionTypes)) { + throw semanticException(TYPE_MISMATCH, + operation, + "MERGE table column types don't match for MERGE case %s, SET expressions: Table: [%s], Expressions: [%s]", + caseCounter, + Joiner.on(", ").join(setColumnTypes), + Joiner.on(", ").join(setExpressionTypes)); + } + for (int index = 0; index < caseColumnNames.size(); index++) { + Expression expression = operation.getSetExpressions().get(index); + Type targetType = dataColumnTypes.get(caseColumnNames.get(index)); + Type expressionType = setExpressionTypes.get(index); + if (!targetType.equals(expressionType)) { + analysis.addCoercion(expression, targetType, typeCoercion.isTypeOnlyCoercion(expressionType, targetType)); + } + } + } + + merge.getMergeCases().stream() + .filter(mergeCase -> mergeCase instanceof MergeInsert) + .findFirst() + .ifPresent(mergeCase -> accessControl.checkCanInsertIntoTable(session.toSecurityContext(), tableName)); + + merge.getMergeCases().stream() + .filter(mergeCase -> mergeCase instanceof MergeDelete) + .findFirst() + .ifPresent(mergeCase -> accessControl.checkCanDeleteFromTable(session.toSecurityContext(), tableName)); + + Set allUpdateColumnNames = allUpdateColumnNamesBuilder.build(); + if (!allUpdateColumnNames.isEmpty()) { + accessControl.checkCanUpdateTableColumns(session.toSecurityContext(), tableName, allUpdateColumnNames); + } + + if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) { + throw semanticException(NOT_SUPPORTED, merge, "Merge table with row filter"); + } + + for (ColumnSchema column : dataColumnSchemas) { + if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, column.getName(), column.getType()).isEmpty()) { + throw semanticException(NOT_SUPPORTED, merge, "Merge table with column mask"); + } + } + + List updatedColumns = allColumnHandles.keySet().stream() + .filter(allUpdateColumnNames::contains) + .map(columnHandle -> new OutputColumn(new Column(columnHandle, dataColumnTypes.get(columnHandle).toString()), ImmutableSet.of())) + .collect(toImmutableList()); + + analysis.setUpdateType("MERGE"); + analysis.setUpdateTarget(tableName, Optional.of(table), Optional.of(updatedColumns)); + List> mergeCaseColumnHandles = buildCaseColumnLists(merge, dataColumnSchemas, allColumnHandles); + + Optional updateLayout = metadata.getUpdateLayout(session, targetTableHandle); + + ImmutableMap.Builder columnHandleFieldNumbersBuilder = ImmutableMap.builder(); + Map fieldIndexes = new HashMap<>(); + RelationType relationType = targetTableScope.getRelationType(); + for (Field field : relationType.getAllFields()) { + // Only the rowId column handle will have no name, and we want to skip that column + field.getName().ifPresent(name -> { + int fieldIndex = relationType.indexOf(field); + ColumnHandle handle = allColumnHandles.get(name); + verify(handle != null, "allColumnHandles does not contain the named handle: %s", name); + columnHandleFieldNumbersBuilder.put(handle, fieldIndex); + fieldIndexes.put(name, fieldIndex); + }); + } + Map columnHandleFieldNumbers = columnHandleFieldNumbersBuilder.buildOrThrow(); + + List insertPartitioningArgumentIndexes = partitioningColumnNames.stream() + .map(fieldIndexes::get) + .collect(toImmutableList()); + + analysis.setMergeAnalysis(new MergeAnalysis( + table, + dataColumnSchemas, + dataColumnHandles, + redistributionColumnHandles, + mergeCaseColumnHandles, + columnHandleFieldNumbers, + insertPartitioningArgumentIndexes, + insertLayout, + updateLayout, + targetTableScope, + joinScope)); + + return createAndAssignScope(merge, Optional.empty(), Field.newUnqualified("rows", BIGINT)); + } + + private static Table getMergeTargetTable(Relation relation) + { + if (relation instanceof Table table) { + return table; + } + checkArgument(relation instanceof AliasedRelation, "relation is neither a Table nor an AliasedRelation"); + return (Table) ((AliasedRelation) relation).getRelation(); + } + + private List> buildCaseColumnLists(Merge merge, List columnSchemas, Map allColumnHandles) + { + ImmutableList.Builder> mergeCaseColumnsListsBuilder = ImmutableList.builder(); + for (int caseCounter = 0; caseCounter < merge.getMergeCases().size(); caseCounter++) { + MergeCase operation = merge.getMergeCases().get(caseCounter); + List mergeColumnNames; + if (operation instanceof MergeInsert && operation.getSetColumns().isEmpty()) { + mergeColumnNames = columnSchemas.stream().map(ColumnSchema::getName).collect(toImmutableList()); + } + else { + mergeColumnNames = lowercaseIdentifierList(operation.getSetColumns()); + } + mergeCaseColumnsListsBuilder.add( + mergeColumnNames.stream() + .map(name -> requireNonNull(allColumnHandles.get(name), "No column found for name")) + .collect(toImmutableList())); + } + return mergeCaseColumnsListsBuilder.build(); + } + + private List lowercaseIdentifierList(Collection identifiers) + { + return identifiers.stream() + .map(identifier -> identifier.getValue().toLowerCase(ENGLISH)) + .collect(toImmutableList()); } private Scope analyzeJoinUsing(Join node, List columns, Optional scope, Scope left, Scope right) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 2dfcf5a7c978..a141770ade1d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -44,6 +44,7 @@ import io.trino.index.IndexManager; import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionId; +import io.trino.metadata.MergeHandle; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableExecuteHandle; @@ -67,6 +68,8 @@ import io.trino.operator.LocalPlannerAware; import io.trino.operator.MarkDistinctOperator.MarkDistinctOperatorFactory; import io.trino.operator.MergeOperator.MergeOperatorFactory; +import io.trino.operator.MergeProcessorOperator; +import io.trino.operator.MergeWriterOperator.MergeWriterOperatorFactory; import io.trino.operator.OperatorFactories; import io.trino.operator.OperatorFactory; import io.trino.operator.OrderByOperator.OrderByOperatorFactory; @@ -197,6 +200,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PatternRecognitionNode.Measure; @@ -221,6 +226,7 @@ import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.TableExecuteTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TopNNode; @@ -3414,6 +3420,44 @@ public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecution return new PhysicalOperation(operatorFactory, outputMapping.buildOrThrow(), context, source); } + @Override + public PhysicalOperation visitMergeWriter(MergeWriterNode node, LocalExecutionPlanContext context) + { + context.setDriverInstanceCount(getTaskWriterCount(session)); + + PhysicalOperation source = node.getSource().accept(this, context); + OperatorFactory operatorFactory = new MergeWriterOperatorFactory(context.getNextOperatorId(), node.getId(), pageSinkManager, node.getTarget(), session); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, source); + } + + @Override + public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecutionPlanContext context) + { + PhysicalOperation source = node.getSource().accept(this, context); + + Map nodeLayout = makeLayout(node); + Map sourceLayout = makeLayout(node.getSource()); + int rowIdChannel = sourceLayout.get(node.getRowIdSymbol()); + int mergeRowChannel = sourceLayout.get(node.getMergeRowSymbol()); + + List redistributionColumns = node.getRedistributionColumnSymbols().stream() + .map(nodeLayout::get) + .collect(toImmutableList()); + List dataColumnChannels = node.getDataColumnSymbols().stream() + .map(nodeLayout::get) + .collect(toImmutableList()); + + OperatorFactory operatorFactory = MergeProcessorOperator.createOperatorFactory( + context.getNextOperatorId(), + node.getId(), + node.getTarget().getMergeParadigmAndTypes(), + rowIdChannel, + mergeRowChannel, + redistributionColumns, + dataColumnChannels); + return new PhysicalOperation(operatorFactory, nodeLayout, context, source); + } + @Override public PhysicalOperation visitTableDelete(TableDeleteNode node, LocalExecutionPlanContext context) { @@ -4052,6 +4096,11 @@ else if (target instanceof TableExecuteTarget) { metadata.finishTableExecute(session, tableExecuteHandle, fragments, tableExecuteContext.getSplitsInfo()); return Optional.empty(); } + else if (target instanceof MergeTarget mergeTarget) { + MergeHandle mergeHandle = mergeTarget.getMergeHandle().orElseThrow(() -> new IllegalArgumentException("mergeHandle not present")); + metadata.finishMerge(session, mergeHandle, fragments, statistics); + return Optional.empty(); + } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index f4514e569f17..e0e14f4dbe2c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -59,6 +59,7 @@ import io.trino.sql.planner.plan.ExplainAnalyzeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; @@ -87,6 +88,7 @@ import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.Insert; import io.trino.sql.tree.LambdaArgumentDeclaration; +import io.trino.sql.tree.Merge; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; @@ -311,6 +313,9 @@ private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement sta if (statement instanceof Update) { return createUpdatePlan(analysis, (Update) statement); } + if (statement instanceof Merge) { + return createMergePlan(analysis, (Merge) statement); + } if (statement instanceof Query) { return createRelationPlan(analysis, (Query) statement); } @@ -756,6 +761,22 @@ private RelationPlan createUpdatePlan(Analysis analysis, Update node) return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputSymbols(), Optional.empty()); } + private RelationPlan createMergePlan(Analysis analysis, Merge node) + { + MergeWriterNode mergeNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), plannerContext, Optional.empty(), session, ImmutableMap.of()) + .plan(node); + + TableFinishNode commitNode = new TableFinishNode( + idAllocator.getNextId(), + mergeNode, + mergeNode.getTarget(), + symbolAllocator.newSymbol("rows", BIGINT), + Optional.empty(), + Optional.empty()); + + return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputSymbols(), Optional.empty()); + } + private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) { ImmutableList.Builder outputs = ImmutableList.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java new file mode 100644 index 000000000000..bd4e984bb96d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java @@ -0,0 +1,186 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonGetter; +import com.google.common.base.VerifyException; +import io.trino.operator.BucketPartitionFunction; +import io.trino.operator.PartitionFunction; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorPartitioningHandle; +import io.trino.spi.type.Type; +import io.trino.sql.planner.SystemPartitioningHandle.SystemPartitionFunction.RoundRobinBucketFunction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.getLast; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public final class MergePartitioningHandle + implements ConnectorPartitioningHandle +{ + private final Optional insertPartitioning; + private final Optional updatePartitioning; + + @JsonCreator + public MergePartitioningHandle(Optional insertPartitioning, Optional updatePartitioning) + { + this.insertPartitioning = requireNonNull(insertPartitioning, "insertPartitioning is null"); + this.updatePartitioning = requireNonNull(updatePartitioning, "updatePartitioning is null"); + checkArgument(insertPartitioning.isPresent() || updatePartitioning.isPresent(), "insert or update partitioning must be present"); + } + + @JsonGetter + public Optional getInsertPartitioning() + { + return insertPartitioning; + } + + @JsonGetter + public Optional getUpdatePartitioning() + { + return updatePartitioning; + } + + @Override + public String toString() + { + List parts = new ArrayList<>(); + insertPartitioning.ifPresent(scheme -> parts.add("insert = " + scheme.getPartitioning().getHandle())); + updatePartitioning.ifPresent(scheme -> parts.add("update = " + scheme.getPartitioning().getHandle())); + return "MERGE " + parts; + } + + public NodePartitionMap getNodePartitioningMap(Function getMap) + { + Optional optionalInsertMap = insertPartitioning.map(scheme -> scheme.getPartitioning().getHandle()).map(getMap); + Optional optionalUpdateMap = updatePartitioning.map(scheme -> scheme.getPartitioning().getHandle()).map(getMap); + + if (optionalInsertMap.isPresent() && optionalUpdateMap.isPresent()) { + NodePartitionMap insertMap = optionalInsertMap.get(); + NodePartitionMap updateMap = optionalUpdateMap.get(); + if (!insertMap.getPartitionToNode().equals(updateMap.getPartitionToNode()) || + !Arrays.equals(insertMap.getBucketToPartition(), updateMap.getBucketToPartition())) { + throw new TrinoException(NOT_SUPPORTED, "Insert and update layout have mismatched BucketNodeMap"); + } + } + + return optionalInsertMap.orElseGet(optionalUpdateMap::orElseThrow); + } + + public PartitionFunction getPartitionFunction(PartitionFunctionLookup partitionFunctionLookup, List types, int[] bucketToPartition) + { + // channels: merge row, insert arguments, update row ID + List insertTypes = types.subList(1, types.size() - (updatePartitioning.isPresent() ? 1 : 0)); + + Optional insertFunction = insertPartitioning.map(scheme -> + partitionFunctionLookup.get(scheme, insertTypes)); + + Optional updateFunction = updatePartitioning.map(scheme -> + partitionFunctionLookup.get(scheme, List.of(getLast(types)))); + + return getPartitionFunction(insertFunction, updateFunction, insertTypes.size(), bucketToPartition); + } + + private static PartitionFunction getPartitionFunction(Optional insertFunction, Optional updateFunction, int insertArguments, int[] bucketToPartition) + { + if (insertFunction.isPresent() && updateFunction.isPresent()) { + return new MergePartitionFunction( + insertFunction.get(), + updateFunction.get(), + IntStream.range(1, insertArguments + 1).toArray(), + new int[] {insertArguments + 1}); + } + + PartitionFunction roundRobinFunction = new BucketPartitionFunction(new RoundRobinBucketFunction(bucketToPartition.length), bucketToPartition); + + if (insertFunction.isPresent()) { + return new MergePartitionFunction( + insertFunction.get(), + roundRobinFunction, + IntStream.range(1, insertArguments + 1).toArray(), + new int[] {}); + } + + if (updateFunction.isPresent()) { + return new MergePartitionFunction( + roundRobinFunction, + updateFunction.get(), + new int[] {}, + new int[] {insertArguments + 1}); + } + + throw new AssertionError(); + } + + public interface PartitionFunctionLookup + { + PartitionFunction get(PartitioningScheme scheme, List partitionChannelTypes); + } + + private static final class MergePartitionFunction + implements PartitionFunction + { + private final PartitionFunction insertFunction; + private final PartitionFunction updateFunction; + private final int[] insertColumns; + private final int[] updateColumns; + + public MergePartitionFunction(PartitionFunction insertFunction, PartitionFunction updateFunction, int[] insertColumns, int[] updateColumns) + { + this.insertFunction = requireNonNull(insertFunction, "insertFunction is null"); + this.updateFunction = requireNonNull(updateFunction, "updateFunction is null"); + this.insertColumns = requireNonNull(insertColumns, "insertColumns is null"); + this.updateColumns = requireNonNull(updateColumns, "updateColumns is null"); + checkArgument(insertFunction.getPartitionCount() == updateFunction.getPartitionCount(), "partition counts must match"); + } + + @Override + public int getPartitionCount() + { + return insertFunction.getPartitionCount(); + } + + @Override + public int getPartition(Page page, int position) + { + Block operationBlock = page.getBlock(0); + int operation = toIntExact(TINYINT.getLong(operationBlock, position)); + switch (operation) { + case INSERT_OPERATION_NUMBER: + return insertFunction.getPartition(page.getColumns(insertColumns), position); + case UPDATE_OPERATION_NUMBER: + case DELETE_OPERATION_NUMBER: + return updateFunction.getPartition(page.getColumns(updateColumns), position); + default: + throw new VerifyException("Invalid merge operation number: " + operation); + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java index d4cf282626a8..5bce806abe20 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java @@ -15,12 +15,14 @@ import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.connector.CatalogHandle; import io.trino.connector.CatalogServiceProvider; import io.trino.execution.scheduler.BucketNodeMap; import io.trino.execution.scheduler.FixedBucketNodeMap; import io.trino.execution.scheduler.NodeScheduler; +import io.trino.execution.scheduler.NodeSelector; import io.trino.execution.scheduler.group.DynamicBucketNodeMap; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; @@ -32,6 +34,7 @@ import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.type.Type; import io.trino.split.EmptySplit; +import io.trino.sql.planner.SystemPartitioningHandle.SystemPartitioning; import io.trino.type.BlockTypeOperators; import javax.inject.Inject; @@ -39,14 +42,21 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.ToIntFunction; import java.util.stream.IntStream; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.SystemSessionProperties.getHashPartitionCount; +import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.util.Failures.checkCondition; import static java.util.Objects.requireNonNull; public class NodePartitioningManager @@ -83,17 +93,39 @@ public PartitionFunction getPartitionFunction( blockTypeOperators); } + if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle handle) { + return handle.getPartitionFunction( + (scheme, types) -> getPartitionFunction(session, scheme, types, bucketToPartition), + partitionChannelTypes, + bucketToPartition); + } + + return getPartitionFunction(session, partitioningScheme, partitionChannelTypes, bucketToPartition); + } + + public PartitionFunction getPartitionFunction(Session session, PartitioningScheme partitioningScheme, List partitionChannelTypes, int[] bucketToPartition) + { + PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); + + if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle handle) { + return handle.getPartitionFunction( + partitionChannelTypes, + partitioningScheme.getHashColumn().isPresent(), + bucketToPartition, + blockTypeOperators); + } + BucketFunction bucketFunction = getBucketFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition.length); return new BucketPartitionFunction(bucketFunction, bucketToPartition); } public BucketFunction getBucketFunction(Session session, PartitioningHandle partitioningHandle, List partitionChannelTypes, int bucketCount) { - CatalogHandle catalogHandle = partitioningHandle.getCatalogHandle() - .orElseThrow(() -> new IllegalArgumentException("No catalog handle for partitioning handle: " + partitioningHandle)); + CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(catalogHandle); + BucketFunction bucketFunction = partitioningProvider.getBucketFunction( - partitioningHandle.getTransactionHandle().orElseThrow(() -> new IllegalArgumentException("No transactionHandle for partitioning handle: " + partitioningHandle)), + partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(), partitioningHandle.getConnectorHandle(), partitionChannelTypes, @@ -103,15 +135,37 @@ public BucketFunction getBucketFunction(Session session, PartitioningHandle part } public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle) + { + return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>()); + } + + /** + * This method is recursive for MergePartitioningHandle. It caches the node mappings + * to ensure that both the insert and update layouts use the same mapping. + */ + private NodePartitionMap getNodePartitioningMap( + Session session, + PartitioningHandle partitioningHandle, + Map> bucketToNodeCache, + AtomicReference> systemPartitioningCache) { requireNonNull(session, "session is null"); requireNonNull(partitioningHandle, "partitioningHandle is null"); if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) { - return ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getNodePartitionMap(session, nodeScheduler); + return systemNodePartitionMap(session, partitioningHandle, systemPartitioningCache); + } + + if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle mergeHandle) { + return mergeHandle.getNodePartitioningMap(handle -> getNodePartitioningMap(session, handle, bucketToNodeCache, systemPartitioningCache)); } - ConnectorBucketNodeMap connectorBucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle); + Optional optionalMap = getConnectorBucketNodeMap(session, partitioningHandle); + if (optionalMap.isEmpty()) { + return systemNodePartitionMap(session, FIXED_HASH_DISTRIBUTION, systemPartitioningCache); + } + ConnectorBucketNodeMap connectorBucketNodeMap = optionalMap.get(); + // safety check for crazy partitioning checkArgument(connectorBucketNodeMap.getBucketCount() < 1_000_000, "Too many buckets in partitioning: %s", connectorBucketNodeMap.getBucketCount()); @@ -120,11 +174,10 @@ public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHand bucketToNode = getFixedMapping(connectorBucketNodeMap); } else { - CatalogHandle catalogHandle = partitioningHandle.getCatalogHandle() - .orElseThrow(() -> new IllegalArgumentException("No catalog handle for partitioning handle: " + partitioningHandle)); - bucketToNode = createArbitraryBucketToNode( - nodeScheduler.createNodeSelector(session, Optional.of(catalogHandle)).allNodes(), - connectorBucketNodeMap.getBucketCount()); + CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); + bucketToNode = bucketToNodeCache.computeIfAbsent( + connectorBucketNodeMap.getBucketCount(), + bucketCount -> createArbitraryBucketToNode(getAllNodes(session, catalogHandle), bucketCount)); } int[] bucketToPartition = new int[connectorBucketNodeMap.getBucketCount()]; @@ -148,26 +201,60 @@ public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHand return new NodePartitionMap(partitionToNode, bucketToPartition, getSplitToBucket(session, partitioningHandle)); } + private NodePartitionMap systemNodePartitionMap(Session session, PartitioningHandle partitioningHandle, AtomicReference> nodesCache) + { + SystemPartitioning partitioning = ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getPartitioning(); + + NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, Optional.empty()); + + List nodes = switch (partitioning) { + case COORDINATOR_ONLY -> ImmutableList.of(nodeSelector.selectCurrentNode()); + case SINGLE -> nodeSelector.selectRandomNodes(1); + case FIXED -> { + List value = nodesCache.get(); + if (value == null) { + value = nodeSelector.selectRandomNodes(getHashPartitionCount(session)); + nodesCache.set(value); + } + yield value; + } + default -> throw new IllegalArgumentException("Unsupported plan distribution " + partitioning); + }; + checkCondition(!nodes.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); + + return new NodePartitionMap(nodes, split -> { + throw new UnsupportedOperationException("System distribution does not support source splits"); + }); + } + public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle, boolean preferDynamic) { - ConnectorBucketNodeMap connectorBucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle); + Optional bucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle); ToIntFunction splitToBucket = getSplitToBucket(session, partitioningHandle); - if (connectorBucketNodeMap.hasFixedMapping()) { - return new FixedBucketNodeMap(splitToBucket, getFixedMapping(connectorBucketNodeMap)); + if (bucketNodeMap.map(ConnectorBucketNodeMap::hasFixedMapping).orElse(false)) { + return new FixedBucketNodeMap(splitToBucket, getFixedMapping(bucketNodeMap.get())); } if (preferDynamic) { - return new DynamicBucketNodeMap(splitToBucket, connectorBucketNodeMap.getBucketCount()); + int bucketCount = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount) + .orElseGet(() -> getNodeCount(session, partitioningHandle)); + return new DynamicBucketNodeMap(splitToBucket, bucketCount); } - Optional catalogName = partitioningHandle.getCatalogHandle(); - checkArgument(catalogName.isPresent(), "No catalog handle for partitioning handle: %s", partitioningHandle); - return new FixedBucketNodeMap( - splitToBucket, - createArbitraryBucketToNode( - new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogName).allNodes()), - connectorBucketNodeMap.getBucketCount())); + List nodes = getAllNodes(session, requiredCatalogHandle(partitioningHandle)); + int bucketCount = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount).orElseGet(nodes::size); + return new FixedBucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount)); + } + + public int getNodeCount(Session session, PartitioningHandle partitioningHandle) + { + return getAllNodes(session, requiredCatalogHandle(partitioningHandle)).size(); + } + + private List getAllNodes(Session session, CatalogHandle catalogHandle) + { + return nodeScheduler.createNodeSelector(session, Optional.of(catalogHandle)).allNodes(); } private static List getFixedMapping(ConnectorBucketNodeMap connectorBucketNodeMap) @@ -177,27 +264,24 @@ private static List getFixedMapping(ConnectorBucketNodeMap connect .collect(toImmutableList()); } - public ConnectorBucketNodeMap getConnectorBucketNodeMap(Session session, PartitioningHandle partitioningHandle) + public Optional getConnectorBucketNodeMap(Session session, PartitioningHandle partitioningHandle) { - CatalogHandle catalogHandle = partitioningHandle.getCatalogHandle() - .orElseThrow(() -> new IllegalArgumentException("No catalog handle for partitioning handle: " + partitioningHandle)); + CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(catalogHandle); - ConnectorBucketNodeMap connectorBucketNodeMap = partitioningProvider.getBucketNodeMap( - partitioningHandle.getTransactionHandle().orElseThrow(() -> new IllegalArgumentException("No transactionHandle for partitioning handle: " + partitioningHandle)), + + return partitioningProvider.getBucketNodeMapping( + partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(catalogHandle), partitioningHandle.getConnectorHandle()); - checkArgument(connectorBucketNodeMap != null, "No partition map %s", partitioningHandle); - return connectorBucketNodeMap; } private ToIntFunction getSplitToBucket(Session session, PartitioningHandle partitioningHandle) { - CatalogHandle catalogHandle = partitioningHandle.getCatalogHandle() - .orElseThrow(() -> new IllegalArgumentException("No catalog handle for partitioning handle: " + partitioningHandle)); + CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(catalogHandle); ToIntFunction splitBucketFunction = partitioningProvider.getSplitBucketFunction( - partitioningHandle.getTransactionHandle().orElseThrow(() -> new IllegalArgumentException("No transactionHandle for partitioning handle: " + partitioningHandle)), + partitioningHandle.getTransactionHandle().orElseThrow(), session.toConnectorSession(catalogHandle), partitioningHandle.getConnectorHandle()); checkArgument(splitBucketFunction != null, "No partitioning %s", partitioningHandle); @@ -219,6 +303,12 @@ private ConnectorNodePartitioningProvider getPartitioningProvider(CatalogHandle return partitioningProvider.getService(requireNonNull(catalogHandle, "catalogHandle is null")); } + private static CatalogHandle requiredCatalogHandle(PartitioningHandle partitioningHandle) + { + return partitioningHandle.getCatalogHandle().orElseThrow(() -> + new IllegalStateException("No catalog handle for partitioning handle: " + partitioningHandle)); + } + private static List createArbitraryBucketToNode(List nodes, int bucketCount) { return cyclingShuffledStream(nodes) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index ef9592b1f227..38ddb77e1419 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -31,6 +31,8 @@ import io.trino.spi.type.Type; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; @@ -309,6 +311,21 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + { + if (node.getPartitioningScheme().isPresent()) { + context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitMergeProcessor(MergeProcessorNode node, RewriteContext context) + { + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitValues(ValuesNode node, RewriteContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 51a76d8e3377..374cf1f6c216 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -113,6 +113,7 @@ import io.trino.sql.planner.iterative.rule.PruneJoinColumns; import io.trino.sql.planner.iterative.rule.PruneLimitColumns; import io.trino.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import io.trino.sql.planner.iterative.rule.PruneMergeSourceColumns; import io.trino.sql.planner.iterative.rule.PruneOffsetColumns; import io.trino.sql.planner.iterative.rule.PruneOrderByInAggregation; import io.trino.sql.planner.iterative.rule.PruneOutputSourceColumns; @@ -1022,6 +1023,7 @@ public static Set> columnPruningRules(Metadata metadata) new PruneJoinColumns(), new PruneLimitColumns(), new PruneMarkDistinctColumns(), + new PruneMergeSourceColumns(), new PruneOffsetColumns(), new PruneOutputSourceColumns(), new PrunePattenRecognitionColumns(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index a3658de0cf41..9537863ffac7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -20,18 +20,26 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import io.trino.Session; +import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableHandle; +import io.trino.metadata.TableLayout; +import io.trino.metadata.TableMetadata; import io.trino.metadata.TableSchema; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ColumnSchema; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SortOrder; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.sql.ExpressionUtils; import io.trino.sql.NodeUtils; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.Analysis; import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis; +import io.trino.sql.analyzer.Analysis.MergeAnalysis; import io.trino.sql.analyzer.Analysis.ResolvedWindow; import io.trino.sql.analyzer.Analysis.SelectExpression; import io.trino.sql.analyzer.FieldId; @@ -39,11 +47,15 @@ import io.trino.sql.planner.RelationPlanner.PatternRecognitionComponents; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.GroupIdNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -53,6 +65,8 @@ import io.trino.sql.planner.plan.SortNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeParadigmAndTypes; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UpdateNode; @@ -64,18 +78,30 @@ import io.trino.sql.tree.Delete; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FetchFirst; +import io.trino.sql.tree.FieldReference; import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.FunctionCall.NullTreatment; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.IntervalLiteral; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; +import io.trino.sql.tree.Join; import io.trino.sql.tree.LambdaArgumentDeclaration; import io.trino.sql.tree.LambdaExpression; +import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.MeasureDefinition; +import io.trino.sql.tree.Merge; +import io.trino.sql.tree.MergeCase; +import io.trino.sql.tree.MergeDelete; +import io.trino.sql.tree.MergeInsert; +import io.trino.sql.tree.MergeUpdate; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NotExpression; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.PatternRecognitionRelation.RowsPerMatch; @@ -83,11 +109,15 @@ import io.trino.sql.tree.Query; import io.trino.sql.tree.QuerySpecification; import io.trino.sql.tree.Relation; +import io.trino.sql.tree.Row; +import io.trino.sql.tree.SearchedCaseExpression; import io.trino.sql.tree.SortItem; +import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.Table; import io.trino.sql.tree.Union; import io.trino.sql.tree.Update; import io.trino.sql.tree.VariableDefinition; +import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.WindowFrame; import io.trino.sql.tree.WindowOperation; import io.trino.type.TypeCoercion; @@ -115,9 +145,15 @@ import static io.trino.SystemSessionProperties.getMaxRecursionDepth; import static io.trino.SystemSessionProperties.isSkipRedundantSort; import static io.trino.spi.StandardErrorCode.INVALID_WINDOW_FRAME; +import static io.trino.spi.StandardErrorCode.MERGE_TARGET_ROW_MULTIPLE_MATCHES; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy; import static io.trino.sql.analyzer.ExpressionAnalyzer.isNumericType; @@ -127,6 +163,7 @@ import static io.trino.sql.planner.OrderingScheme.sortItemToSortOrder; import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -573,6 +610,299 @@ public UpdateNode plan(Update node) outputs); } + public MergeWriterNode plan(Merge merge) + { + MergeAnalysis mergeAnalysis = analysis.getMergeAnalysis().orElseThrow(() -> new IllegalArgumentException("analysis.getMergeAnalysis() isn't present")); + + List> mergeCaseColumnsHandles = mergeAnalysis.getMergeCaseColumnHandles(); + + // Make the plan for the merge target table scan + RelationPlan targetTablePlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, plannerContext, outerContext, session, recursiveSubqueries) + .process(merge.getTarget()); + + // Assign a unique id to every target table row + Symbol uniqueIdSymbol = symbolAllocator.newSymbol("unique_id", BIGINT); + RelationPlan planWithUniqueId = new RelationPlan( + new AssignUniqueId(idAllocator.getNextId(), targetTablePlan.getRoot(), uniqueIdSymbol), + mergeAnalysis.getTargetTableScope(), + targetTablePlan.getFieldMappings(), + outerContext); + + // Project the "present" column + Assignments.Builder projections = Assignments.builder(); + projections.putIdentities(planWithUniqueId.getRoot().getOutputSymbols()); + + Symbol presentColumn = symbolAllocator.newSymbol("present", BOOLEAN); + projections.put(presentColumn, TRUE_LITERAL); + + RelationPlan planWithPresentColumn = new RelationPlan( + new ProjectNode(idAllocator.getNextId(), planWithUniqueId.getRoot(), projections.build()), + mergeAnalysis.getTargetTableScope(), + planWithUniqueId.getFieldMappings(), + outerContext); + + RelationPlan source = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, plannerContext, outerContext, session, recursiveSubqueries) + .process(merge.getSource()); + + RelationPlan joinPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, plannerContext, outerContext, session, recursiveSubqueries) + .planJoin(coerceIfNecessary(analysis, merge.getPredicate(), merge.getPredicate()), Join.Type.RIGHT, mergeAnalysis.getJoinScope(), planWithPresentColumn, source, analysis.getSubqueries(merge)); + + PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); + + // Build the SearchedCaseExpression that creates the project merge_row + + ImmutableList.Builder whenClauses = ImmutableList.builder(); + for (int caseNumber = 0; caseNumber < merge.getMergeCases().size(); caseNumber++) { + MergeCase mergeCase = merge.getMergeCases().get(caseNumber); + + Optional casePredicate = Optional.empty(); + if (mergeCase.getExpression().isPresent()) { + Expression original = mergeCase.getExpression().get(); + Expression predicate = coerceIfNecessary(analysis, original, original); + casePredicate = Optional.of(predicate); + subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(merge)); + } + + ImmutableList.Builder rowBuilder = ImmutableList.builder(); + List mergeCaseSetColumns = mergeCaseColumnsHandles.get(caseNumber); + for (ColumnHandle dataColumnHandle : mergeAnalysis.getDataColumnHandles()) { + int index = mergeCaseSetColumns.indexOf(dataColumnHandle); + if (index >= 0) { + Expression original = mergeCase.getSetExpressions().get(index); + Expression setExpression = coerceIfNecessary(analysis, original, original); + subPlan = subqueryPlanner.handleSubqueries(subPlan, setExpression, analysis.getSubqueries(merge)); + rowBuilder.add(subPlan.rewrite(setExpression)); + } + else { + Integer fieldNumber = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Field number for ColumnHandle is null"); + rowBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldNumber).toSymbolReference()); + } + } + + // Build the match condition for the MERGE case + + // Add a boolean column which is true if a target table row was matched + rowBuilder.add(new IsNotNullPredicate(presentColumn.toSymbolReference())); + + // Add the operation number + rowBuilder.add(new GenericLiteral("TINYINT", String.valueOf(getMergeCaseOperationNumber(mergeCase)))); + + // Add the merge case number, needed by MarkDistinct + rowBuilder.add(new GenericLiteral("INTEGER", String.valueOf(caseNumber))); + + Optional rewritten = casePredicate.map(subPlan::rewrite); + Expression condition = presentColumn.toSymbolReference(); + if (mergeCase instanceof MergeInsert) { + condition = new IsNullPredicate(presentColumn.toSymbolReference()); + } + + if (rewritten.isPresent()) { + condition = ExpressionUtils.and(condition, rewritten.get()); + } + + whenClauses.add(new WhenClause(condition, new Row(rowBuilder.build()))); + } + + // Build the "else" clause for the SearchedCaseExpression + ImmutableList.Builder rowBuilder = ImmutableList.builder(); + mergeAnalysis.getDataColumnSchemas().forEach(columnSchema -> + rowBuilder.add(new Cast(new NullLiteral(), toSqlType(columnSchema.getType())))); + rowBuilder.add(new IsNotNullPredicate(presentColumn.toSymbolReference())); + // The operation number + rowBuilder.add(new GenericLiteral("TINYINT", "-1")); + // The case number + rowBuilder.add(new GenericLiteral("INTEGER", "-1")); + + SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(rowBuilder.build()))); + RowType rowType = createMergeRowType(mergeAnalysis.getDataColumnSchemas()); + + FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); + Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex()); + Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", rowType); + Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); + + // Project the partition symbols, the merge_row, the rowId, and the unique_id symbol + Assignments.Builder projectionAssignmentsBuilder = Assignments.builder(); + for (ColumnHandle column : mergeAnalysis.getRedistributionColumnHandles()) { + int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(column), "Could not find fieldIndex for redistribution column"); + Symbol symbol = planWithPresentColumn.getFieldMappings().get(fieldIndex); + projectionAssignmentsBuilder.put(symbol, symbol.toSymbolReference()); + } + projectionAssignmentsBuilder.put(uniqueIdSymbol, uniqueIdSymbol.toSymbolReference()); + projectionAssignmentsBuilder.put(rowIdSymbol, rowIdSymbol.toSymbolReference()); + projectionAssignmentsBuilder.put(mergeRowSymbol, caseExpression); + + ProjectNode subPlanProject = new ProjectNode( + idAllocator.getNextId(), + subPlan.getRoot(), + projectionAssignmentsBuilder.build()); + + // Now add a column for the case_number, gotten from the merge_row + ProjectNode project = new ProjectNode( + idAllocator.getNextId(), + subPlanProject, + Assignments.builder() + .putIdentities(subPlanProject.getOutputSymbols()) + .put(caseNumberSymbol, new SubscriptExpression(mergeRowSymbol.toSymbolReference(), new LongLiteral(Long.toString(rowType.getFields().size())))) + .build()); + + // Mark distinct combinations of the unique_id value and the case_number + Symbol isDistinctSymbol = symbolAllocator.newSymbol("is_distinct", BOOLEAN); + MarkDistinctNode markDistinctNode = new MarkDistinctNode(idAllocator.getNextId(), project, isDistinctSymbol, ImmutableList.of(uniqueIdSymbol, caseNumberSymbol), Optional.empty()); + + // Raise an error if unique_id symbol is non-null and the unique_id/case_number combination was not distinct + Metadata metadata = plannerContext.getMetadata(); + Expression filter = new IfExpression( + LogicalExpression.and( + new NotExpression(isDistinctSymbol.toSymbolReference()), + new IsNotNullPredicate(uniqueIdSymbol.toSymbolReference())), + new Cast( + failFunction(metadata, session, MERGE_TARGET_ROW_MULTIPLE_MATCHES, "One MERGE target table row matched more than one source row"), + toSqlType(BOOLEAN)), + TRUE_LITERAL); + + FilterNode filterNode = new FilterNode(idAllocator.getNextId(), markDistinctNode, filter); + + Table table = merge.getTargetTable(); + TableHandle handle = analysis.getTableHandle(table); + TableMetadata tableMetadata = metadata.getTableMetadata(session, handle); + + RowChangeParadigm paradigm = metadata.getRowChangeParadigm(session, handle); + Type rowIdType = analysis.getType(analysis.getRowIdField(table)); + List dataColumnTypes = tableMetadata.getMetadata().getColumns().stream() + .filter(column -> !column.isHidden()) + .map(ColumnMetadata::getType) + .collect(toImmutableList()); + + MergeParadigmAndTypes mergeParadigmAndTypes = new MergeParadigmAndTypes(paradigm, dataColumnTypes, rowIdType); + MergeTarget mergeTarget = new MergeTarget(handle, Optional.empty(), tableMetadata.getTable(), mergeParadigmAndTypes); + + ImmutableList.Builder columnSymbolsBuilder = ImmutableList.builder(); + for (ColumnHandle columnHandle : mergeAnalysis.getDataColumnHandles()) { + int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(columnHandle), "Could not find field number for column handle"); + columnSymbolsBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldIndex)); + } + List columnSymbols = columnSymbolsBuilder.build(); + ImmutableList.Builder redistributionSymbolsBuilder = ImmutableList.builder(); + for (ColumnHandle columnHandle : mergeAnalysis.getRedistributionColumnHandles()) { + int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(columnHandle), "Could not find field number for column handle"); + redistributionSymbolsBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldIndex)); + } + + Symbol operationSymbol = symbolAllocator.newSymbol("operation", TINYINT); + Symbol insertFromUpdateSymbol = symbolAllocator.newSymbol("insert_from_update", TINYINT); + + List projectedSymbols = ImmutableList.builder() + .addAll(columnSymbols) + .add(operationSymbol) + .add(rowIdSymbol) + .add(insertFromUpdateSymbol) + .build(); + + MergeProcessorNode mergeProcessorNode = new MergeProcessorNode( + idAllocator.getNextId(), + filterNode, + mergeTarget, + rowIdSymbol, + mergeRowSymbol, + columnSymbols, + redistributionSymbolsBuilder.build(), + projectedSymbols); + + Optional partitioningScheme = createMergePartitioningScheme( + mergeAnalysis.getInsertLayout(), + columnSymbols, + mergeAnalysis.getInsertPartitioningArgumentIndexes(), + mergeAnalysis.getUpdateLayout(), + rowIdSymbol, + operationSymbol); + + List outputs = ImmutableList.of( + symbolAllocator.newSymbol("partialrows", BIGINT), + symbolAllocator.newSymbol("fragment", VARBINARY)); + + return new MergeWriterNode( + idAllocator.getNextId(), + mergeProcessorNode, + mergeTarget, + projectedSymbols, + partitioningScheme, + outputs); + } + + private static int getMergeCaseOperationNumber(MergeCase mergeCase) + { + if (mergeCase instanceof MergeInsert) { + return INSERT_OPERATION_NUMBER; + } + if (mergeCase instanceof MergeUpdate) { + return UPDATE_OPERATION_NUMBER; + } + if (mergeCase instanceof MergeDelete) { + return DELETE_OPERATION_NUMBER; + } + throw new IllegalArgumentException("Unrecognized MergeCase: " + mergeCase); + } + + private static RowType createMergeRowType(List allColumnsSchema) + { + // create the RowType that holds all column values + List fields = new ArrayList<>(); + for (ColumnSchema schema : allColumnsSchema) { + fields.add(new RowType.Field(Optional.of(schema.getName()), schema.getType())); + } + fields.add(new RowType.Field(Optional.empty(), BOOLEAN)); // present + fields.add(new RowType.Field(Optional.empty(), TINYINT)); // operation_number + fields.add(new RowType.Field(Optional.empty(), INTEGER)); // case_number + return RowType.from(fields); + } + + public static Optional createMergePartitioningScheme( + Optional insertLayout, + List symbols, + List insertPartitioningArgumentIndexes, + Optional updateLayout, + Symbol rowIdSymbol, + Symbol operationSymbol) + { + if (insertLayout.isEmpty() && updateLayout.isEmpty()) { + return Optional.empty(); + } + + Optional insertPartitioning = insertLayout.map(layout -> { + List arguments = insertPartitioningArgumentIndexes.stream() + .map(symbols::get) + .collect(toImmutableList()); + + return layout.getPartitioning() + .map(handle -> new PartitioningScheme(Partitioning.create(handle, arguments), symbols)) + // empty connector partitioning handle means evenly partitioning on partitioning columns + .orElseGet(() -> new PartitioningScheme(Partitioning.create(FIXED_HASH_DISTRIBUTION, arguments), symbols)); + }); + + Optional updatePartitioning = updateLayout.map(handle -> + new PartitioningScheme(Partitioning.create(handle, ImmutableList.of(rowIdSymbol)), ImmutableList.of(rowIdSymbol))); + + PartitioningHandle partitioningHandle = new PartitioningHandle( + Optional.empty(), + Optional.empty(), + new MergePartitioningHandle(insertPartitioning, updatePartitioning)); + + List combinedSymbols = new ArrayList<>(); + combinedSymbols.add(operationSymbol); + insertPartitioning.ifPresent(scheme -> combinedSymbols.addAll(partitioningSymbols(scheme))); + updatePartitioning.ifPresent(scheme -> combinedSymbols.addAll(partitioningSymbols(scheme))); + + return Optional.of(new PartitioningScheme(Partitioning.create(partitioningHandle, combinedSymbols), combinedSymbols)); + } + + private static List partitioningSymbols(PartitioningScheme scheme) + { + return scheme.getPartitioning().getArguments().stream() + .map(Partitioning.ArgumentBinding::getColumn) + .collect(toImmutableList()); + } + private static Optional getIdForLeftTableScan(PlanNode node) { if (node instanceof TableScanNode) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 29c4c5d92d4c..70556fd420ce 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -579,7 +579,7 @@ protected RelationPlan visitJoin(Join node, Void context) return planJoin(analysis.getJoinCriteria(node), node.getType(), analysis.getScope(node), leftPlan, rightPlan, analysis.getSubqueries(node)); } - private RelationPlan planJoin(Expression criteria, Join.Type type, Scope scope, RelationPlan leftPlan, RelationPlan rightPlan, Analysis.SubqueryAnalysis subqueries) + public RelationPlan planJoin(Expression criteria, Join.Type type, Scope scope, RelationPlan leftPlan, RelationPlan rightPlan, Analysis.SubqueryAnalysis subqueries) { // NOTE: symbols must be in the same order as the outputDescriptor List outputSymbols = ImmutableList.builder() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java index 642aa6c14159..3ac03c33eca7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java @@ -40,6 +40,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -409,6 +411,18 @@ public Map visitUpdate(UpdateNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Map visitMergeWriter(MergeWriterNode node, Void context) + { + return node.getSource().accept(this, context); + } + + @Override + public Map visitMergeProcessor(MergeProcessorNode node, Void context) + { + return node.getSource().accept(this, context); + } + @Override public Map visitTableDelete(TableDeleteNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java index bf433e6a23bd..ce2abf3682f8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java @@ -15,11 +15,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; -import io.trino.Session; -import io.trino.execution.scheduler.NodeScheduler; -import io.trino.execution.scheduler.NodeSelector; -import io.trino.metadata.InternalNode; import io.trino.operator.BucketPartitionFunction; import io.trino.operator.InterpretedHashGenerator; import io.trino.operator.PartitionFunction; @@ -36,15 +31,12 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.SystemSessionProperties.getHashPartitionCount; -import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; -import static io.trino.util.Failures.checkCondition; import static java.util.Objects.requireNonNull; public final class SystemPartitioningHandle implements ConnectorPartitioningHandle { - private enum SystemPartitioning + enum SystemPartitioning { SINGLE, FIXED, @@ -134,32 +126,6 @@ public String toString() return partitioning.toString(); } - public NodePartitionMap getNodePartitionMap(Session session, NodeScheduler nodeScheduler) - { - NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, Optional.empty()); - List nodes; - - switch (partitioning) { - case COORDINATOR_ONLY: - nodes = ImmutableList.of(nodeSelector.selectCurrentNode()); - break; - case SINGLE: - nodes = nodeSelector.selectRandomNodes(1); - break; - case FIXED: - nodes = nodeSelector.selectRandomNodes(getHashPartitionCount(session)); - break; - default: - throw new IllegalArgumentException("Unsupported plan distribution " + partitioning); - } - - checkCondition(!nodes.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); - - return new NodePartitionMap(nodes, split -> { - throw new UnsupportedOperationException("System distribution does not support source splits"); - }); - } - public PartitionFunction getPartitionFunction(List partitionChannelTypes, boolean isHashPrecomputed, int[] bucketToPartition, BlockTypeOperators blockTypeOperators) { requireNonNull(partitionChannelTypes, "partitionChannelTypes is null"); @@ -227,7 +193,7 @@ public int getBucket(Page page, int position) } } - private static class RoundRobinBucketFunction + public static class RoundRobinBucketFunction implements BucketFunction { private final int bucketCount; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineTableScanNodePartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineTableScanNodePartitioning.java index bf98175c6d35..8b562a63ad29 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineTableScanNodePartitioning.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineTableScanNodePartitioning.java @@ -24,6 +24,8 @@ import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.TableScanNode; +import java.util.Optional; + import static io.trino.SystemSessionProperties.getTableScanNodePartitioningMinBucketToTaskRatio; import static io.trino.SystemSessionProperties.isUseTableScanNodePartitioning; import static io.trino.sql.planner.plan.Patterns.tableScan; @@ -62,8 +64,8 @@ public Result apply(TableScanNode node, Captures captures, Context context) } TablePartitioning partitioning = properties.getTablePartitioning().get(); - ConnectorBucketNodeMap bucketNodeMap = nodePartitioningManager.getConnectorBucketNodeMap(context.getSession(), partitioning.getPartitioningHandle()); - if (bucketNodeMap.hasFixedMapping()) { + Optional bucketNodeMap = nodePartitioningManager.getConnectorBucketNodeMap(context.getSession(), partitioning.getPartitioningHandle()); + if (bucketNodeMap.map(ConnectorBucketNodeMap::hasFixedMapping).orElse(false)) { // use connector table scan node partitioning when bucket to node assignments are fixed return Result.ofPlanNode(node.withUseConnectorNodePartitioning(true)); } @@ -72,7 +74,8 @@ public Result apply(TableScanNode node, Captures captures, Context context) return Result.ofPlanNode(node.withUseConnectorNodePartitioning(false)); } - int numberOfBuckets = bucketNodeMap.getBucketCount(); + int numberOfBuckets = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount) + .orElseGet(() -> nodePartitioningManager.getNodeCount(context.getSession(), partitioning.getPartitioningHandle())); int numberOfTasks = max(taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession()), 1); return Result.ofPlanNode(node diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java new file mode 100644 index 000000000000..84beb96ab1af --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneMergeSourceColumns.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableSet; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.MergeWriterNode; + +import static io.trino.sql.planner.iterative.rule.Util.restrictChildOutputs; +import static io.trino.sql.planner.plan.Patterns.merge; + +public class PruneMergeSourceColumns + implements Rule +{ + private static final Pattern PATTERN = merge(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(MergeWriterNode mergeNode, Captures captures, Context context) + { + return restrictChildOutputs(context.getIdAllocator(), mergeNode, ImmutableSet.copyOf(mergeNode.getProjectedSymbols())) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index aa4c8fff0e0a..f7361941e93f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -54,6 +54,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -624,7 +625,24 @@ public PlanWithProperties visitSimpleTableExecuteNode(SimpleTableExecuteNode nod private PlanWithProperties visitTableWriter(PlanNode node, Optional partitioningScheme, PlanNode source, PreferredProperties preferredProperties, TableWriterNode.WriterTarget writerTarget) { PlanWithProperties newSource = source.accept(this, preferredProperties); + PlanWithProperties partitionedSource = getWriterPlanWithProperties(partitioningScheme, newSource, writerTarget); + return rebaseAndDeriveProperties(node, partitionedSource); + } + + @Override + public PlanWithProperties visitMergeWriter(MergeWriterNode node, PreferredProperties preferredProperties) + { + PlanWithProperties source = node.getSource().accept(this, preferredProperties); + + Optional partitioningScheme = node.getPartitioningScheme(); + PlanWithProperties partitionedSource = getWriterPlanWithProperties(partitioningScheme, source, node.getTarget()); + + return rebaseAndDeriveProperties(node, partitionedSource); + } + + private PlanWithProperties getWriterPlanWithProperties(Optional partitioningScheme, PlanWithProperties newSource, TableWriterNode.WriterTarget writerTarget) + { if (partitioningScheme.isEmpty()) { if (scaleWriters && writerTarget.supportsReportingWrittenBytes(plannerContext.getMetadata(), session)) { partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols())); @@ -633,7 +651,6 @@ else if (redistributeWrites) { partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols())); } } - if (partitioningScheme.isPresent() && !newSource.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, plannerContext.getMetadata(), session)) { newSource = withDerivedProperties( partitionedExchange( @@ -643,7 +660,7 @@ else if (redistributeWrites) { partitioningScheme.get()), newSource.getProperties()); } - return rebaseAndDeriveProperties(node, newSource); + return newSource; } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index 45067db0775a..6db68732abed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -43,6 +43,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -600,15 +601,21 @@ public PlanWithProperties visitTableExecute(TableExecuteNode node, StreamPreferr } private PlanWithProperties visitTableWriter(PlanNode node, Optional partitioningSchemeOptional, PlanNode source, StreamPreferredProperties parentPreferences) + { + return visitPartitionedWriter(node, partitioningSchemeOptional, source, parentPreferences); + } + + private PlanWithProperties visitPartitionedWriter(PlanNode node, Optional optionalPartitioning, PlanNode source, StreamPreferredProperties parentPreferences) { if (getTaskWriterCount(session) == 1) { return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); } - if (partitioningSchemeOptional.isEmpty()) { + + if (optionalPartitioning.isEmpty()) { return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); } - PartitioningScheme partitioningScheme = partitioningSchemeOptional.get(); + PartitioningScheme partitioningScheme = optionalPartitioning.get(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION)) { // arbitrary hash function on predefined set of partition columns @@ -633,6 +640,16 @@ private PlanWithProperties visitTableWriter(PlanNode node, Optional> context) + { + MergeTarget mergeTarget = (MergeTarget) getContextTarget(context); + return new MergeWriterNode( + mergeNode.getId(), + rewriteModifyTableScan(mergeNode.getSource(), mergeTarget.getHandle()), + mergeTarget, + mergeNode.getProjectedSymbols(), + mergeNode.getPartitioningScheme(), + mergeNode.getOutputSymbols()); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context) { @@ -242,6 +258,11 @@ public WriterTarget getWriterTarget(PlanNode node) target.getSchemaTableName(), target.isReportingWrittenBytesSupported()); } + + if (node instanceof MergeWriterNode mergeWriterNode) { + return mergeWriterNode.getTarget(); + } + if (node instanceof ExchangeNode || node instanceof UnionNode) { Set writerTargets = node.getSources().stream() .map(this::getWriterTarget) @@ -277,6 +298,14 @@ private WriterTarget createWriterTarget(WriterTarget target) update.getUpdatedColumns(), update.getUpdatedColumnHandles()); } + if (target instanceof MergeTarget merge) { + MergeHandle mergeHandle = metadata.beginMerge(session, merge.getHandle()); + return new MergeTarget( + mergeHandle.getTableHandle(), + Optional.of(mergeHandle), + merge.getSchemaTableName(), + merge.getMergeParadigmAndTypes()); + } if (target instanceof TableWriterNode.RefreshMaterializedViewReference) { TableWriterNode.RefreshMaterializedViewReference refreshMV = (TableWriterNode.RefreshMaterializedViewReference) target; return new TableWriterNode.RefreshMaterializedViewTarget( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 69a827e5696d..835751c1e29a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -55,6 +55,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -492,6 +494,18 @@ public ActualProperties visitSimpleTableExecuteNode(SimpleTableExecuteNode node, .build(); } + @Override + public ActualProperties visitMergeWriter(MergeWriterNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + @Override + public ActualProperties visitMergeProcessor(MergeProcessorNode node, List inputProperties) + { + return Iterables.getOnlyElement(inputProperties).translate(symbol -> Optional.empty()); + } + @Override public ActualProperties visitJoin(JoinNode node, List inputProperties) { @@ -782,6 +796,11 @@ public ActualProperties visitRefreshMaterializedView(RefreshMaterializedViewNode @Override public ActualProperties visitTableWriter(TableWriterNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + private ActualProperties visitPartitionedWriter(List inputProperties) { ActualProperties properties = Iterables.getOnlyElement(inputProperties); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 126e19ca2cd2..558e553f7924 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -46,6 +46,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -471,6 +473,20 @@ public StreamProperties visitRefreshMaterializedView(RefreshMaterializedViewNode return StreamProperties.singleStream(); } + @Override + public StreamProperties visitMergeWriter(MergeWriterNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + + @Override + public StreamProperties visitMergeProcessor(MergeProcessorNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + @Override public StreamProperties visitTableWriter(TableWriterNode node, List inputProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 2c438897e8f9..9d3f12e7e42e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -25,6 +25,8 @@ import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.GroupIdNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PatternRecognitionNode.Measure; import io.trino.sql.planner.plan.PlanNode; @@ -430,6 +432,49 @@ public TableExecuteNode map(TableExecuteNode node, PlanNode source, PlanNodeId n node.getPreferredPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols()))); } + public MergeWriterNode map(MergeWriterNode node, PlanNode source) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + List newOutputs = map(node.getOutputSymbols()); + + return new MergeWriterNode( + node.getId(), + source, + node.getTarget(), + map(node.getProjectedSymbols()), + node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())), + newOutputs); + } + + public MergeWriterNode map(MergeWriterNode node, PlanNode source, PlanNodeId newId) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + List newOutputs = map(node.getOutputSymbols()); + + return new MergeWriterNode( + newId, + source, + node.getTarget(), + map(node.getProjectedSymbols()), + node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())), + newOutputs); + } + + public MergeProcessorNode map(MergeProcessorNode node, PlanNode source) + { + List newOutputs = map(node.getOutputSymbols()); + + return new MergeProcessorNode( + node.getId(), + source, + node.getTarget(), + map(node.getRowIdSymbol()), + map(node.getMergeRowSymbol()), + map(node.getDataColumnSymbols()), + map(node.getRedistributionColumnSymbols()), + newOutputs); + } + public PartitioningScheme map(PartitioningScheme scheme, List sourceLayout) { return new PartitioningScheme( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index f6c0880cc022..02694735a291 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -54,6 +54,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; @@ -687,6 +689,30 @@ public PlanAndMappings visitSimpleTableExecuteNode(SimpleTableExecuteNode node, mapping); } + @Override + public PlanAndMappings visitMergeWriter(MergeWriterNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + MergeWriterNode rewrittenMerge = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenMerge, mapping); + } + + @Override + public PlanAndMappings visitMergeProcessor(MergeProcessorNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + MergeProcessorNode mergeProcessorNode = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(mergeProcessorNode, mapping); + } + @Override public PlanAndMappings visitStatisticsWriterNode(StatisticsWriterNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java new file mode 100644 index 000000000000..4860632ae388 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.sql.planner.Symbol; + +import java.util.List; + +import static io.trino.sql.planner.plan.TableWriterNode.MergeTarget; +import static java.util.Objects.requireNonNull; + +/** + * The node processes the result of the Searched CASE and RIGHT JOIN + * derived from a MERGE statement. + */ +public class MergeProcessorNode + extends PlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final Symbol rowIdSymbol; + private final Symbol mergeRowSymbol; + private final List dataColumnSymbols; + private final List redistributionColumnSymbols; + private final List outputs; + + @JsonCreator + public MergeProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("rowIdSymbol") Symbol rowIdSymbol, + @JsonProperty("mergeRowSymbol") Symbol mergeRowSymbol, + @JsonProperty("dataColumnSymbols") List dataColumnSymbols, + @JsonProperty("redistributionColumnSymbols") List redistributionColumnSymbols, + @JsonProperty("outputs") List outputs) + { + super(id); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.mergeRowSymbol = requireNonNull(mergeRowSymbol, "mergeRowSymbol is null"); + this.rowIdSymbol = requireNonNull(rowIdSymbol, "rowIdSymbol is null"); + this.dataColumnSymbols = requireNonNull(dataColumnSymbols, "dataColumnSymbols is null"); + this.redistributionColumnSymbols = requireNonNull(redistributionColumnSymbols, "redistributionColumnSymbols is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty + public Symbol getMergeRowSymbol() + { + return mergeRowSymbol; + } + + @JsonProperty + public Symbol getRowIdSymbol() + { + return rowIdSymbol; + } + + @JsonProperty + public List getDataColumnSymbols() + { + return dataColumnSymbols; + } + + @JsonProperty + public List getRedistributionColumnSymbols() + { + return redistributionColumnSymbols; + } + + @JsonProperty("outputs") + @Override + public List getOutputSymbols() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitMergeProcessor(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new MergeProcessorNode(getId(), Iterables.getOnlyElement(newChildren), target, rowIdSymbol, mergeRowSymbol, dataColumnSymbols, redistributionColumnSymbols, outputs); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeWriterNode.java new file mode 100644 index 000000000000..2c82e469e109 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeWriterNode.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class MergeWriterNode + extends PlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final List projectedSymbols; + private final Optional partitioningScheme; + private final List outputs; + + @JsonCreator + public MergeWriterNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("projectedSymbols") List projectedSymbols, + @JsonProperty("partitioningScheme") Optional partitioningScheme, + @JsonProperty("outputs") List outputs) + { + super(id); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.projectedSymbols = requireNonNull(projectedSymbols, "projectedSymbols is null"); + this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty + public List getProjectedSymbols() + { + return projectedSymbols; + } + + @JsonProperty + public Optional getPartitioningScheme() + { + return partitioningScheme; + } + + /** + * Aggregate information about updated data + */ + @JsonProperty("outputs") + @Override + public List getOutputSymbols() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitMergeWriter(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new MergeWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, projectedSymbols, partitioningScheme, outputs); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index 8c7ae8026beb..6e993908de2f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -70,6 +70,11 @@ public static Pattern tableExecute() return typeOf(TableExecuteNode.class); } + public static Pattern merge() + { + return typeOf(MergeWriterNode.class); + } + public static Pattern exchange() { return typeOf(ExchangeNode.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index 340d9b1f1a08..39fab6777293 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -54,6 +54,8 @@ @JsonSubTypes.Type(value = UpdateNode.class, name = "update"), @JsonSubTypes.Type(value = TableExecuteNode.class, name = "tableExecute"), @JsonSubTypes.Type(value = SimpleTableExecuteNode.class, name = "simpleTableExecuteNode"), + @JsonSubTypes.Type(value = MergeWriterNode.class, name = "mergeWriter"), + @JsonSubTypes.Type(value = MergeProcessorNode.class, name = "mergeProcessor"), @JsonSubTypes.Type(value = TableDeleteNode.class, name = "tableDelete"), @JsonSubTypes.Type(value = TableFinishNode.class, name = "tablecommit"), @JsonSubTypes.Type(value = UnnestNode.class, name = "unnest"), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 8ec9406e2fef..acf854937e27 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -154,6 +154,16 @@ public R visitSimpleTableExecuteNode(SimpleTableExecuteNode node, C context) return visitPlan(node, context); } + public R visitMergeWriter(MergeWriterNode node, C context) + { + return visitPlan(node, context); + } + + public R visitMergeProcessor(MergeProcessorNode node, C context) + { + return visitPlan(node, context); + } + public R visitTableDelete(TableDeleteNode node, C context) { return visitPlan(node, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java index f2464741a7c5..bfb994f7a290 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java @@ -23,6 +23,7 @@ import com.google.common.collect.Iterables; import io.trino.Session; import io.trino.metadata.InsertTableHandle; +import io.trino.metadata.MergeHandle; import io.trino.metadata.Metadata; import io.trino.metadata.OutputTableHandle; import io.trino.metadata.QualifiedObjectName; @@ -31,7 +32,9 @@ import io.trino.metadata.TableLayout; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.type.Type; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.tree.Table; @@ -202,10 +205,12 @@ public PlanNode replaceChildren(List newChildren) @JsonSubTypes.Type(value = CreateTarget.class, name = "CreateTarget"), @JsonSubTypes.Type(value = InsertTarget.class, name = "InsertTarget"), @JsonSubTypes.Type(value = DeleteTarget.class, name = "DeleteTarget"), + @JsonSubTypes.Type(value = MergeTarget.class, name = "MergeTarget"), @JsonSubTypes.Type(value = UpdateTarget.class, name = "UpdateTarget"), @JsonSubTypes.Type(value = RefreshMaterializedViewTarget.class, name = "RefreshMaterializedViewTarget"), @JsonSubTypes.Type(value = TableExecuteTarget.class, name = "TableExecuteTarget"), }) + @SuppressWarnings({"EmptyClass", "ClassMayBeInterface"}) public abstract static class WriterTarget { @@ -669,4 +674,98 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session) return sourceHandle.map(tableHandle -> metadata.supportsReportingWrittenBytes(session, tableHandle)).orElse(reportingWrittenBytesSupported); } } + + public static class MergeTarget + extends WriterTarget + { + private final TableHandle handle; + private final Optional mergeHandle; + private final SchemaTableName schemaTableName; + private final MergeParadigmAndTypes mergeParadigmAndTypes; + + @JsonCreator + public MergeTarget( + @JsonProperty("handle") TableHandle handle, + @JsonProperty("mergeHandle") Optional mergeHandle, + @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + @JsonProperty("mergeParadigmAndTypes") MergeParadigmAndTypes mergeParadigmAndTypes) + { + this.handle = requireNonNull(handle, "handle is null"); + this.mergeHandle = requireNonNull(mergeHandle, "mergeHandle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.mergeParadigmAndTypes = requireNonNull(mergeParadigmAndTypes, "mergeElements is null"); + } + + @JsonProperty + public TableHandle getHandle() + { + return handle; + } + + @JsonProperty + public Optional getMergeHandle() + { + return mergeHandle; + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @JsonProperty + public MergeParadigmAndTypes getMergeParadigmAndTypes() + { + return mergeParadigmAndTypes; + } + + @Override + public String toString() + { + return handle.toString(); + } + + @Override + public boolean supportsReportingWrittenBytes(Metadata metadata, Session session) + { + return false; + } + } + + public static class MergeParadigmAndTypes + { + private final RowChangeParadigm paradigm; + private final List columnTypes; + private final Type rowIdType; + + @JsonCreator + public MergeParadigmAndTypes( + @JsonProperty("paradigm") RowChangeParadigm paradigm, + @JsonProperty("columnTypes") List columnTypes, + @JsonProperty("rowIdType") Type rowIdType) + { + this.paradigm = requireNonNull(paradigm, "paradigm is null"); + this.columnTypes = requireNonNull(columnTypes, "columnTypes is null"); + this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); + } + + @JsonProperty + public RowChangeParadigm getParadigm() + { + return paradigm; + } + + @JsonProperty + public List getColumnTypes() + { + return columnTypes; + } + + @JsonProperty + public Type getRowIdType() + { + return rowIdType; + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java index d915252de308..b323ca1e74fa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/IoPlanPrinter.java @@ -43,6 +43,7 @@ import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; import io.trino.sql.planner.plan.TableWriterNode.InsertReference; import io.trino.sql.planner.plan.TableWriterNode.InsertTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TableWriterNode.WriterTarget; import io.trino.sql.planner.planprinter.IoPlanPrinter.FormattedMarker.Bound; @@ -668,6 +669,12 @@ else if (writerTarget instanceof UpdateTarget) { target.getSchemaTableName().getSchemaName(), target.getSchemaTableName().getTableName())); } + else if (writerTarget instanceof MergeTarget target) { + context.setOutputTable(new CatalogSchemaTableName( + target.getHandle().getCatalogHandle().getCatalogName(), + target.getSchemaTableName().getSchemaName(), + target.getSchemaTableName().getTableName())); + } else if (writerTarget instanceof TableWriterNode.RefreshMaterializedViewTarget) { TableWriterNode.RefreshMaterializedViewTarget target = (TableWriterNode.RefreshMaterializedViewTarget) writerTarget; context.setOutputTable(new CatalogSchemaTableName( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 4ff2a7ff32aa..ec4152b80891 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -74,6 +74,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; @@ -1467,6 +1469,27 @@ public Void visitSimpleTableExecuteNode(SimpleTableExecuteNode node, Void contex addNode(node, "SimpleTableExecute", ImmutableMap.of("table", node.getExecuteHandle().toString())); + return null; + } + + @Override + public Void visitMergeWriter(MergeWriterNode node, Void context) + { + addNode(node, + "MergeWriter", + ImmutableMap.of("table", node.getTarget().toString())); + return processChildren(node, context); + } + + @Override + public Void visitMergeProcessor(MergeProcessorNode node, Void context) + { + NodeRepresentation nodeOutput = addNode(node, "MergeProcessor"); + nodeOutput.appendDetails("target: %s", node.getTarget()); + nodeOutput.appendDetails("merge row column: %s", node.getMergeRowSymbol()); + nodeOutput.appendDetails("row id column: %s", node.getRowIdSymbol()); + nodeOutput.appendDetails("redistribution columns: %s", node.getRedistributionColumnSymbols()); + nodeOutput.appendDetails("data columns: %s", node.getDataColumnSymbols()); return processChildren(node, context); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 907d2ffe9ec8..734bbad11c59 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -42,6 +42,8 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeProcessorNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; @@ -672,6 +674,25 @@ public Void visitTableExecute(TableExecuteNode node, Set boundSymbols) { PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child + return null; + } + + @Override + public Void visitMergeWriter(MergeWriterNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + return null; + } + + @Override + public Void visitMergeProcessor(MergeProcessorNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + + checkArgument(source.getOutputSymbols().contains(node.getRowIdSymbol()), "Invalid node. rowId symbol (%s) is not in source plan output (%s)", node.getRowIdSymbol(), node.getSource().getOutputSymbols()); + checkArgument(source.getOutputSymbols().contains(node.getMergeRowSymbol()), "Invalid node. Merge row symbol (%s) is not in source plan output (%s)", node.getMergeRowSymbol(), node.getSource().getOutputSymbols()); return null; } diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 05ecca57fa05..7189e5d7ada9 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -711,7 +711,7 @@ public enum TestingPrivilegeType EXECUTE_QUERY, VIEW_QUERY, KILL_QUERY, EXECUTE_FUNCTION, CREATE_SCHEMA, DROP_SCHEMA, RENAME_SCHEMA, - SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, COMMENT_TABLE, COMMENT_VIEW, COMMENT_COLUMN, INSERT_TABLE, DELETE_TABLE, UPDATE_TABLE, TRUNCATE_TABLE, SET_TABLE_PROPERTIES, SHOW_COLUMNS, + SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, COMMENT_TABLE, COMMENT_VIEW, COMMENT_COLUMN, INSERT_TABLE, DELETE_TABLE, MERGE_TABLE, UPDATE_TABLE, TRUNCATE_TABLE, SET_TABLE_PROPERTIES, SHOW_COLUMNS, ADD_COLUMN, DROP_COLUMN, RENAME_COLUMN, SELECT_COLUMN, CREATE_VIEW, RENAME_VIEW, DROP_VIEW, CREATE_VIEW_WITH_SELECT_COLUMNS, CREATE_MATERIALIZED_VIEW, REFRESH_MATERIALIZED_VIEW, DROP_MATERIALIZED_VIEW, RENAME_MATERIALIZED_VIEW, SET_MATERIALIZED_VIEW_PROPERTIES, diff --git a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java index 26a7a761f85f..0ba72aa0d8e8 100644 --- a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java @@ -84,6 +84,7 @@ import io.trino.sql.tree.Grant; import io.trino.sql.tree.GrantRoles; import io.trino.sql.tree.Insert; +import io.trino.sql.tree.Merge; import io.trino.sql.tree.Prepare; import io.trino.sql.tree.Query; import io.trino.sql.tree.RefreshMaterializedView; @@ -139,6 +140,7 @@ import static io.trino.spi.resourcegroups.QueryType.DESCRIBE; import static io.trino.spi.resourcegroups.QueryType.EXPLAIN; import static io.trino.spi.resourcegroups.QueryType.INSERT; +import static io.trino.spi.resourcegroups.QueryType.MERGE; import static io.trino.spi.resourcegroups.QueryType.SELECT; import static io.trino.spi.resourcegroups.QueryType.UPDATE; import static java.lang.String.format; @@ -176,6 +178,7 @@ private StatementUtils() {} .add(basicStatement(Insert.class, INSERT)) .add(basicStatement(Update.class, UPDATE)) .add(basicStatement(Delete.class, DELETE)) + .add(basicStatement(Merge.class, MERGE)) .add(basicStatement(Analyze.class, ANALYZE)) // DDL .add(dataDefinitionStatement(AddColumn.class, AddColumnTask.class)) diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index 273d7427f51f..dbcb7a0b1bde 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -42,6 +42,7 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleApplicationResult; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SortItem; @@ -468,6 +469,36 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection getUpdateLayout(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public MergeHandle beginMerge(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new UnsupportedOperationException(); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { diff --git a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java index 8c78a498c921..776d25c0f14c 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java @@ -39,6 +39,7 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.ProjectionApplicationResult; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleApplicationResult; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SortItem; @@ -469,6 +470,36 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection getUpdateLayout(Session session, TableHandle tableHandle) + { + return delegate.getUpdateLayout(session, tableHandle); + } + + @Override + public MergeHandle beginMerge(Session session, TableHandle tableHandle) + { + return delegate.beginMerge(session, tableHandle); + } + + @Override + public void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics) + { + delegate.finishMerge(session, tableHandle, fragments, computedStatistics); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java index 6003964c5dd6..ddbd6ac3719f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java +++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java @@ -33,7 +33,6 @@ import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -50,7 +49,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; -import java.util.function.ToIntFunction; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; @@ -444,15 +442,9 @@ public void testPartitionCustomPartitioning() ConnectorNodePartitioningProvider connectorNodePartitioningProvider = new ConnectorNodePartitioningProvider() { @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { - return createBucketNodeMap(2); - } - - @Override - public ToIntFunction getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) - { - throw new UnsupportedOperationException(); + return Optional.of(createBucketNodeMap(2)); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java new file mode 100644 index 000000000000..31ae14eb6c2e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -0,0 +1,244 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; +import io.trino.operator.DeleteAndInsertMergeProcessor; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.RowBlock; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import org.testng.annotations.Test; + +import java.nio.charset.Charset; +import java.util.List; +import java.util.Optional; + +import static io.trino.operator.MergeRowChangeProcessor.DEFAULT_CASE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestDeleteAndInsertMergeProcessor +{ + @Test + public void testSimpleDeletedRowMerge() + { + // target: ('Dave', 11, 'Devon'), ('Dave', 11, 'Darbyshire') + // source: ('Dave', 11, 'Darbyshire') + // merge: + // MERGE INTO target t USING source s + // ON t.customer = s.customer" + + // WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + // THEN DELETE + // expected: ('Dave', 11, 'Darbyshire') + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + Page inputPage = makePageFromBlocks( + 2, + Optional.empty(), + new Block[] { + makeLongArrayBlock(1, 1), // TransactionId + makeLongArrayBlock(1, 0), // rowId + makeIntArrayBlock(536870912, 536870912)}, // bucket + new Block[] { + makeVarcharArrayBlock("", "Dave"), // customer + makeIntArrayBlock(0, 11), // purchases + makeVarcharArrayBlock("", "Devon"), // address + makeByteArrayBlock(1, 1), // "present" boolean + makeByteArrayBlock(DEFAULT_CASE_OPERATION_NUMBER, DELETE_OPERATION_NUMBER), + makeIntArrayBlock(-1, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(1); + + // The single operation is a delete + assertThat(TINYINT.getLong(outputPage.getBlock(3), 0)).isEqualTo(DELETE_OPERATION_NUMBER); + + // Show that the row to be deleted is rowId 0, e.g. ('Dave', 11, 'Devon') + Block rowIdRow = outputPage.getBlock(4).getObject(0, Block.class); + assertThat(INTEGER.getLong(rowIdRow, 1)).isEqualTo(0); + } + + @Test + public void testUpdateAndDeletedMerge() + { + // target: ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 11, 'Darbyshire'), ('Dave', 11, 'Devon'), ('Ed', 7, 'Etherville') + // source: ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville') + // merge: + // MERGE INTO target t USING source s + // ON t.customer = s.customer" + + // WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20 + // THEN DELETE" + + // WHEN MATCHED" + + // THEN UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address)" + + // WHEN NOT MATCHED" + + // THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address) + // expected: ('Aaron', 17, 'Arches/Arches'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire/Darbyshire'), ('Ed', 14, 'Etherville/Etherville'), ('Fred', 30, 'Franklin') + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + boolean[] rowIdNulls = new boolean[] {false, true, false, false, false}; + Page inputPage = makePageFromBlocks( + 5, + Optional.of(rowIdNulls), + new Block[] { + makeLongArrayBlockWithNulls(rowIdNulls, 5, 2, 1, 2, 2), // TransactionId + makeLongArrayBlockWithNulls(rowIdNulls, 5, 0, 3, 1, 2), // rowId + makeIntArrayBlockWithNulls(rowIdNulls, 5, 536870912, 536870912, 536870912, 536870912)}, // bucket + new Block[] { + // customer + makeVarcharArrayBlock("Aaron", "Carol", "Dave", "Dave", "Ed"), + // purchases + makeIntArrayBlock(17, 9, 11, 22, 14), + // address + makeVarcharArrayBlock("Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshire", "Etherville/Etherville"), + // "present" boolean + makeByteArrayBlock(1, 0, 1, 1, 1), + // operation number: update, insert, delete, update + makeByteArrayBlock(UPDATE_OPERATION_NUMBER, INSERT_OPERATION_NUMBER, DELETE_OPERATION_NUMBER, UPDATE_OPERATION_NUMBER, UPDATE_OPERATION_NUMBER), + makeIntArrayBlock(0, 1, 2, 0, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(8); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); + // Show that the first row has address "Arches" + assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); + } + + @Test + public void testAnotherMergeCase() + { + /* + inputPage: Page[positions=5 + 0:Row[0:Long[2, 1, 2, 2], 1:Long[0, 3, 1, 2], 2:Int[536870912, 536870912, 536870912, 536870912]], + 1:Row[0:VarWidth["Aaron", "Carol", "Dave", "Dave", "Ed"], 1:Int[17, 9, 11, 22, 14], 2:VarWidth["Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshir...", "Etherville/Ethervill..."], 3:Int[1, 2, 0, 1, 1], 4:Int[3, 1, 2, 3, 3]]] +Page[positions=8 0:Dict[VarWidth["Aaron", "Dave", "Dave", "Ed", "Aaron", "Carol", "Dave", "Ed"]], 1:Dict[Int[17, 11, 22, 14, 17, 9, 22, 14]], 2:Dict[VarWidth["Arches/Arches", "Devon", "Darbyshire/Darbyshir...", "Etherville/Ethervill...", "Arches/Arches", "Centreville", "Darbyshire/Darbyshir...", "Etherville/Ethervill..."]], 3:Int[2, 2, 2, 2, 1, 1, 1, 1], 4:Row[0:Dict[Long[2, 1, 2, 2, 2, 2, 2, 2]], 1:Dict[Long[0, 3, 1, 2, 0, 0, 0, 0]], 2:Dict[Int[536870912, 536870912, 536870912, 536870912, 536870912, 536870912, 536870912, 536870912]]]] + Expected row count to be <5>, but was <7>; rows=[[Bill, 7, Buena], [Dave, 11, Devon], [Aaron, 11, Arches], [Aaron, 17, Arches/Arches], [Carol, 9, Centreville], [Dave, 22, Darbyshire/Darbyshire], [Ed, 14, Etherville/Etherville]] + */ + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + boolean[] rowIdNulls = new boolean[] {false, true, false, false, false}; + Page inputPage = makePageFromBlocks( + 5, + Optional.of(rowIdNulls), + new Block[] { + makeLongArrayBlockWithNulls(rowIdNulls, 5, 2, 1, 2, 2), // TransactionId + makeLongArrayBlockWithNulls(rowIdNulls, 5, 0, 3, 1, 2), // rowId + makeIntArrayBlockWithNulls(rowIdNulls, 5, 536870912, 536870912, 536870912, 536870912)}, // bucket + new Block[] { + // customer + makeVarcharArrayBlock("Aaron", "Carol", "Dave", "Dave", "Ed"), + // purchases + makeIntArrayBlock(17, 9, 11, 22, 14), + // address + makeVarcharArrayBlock("Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshire", "Etherville/Etherville"), + // "present" boolean + makeByteArrayBlock(1, 0, 1, 1, 0), + // operation number: update, insert, delete, update, update + makeByteArrayBlock(3, 1, 2, 3, 3), + makeIntArrayBlock(0, -1, 1, 0, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(8); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); + // Show that the first row has address "Arches/Arches" + assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); + } + + private Page makePageFromBlocks(int positionCount, Optional rowIdNulls, Block[] rowIdBlocks, Block[] mergeCaseBlocks) + { + Block[] pageBlocks = new Block[] { + RowBlock.fromFieldBlocks(positionCount, rowIdNulls, rowIdBlocks), + RowBlock.fromFieldBlocks(positionCount, Optional.empty(), mergeCaseBlocks) + }; + return new Page(pageBlocks); + } + + private DeleteAndInsertMergeProcessor makeMergeProcessor() + { + // CREATE TABLE (customer VARCHAR, purchases INTEGER, address VARCHAR) + List types = ImmutableList.of(VARCHAR, INTEGER, VARCHAR); + + RowType rowIdType = RowType.anonymous(ImmutableList.of(BIGINT, BIGINT, INTEGER)); + return new DeleteAndInsertMergeProcessor(types, rowIdType, 0, 1, ImmutableList.of(), ImmutableList.of(0, 1, 2)); + } + + private String getString(Block block, int position) + { + return VARBINARY.getSlice(block, position).toString(Charset.defaultCharset()); + } + + private LongArrayBlock makeLongArrayBlock(long... elements) + { + return new LongArrayBlock(elements.length, Optional.empty(), elements); + } + + private LongArrayBlock makeLongArrayBlockWithNulls(boolean[] nulls, int positionCount, long... elements) + { + assertThat(countNonNull(nulls) + elements.length).isEqualTo(positionCount); + return new LongArrayBlock(elements.length, Optional.of(nulls), elements); + } + + private IntArrayBlock makeIntArrayBlock(int... elements) + { + return new IntArrayBlock(elements.length, Optional.empty(), elements); + } + + private IntArrayBlock makeIntArrayBlockWithNulls(boolean[] nulls, int positionCount, int... elements) + { + assertThat(countNonNull(nulls) + elements.length).isEqualTo(positionCount); + return new IntArrayBlock(elements.length, Optional.of(nulls), elements); + } + + private int countNonNull(boolean[] nulls) + { + int count = 0; + for (int position = 0; position < nulls.length; position++) { + if (nulls[position]) { + count++; + } + } + return count; + } + + private ByteArrayBlock makeByteArrayBlock(int... elements) + { + byte[] bytes = new byte[elements.length]; + for (int index = 0; index < elements.length; index++) { + bytes[index] = (byte) elements[index]; + } + return new ByteArrayBlock(elements.length, Optional.empty(), bytes); + } + + private Block makeVarcharArrayBlock(String... elements) + { + BlockBuilder builder = VARCHAR.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), elements.length); + for (String element : elements) { + VARCHAR.writeSlice(builder, Slices.utf8Slice(element)); + } + return builder.build(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java index f55cc020b946..7eddc23d01fa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java @@ -243,16 +243,16 @@ public TestPartitioningProvider(InternalNodeManager nodeManager) } @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { if (partitioningHandle.equals(PARTITIONING_HANDLE)) { - return createBucketNodeMap(BUCKET_COUNT); + return Optional.of(createBucketNodeMap(BUCKET_COUNT)); } if (partitioningHandle.equals(SINGLE_BUCKET_HANDLE)) { - return createBucketNodeMap(1); + return Optional.of(createBucketNodeMap(1)); } if (partitioningHandle.equals(FIXED_PARTITIONING_HANDLE)) { - return createBucketNodeMap(ImmutableList.of(nodeManager.getCurrentNode())); + return Optional.of(createBucketNodeMap(ImmutableList.of(nodeManager.getCurrentNode()))); } throw new IllegalArgumentException(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java new file mode 100644 index 000000000000..4815bd64e3ad --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.connector.SchemaTableName; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.plan.MergeWriterNode; +import org.testng.annotations.Test; + +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneMergeSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneInputColumn() + { + tester().assertThat(new PruneMergeSourceColumns()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol mergeRow = p.symbol("merge_row"); + Symbol rowId = p.symbol("row_id"); + Symbol partialRows = p.symbol("partial_rows"); + Symbol fragment = p.symbol("fragment"); + return p.merge( + new SchemaTableName("schema", "table"), + p.values(a, mergeRow, rowId), + mergeRow, + rowId, + ImmutableList.of(partialRows, fragment)); + }) + .matches( + node( + MergeWriterNode.class, + strictProject( + ImmutableMap.of( + "row_id", PlanMatchPattern.expression("row_id"), + "merge_row", PlanMatchPattern.expression("merge_row")), + values("a", "merge_row", "row_id")))); + } + + @Test + public void testDoNotPruneRowId() + { + tester().assertThat(new PruneMergeSourceColumns()) + .on(p -> { + Symbol mergeRow = p.symbol("merge_row"); + Symbol rowId = p.symbol("row_id"); + Symbol partialRows = p.symbol("partial_rows"); + Symbol fragment = p.symbol("fragment"); + return p.merge( + new SchemaTableName("schema", "table"), + p.values(mergeRow, rowId), + mergeRow, + rowId, + ImmutableList.of(partialRows, fragment)); + }) + .doesNotFire(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 0c57a4566466..2dcd30ed4b67 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -30,6 +30,7 @@ import io.trino.metadata.TableHandle; import io.trino.operator.RetryPolicy; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortOrder; import io.trino.spi.predicate.TupleDomain; @@ -69,6 +70,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.MarkDistinctNode; +import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PatternRecognitionNode; @@ -90,6 +92,8 @@ import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.CreateTarget; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.MergeParadigmAndTypes; +import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TableWriterNode.WriterTarget; import io.trino.sql.planner.plan.TopNNode; @@ -127,6 +131,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -809,6 +814,29 @@ private UpdateTarget updateTarget(SchemaTableName schemaTableName, List .collect(toImmutableList())); } + public MergeWriterNode merge(SchemaTableName schemaTableName, PlanNode mergeSource, Symbol mergeRow, Symbol rowId, List outputs) + { + return new MergeWriterNode( + idAllocator.getNextId(), + mergeSource, + mergeTarget(schemaTableName), + ImmutableList.of(mergeRow, rowId), + Optional.empty(), + outputs); + } + + private MergeTarget mergeTarget(SchemaTableName schemaTableName) + { + return new MergeTarget( + new TableHandle( + TEST_CATALOG_HANDLE, + new TestingTableHandle(), + TestingTransactionHandle.create()), + Optional.empty(), + schemaTableName, + new MergeParadigmAndTypes(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW, ImmutableList.of(), INTEGER)); + } + public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) { return exchange(builder -> builder.type(ExchangeNode.Type.GATHER) diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 1427fec5ed98..06f66c949a67 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -813,7 +813,7 @@ protected Void visitIntersect(Intersect node, Integer indent) protected Void visitMerge(Merge node, Integer indent) { builder.append("MERGE INTO ") - .append(node.getTable().getName()); + .append(node.getTargetTable().getName()); node.getTargetAlias().ifPresent(value -> builder .append(' ') @@ -822,11 +822,11 @@ protected Void visitMerge(Merge node, Integer indent) append(indent + 1, "USING "); - processRelation(node.getRelation(), indent + 2); + processRelation(node.getSource(), indent + 2); builder.append("\n"); append(indent + 1, "ON "); - builder.append(formatExpression(node.getExpression())); + builder.append(formatExpression(node.getPredicate())); for (MergeCase mergeCase : node.getMergeCases()) { builder.append("\n"); diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index de314ca7a176..e9465b64037c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -560,10 +560,14 @@ public Node visitTruncateTable(SqlBaseParser.TruncateTableContext context) @Override public Node visitMerge(SqlBaseParser.MergeContext context) { + Table table = new Table(getLocation(context), getQualifiedName(context.qualifiedName())); + Relation targetRelation = table; + if (context.identifier() != null) { + targetRelation = new AliasedRelation(table, (Identifier) visit(context.identifier()), null); + } return new Merge( getLocation(context), - new Table(getLocation(context), getQualifiedName(context.qualifiedName())), - visitIfPresent(context.identifier(), Identifier.class), + targetRelation, (Relation) visit(context.relation()), (Expression) visit(context.expression()), visit(context.mergeCase(), MergeCase.class)); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index 46b78c9be354..7ac60009dd53 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -678,10 +678,9 @@ protected Void visitUpdateAssignment(UpdateAssignment node, C context) @Override protected Void visitMerge(Merge node, C context) { - process(node.getTable(), context); - node.getTargetAlias().ifPresent(target -> process(target, context)); - process(node.getRelation(), context); - process(node.getExpression(), context); + process(node.getTarget(), context); + process(node.getSource(), context); + process(node.getPredicate(), context); node.getMergeCases().forEach(mergeCase -> process(mergeCase, context)); return null; } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java index 1be838e34cc1..ab9edaeb67a2 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Join.java @@ -36,7 +36,7 @@ public Join(NodeLocation location, Type type, Relation left, Relation right, Opt this(Optional.of(location), type, left, right, criteria); } - private Join(Optional location, Type type, Relation left, Relation right, Optional criteria) + public Join(Optional location, Type type, Relation left, Relation right, Optional criteria) { super(location); requireNonNull(left, "left is null"); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/Merge.java b/core/trino-parser/src/main/java/io/trino/sql/tree/Merge.java index 0b2fb29b0135..ee07a29d0af3 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/Merge.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/Merge.java @@ -20,78 +20,78 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public final class Merge extends Statement { - private final Table table; - private final Optional targetAlias; - private final Relation relation; - private final Expression expression; + private final Relation target; + private final Relation source; + private final Expression predicate; private final List mergeCases; - public Merge( - Table table, - Optional targetAlias, - Relation relation, - Expression expression, - List mergeCases) - - { - this(Optional.empty(), table, targetAlias, relation, expression, mergeCases); - } - public Merge( NodeLocation location, - Table table, - Optional targetAlias, - Relation relation, - Expression expression, + Relation target, + Relation source, + Expression predicate, List mergeCases) { - this(Optional.of(location), table, targetAlias, relation, expression, mergeCases); + this(Optional.of(location), target, source, predicate, mergeCases); } public Merge( Optional location, - Table table, - Optional targetAlias, - Relation relation, - Expression expression, + Relation target, + Relation source, + Expression predicate, List mergeCases) { super(location); - this.table = requireNonNull(table, "table is null"); - this.targetAlias = requireNonNull(targetAlias, "targetAlias is null"); - this.relation = requireNonNull(relation, "relation is null"); - this.expression = requireNonNull(expression, "expression is null"); + // Check that the target is either a Table or an AliasedRelation + this.target = requireNonNull(target, "target is null"); + checkArgument(target instanceof Table || target instanceof AliasedRelation, "target (%s) is neither a Table nor an AliasedRelation"); + this.source = requireNonNull(source, "source is null"); + this.predicate = requireNonNull(predicate, "expression is null"); this.mergeCases = ImmutableList.copyOf(requireNonNull(mergeCases, "mergeCases is null")); } - public Table getTable() + public Relation getTarget() { - return table; + return target; } - public Optional getTargetAlias() + public Relation getSource() { - return targetAlias; + return source; } - public Relation getRelation() + public Expression getPredicate() { - return relation; + return predicate; } - public Expression getExpression() + public List getMergeCases() { - return expression; + return mergeCases; } - public List getMergeCases() + public Table getTargetTable() { - return mergeCases; + if (target instanceof Table) { + return (Table) target; + } + checkArgument(target instanceof AliasedRelation, "MERGE relation is neither a Table nor an AliasedRelation"); + return (Table) ((AliasedRelation) target).getRelation(); + } + + public Optional getTargetAlias() + { + if (target instanceof AliasedRelation) { + return Optional.of(((AliasedRelation) target).getAlias()); + } + return Optional.empty(); } @Override @@ -104,9 +104,9 @@ public R accept(AstVisitor visitor, C context) public List getChildren() { ImmutableList.Builder builder = ImmutableList.builder(); - builder.add(table); - builder.add(relation); - builder.add(expression); + builder.add(target); + builder.add(source); + builder.add(predicate); builder.addAll(mergeCases); return builder.build(); } @@ -121,27 +121,25 @@ public boolean equals(Object o) return false; } Merge merge = (Merge) o; - return Objects.equals(table, merge.table) && - Objects.equals(targetAlias, merge.targetAlias) && - Objects.equals(relation, merge.relation) && - Objects.equals(expression, merge.expression) && + return Objects.equals(target, merge.target) && + Objects.equals(source, merge.source) && + Objects.equals(predicate, merge.predicate) && Objects.equals(mergeCases, merge.mergeCases); } @Override public int hashCode() { - return Objects.hash(table, targetAlias, relation, expression, mergeCases); + return Objects.hash(target, source, predicate, mergeCases); } @Override public String toString() { return toStringHelper(this) - .add("table", table) - .add("targetAlias", targetAlias.orElse(null)) - .add("relation", relation) - .add("expression", expression) + .add("target", target) + .add("relation", source) + .add("expression", predicate) .add("mergeCases", mergeCases) .omitNullValues() .toString(); diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index d1c9ad99bf81..c69a96e1ab8c 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -1849,6 +1849,7 @@ public void testDelete() @Test public void testMerge() { + NodeLocation location = new NodeLocation(1, 1); assertStatement("" + "MERGE INTO inventory AS i " + " USING changes AS c " + @@ -1862,8 +1863,8 @@ public void testMerge() "WHEN NOT MATCHED AND c.action = 'new' " + " THEN INSERT (part, qty) VALUES (c.part, c.qty)", new Merge( - table(QualifiedName.of("inventory")), - Optional.of(new Identifier("i")), + location, + new AliasedRelation(location, table(QualifiedName.of("inventory")), new Identifier("i"), null), aliased(table(QualifiedName.of("changes")), "c"), equal(nameReference("i", "part"), nameReference("c", "part")), ImmutableList.of( diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index c24cd137a880..b98424d0b2d9 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -140,6 +140,7 @@ public enum StandardErrorCode PATH_EVALUATION_ERROR(116, USER_ERROR), INVALID_JSON_LITERAL(117, USER_ERROR), JSON_VALUE_RESULT_ERROR(118, USER_ERROR), + MERGE_TARGET_ROW_MULTIPLE_MATCHES(119, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java new file mode 100644 index 000000000000..0638f4291877 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeSink.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +public interface ConnectorMergeSink +{ + int INSERT_OPERATION_NUMBER = 1; + int DELETE_OPERATION_NUMBER = 2; + int UPDATE_OPERATION_NUMBER = 3; + + /** + * Store the page resulting from a merge. The page consists of {@code n} channels, numbered {@code 0..n-1}: + *

    + *
  • Blocks {@code 0..n-3} in page are the data columns
  • + *
  • Block {@code n-2} is the tinyint operation: + *
      + *
    • {@link #INSERT_OPERATION_NUMBER}
    • + *
    • {@link #DELETE_OPERATION_NUMBER}
    • + *
    • {@link #UPDATE_OPERATION_NUMBER}
    • + *
    + *
  • Block {@code n-1} is a connector-specific rowId column, whose handle was previously returned by + * {@link ConnectorMetadata#getMergeRowIdColumnHandle(ConnectorSession, ConnectorTableHandle) getMergeRowIdColumnHandle()} + *
  • + *
+ * @param page The page to store. + */ + void storeMergedRows(Page page); + + CompletableFuture> finish(); + + default void abort() {} +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java new file mode 100644 index 000000000000..c73c6a4f042a --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMergeTableHandle.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +public interface ConnectorMergeTableHandle +{ + /** + * This method is required because the {@link ConnectorTableHandle} returned by + * {@link ConnectorMetadata#beginMerge} is in general different than the + * one passed to that method, but the updated handle must be made + * available to {@link ConnectorMetadata#finishMerge} + * + * @return the {@link ConnectorTableHandle} returned by {@link ConnectorMetadata#beginMerge} + */ + ConnectorTableHandle getTableHandle(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index df92507eb8a5..d75259abe25d 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -698,6 +698,55 @@ default void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHa throw new TrinoException(NOT_SUPPORTED, "This connector does not support updates"); } + /** + * Return the row change paradigm supported by the connector on the table. + */ + default RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Get the column handle that will generate row IDs for the merge operation. + * These IDs will be passed to the {@link ConnectorMergeSink#storeMergedRows} + * method of the {@link ConnectorMergeSink} that created them. + */ + default ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Get the physical layout for updated or deleted rows of a MERGE operation. + * Inserted rows are handled by {@link #getInsertLayout}. + * This layout always uses the {@link #getMergeRowIdColumnHandle merge row ID column}. + */ + default Optional getUpdateLayout(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return Optional.empty(); + } + + /** + * Do whatever is necessary to start an MERGE query, returning the {@link ConnectorMergeTableHandle} + * instance that will be passed to the PageSink, and to the {@link #finishMerge} method. + */ + default ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support merges"); + } + + /** + * Finish a merge query + * @param session The session + * @param tableHandle A ConnectorMergeTableHandle for the table that is the target of the merge + * @param fragments All fragments returned by {@link UpdatablePageSource#finish()} + * @param computedStatistics Statistics for the table, meaningful only to the connector that produced them. + */ + default void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata beginMerge() is implemented without finishMerge()"); + } + /** * Create the specified view. The view definition is intended to * be serialized by the connector for permanent storage. diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNodePartitioningProvider.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNodePartitioningProvider.java index da4b951bd035..88e0168bbc01 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNodePartitioningProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNodePartitioningProvider.java @@ -16,13 +16,31 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.function.ToIntFunction; public interface ConnectorNodePartitioningProvider { - ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle); + /** + * @deprecated use {@link #getBucketNodeMapping} + */ + @Deprecated + default ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + return null; + } - ToIntFunction getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle); + default Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + return Optional.ofNullable(getBucketNodeMap(transactionHandle, session, partitioningHandle)); + } + + default ToIntFunction getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + return split -> { + throw new UnsupportedOperationException(); + }; + } BucketFunction getBucketFunction( ConnectorTransactionHandle transactionHandle, diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java index 2b7363621f26..9a0655d7396e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java @@ -13,6 +13,10 @@ */ package io.trino.spi.connector; +import io.trino.spi.TrinoException; + +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; + public interface ConnectorPageSinkProvider { ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle outputTableHandle); @@ -23,4 +27,9 @@ default ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionH { throw new IllegalArgumentException("createPageSink not supported for tableExecuteHandle"); } + + default ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support SQL MERGE operations"); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java b/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java new file mode 100644 index 000000000000..5d16c9be4920 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/MergePage.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import java.util.Optional; +import java.util.stream.IntStream; + +import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * Separate deletions and insertions pages from a merge using + * {@link RowChangeParadigm#DELETE_ROW_AND_INSERT_ROW}. + */ +public final class MergePage +{ + private final Optional deletionsPage; + private final Optional insertionsPage; + + private MergePage(Optional deletionsPage, Optional insertionsPage) + { + this.deletionsPage = requireNonNull(deletionsPage); + this.insertionsPage = requireNonNull(insertionsPage); + } + + /** + * @return delete page with data columns followed by row ID column + */ + public Optional getDeletionsPage() + { + return deletionsPage; + } + + /** + * @return insert page with data columns + */ + public Optional getInsertionsPage() + { + return insertionsPage; + } + + public static MergePage createDeleteAndInsertPages(Page inputPage, int dataColumnCount) + { + // see page description in ConnectorMergeSink + int inputChannelCount = inputPage.getChannelCount(); + if (inputChannelCount != dataColumnCount + 2) { + throw new IllegalArgumentException(format("inputPage channelCount (%s) == dataColumns size (%s) + 2", inputChannelCount, dataColumnCount)); + } + + int positionCount = inputPage.getPositionCount(); + if (positionCount <= 0) { + throw new IllegalArgumentException("positionCount should be > 0, but is " + positionCount); + } + Block operationBlock = inputPage.getBlock(inputChannelCount - 2); + + int[] deletePositions = new int[positionCount]; + int[] insertPositions = new int[positionCount]; + int deletePositionCount = 0; + int insertPositionCount = 0; + + for (int position = 0; position < positionCount; position++) { + int operation = toIntExact(TINYINT.getLong(operationBlock, position)); + switch (operation) { + case DELETE_OPERATION_NUMBER: + deletePositions[deletePositionCount] = position; + deletePositionCount++; + break; + case INSERT_OPERATION_NUMBER: + insertPositions[insertPositionCount] = position; + insertPositionCount++; + break; + default: + throw new IllegalArgumentException("Invalid merge operation: " + operation); + } + } + + Optional deletePage = Optional.empty(); + if (deletePositionCount > 0) { + int[] columns = new int[dataColumnCount + 1]; + for (int i = 0; i < dataColumnCount; i++) { + columns[i] = i; + } + columns[dataColumnCount] = dataColumnCount + 1; // row ID channel + deletePage = Optional.of(inputPage + .getColumns(columns) + .getPositions(deletePositions, 0, deletePositionCount)); + } + + Optional insertPage = Optional.empty(); + if (insertPositionCount > 0) { + insertPage = Optional.of(inputPage + .getColumns(IntStream.range(0, dataColumnCount).toArray()) + .getPositions(insertPositions, 0, insertPositionCount)); + } + + return new MergePage(deletePage, insertPage); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java b/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java new file mode 100644 index 000000000000..46805ce34b4f --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/RowChangeParadigm.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +/** + * Different connectors have different ways of representing row updates, + * imposed by the underlying storage systems. The Trino engine classifies + * these different paradigms as elements of this RowChangeParadigm + * enumeration, returned by {@link ConnectorMetadata#getRowChangeParadigm} + */ +public enum RowChangeParadigm +{ + /** + * A storage paradigm in which the connector can update individual columns + * of rows identified by a rowId. The corresponding merge processor class is + * {@code ChangeOnlyUpdatedColumnsMergeProcessor} + */ + CHANGE_ONLY_UPDATED_COLUMNS, + + /** + * A paradigm that translates a changed row into a delete by rowId, and an insert of a + * new record, which will get a new rowId when the connector writes it out. The + * corresponding merge processor class is {@code DeleteAndInsertMergeProcessor}. + */ + DELETE_ROW_AND_INSERT_ROW, +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java index 7ee8ab593075..5bd3004656f7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java @@ -24,4 +24,5 @@ public enum QueryType ANALYZE, DATA_DEFINITION, ALTER_TABLE_EXECUTE, + MERGE, } diff --git a/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java b/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java index 399b64d48d29..c4c4529421c6 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java +++ b/core/trino-spi/src/test/java/io/trino/spi/TestSpiBackwardCompatibility.java @@ -58,16 +58,9 @@ public class TestSpiBackwardCompatibility .put("123", "Method: public void io.trino.spi.predicate.BenchmarkSortedRangeSet$Data.init()") // example .put("123", "Field: public java.util.List io.trino.spi.predicate.BenchmarkSortedRangeSet$Data.ranges") - .put("377", "Constructor: public io.trino.spi.memory.MemoryPoolInfo(long,long,long,java.util.Map,java.util.Map>,java.util.Map)") - .put("383", "Method: public abstract java.lang.String io.trino.spi.function.AggregationState.value()") - .put("383", "Method: public default void io.trino.spi.security.SystemAccessControl.checkCanExecuteFunction(io.trino.spi.security.SystemSecurityContext,io.trino.spi.connector.CatalogSchemaRoutineName)") - .put("383", "Method: public default void io.trino.spi.connector.ConnectorAccessControl.checkCanExecuteFunction(io.trino.spi.connector.ConnectorSecurityContext,io.trino.spi.connector.SchemaRoutineName)") - .put("384", "Constructor: public io.trino.spi.eventlistener.QueryInputMetadata(java.lang.String,java.lang.String,java.lang.String,java.util.List,java.util.Optional,java.util.OptionalLong,java.util.OptionalLong)") - .put("386", "Method: public default java.util.stream.Stream io.trino.spi.connector.ConnectorMetadata.streamTableColumns(io.trino.spi.connector.ConnectorSession,io.trino.spi.connector.SchemaTablePrefix)") - .put("386", "Method: public default boolean io.trino.spi.connector.ConnectorMetadata.isSupportedVersionType(io.trino.spi.connector.ConnectorSession,io.trino.spi.connector.SchemaTableName,io.trino.spi.connector.PointerType,io.trino.spi.type.Type)") - .put("387", "Constructor: public io.trino.spi.eventlistener.QueryContext(java.lang.String,java.util.Optional,java.util.Set,java.util.Optional,java.util.Optional,java.util.Optional,java.util.Optional,java.util.Set,java.util.Set,java.util.Optional,java.util.Optional,java.util.Optional,java.util.Optional,java.util.Map,io.trino.spi.session.ResourceEstimates,java.lang.String,java.lang.String,java.lang.String,java.util.Optional)") - .put("388", "Method: public abstract java.util.concurrent.CompletableFuture io.trino.spi.connector.ConnectorSplitSource.getNextBatch(io.trino.spi.connector.ConnectorPartitionHandle,int)") - .put("388", "Method: public java.util.concurrent.CompletableFuture io.trino.spi.connector.FixedSplitSource.getNextBatch(io.trino.spi.connector.ConnectorPartitionHandle,int)") + // changes + .put("393", "Method: public abstract io.trino.spi.connector.ConnectorBucketNodeMap io.trino.spi.connector.ConnectorNodePartitioningProvider.getBucketNodeMap(io.trino.spi.connector.ConnectorTransactionHandle,io.trino.spi.connector.ConnectorSession,io.trino.spi.connector.ConnectorPartitioningHandle)") + .put("393", "Method: public abstract java.util.function.ToIntFunction io.trino.spi.connector.ConnectorNodePartitioningProvider.getSplitBucketFunction(io.trino.spi.connector.ConnectorTransactionHandle,io.trino.spi.connector.ConnectorSession,io.trino.spi.connector.ConnectorPartitioningHandle)") .build(); @Test diff --git a/docs/src/main/sphinx/develop.rst b/docs/src/main/sphinx/develop.rst index fad497456e27..b4a6799ac524 100644 --- a/docs/src/main/sphinx/develop.rst +++ b/docs/src/main/sphinx/develop.rst @@ -12,6 +12,7 @@ This guide is intended for Trino contributors and plugin developers. develop/example-http develop/insert develop/delete-and-update + develop/supporting-merge develop/types develop/functions develop/table-functions diff --git a/docs/src/main/sphinx/develop/supporting-merge.rst b/docs/src/main/sphinx/develop/supporting-merge.rst new file mode 100644 index 000000000000..10d75bceb05f --- /dev/null +++ b/docs/src/main/sphinx/develop/supporting-merge.rst @@ -0,0 +1,415 @@ +==================== +Supporting ``MERGE`` +==================== + +The Trino engine provides APIs to support row-level SQL ``MERGE``. +To implement ``MERGE``, a connector must provide an implementation +of ``ConnectorMergeSink``, which is typically layered on top of a +``ConnectorPageSink``, and define ``ConnectorMetadata`` +methods to get a "rowId" column handle; get the row change paradigm; +and to start and complete the ``MERGE`` operation. + +Standard SQL ``MERGE`` +---------------------- + +Different query engines support varying definitions of SQL ``MERGE``. +Trino supports the strict SQL specification ``ISO/IEC 9075``, published +in 2016. As a simple example, given tables ``target_table`` and +``source_table`` defined as:: + + CREATE TABLE accounts ( + customer VARCHAR, + purchases DECIMAL, + address VARCHAR); + INSERT INTO accounts (customer, purchases, address) VALUES ...; + CREATE TABLE monthly_accounts_update ( + customer VARCHAR, + purchases DECIMAL, + address VARCHAR); + INSERT INTO monthly_accounts_update (customer, purchases, address) VALUES ...; + +Here is a possible ``MERGE`` operation, from ``monthly_accounts_update`` to +``accounts``:: + + MERGE INTO accounts t USING monthly_accounts_update s + ON (t.customer = s.customer) + WHEN MATCHED AND s.address = 'Berkeley' THEN + DELETE + WHEN MATCHED AND s.customer = 'Joe Shmoe' THEN + UPDATE SET purchases = purchases + 100.0 + WHEN MATCHED THEN + UPDATE + SET purchases = s.purchases + t.purchases, address = s.address + WHEN NOT MATCHED THEN + INSERT (customer, purchases, address) + VALUES (s.customer, s.purchases, s.address); + +SQL ``MERGE`` tries to match each ``WHEN`` clause in source order. When +a match is found, the corresponding ``DELETE``, ``INSERT`` or ``UPDATE`` +is executed and subsequent ``WHEN`` clauses are ignored. + +SQL ``MERGE`` supports two operations on the target table and source +when a row from the source table or query matches a row in the target table: + +* ``UPDATE``, in which the columns in the target row are updated. +* ``DELETE``, in which the target row is deleted. + +In the ``NOT MATCHED`` case, SQL ``MERGE`` supports only ``INSERT`` +operations. The values inserted are arbitrary but usually come from +the unmatched row of the source table or query. + +``RowChangeParadigm`` +--------------------- + +Different connectors have different ways of representing row updates, +imposed by the underlying storage systems. The Trino engine classifies +these different paradigms as elements of the ``RowChangeParadigm`` +enumeration, returned by enumeration, returned by method +``ConnectorMetadata.getRowChangeParadigm(...)``. + +The ``RowChangeParadigm`` enumeration values are: + +* ``CHANGE_ONLY_UPDATED_COLUMNS``, intended for connectors that can update + individual columns of rows identified by a ``rowId``. The corresponding + merge processor class is ``ChangeOnlyUpdatedColumnsMergeProcessor``. +* ``DELETE_ROW_AND_INSERT_ROW``, intended for connectors that represent a + row change as a row deletion paired with a row insertion. The corresponding + merge processor class is ``DeleteAndInsertMergeProcessor``. + +Overview of ``MERGE`` processing +-------------------------------- + +A ``MERGE`` statement is processed by creating a ``RIGHT JOIN`` between the +target table and the source, on the ``MERGE`` criteria. The source may be +a table or an arbitrary query. For each row in the source table or query, +``MERGE`` produces a ``ROW`` object containing: + +* the data column values from the ``UPDATE`` or ``INSERT`` cases. For the + ``DELETE`` cases, only the partition columns, which determine + partitioning and bucketing, are non-NULL. +* a boolean column containing ``true`` for source rows that matched some + target row, and ``false`` otherwise. +* an integer that identifies whether the merge case operation is ``UPDATE``, + ``DELETE`` or ``INSERT``, or a source row for which no case matched. If a + source row does not match any merge case, all data column values except + those that determine distribution are null, and the operation number + is -1. + +A ``SearchedCaseExpression`` is constructed from ``RIGHT JOIN`` result +to represent the ``WHEN`` clauses of the ``MERGE``. In the example above +the ``MERGE`` is executed as if the ``SearchedCaseExpression`` were written as:: + + SELECT + CASE + WHEN present AND s.address = 'Berkeley' THEN + -- Null values for delete; present=true; operation DELETE=2, case_number=0 + row(null, null, null, false, 2, 0) + WHEN present AND s.customer = 'Joe Shmoe' THEN + -- Update column values; present=true; operation UPDATE=3, case_number=1 + row(t.customer, t.purchases + 100.0, t.address, true, 3, 1) + WHEN present THEN + -- Update column values; present=true; operation UPDATE=3, case_number=2 + row(t.customer, s.purchases + t.purchases, s.address, true, 3, 2) + WHEN (present IS NULL) THEN + -- Insert column values; present=false; operation INSERT=1, case_number=3 + row(s.customer, s.purchases, s.address, false, 1, 3) + ELSE + -- Null values for no case matched; present=false; operation=-1, + -- case_number=-1 + row(null, null, null, false, -1, -1) + END + FROM (SELECT *, true AS present FROM target_table) t + RIGHT JOIN source_table s ON s.customer = t.customer; + +The Trino engine executes the ``RIGHT JOIN`` and ``CASE`` expression, +and ensures that no target table row matches more than one source expression +row, and ultimately creates a sequence of pages to be routed to the node that +runs the ``ConnectorMergeSink.storeMergedRows(...)`` method. + +Like ``DELETE`` and ``UPDATE``, ``MERGE`` target table rows are identified by +a connector-specific ``rowId`` column handle. For ``MERGE``, the ``rowId`` +handle is returned by ``ConnectorMetadata.getMergeRowIdColumnHandle(...)``. + +``MERGE`` Redistribution +------------------------ + +The Trino ``MERGE`` implementation allows ``UPDATE`` to change +the values of columns that determine partitioning and/or bucketing, and so +it must "redistribute" rows from the ``MERGE`` operation to the worker +nodes responsible for writing rows with the merged partitioning and/or +bucketing columns. + +Since the ``MERGE`` process in general requires redistribution of +merged rows among Trino nodes, the order of rows in pages to be stored +are indeterminate. Connectors like Hive that depend on an ascending +rowId order for deleted rows must sort the deleted rows before storing +them. + +To ensure that all inserted rows for a given partition end up on a +single node, the redistribution hash on the partition key/bucket column(s) +is applied to the page partition key(s). As a result of the hash, all +rows for a specific partition/bucket hash together, whether they +were ``MATCHED`` rows or ``NOT MATCHED`` rows. + +For connectors whose ``RowChangeParadigm`` is ``DELETE_ROW_AND_INSERT_ROW``, +inserted rows are distributed using the layout supplied by +``ConnectorMetadata.getInsertLayout()``. For some connectors, the same +layout is used for updated rows. Other connectors require a special +layout for updated rows, supplied by ``ConnectorMetadata.getUpdateLayout()``. + +Connector support for ``MERGE`` +=============================== + +To start ``MERGE`` processing, the Trino engine calls: + +* ``ConnectorMetadata.getMergeRowIdColumnHandle(...)`` to get the + ``rowId`` column handle. +* ``ConnectorMetadata.getRowChangeParadigm(...)`` to get the paradigm + supported by the connector for changing existing table rows. +* ``ConnectorMetadata.beginMerge(...)`` to get the a + ``ConnectorMergeTableHandle`` for the merge operation. That + ``ConnectorMergeTableHandle`` object contains whatever information the + connector needs to specify the ``MERGE`` operation. +* ``ConnectorMetadata.getInsertLayout(...)``, from which it extracts the + the list of partition or table columns that impact write redistribution. +* ``ConnectorMetadata.getUpdateLayout(...)``. If that layout is non-empty, + it is used to distribute updated rows resulting from the ``MERGE`` + operation. + +On nodes that are targets of the hash, the Trino engine calls +``ConnectorPageSinkProvider.createMergeSink(...)`` to create a +``ConnectorMergeSink``. + +To write out each page of merged rows, the Trino engine calls +``ConnectorMergeSink.storeMergedRows(Page)``. The ``storeMergedRows(Page)`` +method iterates over the rows in the page, performing updates and deletes +in the ``MATCHED`` cases, and inserts in the ``NOT MATCHED`` cases. + +For some ``RowChangeParadigm``s, ``UPDATE`` operations translated into the +corresponding ``DELETE`` and ``INSERT`` operations before +``storeMergedRows(Page)`` is called. + +To complete the ``MERGE`` operation, the Trino engine calls +``ConnectorMetadata.finishMerge(...)``, passing the table handle +and a collection of JSON objects encoded as ``Slice`` instances. These +objects contain connector-specific information specifying what was changed +by the ``MERGE`` operation. Typically this JSON object contains the files +written and table and partition statistics generated by the ``MERGE`` +operation. The connector takes appropriate actions, if any. + +``RowChangeProcessor`` implementation for ``MERGE`` +--------------------------------------------------- + +In the ``MERGE`` implementation, each ``RowChangeParadigm`` +corresponds to an internal Trino engine class that implements interface +``RowChangeProcessor``. ``RowChangeProcessor`` has one interesting method: +``Page transformPage(Page)``. The format of the output page depends +on the ``RowChangeParadigm``. + +The connector has no access to the ``RowChangeProcessor`` instance -- it +is used inside the Trino engine to transform the merge page rows into rows +to be stored, based on the connector's choice of ``RowChangeParadigm``. + +The page supplied to ``transformPage()`` consists of: + +* The write redistribution columns if any +* For partitioned or bucketed tables, a long hash value column. +* The ``rowId`` column for the row from the target table if matched, or + null if not matched +* The merge case ``RowBlock`` +* The integer case number block +* The byte is_distinct block, with value 0 if not distinct. + +The merge case ``RowBlock`` has the following layout: + +* Blocks for each column in the table, including partition columns, in + table column order. +* A block containing the boolean "present" value which is true if the + source row matched a target row, and false otherwise. +* A block containing the ``MERGE`` case operation number, encoded as + ``INSERT`` = 1, ``DELETE`` = 2, ``UPDATE`` = 3 and if no ``MERGE`` + case matched, -1. +* A block containing the number, starting with 0, for the + ``WHEN`` clause that matched for the row, or -1 if no clause + matched. + +The page returned from ``transformPage`` consists of: + +* All table columns, in table column order. +* The merge case operation block. +* The rowId block. +* A byte block containing 1 if the row is an insert derived from an + update operation, and 0 otherwise. This block is used to correctly + calculate the count of rows changed for connectors that represent + updates and deletes plus inserts. + +``transformPage`` +must ensure that there are no rows whose operation number is -1 in +the page it returns. + +Detecting duplicate matching target rows +---------------------------------------- + +The SQL ``MERGE`` specification requires that in each ``MERGE`` case, +a single target table row must match at most one source row, after +applying the ``MERGE`` case condition expression. The first step +toward finding these error is done by labeling each row in the target +table with a unique id, using an ``AssignUniqueId`` node above the +target table scan. The projected results from the ``RIGHT JOIN`` +have these unique ids for matched target table rows as well as +the ``WHEN`` clause number. A ``MarkDistinct`` node adds an +"is_distinct" column which is true if no other row has the same +unique id and ``WHEN`` clause number, and false otherwise. If +any row has "is_distinct" = false, a +``MERGE_TARGET_ROW_MULTIPLE_MATCHES`` exception is raised and +the ``MERGE`` operation fails. + +``ConnectorMergeTableHandle`` API +--------------------------------- + +Interface ``ConnectorMergeTableHandle`` defines one method, +``getTableHandle()`` to retrieve the ``ConnectorTableHandle`` +originally passed to ``ConnectorMetadata.beginMerge()``. + +``ConnectorPageSinkProvider`` API +--------------------------------- + +To support SQL ``MERGE``,, ``ConnectorPageSinkProvider`` must implement +the method that creates the ``ConnectorMergeSink``: + +* ``createMergeSink``:: + + ConnectorMergeSink createMergeSink( + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + ConnectorMergeTableHandle mergeHandle) + +``ConnectorMergeSink`` API +-------------------------- + +As mentioned above, to support ``MERGE``, the connector must define an +implementation of ``ConnectorMergeSink``, usually layered over the +connector's ``ConnectorPageSink``. + +The ``ConnectorMergeSink`` is created by a call to +``ConnectorPageSinkProvider.createMergeSink()``. + +The only interesting methods are: + +* ``storeMergedRows``:: + + void storeMergedRows(Page page) + + The Trino engine calls the ``storeMergedRows(Page)`` method of the + ``ConnectorMergeSink`` instance returned by + ``ConnectorPageSinkProvider.createMergeSink()``, passing the page + generated by the ``RowChangeProcessor.transformPage()`` method. + That page consists of all table columns, in table column order, + followed by the rowId column, followed by the operation column + from the merge case ``RowBlock``. + + The job of ``storeMergedRows()`` is iterate over the rows in the page, + and process them based on the value of the operation column, ``INSERT``, + ``DELETE``, ``UPDATE``, or ignore the row. By choosing appropriate + paradigm, the connector can request that the UPDATE operation be + transformed into ``DELETE`` and ``INSERT`` operations. + +* ``finish``:: + + ``CompletableFuture> finish()`` + + The Trino engine calls ``finish()`` when all the data has been processed by + a specific ``ConnectorMergeSink`` instance. The connector returns a future + containing a collection of ``Slice``, representing connector-specific + information about the rows processed. Usually this includes the row count, + and might include information like the files or partitions created or + changed. + +``ConnectorMetadata`` ``MERGE`` API +----------------------------------- + +A connector implementing ``MERGE`` must implement these ``ConnectorMetadata`` +methods. + +* ``getRowChangeParadigm()``:: + + RowChangeParadigm getRowChangeParadigm( + ConnectorSession session, + ConnectorTableHandle tableHandle) + + This method is called as the engine starts processing a ``MERGE`` statement. + The connector must return a ``RowChangeParadigm`` enum instance. If the + connector does not support ``MERGE`` it should throw a ``NOT_SUPPORTED`` + exception, meaning that SQL ``MERGE`` is not supported by the connector. + +* ``getMergeRowIdColumnHandle()``:: + + ColumnHandle getMergeRowIdColumnHandle( + ConnectorSession session, + ConnectorTableHandle tableHandle) + + This method is called in the early stages of query planning for ``MERGE`` + statements. The ColumnHandle returned provides the ``rowId`` used by the + connector to identify rows to be merged, as well as any other fields of + the row that the connector needs to complete the ``MERGE`` operation. + +* ``getInsertLayout()``:: + + Optional getInsertLayout( + ConnectorSession session, + ConnectorTableHandle tableHandle) + + This method is called during query planning to get the table layout to be + used for rows inserted by the ``MERGE`` operation. For some connectors, + this layout will be used for rows deleted as well. + +* ``getUpdateLayout()``:: + + Optional getUpdateLayout( + ConnectorSession session, + ConnectorTableHandle tableHandle) + + This method is called during query planning to get the table layout to + be used for rows deleted by the ``MERGE`` operation. If the optional + return value is present, the Trino engine will use the layout for + updated rows. Otherwise, it will use the result of + ``ConnectorMetadata.getInsertLayout`` to distribute updated rows. + +* ``beginMerge()``:: + + ConnectorMergeTableHandle beginMerge( + ConnectorSession session, + ConnectorTableHandle tableHandle, + MergeDetails mergeDetails) + + As the last step in creating the ``MERGE`` execution plan, the connector's + ``beginMerge()`` method is called, passing the ``session``, the + ``tableHandle`` and the ``MergeDetails`` object. + + ``beginMerge()`` performs any orchestration needed in the connector to + start processing the ``MERGE``. This orchestration varies from connector + to connector. In the case of Hive connector operating on ACID tables, + for example, ``beginMerge()`` checks that the table is transactional and + that all updated columns are writable, and starts a Hive Metastore + transaction. + + ``beginMerge()`` returns a ``ConnectorMergeTableHandle`` with any added + information the connector needs when the handle is passed back to + ``finishMerge()`` and the split generation machinery. For most + connectors, the returned table handle contains at least a flag identifying + the table handle as a table handle for a ``MERGE`` operation. + +* ``finishMerge()``:: + + void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle tableHandle, + Collection fragments) + + During ``MERGE`` processing, the Trino engine accumulates the ``Slice`` + collections returned by ``ConnectorMergeSink.finish()``. The engine calls + ``finishMerge()``, passing the table handle and that collection of + ``Slice`` fragments. In response, the connector takes appropriate actions + to complete the ``MERGE`` operation. Those actions might include + committing an underlying transaction (if any) or freeing any other + resources. diff --git a/docs/src/main/sphinx/sql.rst b/docs/src/main/sphinx/sql.rst index 1dc40a81fe11..d89f1ad5014c 100644 --- a/docs/src/main/sphinx/sql.rst +++ b/docs/src/main/sphinx/sql.rst @@ -42,6 +42,7 @@ Trino also provides :doc:`numerous SQL functions and operators`. sql/grant-roles sql/insert sql/match-recognize + sql/merge sql/pattern-recognition-in-window sql/prepare sql/refresh-materialized-view diff --git a/docs/src/main/sphinx/sql/merge.rst b/docs/src/main/sphinx/sql/merge.rst new file mode 100644 index 000000000000..d75f41167006 --- /dev/null +++ b/docs/src/main/sphinx/sql/merge.rst @@ -0,0 +1,97 @@ +===== +MERGE +===== + +Synopsis +-------- + +.. code-block:: text + + MERGE INTO target_table [ [ AS ] target_alias ] + USING { source_table | query } [ [ AS ] source_alias ] + ON search_condition + when_clause [...] + +where ``when_clause`` is one of + +.. code-block:: text + + WHEN MATCHED [ AND condition ] + THEN DELETE + +.. code-block:: text + + WHEN MATCHED [ AND condition ] + THEN UPDATE SET ( column = expression [, ...] ) + +.. code-block:: text + + WHEN NOT MATCHED [ AND condition ] + THEN INSERT [ column_list ] VALUES (expression, ...) + +Description +----------- + +Conditionally update and/or delete rows of a table and/or insert new +rows into a table. + +``MERGE`` supports an arbitrary number of ``WHEN`` clauses with different +``MATCHED`` conditions, executing the ``DELETE``, ``UPDATE`` or ``INSERT`` +operation in the first ``WHEN`` clause selected by the ``MATCHED`` +state and the match condition. + +For each source row, the ``WHEN`` clauses are processed in order. Only +the first first matching ``WHEN`` clause is executed and subsequent clauses +are ignored. A ``MERGE_TARGET_ROW_MULTIPLE_MATCHES`` exception is +raised when a single target table row matches more than one source row. + +If a source row is not matched by any ``WHEN`` clause and there is no +``WHEN NOT MATCHED`` clause, the source row is ignored. + +In ``WHEN`` clauses with ``UPDATE`` operations, the column value expressions +can depend on any field of the target or the source. In the ``NOT MATCHED`` +case, the ``INSERT`` expressions can depend on any field of the source. + +Examples +-------- + +Delete all customers mentioned in the source table:: + + MERGE INTO accounts t USING monthly_accounts_update s + ON t.customer = s.customer + WHEN MATCHED + THEN DELETE + +For matching customer rows, increment the purchases, and if there is no +match, insert the row from the source table:: + + MERGE INTO accounts t USING monthly_accounts_update s + ON (t.customer = s.customer) + WHEN MATCHED + THEN UPDATE SET purchases = s.purchases + t.purchases + WHEN NOT MATCHED + THEN INSERT (customer, purchases, address) + VALUES(s.customer, s.purchases, s.address) + +``MERGE`` into the target table from the source table, deleting any matching +target row for which the source address is Centreville. For all other +matching rows, add the source purchases and set the address to the source +address, if there is no match in the target table, insert the source +table row:: + + MERGE INTO accounts t USING monthly_accounts_update s + ON (t.customer = s.customer) + WHEN MATCHED AND s.address = 'Centreville' + THEN DELETE + WHEN MATCHED + THEN UPDATE + SET purchases = s.purchases + t.purchases, address = s.address + WHEN NOT MATCHED + THEN INSERT (customer, purchases, address) + VALUES(s.customer, s.purchases, s.address) + +Limitations +----------- + +Some connectors have limited or no support for ``MERGE``. +See connector documentation for more details. diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java new file mode 100644 index 000000000000..b0713e31bae8 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMergeSink.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.classloader; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorMergeSink; + +import javax.inject.Inject; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +public class ClassLoaderSafeConnectorMergeSink + implements ConnectorMergeSink +{ + private final ConnectorMergeSink delegate; + private final ClassLoader classLoader; + + @Inject + public ClassLoaderSafeConnectorMergeSink(@ForClassLoaderSafe ConnectorMergeSink delegate, ClassLoader classLoader) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public void storeMergedRows(Page page) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.storeMergedRows(page); + } + } + + @Override + public CompletableFuture> finish() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.finish(); + } + } + + @Override + public void abort() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.abort(); + } + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 2b34ae96a163..963aa2dda83a 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -25,6 +25,7 @@ import io.trino.spi.connector.ConnectorAnalyzeMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -49,6 +50,7 @@ import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SampleApplicationResult; import io.trino.spi.connector.SampleType; import io.trino.spi.connector.SchemaTableName; @@ -999,6 +1001,36 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan } } + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getMergeRowIdColumnHandle(session, tableHandle); + } + } + + @Override + public Optional getUpdateLayout(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return delegate.getUpdateLayout(session, tableHandle); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.beginMerge(session, tableHandle, retryMode); + } + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.finishMerge(session, tableHandle, fragments, computedStatistics); + } + } + @Override public Optional redirectTable(ConnectorSession session, SchemaTableName tableName) { @@ -1015,6 +1047,14 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable } } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getRowChangeParadigm(session, tableHandle); + } + } + @Override public boolean supportsReportingWrittenBytes(ConnectorSession session, SchemaTableName schemaTableName, Map tableProperties) { @@ -1030,4 +1070,11 @@ public boolean supportsReportingWrittenBytes(ConnectorSession session, Connector return delegate.supportsReportingWrittenBytes(session, connectorTableHandle); } } + + @Override + protected Object clone() + throws CloneNotSupportedException + { + return super.clone(); + } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java index 88ab689cca07..1600b92efc90 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java @@ -15,6 +15,8 @@ import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -62,4 +64,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, tableExecuteHandle), classLoader); } } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return new ClassLoaderSafeConnectorMergeSink(delegate.createMergeSink(transactionHandle, session, mergeHandle), classLoader); + } + } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeNodePartitioningProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeNodePartitioningProvider.java index e512ecc2bfe7..50d158427664 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeNodePartitioningProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeNodePartitioningProvider.java @@ -26,6 +26,7 @@ import javax.inject.Inject; import java.util.List; +import java.util.Optional; import java.util.function.ToIntFunction; import static java.util.Objects.requireNonNull; @@ -56,6 +57,14 @@ public BucketFunction getBucketFunction( } } + @Override + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getBucketNodeMapping(transactionHandle, session, partitioningHandle); + } + } + @Override public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java index 448db30b079a..30e324b509ab 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java @@ -14,6 +14,7 @@ package io.trino.plugin.base.classloader; import io.trino.spi.connector.ConnectorAccessControl; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPageSink; @@ -46,6 +47,7 @@ public void test() { testClassLoaderSafe(ConnectorAccessControl.class, ClassLoaderSafeConnectorAccessControl.class); testClassLoaderSafe(ConnectorMetadata.class, ClassLoaderSafeConnectorMetadata.class); + testClassLoaderSafe(ConnectorMergeSink.class, ClassLoaderSafeConnectorMergeSink.class); testClassLoaderSafe(ConnectorPageSink.class, ClassLoaderSafeConnectorPageSink.class); testClassLoaderSafe(ConnectorPageSinkProvider.class, ClassLoaderSafeConnectorPageSinkProvider.class); testClassLoaderSafe(ConnectorPageSourceProvider.class, ClassLoaderSafeConnectorPageSourceProvider.class); diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleConnectorFactory.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleConnectorFactory.java index 7919fd63c9a3..92d0d0978c6d 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleConnectorFactory.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleConnectorFactory.java @@ -45,7 +45,7 @@ public Connector create(String catalogName, Map requiredConfig, new BlackHoleSplitManager(), new BlackHolePageSourceProvider(executorService), new BlackHolePageSinkProvider(executorService), - new BlackHoleNodePartitioningProvider(context.getNodeManager(), context.getTypeManager().getTypeOperators()), + new BlackHoleNodePartitioningProvider(context.getTypeManager().getTypeOperators()), context.getTypeManager(), executorService); } diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleNodePartitioningProvider.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleNodePartitioningProvider.java index d6565f73f916..8d7b1b3ffeea 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleNodePartitioningProvider.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleNodePartitioningProvider.java @@ -13,61 +13,33 @@ */ package io.trino.plugin.blackhole; -import io.trino.spi.NodeManager; -import io.trino.spi.TrinoException; import io.trino.spi.connector.BucketFunction; -import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import java.lang.invoke.MethodHandle; import java.util.List; -import java.util.function.ToIntFunction; import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static java.util.Objects.requireNonNull; public class BlackHoleNodePartitioningProvider implements ConnectorNodePartitioningProvider { - private final NodeManager nodeManager; private final TypeOperators typeOperators; - public BlackHoleNodePartitioningProvider(NodeManager nodeManager, TypeOperators typeOperators) + public BlackHoleNodePartitioningProvider(TypeOperators typeOperators) { - this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeOperators = typeOperators; } - @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) - { - // create one bucket per node - return createBucketNodeMap(nodeManager.getRequiredWorkerNodes().size()); - } - - @Override - public ToIntFunction getSplitBucketFunction( - ConnectorTransactionHandle transactionHandle, - ConnectorSession session, - ConnectorPartitioningHandle partitioningHandle) - { - return value -> { - throw new TrinoException(NOT_SUPPORTED, "Black hole connector does not supported distributed reads"); - }; - } - @Override public BucketFunction getBucketFunction( ConnectorTransactionHandle transactionHandle, diff --git a/plugin/trino-delta-lake/pom.xml b/plugin/trino-delta-lake/pom.xml index 8ae6bd43e604..cd3626911eb5 100644 --- a/plugin/trino-delta-lake/pom.xml +++ b/plugin/trino-delta-lake/pom.xml @@ -163,6 +163,11 @@ failsafe + + org.roaringbitmap + RoaringBitmap + + org.weakref jmxutils diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnHandle.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnHandle.java index ca33b533a641..6e3715515c46 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnHandle.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeColumnHandle.java @@ -27,6 +27,8 @@ import static io.trino.plugin.deltalake.DeltaHiveTypeTranslator.toHiveType; import static io.trino.plugin.deltalake.DeltaLakeColumnType.SYNTHESIZED; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.RowType.field; +import static io.trino.spi.type.RowType.rowType; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.util.Objects.requireNonNull; @@ -39,6 +41,11 @@ public class DeltaLakeColumnHandle public static final String ROW_ID_COLUMN_NAME = "$row_id"; public static final Type ROW_ID_COLUMN_TYPE = BIGINT; + public static final Type MERGE_ROW_ID_TYPE = rowType( + field("path", VARCHAR), + field("position", BIGINT), + field("partition", VARCHAR)); + public static final String PATH_COLUMN_NAME = "$path"; public static final Type PATH_TYPE = VARCHAR; diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeResult.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeResult.java new file mode 100644 index 000000000000..2f5b410b3c45 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeResult.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class DeltaLakeMergeResult +{ + private final Optional oldFile; + private final Optional newFile; + + @JsonCreator + public DeltaLakeMergeResult(Optional oldFile, Optional newFile) + { + this.oldFile = requireNonNull(oldFile, "oldFile is null"); + this.newFile = requireNonNull(newFile, "newFile is null"); + checkArgument(oldFile.isPresent() || newFile.isPresent(), "old or new must be present"); + } + + @JsonProperty + public Optional getOldFile() + { + return oldFile; + } + + @JsonProperty + public Optional getNewFile() + { + return newFile; + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java new file mode 100644 index 000000000000..f2f2cb7e7fb0 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -0,0 +1,339 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.common.collect.ImmutableList; +import io.airlift.json.JsonCodec; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.parquet.ParquetReaderOptions; +import io.trino.parquet.writer.ParquetSchemaConverter; +import io.trino.parquet.writer.ParquetWriterOptions; +import io.trino.plugin.hive.FileFormatDataSourceStats; +import io.trino.plugin.hive.FileWriter; +import io.trino.plugin.hive.HdfsEnvironment; +import io.trino.plugin.hive.HdfsEnvironment.HdfsContext; +import io.trino.plugin.hive.ReaderPageSource; +import io.trino.plugin.hive.parquet.ParquetFileWriter; +import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.MergePage; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.Type; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.joda.time.DateTimeZone; +import org.roaringbitmap.longlong.ImmutableLongBitmapDataProvider; +import org.roaringbitmap.longlong.LongBitmapDataProvider; +import org.roaringbitmap.longlong.Roaring64Bitmap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.stream.IntStream; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.json.JsonCodec.listJsonCodec; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; +import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_WRITE; +import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getCompressionCodec; +import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getParquetWriterBlockSize; +import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getParquetWriterPageSize; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.connector.MergePage.createDeleteAndInsertPages; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.util.Objects.requireNonNull; +import static java.util.UUID.randomUUID; +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class DeltaLakeMergeSink + implements ConnectorMergeSink +{ + private static final JsonCodec> PARTITIONS_CODEC = listJsonCodec(String.class); + + private final HdfsEnvironment hdfsEnvironment; + private final ConnectorSession session; + private final DateTimeZone parquetDateTimeZone; + private final String trinoVersion; + private final JsonCodec dataFileInfoCodec; + private final JsonCodec mergeResultJsonCodec; + private final DeltaLakeWriterStats writerStats; + private final String rootTableLocation; + private final ConnectorPageSink insertPageSink; + private final List dataColumns; + private final int tableColumnCount; + private final Map fileDeletions = new HashMap<>(); + + public DeltaLakeMergeSink( + HdfsEnvironment hdfsEnvironment, + ConnectorSession session, + DateTimeZone parquetDateTimeZone, + String trinoVersion, + JsonCodec dataFileInfoCodec, + JsonCodec mergeResultJsonCodec, + DeltaLakeWriterStats writerStats, + String rootTableLocation, + ConnectorPageSink insertPageSink, + List tableColumns) + { + this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.session = requireNonNull(session, "session is null"); + this.parquetDateTimeZone = requireNonNull(parquetDateTimeZone, "parquetDateTimeZone is null"); + this.trinoVersion = requireNonNull(trinoVersion, "trinoVersion is null"); + this.dataFileInfoCodec = requireNonNull(dataFileInfoCodec, "dataFileInfoCodec is null"); + this.mergeResultJsonCodec = requireNonNull(mergeResultJsonCodec, "mergeResultJsonCodec is null"); + this.writerStats = requireNonNull(writerStats, "writerStats is null"); + this.rootTableLocation = requireNonNull(rootTableLocation, "rootTableLocation is null"); + this.insertPageSink = requireNonNull(insertPageSink, "insertPageSink is null"); + requireNonNull(tableColumns, "tableColumns is null"); + this.tableColumnCount = tableColumns.size(); + this.dataColumns = tableColumns.stream() + .filter(column -> column.getColumnType() == REGULAR) + .collect(toImmutableList()); + } + + @Override + public void storeMergedRows(Page page) + { + MergePage mergePage = createDeleteAndInsertPages(page, tableColumnCount); + + mergePage.getInsertionsPage().ifPresent(insertPageSink::appendPage); + + mergePage.getDeletionsPage().ifPresent(deletions -> { + ColumnarRow rowIdRow = toColumnarRow(deletions.getBlock(deletions.getChannelCount() - 1)); + + for (int position = 0; position < rowIdRow.getPositionCount(); position++) { + Slice filePath = VARCHAR.getSlice(rowIdRow.getField(0), position); + long rowPosition = BIGINT.getLong(rowIdRow.getField(1), position); + Slice partitions = VARCHAR.getSlice(rowIdRow.getField(2), position); + + List partitionValues = PARTITIONS_CODEC.fromJson(partitions.toStringUtf8()); + + FileDeletion deletion = fileDeletions.computeIfAbsent(filePath, x -> new FileDeletion(partitionValues)); + + deletion.rowsToDelete().addLong(rowPosition); + } + }); + } + + @Override + public CompletableFuture> finish() + { + List fragments = new ArrayList<>(); + + insertPageSink.finish().join().stream() + .map(Slice::getBytes) + .map(dataFileInfoCodec::fromJson) + .map(info -> new DeltaLakeMergeResult(Optional.empty(), Optional.of(info))) + .map(mergeResultJsonCodec::toJsonBytes) + .map(Slices::wrappedBuffer) + .forEach(fragments::add); + + fileDeletions.forEach((path, deletion) -> + fragments.addAll(rewriteFile(new Path(path.toStringUtf8()), deletion))); + + return completedFuture(fragments); + } + + // In spite of the name "Delta" Lake, we must rewrite the entire file to delete rows. + private List rewriteFile(Path sourcePath, FileDeletion deletion) + { + try { + Path rootTablePath = new Path(rootTableLocation); + String sourceRelativePath = rootTablePath.toUri().relativize(sourcePath.toUri()).toString(); + FileSystem fileSystem = hdfsEnvironment.getFileSystem(new HdfsContext(session.getIdentity()), rootTablePath); + + Path targetPath = new Path(sourcePath.getParent(), session.getQueryId() + "_" + randomUUID()); + String targetRelativePath = rootTablePath.toUri().relativize(targetPath.toUri()).toString(); + FileWriter fileWriter = createParquetFileWriter(fileSystem, targetPath, dataColumns); + + DeltaLakeWriter writer = new DeltaLakeWriter( + fileSystem, + fileWriter, + rootTablePath, + targetRelativePath, + deletion.partitionValues(), + writerStats, + dataColumns); + + DataFileInfo newFileInfo = rewriteParquetFile(sourcePath, deletion.rowsToDelete(), writer); + + DeltaLakeMergeResult result = new DeltaLakeMergeResult(Optional.of(sourceRelativePath), Optional.of(newFileInfo)); + return ImmutableList.of(utf8Slice(mergeResultJsonCodec.toJson(result))); + } + catch (IOException e) { + throw new TrinoException(DELTA_LAKE_BAD_WRITE, "Unable to rewrite Parquet file", e); + } + } + + private FileWriter createParquetFileWriter(FileSystem fileSystem, Path path, List dataColumns) + { + ParquetWriterOptions parquetWriterOptions = ParquetWriterOptions.builder() + .setMaxBlockSize(getParquetWriterBlockSize(session)) + .setMaxPageSize(getParquetWriterPageSize(session)) + .build(); + CompressionCodecName compressionCodecName = getCompressionCodec(session).getParquetCompressionCodec(); + + try { + Callable rollbackAction = () -> { + fileSystem.delete(path, false); + return null; + }; + + List parquetTypes = dataColumns.stream() + .map(column -> { + Type type = column.getType(); + if (type instanceof TimestampWithTimeZoneType timestamp) { + verify(timestamp.getPrecision() == 3, "Unsupported type: %s", type); + return TIMESTAMP_MILLIS; + } + return type; + }) + .collect(toImmutableList()); + + ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter( + parquetTypes, + dataColumns.stream() + .map(DeltaLakeColumnHandle::getName) + .collect(toImmutableList()), + false, + false); + + return new ParquetFileWriter( + fileSystem.create(path), + rollbackAction, + parquetTypes, + schemaConverter.getMessageType(), + schemaConverter.getPrimitiveTypes(), + parquetWriterOptions, + IntStream.range(0, dataColumns.size()).toArray(), + compressionCodecName, + trinoVersion, + Optional.empty()); + } + catch (IOException e) { + throw new TrinoException(DELTA_LAKE_BAD_WRITE, "Error creating Parquet file", e); + } + } + + private DataFileInfo rewriteParquetFile(Path path, ImmutableLongBitmapDataProvider rowsToDelete, DeltaLakeWriter fileWriter) + throws IOException + { + try (ConnectorPageSource connectorPageSource = createParquetPageSource(path).get()) { + long filePosition = 0; + while (!connectorPageSource.isFinished()) { + Page page = connectorPageSource.getNextPage(); + if (page == null) { + continue; + } + + int positionCount = page.getPositionCount(); + int[] retained = new int[positionCount]; + int retainedCount = 0; + for (int position = 0; position < positionCount; position++) { + if (!rowsToDelete.contains(filePosition)) { + retained[retainedCount] = position; + retainedCount++; + } + filePosition++; + } + if (retainedCount != positionCount) { + page = page.getPositions(retained, 0, retainedCount); + } + + fileWriter.appendRows(page); + } + fileWriter.commit(); + } + catch (Throwable t) { + try { + fileWriter.rollback(); + } + catch (RuntimeException e) { + if (!t.equals(e)) { + t.addSuppressed(e); + } + } + throw t; + } + + return fileWriter.getDataFileInfo(); + } + + private ReaderPageSource createParquetPageSource(Path path) + throws IOException + { + HdfsContext hdfsContext = new HdfsContext(session); + Configuration config = hdfsEnvironment.getConfiguration(hdfsContext, path); + FileSystem fileSystem = hdfsEnvironment.getFileSystem(hdfsContext, path); + long fileSize = fileSystem.getFileStatus(path).getLen(); + + return ParquetPageSourceFactory.createPageSource( + path, + 0, + fileSize, + fileSize, + dataColumns.stream() + .map(DeltaLakeColumnHandle::toHiveColumnHandle) + .collect(toImmutableList()), + TupleDomain.all(), + true, + hdfsEnvironment, + config, + session.getIdentity(), + parquetDateTimeZone, + new FileFormatDataSourceStats(), + new ParquetReaderOptions()); + } + + private static class FileDeletion + { + private final List partitionValues; + private final LongBitmapDataProvider rowsToDelete = new Roaring64Bitmap(); + + private FileDeletion(List partitionValues) + { + this.partitionValues = ImmutableList.copyOf(requireNonNull(partitionValues, "partitionValues is null")); + } + + public List partitionValues() + { + return partitionValues; + } + + public LongBitmapDataProvider rowsToDelete() + { + return rowsToDelete; + } + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeTableHandle.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeTableHandle.java new file mode 100644 index 000000000000..d417eea1709f --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeTableHandle.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public class DeltaLakeMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final DeltaLakeTableHandle tableHandle; + private final DeltaLakeInsertTableHandle insertTableHandle; + + @JsonCreator + public DeltaLakeMergeTableHandle(DeltaLakeTableHandle tableHandle, DeltaLakeInsertTableHandle insertTableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertTableHandle = requireNonNull(insertTableHandle, "insertTableHandle is null"); + } + + @Override + @JsonProperty + public DeltaLakeTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public DeltaLakeInsertTableHandle getInsertTableHandle() + { + return insertTableHandle; + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 355ecda5442a..18861427027b 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -70,9 +70,11 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorAnalyzeMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; +import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; @@ -83,6 +85,7 @@ import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -146,6 +149,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.MERGE_ROW_ID_TYPE; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_TYPE; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.fileModifiedTimeColumnHandle; @@ -197,6 +201,7 @@ import static io.trino.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.connector.SchemaTableName.schemaTableName; import static io.trino.spi.predicate.Range.greaterThanOrEqual; import static io.trino.spi.predicate.Range.lessThanOrEqual; @@ -245,6 +250,7 @@ public class DeltaLakeMetadata public static final String CREATE_TABLE_OPERATION = "CREATE TABLE"; public static final String ADD_COLUMN_OPERATION = "ADD COLUMNS"; public static final String INSERT_OPERATION = "WRITE"; + public static final String MERGE_OPERATION = "MERGE"; public static final String DELETE_OPERATION = "DELETE"; public static final String UPDATE_OPERATION = "UPDATE"; public static final String OPTIMIZE_OPERATION = "OPTIMIZE"; @@ -272,6 +278,7 @@ public class DeltaLakeMetadata private final boolean unsafeWritesEnabled; private final JsonCodec dataFileInfoCodec; private final JsonCodec updateResultJsonCodec; + private final JsonCodec mergeResultJsonCodec; private final TransactionLogWriterFactory transactionLogWriterFactory; private final String nodeVersion; private final String nodeId; @@ -290,6 +297,7 @@ public DeltaLakeMetadata( boolean unsafeWritesEnabled, JsonCodec dataFileInfoCodec, JsonCodec updateResultJsonCodec, + JsonCodec mergeResultJsonCodec, TransactionLogWriterFactory transactionLogWriterFactory, NodeManager nodeManager, CheckpointWriterManager checkpointWriterManager, @@ -308,6 +316,7 @@ public DeltaLakeMetadata( this.unsafeWritesEnabled = unsafeWritesEnabled; this.dataFileInfoCodec = requireNonNull(dataFileInfoCodec, "dataFileInfoCodec is null"); this.updateResultJsonCodec = requireNonNull(updateResultJsonCodec, "updateResultJsonCodec is null"); + this.mergeResultJsonCodec = requireNonNull(mergeResultJsonCodec, "mergeResultJsonCodec is null"); this.transactionLogWriterFactory = requireNonNull(transactionLogWriterFactory, "transactionLogWriterFactory is null"); this.nodeVersion = nodeManager.getCurrentNode().getVersion(); this.nodeId = nodeManager.getCurrentNode().getNodeIdentifier(); @@ -1246,6 +1255,12 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto // This check acts as a safeguard in cases where the input columns may differ from the table metadata case-sensitively checkAllColumnsPassedOnInsert(tableMetadata, inputColumns); + + return createInsertHandle(session, retryMode, table, inputColumns, tableMetadata); + } + + private DeltaLakeInsertTableHandle createInsertHandle(ConnectorSession session, RetryMode retryMode, DeltaLakeTableHandle table, List inputColumns, ConnectorTableMetadata tableMetadata) + { String tableLocation = getLocation(tableMetadata.getProperties()); try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(new HdfsContext(session), new Path(tableLocation)); @@ -1460,6 +1475,134 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan finishWrite(session, tableHandle, fragments); } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return DELETE_ROW_AND_INSERT_ROW; + } + + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return new DeltaLakeColumnHandle(ROW_ID_COLUMN_NAME, MERGE_ROW_ID_TYPE, ROW_ID_COLUMN_NAME, MERGE_ROW_ID_TYPE, SYNTHESIZED); + } + + @Override + public Optional getUpdateLayout(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return Optional.of(DeltaLakeUpdateHandle.INSTANCE); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + DeltaLakeTableHandle handle = (DeltaLakeTableHandle) tableHandle; + if (isAppendOnly(handle.getMetadataEntry())) { + throw new TrinoException(NOT_SUPPORTED, "Cannot update rows from a table with '" + APPEND_ONLY_CONFIGURATION_KEY + "' set to true"); + } + if (!allowWrite(session, handle)) { + String fileSystem = new Path(handle.getLocation()).toUri().getScheme(); + throw new TrinoException(NOT_SUPPORTED, format("Updates are not supported on the %s filesystem", fileSystem)); + } + if (getColumnsNullability(handle.getMetadataEntry()).values().stream().anyMatch(nullability -> !nullability)) { + throw new TrinoException(NOT_SUPPORTED, "Updates are not supported for tables with non-nullable columns"); + } + if (!getColumnInvariants(handle.getMetadataEntry()).isEmpty()) { + throw new TrinoException(NOT_SUPPORTED, "Updates are not supported for tables with delta invariants"); + } + checkSupportedWriterVersion(session, handle.getSchemaTableName()); + + ConnectorTableMetadata tableMetadata = getTableMetadata(session, handle); + + List inputColumns = getColumns(handle.getMetadataEntry()).stream() + .filter(column -> column.getColumnType() != SYNTHESIZED) + .collect(toImmutableList()); + + DeltaLakeInsertTableHandle insertHandle = createInsertHandle(session, retryMode, handle, inputColumns, tableMetadata); + + return new DeltaLakeMergeTableHandle(handle, insertHandle); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + DeltaLakeTableHandle handle = ((DeltaLakeMergeTableHandle) tableHandle).getTableHandle(); + + List mergeResults = fragments.stream() + .map(Slice::getBytes) + .map(mergeResultJsonCodec::fromJson) + .collect(toImmutableList()); + + List oldFiles = mergeResults.stream() + .map(DeltaLakeMergeResult::getOldFile) + .flatMap(Optional::stream) + .collect(toImmutableList()); + + List newFiles = mergeResults.stream() + .map(DeltaLakeMergeResult::getNewFile) + .flatMap(Optional::stream) + .collect(toImmutableList()); + + if (handle.isRetriesEnabled()) { + cleanExtraOutputFilesForUpdate(session, handle.getLocation(), newFiles); + } + + Optional checkpointInterval = handle.getMetadataEntry().getCheckpointInterval(); + + String tableLocation = metastore.getTableLocation(handle.getSchemaTableName(), session); + + boolean writeCommitted = false; + try { + TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, tableLocation); + + long createdTime = Instant.now().toEpochMilli(); + + FileSystem fileSystem = hdfsEnvironment.getFileSystem(new HdfsContext(session), new Path(tableLocation)); + long currentVersion = getMandatoryCurrentVersion(fileSystem, new Path(tableLocation)); + if (currentVersion != handle.getReadVersion()) { + throw new TransactionConflictException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s", handle.getReadVersion(), currentVersion)); + } + long commitVersion = currentVersion + 1; + + transactionLogWriter.appendCommitInfoEntry( + new CommitInfoEntry( + commitVersion, + createdTime, + session.getUser(), + session.getUser(), + MERGE_OPERATION, + ImmutableMap.of("queryId", session.getQueryId()), + null, + null, + "trino-" + nodeVersion + "-" + nodeId, + handle.getReadVersion(), + ISOLATION_LEVEL, + true)); + // TODO: Delta writes another field "operationMetrics" (https://github.com/trinodb/trino/issues/12005) + + long writeTimestamp = Instant.now().toEpochMilli(); + + for (String file : oldFiles) { + transactionLogWriter.appendRemoveFileEntry(new RemoveFileEntry(file, writeTimestamp, true)); + } + + List partitionColumns = handle.getMetadataEntry().getOriginalPartitionColumns(); + appendAddFileEntries(transactionLogWriter, newFiles, partitionColumns, true); + + transactionLogWriter.flush(); + writeCommitted = true; + + writeCheckpointIfNeeded(session, new SchemaTableName(handle.getSchemaName(), handle.getTableName()), checkpointInterval, commitVersion); + } + catch (IOException | RuntimeException e) { + if (!writeCommitted) { + // TODO perhaps it should happen in a background thread (https://github.com/trinodb/trino/issues/12011) + cleanupFailedWrite(session, tableLocation, newFiles); + } + throw new TrinoException(DELTA_LAKE_BAD_WRITE, "Failed to write Delta Lake transaction log entry", e); + } + } + @Override public Optional getTableHandleForExecute( ConnectorSession session, @@ -1690,8 +1833,13 @@ private void finishWrite(ConnectorSession session, ConnectorTableHandle tableHan .map(updateResultJsonCodec::fromJson) .collect(toImmutableList()); + List newFiles = updateResults.stream() + .map(DeltaLakeUpdateResult::getNewFile) + .flatMap(Optional::stream) + .collect(toImmutableList()); + if (handle.isRetriesEnabled()) { - cleanExtraOutputFilesForUpdate(session, handle.getLocation(), updateResults); + cleanExtraOutputFilesForUpdate(session, handle.getLocation(), newFiles); } String tableLocation = metastore.getTableLocation(handle.getSchemaTableName(), session); @@ -1747,11 +1895,7 @@ private void finishWrite(ConnectorSession session, ConnectorTableHandle tableHan } appendAddFileEntries( transactionLogWriter, - updateResults.stream() - .map(DeltaLakeUpdateResult::getNewFile) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()), + newFiles, handle.getMetadataEntry().getOriginalPartitionColumns(), true); @@ -1762,11 +1906,7 @@ private void finishWrite(ConnectorSession session, ConnectorTableHandle tableHan catch (Exception e) { if (!writeCommitted) { // TODO perhaps it should happen in a background thread (https://github.com/trinodb/trino/issues/12011) - cleanupFailedWrite(session, tableLocation, updateResults.stream() - .map(DeltaLakeUpdateResult::getNewFile) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList())); + cleanupFailedWrite(session, tableLocation, newFiles); } throw new TrinoException(DELTA_LAKE_BAD_WRITE, "Failed to write Delta Lake transaction log entry", e); } @@ -2249,12 +2389,9 @@ private void cleanExtraOutputFiles(ConnectorSession session, String baseLocation cleanExtraOutputFiles(session, writtenFilePaths); } - private void cleanExtraOutputFilesForUpdate(ConnectorSession session, String baseLocation, List validUpdateResults) + private void cleanExtraOutputFilesForUpdate(ConnectorSession session, String baseLocation, List newFiles) { - Set writtenFilePaths = validUpdateResults.stream() - .map(DeltaLakeUpdateResult::getNewFile) - .filter(Optional::isPresent) - .map(Optional::get) + Set writtenFilePaths = newFiles.stream() .map(dataFileInfo -> baseLocation + "/" + dataFileInfo.getPath()) .collect(toImmutableSet()); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java index b72580456374..4989f023c7b0 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadataFactory.java @@ -42,6 +42,7 @@ public class DeltaLakeMetadataFactory private final DeltaLakeAccessControlMetadataFactory accessControlMetadataFactory; private final JsonCodec dataFileInfoCodec; private final JsonCodec updateResultJsonCodec; + private final JsonCodec mergeResultJsonCodec; private final TransactionLogWriterFactory transactionLogWriterFactory; private final NodeManager nodeManager; private final CheckpointWriterManager checkpointWriterManager; @@ -65,6 +66,7 @@ public DeltaLakeMetadataFactory( DeltaLakeConfig deltaLakeConfig, JsonCodec dataFileInfoCodec, JsonCodec updateResultJsonCodec, + JsonCodec mergeResultJsonCodec, TransactionLogWriterFactory transactionLogWriterFactory, NodeManager nodeManager, CheckpointWriterManager checkpointWriterManager, @@ -78,6 +80,7 @@ public DeltaLakeMetadataFactory( this.accessControlMetadataFactory = requireNonNull(accessControlMetadataFactory, "accessControlMetadataFactory is null"); this.dataFileInfoCodec = requireNonNull(dataFileInfoCodec, "dataFileInfoCodec is null"); this.updateResultJsonCodec = requireNonNull(updateResultJsonCodec, "updateResultJsonCodec is null"); + this.mergeResultJsonCodec = requireNonNull(mergeResultJsonCodec, "mergeResultJsonCodec is null"); this.transactionLogWriterFactory = requireNonNull(transactionLogWriterFactory, "transactionLogWriterFactory is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.checkpointWriterManager = requireNonNull(checkpointWriterManager, "checkpointWriterManager is null"); @@ -114,6 +117,7 @@ public DeltaLakeMetadata create(ConnectorIdentity identity) unsafeWritesEnabled, dataFileInfoCodec, updateResultJsonCodec, + mergeResultJsonCodec, transactionLogWriterFactory, nodeManager, checkpointWriterManager, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java index 51bae29cb931..67e1e0264720 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java @@ -146,6 +146,7 @@ public void setup(Binder binder) jsonCodecBinder(binder).bindJsonCodec(DataFileInfo.class); jsonCodecBinder(binder).bindJsonCodec(DeltaLakeUpdateResult.class); + jsonCodecBinder(binder).bindJsonCodec(DeltaLakeMergeResult.class); binder.bind(DeltaLakeWriterStats.class).in(Scopes.SINGLETON); binder.bind(FileFormatDataSourceStats.class).in(Scopes.SINGLETON); newExporter(binder).export(FileFormatDataSourceStats.class) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeNodePartitioningProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeNodePartitioningProvider.java index dcdecb4a26bd..02141ad1aa2e 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeNodePartitioningProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeNodePartitioningProvider.java @@ -13,13 +13,10 @@ */ package io.trino.plugin.deltalake; -import io.trino.spi.NodeManager; import io.trino.spi.connector.BucketFunction; -import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; @@ -28,39 +25,18 @@ import javax.inject.Inject; import java.util.List; -import java.util.function.ToIntFunction; -import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; import static java.util.Objects.requireNonNull; public class DeltaLakeNodePartitioningProvider implements ConnectorNodePartitioningProvider { private final TypeOperators typeOperators; - private final NodeManager nodeManager; @Inject - public DeltaLakeNodePartitioningProvider(TypeManager typeManager, NodeManager nodeManager) + public DeltaLakeNodePartitioningProvider(TypeManager typeManager) { this.typeOperators = requireNonNull(typeManager, "typeManager is null").getTypeOperators(); - this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); - } - - @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) - { - return createBucketNodeMap(nodeManager.getRequiredWorkerNodes().size()); - } - - @Override - public ToIntFunction getSplitBucketFunction( - ConnectorTransactionHandle transactionHandle, - ConnectorSession session, - ConnectorPartitioningHandle partitioningHandle) - { - return split -> { - throw new UnsupportedOperationException(); - }; } @Override @@ -71,6 +47,10 @@ public BucketFunction getBucketFunction( List partitionChannelTypes, int bucketCount) { + if (partitioningHandle instanceof DeltaLakeUpdateHandle) { + return new DeltaLakeUpdateBucketFunction(bucketCount); + } + DeltaLakePartitioningHandle handle = (DeltaLakePartitioningHandle) partitioningHandle; return new DeltaLakeBucketFunction(typeOperators, handle.getPartitioningColumns(), bucketCount); } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSinkProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSinkProvider.java index bb30510a961e..8ae41b409018 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSinkProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSinkProvider.java @@ -20,6 +20,8 @@ import io.trino.plugin.hive.NodeVersion; import io.trino.spi.PageIndexerFactory; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -27,6 +29,7 @@ import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.TypeManager; +import org.joda.time.DateTimeZone; import javax.inject.Inject; @@ -38,8 +41,10 @@ public class DeltaLakePageSinkProvider private final PageIndexerFactory pageIndexerFactory; private final HdfsEnvironment hdfsEnvironment; private final JsonCodec dataFileInfoCodec; + private final JsonCodec mergeResultJsonCodec; private final DeltaLakeWriterStats stats; private final int maxPartitionsPerWriter; + private final DateTimeZone parquetDateTimeZone; private final TypeManager typeManager; private final String trinoVersion; @@ -48,6 +53,7 @@ public DeltaLakePageSinkProvider( PageIndexerFactory pageIndexerFactory, HdfsEnvironment hdfsEnvironment, JsonCodec dataFileInfoCodec, + JsonCodec mergeResultJsonCodec, DeltaLakeWriterStats stats, DeltaLakeConfig deltaLakeConfig, TypeManager typeManager, @@ -56,8 +62,10 @@ public DeltaLakePageSinkProvider( this.pageIndexerFactory = pageIndexerFactory; this.hdfsEnvironment = hdfsEnvironment; this.dataFileInfoCodec = dataFileInfoCodec; + this.mergeResultJsonCodec = requireNonNull(mergeResultJsonCodec, "mergeResultJsonCodec is null"); this.stats = stats; this.maxPartitionsPerWriter = deltaLakeConfig.getMaxPartitionsPerWriter(); + this.parquetDateTimeZone = deltaLakeConfig.getParquetDateTimeZone(); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.trinoVersion = requireNonNull(nodeVersion, "nodeVersion is null").toString(); } @@ -121,4 +129,24 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa throw new IllegalArgumentException("Unknown procedure: " + executeHandle.getProcedureId()); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + DeltaLakeMergeTableHandle merge = (DeltaLakeMergeTableHandle) mergeHandle; + DeltaLakeInsertTableHandle tableHandle = merge.getInsertTableHandle(); + ConnectorPageSink pageSink = createPageSink(transactionHandle, session, tableHandle); + + return new DeltaLakeMergeSink( + hdfsEnvironment, + session, + parquetDateTimeZone, + trinoVersion, + dataFileInfoCodec, + mergeResultJsonCodec, + stats, + tableHandle.getLocation(), + pageSink, + tableHandle.getInputColumns()); + } } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java index 8f91fd8bbc40..0f09ff1fe8bb 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java @@ -13,9 +13,12 @@ */ package io.trino.plugin.deltalake; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.metrics.Metrics; @@ -31,28 +34,37 @@ import static com.google.common.base.Throwables.throwIfInstanceOf; import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_TYPE; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_TYPE; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_TYPE; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.VarcharType.VARCHAR; import static java.util.Objects.requireNonNull; public class DeltaLakePageSource implements ConnectorPageSource { + private static final JsonCodec> PARTITIONS_CODEC = new JsonCodecFactory().listJsonCodec(String.class); + private final Block[] prefilledBlocks; private final int[] delegateIndexes; + private final int rowIdIndex; + private final Block pathBlock; + private final Block partitionsBlock; private final ConnectorPageSource delegate; public DeltaLakePageSource( List columns, Map> partitionKeys, + List partitionValues, ConnectorPageSource delegate, String path, long fileSize, @@ -67,6 +79,11 @@ public DeltaLakePageSource( int outputIndex = 0; int delegateIndex = 0; + + int rowIdIndex = -1; + Block pathBlock = null; + Block partitionsBlock = null; + for (DeltaLakeColumnHandle column : columns) { if (partitionKeys.containsKey(column.getName())) { Type type = column.getType(); @@ -87,12 +104,23 @@ else if (column.getName().equals(FILE_MODIFIED_TIME_COLUMN_NAME)) { prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(FILE_MODIFIED_TIME_TYPE, packedTimestamp); delegateIndexes[outputIndex] = -1; } + else if (column.getName().equals(ROW_ID_COLUMN_NAME)) { + rowIdIndex = outputIndex; + pathBlock = Utils.nativeValueToBlock(VARCHAR, utf8Slice(path)); + partitionsBlock = Utils.nativeValueToBlock(VARCHAR, wrappedBuffer(PARTITIONS_CODEC.toJsonBytes(partitionValues))); + delegateIndexes[outputIndex] = delegateIndex; + delegateIndex++; + } else { delegateIndexes[outputIndex] = delegateIndex; delegateIndex++; } outputIndex++; } + + this.rowIdIndex = rowIdIndex; + this.pathBlock = pathBlock; + this.partitionsBlock = partitionsBlock; } @Override @@ -133,6 +161,9 @@ public Page getNextPage() if (prefilledBlocks[i] != null) { blocks[i] = new RunLengthEncodedBlock(prefilledBlocks[i], batchSize); } + else if (i == rowIdIndex) { + blocks[i] = createRowIdBlock(dataPage.getBlock(delegateIndexes[i])); + } else { blocks[i] = dataPage.getBlock(delegateIndexes[i]); } @@ -146,6 +177,17 @@ public Page getNextPage() } } + private Block createRowIdBlock(Block rowIndexBlock) + { + int positions = rowIndexBlock.getPositionCount(); + Block[] fields = { + new RunLengthEncodedBlock(pathBlock, positions), + rowIndexBlock, + new RunLengthEncodedBlock(partitionsBlock, positions), + }; + return RowBlock.fromFieldBlocks(positions, Optional.empty(), fields); + } + @Override public void close() { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index 2696940ab5e2..32375a19d2e5 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -42,6 +42,7 @@ import javax.inject.Inject; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -49,9 +50,12 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getParquetMaxReadBlockSize; +import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.hive.HiveSessionProperties.isParquetUseColumnIndex; +import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.PARQUET_ROW_INDEX_COLUMN; import static java.util.Objects.requireNonNull; public class DeltaLakePageSourceProvider @@ -117,12 +121,27 @@ public ConnectorPageSource createPageSource( Map> partitionKeys = split.getPartitionKeys(); + List partitionValues = new ArrayList<>(); + if (deltaLakeColumns.stream().anyMatch(column -> column.getName().equals(ROW_ID_COLUMN_NAME))) { + for (DeltaLakeColumnMetadata column : extractSchema(table.getMetadataEntry(), typeManager)) { + Optional value = partitionKeys.get(column.getName()); + if (value != null) { + partitionValues.add(value.orElse(null)); + } + } + } + List regularColumns = deltaLakeColumns.stream() - .filter(column -> column.getColumnType() == REGULAR) + .filter(column -> (column.getColumnType() == REGULAR) || column.getName().equals(ROW_ID_COLUMN_NAME)) .collect(toImmutableList()); List hiveColumnHandles = regularColumns.stream() - .map(DeltaLakeColumnHandle::toHiveColumnHandle) + .map(column -> { + if (column.getName().equals(ROW_ID_COLUMN_NAME)) { + return PARQUET_ROW_INDEX_COLUMN; + } + return column.toHiveColumnHandle(); + }) .collect(toImmutableList()); Path path = new Path(split.getPath()); @@ -166,7 +185,14 @@ public ConnectorPageSource createPageSource( verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); - return new DeltaLakePageSource(deltaLakeColumns, partitionKeys, pageSource.get(), split.getPath(), split.getFileSize(), split.getFileModifiedTime()); + return new DeltaLakePageSource( + deltaLakeColumns, + partitionKeys, + partitionValues, + pageSource.get(), + split.getPath(), + split.getFileSize(), + split.getFileModifiedTime()); } private static TupleDomain getParquetTupleDomain(TupleDomain effectivePredicate) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java index 387501b69098..042bb1961d37 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdatablePageSource.java @@ -220,6 +220,7 @@ public DeltaLakeUpdatablePageSource( this.pageSourceDelegate = new DeltaLakePageSource( delegatedColumns, partitionKeys, + ImmutableList.of(), parquetPageSource.get(), path, fileSize, diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateBucketFunction.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateBucketFunction.java new file mode 100644 index 000000000000..086c86849473 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateBucketFunction.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.BucketFunction; + +import static io.trino.spi.type.VarcharType.VARCHAR; + +public class DeltaLakeUpdateBucketFunction + implements BucketFunction +{ + private final int bucketCount; + + public DeltaLakeUpdateBucketFunction(int bucketCount) + { + this.bucketCount = bucketCount; + } + + @Override + public int getBucket(Page page, int position) + { + Block row = page.getBlock(0).getObject(position, Block.class); + Slice value = VARCHAR.getSlice(row, 0); // file path field of row ID + return (value.hashCode() & Integer.MAX_VALUE) % bucketCount; + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateHandle.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateHandle.java new file mode 100644 index 000000000000..fca88a6d9cf6 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeUpdateHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import io.trino.spi.connector.ConnectorPartitioningHandle; + +public enum DeltaLakeUpdateHandle + implements ConnectorPartitioningHandle +{ + INSTANCE +} diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 504287be60d7..029cb3e38777 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -172,6 +172,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_DELETE: case SUPPORTS_UPDATE: + case SUPPORTS_MERGE: return true; default: diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java index 2f3e8eff6e80..0b62a14189d9 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeMinioConnectorTest.java @@ -38,6 +38,8 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -105,6 +107,8 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return true; case SUPPORTS_UPDATE: return true; + case SUPPORTS_MERGE: + return true; case SUPPORTS_PREDICATE_PUSHDOWN: case SUPPORTS_LIMIT_PUSHDOWN: case SUPPORTS_TOPN_PUSHDOWN: @@ -512,6 +516,260 @@ public void testPathColumn() } } + @Test + public void testMergeSimpleSelectPartitioned() + { + String targetTable = "merge_simple_target_" + randomTableSuffix(); + String sourceTable = "merge_simple_source_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", targetTable, bucketName, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')", sourceTable, bucketName, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sql, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test(dataProvider = "partitionedProvider") + public void testMergeUpdateWithVariousLayouts(String partitionPhase) + { + String targetTable = "merge_formats_target_" + randomTableSuffix(); + String sourceTable = "merge_formats_source_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (location = 's3://%s/%s'%s)", targetTable, bucketName, targetTable, partitionPhase)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')"); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (location = 's3://%s/%s')", sourceTable, bucketName, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + @Language("SQL") String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + assertUpdate(sql, 3); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')"); + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] partitionedProvider() + { + return new Object[][] { + {""}, + {", partitioned_by = ARRAY['customer']"}, + {", partitioned_by = ARRAY['purchase']"} + }; + } + + @Test(dataProvider = "partitionedProvider") + public void testMergeMultipleOperations(String partitioning) + { + int targetCustomerCount = 32; + String targetTable = "merge_multiple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (purchase INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) WITH (location = 's3://%s/%s'%s)", targetTable, bucketName, targetTable, partitioning)); + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchase, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchase = s.purchase, zipcode = s.zipcode, spouse = s.spouse, address = s.address", + targetCustomerCount / 2); + + assertQuery( + "SELECT customer, purchase, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s", originalInsertFirstHalf, firstMergeSource)); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchase, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase, zipcode, spouse, address) VALUES(s.customer, s.purchase, s.zipcode, s.spouse, s.address)", + targetCustomerCount * 3 / 2 - 1); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertQuery( + "SELECT customer, purchase, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s, %s", updatedBeginning, updatedMiddle, updatedEnd)); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeSimpleQueryPartitioned() + { + String targetTable = "merge_simple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", targetTable, bucketName, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String query = format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + assertUpdate(query, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test(dataProvider = "targetWithDifferentPartitioning") + public void testMergeMultipleRowsMatchFails(String createTableSql) + { + String targetTable = "merge_multiple_target_" + randomTableSuffix(); + String sourceTable = "merge_multiple_source_" + randomTableSuffix(); + assertUpdate(format(createTableSql, targetTable, bucketName, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')", sourceTable, bucketName, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable), 2); + + assertThatThrownBy(() -> computeActual(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address")) + .hasMessage("One MERGE target table row matched more than one source row"); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address", + 1); + assertQuery("SELECT customer, purchases, address FROM " + targetTable, "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] targetWithDifferentPartitioning() + { + return new Object[][] { + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer'])"}, + {"CREATE TABLE %s (customer VARCHAR, address VARCHAR, purchases INT) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])"}, + // TODO: enable when https://github.com/trinodb/trino/issues/13505 is fixed + // {"CREATE TABLE %s (purchases INT, customer VARCHAR, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address', 'customer'])"} + {"CREATE TABLE %s (purchases INT, address VARCHAR, customer VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address', 'customer'])"} + }; + } + + @Test(dataProvider = "targetAndSourceWithDifferentPartitioning") + public void testMergeWithDifferentPartitioning(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + String targetTable = format("%s_target_%s", testDescription, randomTableSuffix()); + String sourceTable = format("%s_source_%s", testDescription, randomTableSuffix()); + + assertUpdate(format(createTargetTableSql, targetTable, bucketName, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format(createSourceTableSql, sourceTable, bucketName, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + assertUpdate(sql, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentPartitioning() + { + return new Object[][] { + // TODO: enable when https://github.com/trinodb/trino/issues/13505 is fixed + // { + // "target_partitioned_source_and_target_partitioned", + // "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address', 'customer'])", + // "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", + // }, + { + "target_partitioned_source_and_target_partitioned", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer', 'address'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", + }, + { + "target_flat_source_partitioned_by_customer", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')", + "CREATE TABLE %s (purchases INT, address VARCHAR, customer VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer'])" + }, + { + "target_partitioned_by_customer_source_flat", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')", + }, + { + "target_bucketed_by_customer_source_flat", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer', 'address'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s')", + }, + { + "target_partitioned_source_partitioned", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", + }, + { + "target_partitioned_target_partitioned", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['address'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (location = 's3://%s/%s', partitioned_by = ARRAY['customer'])", + } + }; + } + @Override protected String createSchemaSql(String schemaName) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java index a42840c696d4..b1b31be42b97 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java @@ -171,6 +171,7 @@ private static ConnectorPageSink createPageSink(Path outputPath, DeltaLakeWriter new GroupByHashPageIndexerFactory(new JoinCompiler(new TypeOperators()), new BlockTypeOperators()), HDFS_ENVIRONMENT, JsonCodec.jsonCodec(DataFileInfo.class), + JsonCodec.jsonCodec(DeltaLakeMergeResult.class), stats, deltaLakeConfig, new TestingTypeManager(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java index 8081018633be..a72079937bc4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/AbstractHiveAcidWriters.java @@ -13,11 +13,19 @@ */ package io.trino.plugin.hive; +import com.google.common.annotations.VisibleForTesting; +import io.trino.plugin.hive.HiveWriterFactory.RowIdSortingFileWriterMaker; import io.trino.plugin.hive.acid.AcidOperation; import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.orc.OrcFileWriter; import io.trino.plugin.hive.orc.OrcFileWriterFactory; +import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.TypeManager; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -31,11 +39,16 @@ import static io.trino.orc.OrcWriter.OrcOperation.DELETE; import static io.trino.orc.OrcWriter.OrcOperation.INSERT; import static io.trino.plugin.hive.HiveStorageFormat.ORC; +import static io.trino.plugin.hive.HiveUpdatablePageSource.BUCKET_CHANNEL; +import static io.trino.plugin.hive.HiveUpdatablePageSource.ORIGINAL_TRANSACTION_CHANNEL; +import static io.trino.plugin.hive.HiveUpdatablePageSource.ROW_ID_CHANNEL; import static io.trino.plugin.hive.acid.AcidSchema.ACID_COLUMN_NAMES; import static io.trino.plugin.hive.acid.AcidSchema.createAcidSchema; import static io.trino.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat; import static io.trino.plugin.hive.util.ConfigurationUtils.toJobConf; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; import static io.trino.spi.predicate.Utils.nativeValueToBlock; +import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -57,40 +70,53 @@ public abstract class AbstractHiveAcidWriters protected final AcidTransaction transaction; protected final OptionalInt bucketNumber; protected final int statementId; + protected final Block bucketValueBlock; + + private final Optional sortingFileWriterMaker; private final OrcFileWriterFactory orcFileWriterFactory; private final Configuration configuration; protected final ConnectorSession session; protected final HiveType hiveRowType; + private final AcidOperation updateKind; private final Properties hiveAcidSchema; + protected final Block hiveRowTypeNullsBlock; + protected Path deltaDirectory; protected final Path deleteDeltaDirectory; private final String bucketFilename; - protected Optional deltaDirectory; protected Optional deleteFileWriter = Optional.empty(); protected Optional insertFileWriter = Optional.empty(); + private int insertRowCounter; public AbstractHiveAcidWriters( AcidTransaction transaction, int statementId, OptionalInt bucketNumber, + Optional sortingFileWriterMaker, Path bucketPath, boolean originalFile, OrcFileWriterFactory orcFileWriterFactory, Configuration configuration, ConnectorSession session, + TypeManager typeManager, HiveType hiveRowType, AcidOperation updateKind) { this.transaction = requireNonNull(transaction, "transaction is null"); this.statementId = statementId; this.bucketNumber = requireNonNull(bucketNumber, "bucketNumber is null"); + this.sortingFileWriterMaker = requireNonNull(sortingFileWriterMaker, "sortingFileWriterMaker is null"); + this.bucketValueBlock = nativeValueToBlock(INTEGER, Long.valueOf(OrcFileWriter.computeBucketValue(bucketNumber.orElse(0), statementId))); this.orcFileWriterFactory = requireNonNull(orcFileWriterFactory, "orcFileWriterFactory is null"); this.configuration = requireNonNull(configuration, "configuration is null"); this.session = requireNonNull(session, "session is null"); checkArgument(transaction.isTransactional(), "Not in a transaction: %s", transaction); this.hiveRowType = requireNonNull(hiveRowType, "hiveRowType is null"); + this.updateKind = requireNonNull(updateKind, "updateKind is null"); this.hiveAcidSchema = createAcidSchema(hiveRowType); + this.hiveRowTypeNullsBlock = nativeValueToBlock(hiveRowType.getType(typeManager), null); requireNonNull(bucketPath, "bucketPath is null"); + checkArgument(updateKind != AcidOperation.MERGE || sortingFileWriterMaker.isPresent(), "updateKind is MERGE but sortingFileWriterMaker is not present"); Matcher matcher; if (originalFile) { matcher = ORIGINAL_FILE_PATH_MATCHER.matcher(bucketPath.toString()); @@ -109,22 +135,59 @@ public AbstractHiveAcidWriters( } } long writeId = transaction.getWriteId(); + this.deltaDirectory = new Path(format("%s/%s", matcher.group("rootDir"), deltaSubdir(writeId, writeId, statementId))); this.deleteDeltaDirectory = new Path(format("%s/%s", matcher.group("rootDir"), deleteDeltaSubdir(writeId, writeId, statementId))); - if (updateKind == AcidOperation.UPDATE) { - this.deltaDirectory = Optional.of(new Path(format("%s/%s", matcher.group("rootDir"), deltaSubdir(writeId, writeId, statementId)))); - } - else { - this.deltaDirectory = Optional.empty(); + } + + protected Page buildDeletePage(Block rowIds, long writeId) + { + return buildDeletePage(rowIds, writeId, hiveRowTypeNullsBlock); + } + + @VisibleForTesting + public static Page buildDeletePage(Block rowIdsRowBlock, long writeId, Block rowTypeNullsBlock) + { + ColumnarRow columnarRow = toColumnarRow(rowIdsRowBlock); + checkArgument(!columnarRow.mayHaveNull(), "The rowIdsRowBlock may not have null rows"); + int positionCount = rowIdsRowBlock.getPositionCount(); + // We've verified that the rowIds block has no null rows, so it's okay to get the field blocks + Block[] blockArray = { + new RunLengthEncodedBlock(DELETE_OPERATION_BLOCK, positionCount), + columnarRow.getField(ORIGINAL_TRANSACTION_CHANNEL), + columnarRow.getField(BUCKET_CHANNEL), + columnarRow.getField(ROW_ID_CHANNEL), + RunLengthEncodedBlock.create(BIGINT, writeId, positionCount), + new RunLengthEncodedBlock(rowTypeNullsBlock, positionCount), + }; + return new Page(blockArray); + } + + protected Block createRowIdBlock(int positionCount) + { + Block block = createRowIdBlock(positionCount, insertRowCounter); + insertRowCounter += positionCount; + return block; + } + + @VisibleForTesting + public static Block createRowIdBlock(int positionCount, int rowCounter) + { + long[] rowIds = new long[positionCount]; + for (int index = 0; index < positionCount; index++) { + rowIds[index] = rowCounter; + rowCounter++; } + return new LongArrayBlock(positionCount, Optional.empty(), rowIds); } - protected void lazyInitializeDeleteFileWriter() + protected FileWriter getOrCreateDeleteFileWriter() { if (deleteFileWriter.isEmpty()) { Properties schemaCopy = new Properties(); schemaCopy.putAll(hiveAcidSchema); + Path deletePath = new Path(format("%s/%s", deleteDeltaDirectory, bucketFilename)); deleteFileWriter = orcFileWriterFactory.createFileWriter( - new Path(format("%s/%s", deleteDeltaDirectory, bucketFilename)), + deletePath, ACID_COLUMN_NAMES, fromHiveStorageFormat(ORC), schemaCopy, @@ -134,17 +197,26 @@ protected void lazyInitializeDeleteFileWriter() transaction, true, WriterKind.DELETE); + if (updateKind == AcidOperation.MERGE) { + deleteFileWriter = Optional.of(sortingFileWriterMaker.orElseThrow(() -> new IllegalArgumentException("sortingFileWriterMaker not present")) + .makeFileWriter(getWriter(deleteFileWriter), deletePath)); + } } + return getWriter(deleteFileWriter); + } + + private FileWriter getWriter(Optional writer) + { + return writer.orElseThrow(() -> new IllegalArgumentException("writer is not present")); } - protected void lazyInitializeInsertFileWriter() + protected FileWriter getOrCreateInsertFileWriter() { if (insertFileWriter.isEmpty()) { Properties schemaCopy = new Properties(); schemaCopy.putAll(hiveAcidSchema); - Path deltaDir = deltaDirectory.orElseThrow(() -> new IllegalArgumentException("deltaDirectory not present")); insertFileWriter = orcFileWriterFactory.createFileWriter( - new Path(format("%s/%s", deltaDir, bucketFilename)), + new Path(format("%s/%s", deltaDirectory, bucketFilename)), ACID_COLUMN_NAMES, fromHiveStorageFormat(ORC), schemaCopy, @@ -155,5 +227,6 @@ protected void lazyInitializeInsertFileWriter() true, WriterKind.INSERT); } + return getWriter(insertFileWriter); } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java index 26bffc9b3f44..77785be7c2c0 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveColumnHandle.java @@ -36,7 +36,7 @@ import static io.trino.plugin.hive.HiveType.HIVE_LONG; import static io.trino.plugin.hive.HiveType.HIVE_STRING; import static io.trino.plugin.hive.HiveType.toHiveType; -import static io.trino.plugin.hive.HiveUpdateProcessor.getUpdateRowIdColumnHandle; +import static io.trino.plugin.hive.HiveUpdateProcessor.getRowIdColumnHandleForNonUpdatedColumns; import static io.trino.plugin.hive.acid.AcidSchema.ACID_ROW_ID_ROW_TYPE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; @@ -271,7 +271,12 @@ public static HiveColumnHandle updateRowIdColumnHandle(List co List nonUpdatedColumnHandles = columnHandles.stream() .filter(column -> !column.isPartitionKey() && !column.isHidden() && !updatedColumns.contains(column)) .collect(toImmutableList()); - return getUpdateRowIdColumnHandle(nonUpdatedColumnHandles); + return getRowIdColumnHandleForNonUpdatedColumns(nonUpdatedColumnHandles); + } + + public static HiveColumnHandle mergeRowIdColumnHandle() + { + return createBaseColumn(UPDATE_ROW_ID_COLUMN_NAME, UPDATE_ROW_ID_COLUMN_INDEX, toHiveType(ACID_ROW_ID_ROW_TYPE), ACID_ROW_ID_ROW_TYPE, SYNTHESIZED, Optional.empty()); } public static HiveColumnHandle pathColumnHandle() diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java new file mode 100644 index 000000000000..15837d9d861a --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMergeTableHandle.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public class HiveMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final HiveTableHandle tableHandle; + private final HiveInsertTableHandle insertHandle; + + @JsonCreator + public HiveMergeTableHandle( + @JsonProperty("tableHandle") HiveTableHandle tableHandle, + @JsonProperty("insertHandle") HiveInsertTableHandle insertHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertHandle = requireNonNull(insertHandle, "insertHandle is null"); + } + + @Override + @JsonProperty + public HiveTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public HiveInsertTableHandle getInsertHandle() + { + return insertHandle; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index 598f9c538054..9db9cdb688df 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -66,6 +66,7 @@ import io.trino.spi.connector.ConnectorAnalyzeMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; @@ -85,6 +86,7 @@ import io.trino.spi.connector.MetadataProvider; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -170,6 +172,7 @@ import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.SYNTHESIZED; import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; +import static io.trino.plugin.hive.HiveColumnHandle.mergeRowIdColumnHandle; import static io.trino.plugin.hive.HiveColumnHandle.updateRowIdColumnHandle; import static io.trino.plugin.hive.HiveCompressionCodecs.selectCompressionCodec; import static io.trino.plugin.hive.HiveErrorCode.HIVE_COLUMN_ORDER_MISMATCH; @@ -292,6 +295,7 @@ import static io.trino.spi.StandardErrorCode.UNSUPPORTED_TABLE_TYPE; import static io.trino.spi.connector.Constraint.alwaysTrue; import static io.trino.spi.connector.RetryMode.NO_RETRIES; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.predicate.TupleDomain.withColumnDomains; import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; @@ -1783,7 +1787,8 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan HdfsContext context = new HdfsContext(session); for (PartitionAndStatementId ps : partitionAndStatementIds) { - createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory())); + createOrcAcidVersionFile(context, new Path(ps.getDeltaDirectory().get())); + createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory().get())); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); @@ -1791,8 +1796,78 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan metastore.finishUpdate(session, table.getDatabaseName(), table.getTableName(), writeInfo.getWritePath(), partitionAndStatementIds); } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + HiveTableHandle handle = (HiveTableHandle) tableHandle; + Optional> properties = handle.getTableParameters(); + if (isTransactionalTable(properties.get())) { + return DELETE_ROW_AND_INSERT_ROW; + } + // TODO: At some point we should add detection to see if the metastore supports + // transactional tables and just say merge isn't supported in the HMS + throw new TrinoException(NOT_SUPPORTED, "Hive merge is only supported for transactional tables"); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + SchemaTableName tableName = hiveTableHandle.getSchemaTableName(); + Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) + .orElseThrow(() -> new TableNotFoundException(tableName)); + + if (!isTransactionalTable(table.getParameters())) { + throw new TrinoException(NOT_SUPPORTED, "Hive merge is only supported for transactional tables"); + } + + HiveInsertTableHandle insertHandle = beginInsertOrMerge(session, tableHandle, retryMode, "Merging", true); + return new HiveMergeTableHandle(hiveTableHandle.withTransaction(insertHandle.getTransaction()), insertHandle); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + HiveMergeTableHandle mergeHandle = (HiveMergeTableHandle) tableHandle; + HiveInsertTableHandle insertHandle = mergeHandle.getInsertHandle(); + HiveTableHandle handle = mergeHandle.getTableHandle(); + checkArgument(handle.isAcidMerge(), "handle should be a merge handle, but is %s", handle); + + requireNonNull(fragments, "fragments is null"); + List partitionMergeResults = fragments.stream() + .map(Slice::getBytes) + .map(PartitionUpdateAndMergeResults.CODEC::fromJson) + .collect(toImmutableList()); + + List partitionUpdates = partitionMergeResults.stream() + .map(PartitionUpdateAndMergeResults::getPartitionUpdate) + .collect(toImmutableList()); + + Table table = finishChangingTable(AcidOperation.MERGE, "merge", session, insertHandle, partitionUpdates, computedStatistics); + + HdfsContext context = new HdfsContext(session); + for (PartitionUpdateAndMergeResults results : partitionMergeResults) { + results.getDeltaDirectory().ifPresent(deltaDirectory -> createOrcAcidVersionFile(context, new Path(deltaDirectory))); + results.getDeleteDeltaDirectory().ifPresent(deleteDeltadDirectory -> createOrcAcidVersionFile(context, new Path(deleteDeltadDirectory))); + } + + List partitions = partitionUpdates.stream() + .filter(update -> !update.getName().isEmpty()) + .map(update -> buildPartitionObject(session, table, update)) + .collect(toImmutableList()); + + LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); + WriteInfo writeInfo = locationService.getQueryWriteInfo(locationHandle); + metastore.finishMerge(session, table.getDatabaseName(), table.getTableName(), writeInfo.getWritePath(), partitionMergeResults, partitions); + } + @Override public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List columns, RetryMode retryMode) + { + return beginInsertOrMerge(session, tableHandle, retryMode, "Inserting", false); + } + + private HiveInsertTableHandle beginInsertOrMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode, String description, boolean isForMerge) { SchemaTableName tableName = ((HiveTableHandle) tableHandle).getSchemaTableName(); Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) @@ -1802,7 +1877,7 @@ public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTabl for (Column column : table.getDataColumns()) { if (!isWritableType(column.getType())) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table %s with column type %s not supported", tableName, column.getType())); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table %s with column type %s not supported", description, tableName, column.getType())); } } @@ -1824,16 +1899,22 @@ public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTabl HiveStorageFormat tableStorageFormat = extractHiveStorageFormat(table); Optional.ofNullable(table.getParameters().get(SKIP_HEADER_COUNT_KEY)).map(Integer::parseInt).ifPresent(headerSkipCount -> { if (headerSkipCount > 1) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table with value of %s property greater than 1 is not supported", SKIP_HEADER_COUNT_KEY)); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table with value of %s property greater than 1 is not supported", description, SKIP_HEADER_COUNT_KEY)); } }); if (table.getParameters().containsKey(SKIP_FOOTER_COUNT_KEY)) { - throw new TrinoException(NOT_SUPPORTED, format("Inserting into Hive table with %s property not supported", SKIP_FOOTER_COUNT_KEY)); + throw new TrinoException(NOT_SUPPORTED, format("%s into Hive table with %s property not supported", description, SKIP_FOOTER_COUNT_KEY)); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); - AcidTransaction transaction = isTransactional ? metastore.beginInsert(session, table) : NO_ACID_TRANSACTION; - + AcidTransaction transaction = NO_ACID_TRANSACTION; + if (isForMerge) { + checkArgument(isTransactional, "The target table in Hive MERGE must be a transactional table"); + transaction = metastore.beginMerge(session, table); + } + else if (isTransactional) { + transaction = metastore.beginInsert(session, table); + } HiveInsertTableHandle result = new HiveInsertTableHandle( tableName.getSchemaName(), tableName.getTableName(), @@ -1872,13 +1953,32 @@ public Optional finishInsert(ConnectorSession session, .map(partitionUpdateCodec::fromJson) .collect(toImmutableList()); + Table table = finishChangingTable(AcidOperation.INSERT, "insert", session, handle, partitionUpdates, computedStatistics); + + if (isFullAcidTable(table.getParameters())) { + HdfsContext context = new HdfsContext(session); + for (PartitionUpdate update : partitionUpdates) { + long writeId = handle.getTransaction().getWriteId(); + Path deltaDirectory = new Path(format("%s/%s/%s", table.getStorage().getLocation(), update.getName(), deltaSubdir(writeId, writeId, 0))); + createOrcAcidVersionFile(context, deltaDirectory); + } + } + + return Optional.of(new HiveWrittenPartitions( + partitionUpdates.stream() + .map(PartitionUpdate::getName) + .collect(toImmutableList()))); + } + + private Table finishChangingTable(AcidOperation acidOperation, String changeDescription, ConnectorSession session, HiveInsertTableHandle handle, List partitionUpdates, Collection computedStatistics) + { HiveStorageFormat tableStorageFormat = handle.getTableStorageFormat(); partitionUpdates = PartitionUpdate.mergePartitionUpdates(partitionUpdates); Table table = metastore.getTable(handle.getSchemaName(), handle.getTableName()) .orElseThrow(() -> new TableNotFoundException(handle.getSchemaTableName())); if (!table.getStorage().getStorageFormat().getInputFormat().equals(tableStorageFormat.getInputFormat()) && isRespectTableFormat(session)) { - throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during insert"); + throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during " + changeDescription); } if (handle.getBucketProperty().isPresent() && isCreateEmptyBucketFiles(session)) { @@ -1914,7 +2014,7 @@ public Optional finishInsert(ConnectorSession session, if (partitionUpdate.getName().isEmpty()) { // insert into unpartitioned table if (!table.getStorage().getStorageFormat().getInputFormat().equals(handle.getPartitionStorageFormat().getInputFormat()) && isRespectTableFormat(session)) { - throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during insert"); + throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during " + changeDescription); } PartitionStatistics partitionStatistics = createPartitionStatistics( @@ -1934,7 +2034,8 @@ public Optional finishInsert(ConnectorSession session, } else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode() == APPEND) { // insert into unpartitioned table - metastore.finishInsertIntoExistingTable( + metastore.finishChangingExistingTable( + acidOperation, session, handle.getSchemaName(), handle.getTableName(), @@ -1995,20 +2096,7 @@ else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode throw new IllegalArgumentException(format("Unsupported update mode: %s", partitionUpdate.getUpdateMode())); } } - - if (isFullAcidTable(table.getParameters())) { - HdfsContext context = new HdfsContext(session); - for (PartitionUpdate update : partitionUpdates) { - long writeId = handle.getTransaction().getWriteId(); - Path deltaDirectory = new Path(format("%s/%s/%s", table.getStorage().getLocation(), update.getName(), deltaSubdir(writeId, writeId, 0))); - createOrcAcidVersionFile(context, deltaDirectory); - } - } - - return Optional.of(new HiveWrittenPartitions( - partitionUpdates.stream() - .map(PartitionUpdate::getName) - .collect(toImmutableList()))); + return table; } private void removeNonCurrentQueryFiles(ConnectorSession session, Path partitionPath) @@ -2253,7 +2341,12 @@ private void finishOptimize(ConnectorSession session, ConnectorTableExecuteHandl throw new TrinoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Table format changed during optimize"); } - metastore.finishInsertIntoExistingTable( + AcidOperation operation = handle.getTransaction().getOperation(); + if (operation == AcidOperation.NONE) { + operation = AcidOperation.INSERT; + } + metastore.finishChangingExistingTable( + operation, session, handle.getSchemaName(), handle.getTableName(), @@ -2559,7 +2652,6 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan SchemaTableName tableName = handle.getSchemaTableName(); Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) .orElseThrow(() -> new TableNotFoundException(tableName)); - ensureTableSupportsDelete(table); List partitionAndStatementIds = fragments.stream() .map(Slice::getBytes) @@ -2568,7 +2660,7 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan HdfsContext context = new HdfsContext(session); for (PartitionAndStatementId ps : partitionAndStatementIds) { - createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory())); + createOrcAcidVersionFile(context, new Path(ps.getDeleteDeltaDirectory().get())); } LocationHandle locationHandle = locationService.forExistingTable(metastore, session, table); @@ -2596,6 +2688,12 @@ public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, Connect return updateRowIdColumnHandle(table.getDataColumns(), updatedColumns); } + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return mergeRowIdColumnHandle(); + } + @Override public Optional applyDelete(ConnectorSession session, ConnectorTableHandle handle) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java index c1d48ee09444..3301c38f99f6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java @@ -28,6 +28,7 @@ import javax.inject.Inject; import java.util.List; +import java.util.Optional; import java.util.function.ToIntFunction; import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; @@ -71,11 +72,11 @@ public BucketFunction getBucketFunction( } @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { HivePartitioningHandle handle = (HivePartitioningHandle) partitioningHandle; if (!handle.isUsePartitionedBucketing()) { - return createBucketNodeMap(handle.getBucketCount()); + return Optional.of(createBucketNodeMap(handle.getBucketCount())); } // Allocate a fixed number of buckets. Trino will assign consecutive buckets @@ -88,7 +89,7 @@ public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transa // However, number of partitions is not known here // If number of workers < ( P * B), we need multiple writers per node to fully // parallelize the write within a worker - return createBucketNodeMap(nodeManager.getRequiredWorkerNodes().size() * PARTITIONED_BUCKETS_PER_NODE); + return Optional.of(createBucketNodeMap(nodeManager.getRequiredWorkerNodes().size() * PARTITIONED_BUCKETS_PER_NODE)); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java index c0833772b498..91c59996751c 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSink.java @@ -30,6 +30,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlockBuilder; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; @@ -47,6 +48,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.slice.Slices.wrappedBuffer; @@ -58,7 +60,7 @@ import static java.util.stream.Collectors.toList; public class HivePageSink - implements ConnectorPageSink + implements ConnectorPageSink, ConnectorMergeSink { private static final Logger log = Logger.get(HivePageSink.class); @@ -90,11 +92,13 @@ public class HivePageSink private final List partitionUpdates = new ArrayList<>(); private final List> verificationTasks = new ArrayList<>(); + private final boolean isMergeSink; private long writtenBytes; private long memoryUsage; private long validationCpuNanos; public HivePageSink( + HiveWritableTableHandle tableHandle, HiveWriterFactory writerFactory, List inputColumns, boolean isTransactional, @@ -118,6 +122,7 @@ public HivePageSink( this.writeVerificationExecutor = requireNonNull(writeVerificationExecutor, "writeVerificationExecutor is null"); this.partitionUpdateCodec = requireNonNull(partitionUpdateCodec, "partitionUpdateCodec is null"); + this.isMergeSink = tableHandle.getTransaction().isMerge(); requireNonNull(bucketProperty, "bucketProperty is null"); this.pagePartitioner = new HiveWriterPagePartitioner( inputColumns, @@ -188,11 +193,30 @@ public CompletableFuture> finish() { // Must be wrapped in doAs entirely // Implicit FileSystem initializations are possible in HiveRecordWriter#commit -> RecordWriter#close - ListenableFuture> result = hdfsEnvironment.doAs(session.getIdentity(), this::doFinish); + ListenableFuture> result = hdfsEnvironment.doAs( + session.getIdentity(), + isMergeSink ? this::doMergeSinkFinish : this::doInsertSinkFinish); + return MoreFutures.toCompletableFuture(result); } - private ListenableFuture> doFinish() + private ListenableFuture> doMergeSinkFinish() + { + ImmutableList.Builder resultSlices = ImmutableList.builder(); + for (HiveWriter writer : writers) { + writer.commit(); + MergeFileWriter mergeFileWriter = (MergeFileWriter) writer.getFileWriter(); + PartitionUpdateAndMergeResults results = mergeFileWriter.getPartitionUpdateAndMergeResults(writer.getPartitionUpdate()); + resultSlices.add(wrappedBuffer(PartitionUpdateAndMergeResults.CODEC.toJsonBytes(results))); + } + List result = resultSlices.build(); + writtenBytes = writers.stream() + .mapToLong(HiveWriter::getWrittenBytes) + .sum(); + return Futures.immediateFuture(result); + } + + private ListenableFuture> doInsertSinkFinish() { for (HiveWriter writer : writers) { closeWriter(writer); @@ -373,9 +397,11 @@ private int[] getWriterIndexes(Page page) OptionalInt bucketNumber = OptionalInt.empty(); if (bucketBlock != null) { - bucketNumber = OptionalInt.of(bucketBlock.getInt(position, 0)); + bucketNumber = OptionalInt.of((int) INTEGER.getLong(bucketBlock, position)); } + writer = writerFactory.createWriter(partitionColumns, position, bucketNumber); + writers.set(writerIndex, writer); } verify(writers.size() == pagePartitioner.getMaxIndex() + 1); @@ -386,6 +412,9 @@ private int[] getWriterIndexes(Page page) private Page getDataPage(Page page) { + if (isMergeSink) { + return page; + } Block[] blocks = new Block[dataColumnInputIndex.length]; for (int i = 0; i < dataColumnInputIndex.length; i++) { int dataColumn = dataColumnInputIndex[i]; @@ -419,6 +448,13 @@ private static Page extractColumns(Page page, int[] columns) return new Page(page.getPositionCount(), blocks); } + @Override + public void storeMergedRows(Page page) + { + checkArgument(isMergeSink, "isMergeSink is false"); + appendPage(page); + } + private static class HiveWriterPagePartitioner { private final PageIndexer pageIndexer; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java index 33d6b9f6e21f..529563237e06 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSinkProvider.java @@ -27,6 +27,8 @@ import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -44,6 +46,7 @@ import java.util.OptionalInt; import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.plugin.hive.metastore.cache.CachingHiveMetastore.memoizeMetastore; @@ -129,7 +132,16 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return createPageSink(handle, false, session, ImmutableMap.of()); } - private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, Map additionalTableParameters) + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + HiveMergeTableHandle hiveMergeHandle = (HiveMergeTableHandle) mergeHandle; + HiveInsertTableHandle insertHandle = hiveMergeHandle.getInsertHandle(); + checkArgument(insertHandle.getTransaction().isMerge(), "handle isn't an ACID MERGE"); + return createPageSink(insertHandle, false, session, ImmutableMap.of()); + } + + private HivePageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, Map additionalTableParameters) { OptionalInt bucketCount = OptionalInt.empty(); List sortedBy = ImmutableList.of(); @@ -170,6 +182,7 @@ private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean hiveWriterStats); return new HivePageSink( + handle, writerFactory, handle.getInputColumns(), handle.isTransactional(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java index a42830ef250d..92d55c3f42ac 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveTableHandle.java @@ -435,6 +435,12 @@ public boolean isAcidUpdate() return transaction.isUpdate(); } + @JsonIgnore + public boolean isAcidMerge() + { + return transaction.isMerge(); + } + @JsonIgnore public Optional getUpdateProcessor() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java index 995f0bef8c93..3ddba881672a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdatablePageSource.java @@ -44,7 +44,6 @@ import static io.trino.plugin.hive.PartitionAndStatementId.CODEC; import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static io.trino.spi.block.ColumnarRow.toColumnarRow; -import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -65,7 +64,6 @@ public class HiveUpdatablePageSource private final String partitionName; private final ConnectorPageSource hivePageSource; private final AcidOperation updateKind; - private final Block hiveRowTypeNullsBlock; private final long writeId; private final Optional> dependencyChannels; @@ -95,11 +93,10 @@ public HiveUpdatablePageSource( long initialRowId, long maxNumberOfRowsPerSplit) { - super(hiveTableHandle.getTransaction(), statementId, bucketNumber, bucketPath, originalFile, orcFileWriterFactory, configuration, session, hiveRowType, updateKind); + super(hiveTableHandle.getTransaction(), statementId, bucketNumber, Optional.empty(), bucketPath, originalFile, orcFileWriterFactory, configuration, session, typeManager, hiveRowType, updateKind); this.partitionName = requireNonNull(partitionName, "partitionName is null"); this.hivePageSource = requireNonNull(hivePageSource, "hivePageSource is null"); this.updateKind = requireNonNull(updateKind, "updateKind is null"); - this.hiveRowTypeNullsBlock = nativeValueToBlock(hiveRowType.getType(typeManager), null); checkArgument(hiveTableHandle.isInAcidTransaction(), "Not in a transaction; hiveTableHandle: %s", hiveTableHandle); this.writeId = hiveTableHandle.getWriteId(); this.initialRowId = initialRowId; @@ -147,8 +144,7 @@ private void deleteRowsInternal(ColumnarRow columnarRow) maxWriteId = Math.max(maxWriteId, originalTransactionChannel.getLong(index, 0)); } - lazyInitializeDeleteFileWriter(); - deleteFileWriter.orElseThrow(() -> new IllegalArgumentException("deleteFileWriter not present")).appendRows(deletePage); + getOrCreateDeleteFileWriter().appendRows(deletePage); rowCount += positionCount; } @@ -178,11 +174,12 @@ public void updateRows(Page page, List columnValueAndRowIdChannels) }; Page insertPage = new Page(blockArray); - lazyInitializeInsertFileWriter(); + getOrCreateInsertFileWriter(); insertFileWriter.orElseThrow(() -> new IllegalArgumentException("insertFileWriter not present")).appendRows(insertPage); } - Block createRowIdBlock(int positionCount) + @Override + protected Block createRowIdBlock(int positionCount) { long[] rowIds = new long[positionCount]; for (int index = 0; index < positionCount; index++) { @@ -214,8 +211,7 @@ public CompletableFuture> finish() OrcFileWriter insertWriter = (OrcFileWriter) insertFileWriter.get(); insertWriter.setMaxWriteId(maxWriteId); insertWriter.commit(); - checkArgument(deltaDirectory.isPresent(), "deltaDirectory not present"); - deltaDirectoryString = Optional.of(deltaDirectory.get().toString()); + deltaDirectoryString = Optional.of(deltaDirectory.toString()); break; default: @@ -225,7 +221,7 @@ public CompletableFuture> finish() partitionName, statementId, rowCount, - deleteDeltaDirectory.toString(), + Optional.of(deleteDeltaDirectory.toString()), deltaDirectoryString))); return completedFuture(ImmutableList.of(fragment)); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java index a57ce4c972a7..4f1a0c952224 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateProcessor.java @@ -164,9 +164,9 @@ public Page removeNonDependencyColumns(Page page, List dependencyChanne } /** - * Return the column UPDATE column handle, which depends on the 3 ACID columns as well as the non-updated columns. + * Return the column rowId for UPDATE or MERGED column handle, which depends on the 3 ACID columns as well as the non-updated columns. */ - public static HiveColumnHandle getUpdateRowIdColumnHandle(List nonUpdatedColumnHandles) + public static HiveColumnHandle getRowIdColumnHandleForNonUpdatedColumns(List nonUpdatedColumnHandles) { List allAcidFields = new ArrayList<>(ACID_READ_FIELDS); if (!nonUpdatedColumnHandles.isEmpty()) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java index 308becf8e752..c4e89fd130dd 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriter.java @@ -57,6 +57,11 @@ public HiveWriter( this.hiveWriterStats = requireNonNull(hiveWriterStats, "hiveWriterStats is null"); } + public FileWriter getFileWriter() + { + return fileWriter; + } + public long getWrittenBytes() { return fileWriter.getWrittenBytes(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java index 561ecbf1259d..045229b8ceba 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveWriterFactory.java @@ -39,6 +39,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SortOrder; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import org.apache.hadoop.conf.Configuration; @@ -88,7 +89,9 @@ import static io.trino.plugin.hive.HiveSessionProperties.getTemporaryStagingDirectoryPath; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; import static io.trino.plugin.hive.HiveSessionProperties.isTemporaryStagingDirectoryEnabled; +import static io.trino.plugin.hive.HiveType.toHiveType; import static io.trino.plugin.hive.LocationHandle.WriteMode.DIRECT_TO_TARGET_EXISTING_DIRECTORY; +import static io.trino.plugin.hive.acid.AcidOperation.CREATE_TABLE; import static io.trino.plugin.hive.metastore.MetastoreUtil.getHiveSchema; import static io.trino.plugin.hive.metastore.StorageFormat.fromHiveStorageFormat; import static io.trino.plugin.hive.util.CompressionConfigUtil.assertCompressionConfigured; @@ -97,6 +100,9 @@ import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; import static io.trino.plugin.hive.util.HiveWriteUtils.createPartitionValues; +import static io.trino.spi.connector.SortOrder.ASC_NULLS_FIRST; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.Math.min; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -122,6 +128,7 @@ public class HiveWriterFactory private final String schemaName; private final String tableName; private final AcidTransaction transaction; + private final List inputColumns; private final List dataColumns; @@ -159,6 +166,8 @@ public class HiveWriterFactory private final Map sessionProperties; private final HiveWriterStats hiveWriterStats; + private final Optional rowType; + private final Optional hiveRowtype; public HiveWriterFactory( Set fileWriterFactories, @@ -192,6 +201,7 @@ public HiveWriterFactory( this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.tableName = requireNonNull(tableName, "tableName is null"); this.transaction = requireNonNull(transaction, "transaction is null"); + this.inputColumns = requireNonNull(inputColumns, "inputColumns is null"); this.tableStorageFormat = requireNonNull(tableStorageFormat, "tableStorageFormat is null"); this.partitionStorageFormat = requireNonNull(partitionStorageFormat, "partitionStorageFormat is null"); this.additionalTableParameters = ImmutableMap.copyOf(requireNonNull(additionalTableParameters, "additionalTableParameters is null")); @@ -213,7 +223,6 @@ public HiveWriterFactory( this.parquetTimeZone = requireNonNull(parquetTimeZone, "parquetTimeZone is null"); // divide input columns into partition and data columns - requireNonNull(inputColumns, "inputColumns is null"); ImmutableList.Builder partitionColumnNames = ImmutableList.builder(); ImmutableList.Builder partitionColumnTypes = ImmutableList.builder(); ImmutableList.Builder dataColumns = ImmutableList.builder(); @@ -227,6 +236,18 @@ public HiveWriterFactory( dataColumns.add(new DataColumn(column.getName(), hiveType)); } } + if (transaction.isMerge()) { + Type mergeRowType = RowType.from(inputColumns.stream() + .filter(column -> !column.isPartitionKey()) + .map(column -> new RowType.Field(Optional.of(column.getName()), column.getType())) + .collect(toImmutableList())); + this.rowType = Optional.of(mergeRowType); + this.hiveRowtype = Optional.of(toHiveType(mergeRowType)); + } + else { + this.rowType = Optional.empty(); + this.hiveRowtype = Optional.empty(); + } this.partitionColumnNames = partitionColumnNames.build(); this.partitionColumnTypes = partitionColumnTypes.build(); this.dataColumns = dataColumns.build(); @@ -457,10 +478,11 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt Path path; String fileNameWithExtension; - if (transaction.isAcidTransactionRunning()) { + if (transaction.isAcidTransactionRunning() && transaction.getOperation() != CREATE_TABLE) { String subdir = computeAcidSubdir(transaction); Path subdirPath = new Path(writeInfo.getWritePath(), subdir); - path = createHiveBucketPath(subdirPath, bucketToUse, table.getParameters()); + String nameFormat = table != null && isInsertOnlyTable(table.getParameters()) ? "%05d_0" : "bucket_%05d"; + path = new Path(subdirPath, format(nameFormat, bucketToUse)); fileNameWithExtension = path.getName(); } else { @@ -472,24 +494,36 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt boolean useAcidSchema = isCreateTransactionalTable || (table != null && isFullAcidTable(table.getParameters())); FileWriter hiveFileWriter = null; - for (HiveFileWriterFactory fileWriterFactory : fileWriterFactories) { - Optional fileWriter = fileWriterFactory.createFileWriter( - path, - dataColumns.stream() - .map(DataColumn::getName) - .collect(toList()), - outputStorageFormat, - schema, - outputConf, - session, - bucketNumber, - transaction, - useAcidSchema, - WriterKind.INSERT); - - if (fileWriter.isPresent()) { - hiveFileWriter = fileWriter.get(); - break; + + if (transaction.isMerge()) { + OrcFileWriterFactory orcFileWriterFactory = (OrcFileWriterFactory) fileWriterFactories.stream() + .filter(factory -> factory instanceof OrcFileWriterFactory) + .findFirst() + .get(); + checkArgument(hiveRowtype.isPresent(), "rowTypes not present"); + RowIdSortingFileWriterMaker fileWriterMaker = (deleteWriter, deletePath) -> makeRowIdSortingWriter(deleteWriter, deletePath); + hiveFileWriter = new MergeFileWriter(transaction, 0, bucketNumber, fileWriterMaker, path, partitionName, orcFileWriterFactory, inputColumns, conf, session, typeManager, hiveRowtype.get()); + } + else { + for (HiveFileWriterFactory fileWriterFactory : fileWriterFactories) { + Optional fileWriter = fileWriterFactory.createFileWriter( + path, + dataColumns.stream() + .map(DataColumn::getName) + .collect(toList()), + outputStorageFormat, + schema, + outputConf, + session, + bucketNumber, + transaction, + useAcidSchema, + WriterKind.INSERT); + + if (fileWriter.isPresent()) { + hiveFileWriter = fileWriter.get(); + break; + } } } @@ -604,10 +638,43 @@ public HiveWriter createWriter(Page partitionColumns, int position, OptionalInt hiveWriterStats); } - private static Path createHiveBucketPath(Path subdirPath, int bucketToUse, Map tableParameters) + public interface RowIdSortingFileWriterMaker { - String nameFormat = isInsertOnlyTable(tableParameters) ? "%05d_0" : "bucket_%05d"; - return new Path(subdirPath, format(nameFormat, bucketToUse)); + SortingFileWriter makeFileWriter(FileWriter deleteFileWriter, Path path); + } + + public SortingFileWriter makeRowIdSortingWriter(FileWriter deleteFileWriter, Path path) + { + FileSystem fileSystem; + Path tempFilePath = new Path(path.getParent(), ".tmp-sort." + path.getName()); + try { + Configuration configuration = new Configuration(conf); + // Explicitly set the default FS to local file system to avoid getting HDFS when sortedWritingTempStagingPath specifies no scheme + configuration.set(FS_DEFAULT_NAME_KEY, "file:///"); + fileSystem = hdfsEnvironment.getFileSystem(session.getIdentity(), tempFilePath, configuration); + } + catch (IOException e) { + throw new TrinoException(HIVE_WRITER_OPEN_ERROR, e); + } + // The ORC columns are: operation, originalTransaction, bucket, rowId, row + // The deleted rows should be sorted by originalTransaction, then by rowId + List sortFields = ImmutableList.of(1, 3); + List sortOrders = ImmutableList.of(ASC_NULLS_FIRST, ASC_NULLS_FIRST); + // The types are indexed by sortField in the SortFileWriter stack + List types = ImmutableList.of(INTEGER, BIGINT, INTEGER, BIGINT, BIGINT, rowType.get()); + + return new SortingFileWriter( + fileSystem, + tempFilePath, + deleteFileWriter, + sortBufferSize, + maxOpenSortFiles, + types, + sortFields, + sortOrders, + pageSorter, + typeManager.getTypeOperators(), + OrcFileWriterFactory::createOrcDataSink); } private void validateSchema(Optional partitionName, Properties schema) @@ -661,9 +728,12 @@ private String computeAcidSubdir(AcidTransaction transaction) long writeId = transaction.getWriteId(); switch (transaction.getOperation()) { case INSERT: + case CREATE_TABLE: return deltaSubdir(writeId, writeId, 0); case DELETE: return deleteDeltaSubdir(writeId, writeId, 0); + case MERGE: + return deltaSubdir(writeId, writeId, 0); default: throw new UnsupportedOperationException("transaction operation is " + transaction.getOperation()); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java new file mode 100644 index 000000000000..fa72951e9fd2 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.annotations.VisibleForTesting; +import io.trino.plugin.hive.HiveWriterFactory.RowIdSortingFileWriterMaker; +import io.trino.plugin.hive.acid.AcidOperation; +import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.orc.OrcFileWriterFactory; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.MergePage; +import io.trino.spi.type.TypeManager; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.connector.MergePage.createDeleteAndInsertPages; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.util.Objects.requireNonNull; + +public class MergeFileWriter + extends AbstractHiveAcidWriters + implements FileWriter +{ + private final String partitionName; + private final List inputColumns; + + private int deleteRowCount; + private int insertRowCount; + + public MergeFileWriter( + AcidTransaction transaction, + int statementId, + OptionalInt bucketNumber, + RowIdSortingFileWriterMaker sortingFileWriterMaker, + Path bucketPath, + Optional partitionName, + OrcFileWriterFactory orcFileWriterFactory, + List inputColumns, + Configuration configuration, + ConnectorSession session, + TypeManager typeManager, + HiveType hiveRowType) + { + super(transaction, + statementId, + bucketNumber, + Optional.of(sortingFileWriterMaker), + bucketPath, + false, + orcFileWriterFactory, + configuration, + session, + typeManager, + hiveRowType, + AcidOperation.MERGE); + this.partitionName = requireNonNull(partitionName, "partitionName is null").orElse(""); + this.inputColumns = requireNonNull(inputColumns, "inputColumns is null"); + } + + @Override + public void appendRows(Page page) + { + if (page.getPositionCount() == 0) { + return; + } + + MergePage mergePage = createDeleteAndInsertPages(page, inputColumns.size()); + mergePage.getDeletionsPage().ifPresent(deletePage -> { + Block acidBlock = deletePage.getBlock(deletePage.getChannelCount() - 1); + Page orcDeletePage = buildDeletePage(acidBlock, transaction.getWriteId()); + getOrCreateDeleteFileWriter().appendRows(orcDeletePage); + deleteRowCount += deletePage.getPositionCount(); + }); + mergePage.getInsertionsPage().ifPresent(insertPage -> { + Page orcInsertPage = buildInsertPage(insertPage, transaction.getWriteId(), inputColumns, bucketValueBlock, insertRowCount); + insertRowCount += insertPage.getPositionCount(); + getOrCreateInsertFileWriter().appendRows(orcInsertPage); + insertRowCount += insertPage.getPositionCount(); + }); + } + + @VisibleForTesting + public static Page buildInsertPage(Page insertPage, long writeId, List columns, Block bucketValueBlock, int insertRowCount) + { + int positionCount = insertPage.getPositionCount(); + List dataColumns = columns.stream() + .filter(column -> !column.isPartitionKey() && !column.isHidden()) + .map(column -> insertPage.getBlock(column.getBaseHiveColumnIndex())) + .collect(toImmutableList()); + Block mergedColumnsBlock = RowBlock.fromFieldBlocks(positionCount, Optional.empty(), dataColumns.toArray(new Block[]{})); + Block currentTransactionBlock = RunLengthEncodedBlock.create(BIGINT, writeId, positionCount); + Block[] blockArray = { + new RunLengthEncodedBlock(INSERT_OPERATION_BLOCK, positionCount), + currentTransactionBlock, + new RunLengthEncodedBlock(bucketValueBlock, positionCount), + createRowIdBlock(positionCount, insertRowCount), + currentTransactionBlock, + mergedColumnsBlock + }; + + return new Page(blockArray); + } + + public String getPartitionName() + { + return partitionName; + } + + @Override + public long getWrittenBytes() + { + return deleteFileWriter.map(FileWriter::getWrittenBytes).orElse(0L) + + insertFileWriter.map(FileWriter::getWrittenBytes).orElse(0L); + } + + @Override + public long getMemoryUsage() + { + return (deleteFileWriter.map(FileWriter::getMemoryUsage).orElse(0L)) + + (insertFileWriter.map(FileWriter::getMemoryUsage).orElse(0L)); + } + + @Override + public void commit() + { + deleteFileWriter.ifPresent(FileWriter::commit); + insertFileWriter.ifPresent(FileWriter::commit); + } + + @Override + public void rollback() + { + // Make sure both writers get rolled back + try { + deleteFileWriter.ifPresent(FileWriter::rollback); + } + finally { + insertFileWriter.ifPresent(FileWriter::rollback); + } + } + + @Override + public long getValidationCpuNanos() + { + return (deleteFileWriter.map(FileWriter::getValidationCpuNanos).orElse(0L)) + + (insertFileWriter.map(FileWriter::getValidationCpuNanos).orElse(0L)); + } + + public PartitionUpdateAndMergeResults getPartitionUpdateAndMergeResults(PartitionUpdate partitionUpdate) + { + return new PartitionUpdateAndMergeResults( + partitionUpdate, + insertRowCount, + insertFileWriter.isPresent() ? Optional.of(deltaDirectory.toString()) : Optional.empty(), + deleteRowCount, + deleteFileWriter.isPresent() ? Optional.of(deleteDeltaDirectory.toString()) : Optional.empty()); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java index dcfda26b8a7c..8a9e332ad760 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionAndStatementId.java @@ -32,7 +32,7 @@ public class PartitionAndStatementId private final String partitionName; private final int statementId; private final long rowCount; - private final String deleteDeltaDirectory; + private final Optional deleteDeltaDirectory; private final Optional deltaDirectory; @JsonCreator @@ -40,7 +40,7 @@ public PartitionAndStatementId( @JsonProperty("partitionName") String partitionName, @JsonProperty("statementId") int statementId, @JsonProperty("rowCount") long rowCount, - @JsonProperty("deleteDeltaDirectory") String deleteDeltaDirectory, + @JsonProperty("deleteDeltaDirectory") Optional deleteDeltaDirectory, @JsonProperty("deltaDirectory") Optional deltaDirectory) { this.partitionName = requireNonNull(partitionName, "partitionName is null"); @@ -69,7 +69,7 @@ public long getRowCount() } @JsonProperty - public String getDeleteDeltaDirectory() + public Optional getDeleteDeltaDirectory() { return deleteDeltaDirectory; } @@ -83,9 +83,10 @@ public Optional getDeltaDirectory() @JsonIgnore public List getAllDirectories() { - return deltaDirectory - .map(directory -> ImmutableList.of(deleteDeltaDirectory, directory)) - .orElseGet(() -> ImmutableList.of(deleteDeltaDirectory)); + ImmutableList.Builder builder = ImmutableList.builder(); + deltaDirectory.ifPresent(builder::add); + deleteDeltaDirectory.ifPresent(builder::add); + return builder.build(); } @Override diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java new file mode 100644 index 000000000000..f1e91c960829 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/PartitionUpdateAndMergeResults.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.json.JsonCodec; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class PartitionUpdateAndMergeResults +{ + public static final JsonCodec CODEC = JsonCodec.jsonCodec(PartitionUpdateAndMergeResults.class); + + private final PartitionUpdate partitionUpdate; + private final long insertRowCount; + private final Optional deltaDirectory; + private final long deleteRowCount; + private final Optional deleteDeltaDirectory; + + @JsonCreator + public PartitionUpdateAndMergeResults( + @JsonProperty("partitionUpdate") PartitionUpdate partitionUpdate, + @JsonProperty("insertRowCount") long insertRowCount, + @JsonProperty("deleteDirectory") Optional deltaDirectory, + @JsonProperty("deleteRowCount") long deleteRowCount, + @JsonProperty("deleteDeltaDirectory") Optional deleteDeltaDirectory) + { + this.partitionUpdate = requireNonNull(partitionUpdate, "partitionUpdate is null"); + this.insertRowCount = insertRowCount; + this.deltaDirectory = requireNonNull(deltaDirectory, "deltaDirectory is null"); + this.deleteRowCount = deleteRowCount; + this.deleteDeltaDirectory = requireNonNull(deleteDeltaDirectory, "deleteDeltaDirectory is null"); + } + + @JsonProperty + public PartitionUpdate getPartitionUpdate() + { + return partitionUpdate; + } + + @JsonProperty + public long getInsertRowCount() + { + return insertRowCount; + } + + @JsonProperty + public Optional getDeltaDirectory() + { + return deltaDirectory; + } + + @JsonProperty + public long getDeleteRowCount() + { + return deleteRowCount; + } + + @JsonProperty + public Optional getDeleteDeltaDirectory() + { + return deleteDeltaDirectory; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java index 5e00160681e3..8f8aaacdec11 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidOperation.java @@ -22,16 +22,18 @@ public enum AcidOperation { - // UPDATE and MERGE will be added when they are implemented NONE, CREATE_TABLE, DELETE, INSERT, - UPDATE; + UPDATE, + MERGE, + /**/; private static final Map DATA_OPERATION_TYPES = ImmutableMap.of( DELETE, DataOperationType.DELETE, - INSERT, DataOperationType.INSERT); + INSERT, DataOperationType.INSERT, + MERGE, DataOperationType.UPDATE); private static final Map ORC_OPERATIONS = ImmutableMap.of( DELETE, OrcOperation.DELETE, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java index cdcfd17c1bd2..9c73a7aa5264 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/acid/AcidTransaction.java @@ -27,6 +27,7 @@ import static io.trino.plugin.hive.acid.AcidOperation.CREATE_TABLE; import static io.trino.plugin.hive.acid.AcidOperation.DELETE; import static io.trino.plugin.hive.acid.AcidOperation.INSERT; +import static io.trino.plugin.hive.acid.AcidOperation.MERGE; import static io.trino.plugin.hive.acid.AcidOperation.NONE; import static io.trino.plugin.hive.acid.AcidOperation.UPDATE; import static java.util.Objects.requireNonNull; @@ -50,10 +51,10 @@ public AcidTransaction( this.operation = requireNonNull(operation, "operation is null"); this.transactionId = transactionId; this.writeId = writeId; - this.updateProcessor = updateProcessor; + this.updateProcessor = requireNonNull(updateProcessor, "updateProcessor is null"); } - @JsonProperty("operation") + @JsonProperty public AcidOperation getOperation() { return operation; @@ -80,7 +81,7 @@ public Optional getUpdateProcessor() @JsonIgnore public boolean isAcidTransactionRunning() { - return operation == INSERT || operation == DELETE || operation == UPDATE; + return operation == INSERT || operation == CREATE_TABLE || operation == DELETE || operation == UPDATE || operation == MERGE; } @JsonIgnore @@ -133,6 +134,12 @@ public boolean isUpdate() return operation == UPDATE; } + @JsonIgnore + public boolean isMerge() + { + return operation == MERGE; + } + public boolean isAcidInsertOperation(WriterKind writerKind) { return isInsert() || (isUpdate() && writerKind == WriterKind.INSERT); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java index cbed4dfc2799..b7a085d4804d 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/SemiTransactionalHiveMetastore.java @@ -34,6 +34,7 @@ import io.trino.plugin.hive.PartitionAndStatementId; import io.trino.plugin.hive.PartitionNotFoundException; import io.trino.plugin.hive.PartitionStatistics; +import io.trino.plugin.hive.PartitionUpdateAndMergeResults; import io.trino.plugin.hive.TableAlreadyExistsException; import io.trino.plugin.hive.TableInvalidationCallback; import io.trino.plugin.hive.acid.AcidOperation; @@ -107,6 +108,7 @@ import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION; import static io.trino.plugin.hive.metastore.HivePrivilegeInfo.HivePrivilege.OWNERSHIP; import static io.trino.plugin.hive.metastore.MetastoreUtil.buildInitialPrivilegeSet; +import static io.trino.plugin.hive.metastore.PrincipalPrivileges.NO_PRIVILEGES; import static io.trino.plugin.hive.metastore.thrift.ThriftMetastoreUtil.NUM_ROWS; import static io.trino.plugin.hive.util.HiveUtil.toPartitionValues; import static io.trino.plugin.hive.util.HiveWriteUtils.checkedDelete; @@ -140,6 +142,10 @@ public class SemiTransactionalHiveMetastore .maxAttempts(3) .exponentialBackoff(new Duration(1, SECONDS), new Duration(1, SECONDS), new Duration(10, SECONDS), 2.0); + private static final Map ACID_OPERATION_ACTION_TYPES = ImmutableMap.of( + AcidOperation.INSERT, ActionType.INSERT_EXISTING, + AcidOperation.MERGE, ActionType.MERGE); + private final HiveMetastoreClosure delegate; private final HdfsEnvironment hdfsEnvironment; private final Executor renameExecutor; @@ -246,6 +252,7 @@ public synchronized Optional getTable(String databaseName, String tableNa case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return Optional.of(tableAction.getData().getTable()); case DROP: return Optional.empty(); @@ -269,6 +276,7 @@ public synchronized boolean isReadableWithinTransaction(String databaseName, Str case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: // Until transaction is committed, the table data may or may not be visible. return false; case DROP: @@ -296,6 +304,7 @@ public synchronized PartitionStatistics getTableStatistics(String databaseName, case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return tableAction.getData().getStatistics(); case DROP: return PartitionStatistics.empty(); @@ -369,6 +378,7 @@ private TableSource getTableSource(String databaseName, String tableName) case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: return TableSource.PRE_EXISTING_TABLE; case DROP_PRESERVE_DATA: // TODO @@ -539,6 +549,7 @@ public synchronized void createTable( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new TableAlreadyExistsException(table.getSchemaTableName()); case DROP_PRESERVE_DATA: // TODO @@ -567,6 +578,7 @@ public synchronized void dropTable(ConnectorSession session, String databaseName case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("dropping a table added/modified in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -624,7 +636,8 @@ public synchronized void dropColumn(String databaseName, String tableName, Strin setExclusive((delegate, hdfsEnvironment) -> delegate.dropColumn(databaseName, tableName, columnName)); } - public synchronized void finishInsertIntoExistingTable( + public synchronized void finishChangingExistingTable( + AcidOperation acidOperation, ConnectorSession session, String databaseName, String tableName, @@ -637,6 +650,7 @@ public synchronized void finishInsertIntoExistingTable( // Therefore, this method assumes that the table is unpartitioned. setShared(); SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + ActionType actionType = requireNonNull(ACID_OPERATION_ACTION_TYPES.get(acidOperation), "ACID_OPERATION_ACTION_TYPES doesn't contain the acidOperation"); Action oldTableAction = tableActions.get(schemaTableName); if (oldTableAction == null) { Table table = getExistingTable(schemaTableName.getSchemaName(), schemaTableName.getTableName()); @@ -648,7 +662,7 @@ public synchronized void finishInsertIntoExistingTable( tableActions.put( schemaTableName, new Action<>( - ActionType.INSERT_EXISTING, + actionType, new TableAndMore( table, Optional.empty(), @@ -671,6 +685,7 @@ public synchronized void finishInsertIntoExistingTable( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting into an unpartitioned table that were added, altered, or inserted into in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -751,6 +766,7 @@ public synchronized void finishRowLevelDelete( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting or deleting in an unpartitioned table that were added, altered, or inserted into in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -798,6 +814,60 @@ public synchronized void finishUpdate( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: + throw new UnsupportedOperationException("Inserting, updating or deleting in a table that was added, altered, inserted into, updated or deleted from in the same transaction is not supported"); + default: + throw new IllegalStateException("Unknown action type"); + } + } + + public synchronized void finishMerge( + ConnectorSession session, + String databaseName, + String tableName, + Path currentLocation, + List partitionUpdateAndMergeResults, + List partitions) + { + if (partitionUpdateAndMergeResults.isEmpty()) { + return; + } + checkArgument(partitionUpdateAndMergeResults.size() >= partitions.size(), "partitionUpdateAndMergeResults.size() (%s) < partitions.size() (%s)", partitionUpdateAndMergeResults.size(), partitions.size()); + setShared(); + if (partitions.isEmpty()) { + return; + } + SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + Action oldTableAction = tableActions.get(schemaTableName); + if (oldTableAction == null) { + Table table = getExistingTable(schemaTableName.getSchemaName(), schemaTableName.getTableName()); + HdfsContext hdfsContext = new HdfsContext(session); + PrincipalPrivileges principalPrivileges = table.getOwner().isEmpty() ? NO_PRIVILEGES : + buildInitialPrivilegeSet(table.getOwner().get()); + tableActions.put( + schemaTableName, + new Action<>( + ActionType.MERGE, + new TableAndMergeResults( + table, + Optional.of(principalPrivileges), + Optional.of(currentLocation), + partitionUpdateAndMergeResults, + partitions), + hdfsContext, + session.getQueryId())); + return; + } + + switch (oldTableAction.getType()) { + case DROP: + throw new TableNotFoundException(schemaTableName); + case ADD: + case ALTER: + case INSERT_EXISTING: + case DELETE_ROWS: + case UPDATE: + case MERGE: throw new UnsupportedOperationException("Inserting, updating or deleting in a table that was added, altered, inserted into, updated or deleted from in the same transaction is not supported"); case DROP_PRESERVE_DATA: // TODO @@ -886,6 +956,7 @@ private Optional> doGetPartitionNames( case INSERT_EXISTING: case DELETE_ROWS: case UPDATE: + case MERGE: resultBuilder.add(partitionName); break; default: @@ -951,6 +1022,7 @@ private static Optional getPartitionFromPartitionAction(Action listTablePrivileges(String databaseNa case DELETE_ROWS: case UPDATE: return delegate.listTablePrivileges(databaseName, tableName, getExistingTable(databaseName, tableName).getOwner(), principal); + case MERGE: + return delegate.listTablePrivileges(databaseName, tableName, getExistingTable(databaseName, tableName).getOwner(), principal); case DROP: throw new TableNotFoundException(schemaTableName); case DROP_PRESERVE_DATA: @@ -1319,6 +1396,11 @@ public AcidTransaction beginUpdate(ConnectorSession session, Table table, HiveUp return beginOperation(session, table, AcidOperation.UPDATE, DataOperationType.UPDATE, Optional.of(updateProcessor)); } + public AcidTransaction beginMerge(ConnectorSession session, Table table) + { + return beginOperation(session, table, AcidOperation.MERGE, DataOperationType.UPDATE, Optional.empty()); + } + private AcidTransaction beginOperation(ConnectorSession session, Table table, AcidOperation operation, DataOperationType hiveOperation, Optional updateProcessor) { String queryId = session.getQueryId(); @@ -1477,6 +1559,9 @@ private void commitShared() case UPDATE: committer.prepareUpdateExistingTable(action.getHdfsContext(), action.getData()); break; + case MERGE: + committer.prepareMergeExistingTable(action.getHdfsContext(), action.getData()); + break; default: throw new IllegalStateException("Unknown action type: " + action.getType()); } @@ -1502,6 +1587,9 @@ private void commitShared() case INSERT_EXISTING: committer.prepareInsertExistingPartition(action.getHdfsContext(), action.getQueryId(), action.getData()); break; + case MERGE: + committer.prepareInsertExistingPartition(action.getHdfsContext(), action.getQueryId(), action.getData()); + break; case UPDATE: case DELETE_ROWS: break; @@ -1790,9 +1878,12 @@ private void prepareDeleteRowsFromExistingTable(HdfsContext context, TableAndMor AcidTransaction transaction = currentHiveTransaction.get().getTransaction(); checkArgument(transaction.isDelete(), "transaction should be delete, but is %s", transaction); - cleanUpTasksForAbort.addAll(deletionState.getPartitionAndStatementIds().stream() - .map(ps -> new DirectoryCleanUpTask(context, new Path(ps.getDeleteDeltaDirectory()), true)) - .collect(toImmutableList())); + deletionState.getPartitionAndStatementIds().stream().forEach(ps -> { + ps.getDeleteDeltaDirectory().ifPresent(dir -> + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, new Path(dir), true))); + ps.getDeltaDirectory().ifPresent(dir -> + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, new Path(dir), true))); + }); Map partitionRowCounts = new HashMap<>(partitionAndStatementIds.size()); int totalRowsDeleted = 0; @@ -1888,6 +1979,29 @@ private void prepareUpdateExistingTable(HdfsContext context, TableAndMore tableA updateTableWriteId(databaseName, tableName, transactionId, writeId, OptionalLong.empty()); } + private void prepareMergeExistingTable(HdfsContext context, TableAndMore tableAndMore) + { + checkArgument(currentHiveTransaction.isPresent(), "currentHiveTransaction isn't present"); + AcidTransaction transaction = currentHiveTransaction.get().getTransaction(); + checkArgument(transaction.isMerge(), "transaction should be merge, but is %s", transaction); + + deleteOnly = false; + Table table = tableAndMore.getTable(); + Path targetPath = new Path(table.getStorage().getLocation()); + Path currentPath = tableAndMore.getCurrentLocation().get(); + cleanUpTasksForAbort.add(new DirectoryCleanUpTask(context, targetPath, false)); + if (!targetPath.equals(currentPath)) { + asyncRename(hdfsEnvironment, renameExecutor, fileRenameCancelled, fileRenameFutures, context, currentPath, targetPath, tableAndMore.getFileNames().get()); + } + updateStatisticsOperations.add(new UpdateStatisticsOperation( + table.getSchemaTableName(), + Optional.empty(), + tableAndMore.getStatisticsUpdate(), + true)); + + updateTableWriteId(table.getDatabaseName(), table.getTableName(), transaction.getAcidTransactionId(), transaction.getWriteId(), OptionalLong.empty()); + } + private void prepareDropPartition(SchemaTableName schemaTableName, List partitionValues, boolean deleteData) { metastoreDeleteOperations.add(new IrreversibleMetastoreOperation( @@ -2783,6 +2897,7 @@ private enum ActionType INSERT_EXISTING, DELETE_ROWS, UPDATE, + MERGE, } private enum TableSource @@ -2972,6 +3087,42 @@ public String toString() } } + private static class TableAndMergeResults + extends TableAndMore + { + private final List partitionMergeResults; + private final List partitions; + + public TableAndMergeResults(Table table, Optional principalPrivileges, Optional currentLocation, List partitionMergeResults, List partitions) + { + super(table, principalPrivileges, currentLocation, Optional.empty(), false, PartitionStatistics.empty(), PartitionStatistics.empty(), false); // retries are not supported for transactional tables + this.partitionMergeResults = requireNonNull(partitionMergeResults, "partitionMergeResults is null"); + this.partitions = requireNonNull(partitions, "partitions is nul"); + } + + public List getPartitionMergeResults() + { + return partitionMergeResults; + } + + public List getPartitions() + { + return partitions; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("table", getTable()) + .add("partitionMergeResults", partitionMergeResults) + .add("partitions", partitions) + .add("principalPrivileges", getPrincipalPrivileges()) + .add("currentLocation", getCurrentLocation()) + .toString(); + } + } + private static class PartitionAndMore { private final Partition partition; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java index 416f85e4ec92..c2350ab9f84b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriter.java @@ -15,6 +15,7 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; import io.trino.orc.OrcDataSink; import io.trino.orc.OrcDataSource; import io.trino.orc.OrcWriteValidation.OrcWriteValidationMode; @@ -63,6 +64,7 @@ public class OrcFileWriter implements FileWriter { + private static final Logger log = Logger.get(OrcFileWriter.class); private static final int INSTANCE_SIZE = ClassLayout.parseClass(OrcFileWriter.class).instanceSize(); private static final ThreadMXBean THREAD_MX_BEAN = ManagementFactory.getThreadMXBean(); @@ -127,6 +129,9 @@ public OrcFileWriter( validationInputFactory.isPresent(), validationMode, stats); + if (transaction.isTransactional()) { + this.setMaxWriteId(transaction.getWriteId()); + } } @Override @@ -188,6 +193,7 @@ public void commit() } catch (Exception ignored) { // ignore + log.error(ignored, "Exception when committing file"); } throw new TrinoException(HIVE_WRITER_CLOSE_ERROR, "Error committing write to Hive", e); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java index 4f907960acf6..b597a5031f90 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcFileWriterFactory.java @@ -27,6 +27,7 @@ import io.trino.plugin.hive.FileWriter; import io.trino.plugin.hive.HdfsEnvironment; import io.trino.plugin.hive.HiveFileWriterFactory; +import io.trino.plugin.hive.HiveType; import io.trino.plugin.hive.NodeVersion; import io.trino.plugin.hive.WriterKind; import io.trino.plugin.hive.acid.AcidTransaction; @@ -66,6 +67,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.getOrcStringStatisticsLimit; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; import static io.trino.plugin.hive.HiveSessionProperties.isOrcOptimizedWriterValidate; +import static io.trino.plugin.hive.HiveType.toHiveType; import static io.trino.plugin.hive.acid.AcidSchema.ACID_COLUMN_NAMES; import static io.trino.plugin.hive.acid.AcidSchema.createAcidColumnPrestoTypes; import static io.trino.plugin.hive.acid.AcidSchema.createRowType; @@ -224,6 +226,17 @@ public Optional createFileWriter( } } + public static HiveType createHiveRowType(Properties schema, TypeManager typeManager, ConnectorSession session) + { + List dataColumnNames = getColumnNames(schema); + List dataColumnTypes = getColumnTypes(schema).stream() + .map(hiveType -> hiveType.getType(typeManager, getTimestampPrecision(session))) + .collect(toList()); + Type dataRowType = createRowType(dataColumnNames, dataColumnTypes); + Type acidRowType = createRowType(ACID_COLUMN_NAMES, createAcidColumnPrestoTypes(dataRowType)); + return toHiveType(acidRowType); + } + public static OrcDataSink createOrcDataSink(FileSystem fileSystem, Path path) throws IOException { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java index 53dc70039c08..803712a6bcb7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java @@ -302,6 +302,16 @@ static ColumnAdaptation positionColumn() { return new PositionAdaptation(); } + + static ColumnAdaptation mergedRowColumns() + { + return new MergedRowAdaptation(); + } + + static ColumnAdaptation mergedRowColumnsWithOriginalFiles(long startingRowId, int bucketId) + { + return new MergedRowAdaptationWithOriginalFiles(startingRowId, bucketId); + } } private static class NullColumn @@ -463,11 +473,64 @@ public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunct for (int channel = 0; channel < sourcePage.getChannelCount(); channel++) { originalFilesBlockBuilder.add(sourcePage.getBlock(channel)); } - Page page = new Page(originalFilesBlockBuilder.build().toArray(new Block[]{})); + Page page = new Page(originalFilesBlockBuilder.build().toArray(new Block[] {})); return updateProcessor.createUpdateRowBlock(page, nonUpdatedSourceChannels, maskDeletedRowsFunction); } } + /* + * The rowId contains the ACID columns - - originalTransaction, rowId, bucket + */ + private static final class MergedRowAdaptation + implements ColumnAdaptation + { + @Override + public Block block(Page page, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) + { + requireNonNull(page, "page is null"); + return maskDeletedRowsFunction.apply(fromFieldBlocks( + page.getPositionCount(), + Optional.empty(), + new Block[] { + page.getBlock(ORIGINAL_TRANSACTION_CHANNEL), + page.getBlock(BUCKET_CHANNEL), + page.getBlock(ROW_ID_CHANNEL) + })); + } + } + + /** + * The rowId contains the ACID columns - - originalTransaction, rowId, bucket, + * derived from the original file. The transactionId is always zero, + * and the rowIds count up from the startingRowId. + */ + private static final class MergedRowAdaptationWithOriginalFiles + implements ColumnAdaptation + { + private final long startingRowId; + private final Block bucketBlock; + + public MergedRowAdaptationWithOriginalFiles(long startingRowId, int bucketId) + { + this.startingRowId = startingRowId; + this.bucketBlock = nativeValueToBlock(INTEGER, Long.valueOf(computeBucketValue(bucketId, 0))); + } + + @Override + public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) + { + int positionCount = sourcePage.getPositionCount(); + return maskDeletedRowsFunction.apply(fromFieldBlocks( + positionCount, + Optional.empty(), + new Block[] { + new RunLengthEncodedBlock(ORIGINAL_FILE_TRANSACTION_ID_BLOCK, positionCount), + new RunLengthEncodedBlock(bucketBlock, positionCount), + createRowNumberBlock(startingRowId, filePosition, positionCount) + })); + } + } + private static class OriginalFileRowIdAdaptation implements ColumnAdaptation { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java index 3043192ccf0e..34ffd6a45896 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java @@ -99,6 +99,7 @@ import static io.trino.plugin.hive.HiveSessionProperties.isOrcBloomFiltersEnabled; import static io.trino.plugin.hive.HiveSessionProperties.isOrcNestedLazy; import static io.trino.plugin.hive.HiveSessionProperties.isUseOrcColumnNames; +import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.mergedRowColumns; import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.updatedRowColumns; import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.updatedRowColumnsWithOriginalFiles; import static io.trino.plugin.hive.orc.OrcPageSource.handleException; @@ -440,6 +441,16 @@ else if (transaction.isUpdate()) { columnAdaptations.add(updatedRowColumns(updateProcessor, dependencyColumns)); } } + else if (transaction.isMerge()) { + if (originalFile) { + int bucket = bucketNumber.orElse(0); + long startingRowId = originalFileRowId.orElse(0L); + columnAdaptations.add(OrcPageSource.ColumnAdaptation.mergedRowColumnsWithOriginalFiles(startingRowId, bucket)); + } + else { + columnAdaptations.add(mergedRowColumns()); + } + } return new OrcPageSource( recordReader, diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java index 225c40721274..fb20d951485f 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/AbstractTestHive.java @@ -5626,7 +5626,7 @@ protected void insertBucketedTableLayout(boolean transactional) false); assertEquals(insertLayout.get().getPartitioning(), Optional.of(partitioningHandle)); assertEquals(insertLayout.get().getPartitionColumns(), ImmutableList.of("column1")); - ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMap(transaction.getTransactionHandle(), session, partitioningHandle); + ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMapping(transaction.getTransactionHandle(), session, partitioningHandle).orElseThrow(); assertEquals(connectorBucketNodeMap.getBucketCount(), 4); assertFalse(connectorBucketNodeMap.hasFixedMapping()); } @@ -5676,7 +5676,7 @@ protected void insertPartitionedBucketedTableLayout(boolean transactional) true); assertEquals(insertLayout.get().getPartitioning(), Optional.of(partitioningHandle)); assertEquals(insertLayout.get().getPartitionColumns(), ImmutableList.of("column1", "column2")); - ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMap(transaction.getTransactionHandle(), session, partitioningHandle); + ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMapping(transaction.getTransactionHandle(), session, partitioningHandle).orElseThrow(); assertEquals(connectorBucketNodeMap.getBucketCount(), 32); assertFalse(connectorBucketNodeMap.hasFixedMapping()); } @@ -5737,7 +5737,7 @@ public void testCreateBucketedTableLayout() false); assertEquals(newTableLayout.get().getPartitioning(), Optional.of(partitioningHandle)); assertEquals(newTableLayout.get().getPartitionColumns(), ImmutableList.of("column1")); - ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMap(transaction.getTransactionHandle(), session, partitioningHandle); + ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMapping(transaction.getTransactionHandle(), session, partitioningHandle).orElseThrow(); assertEquals(connectorBucketNodeMap.getBucketCount(), 10); assertFalse(connectorBucketNodeMap.hasFixedMapping()); } @@ -5770,7 +5770,7 @@ public void testCreatePartitionedBucketedTableLayout() true); assertEquals(newTableLayout.get().getPartitioning(), Optional.of(partitioningHandle)); assertEquals(newTableLayout.get().getPartitionColumns(), ImmutableList.of("column1", "column2")); - ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMap(transaction.getTransactionHandle(), session, partitioningHandle); + ConnectorBucketNodeMap connectorBucketNodeMap = nodePartitioningProvider.getBucketNodeMapping(transaction.getTransactionHandle(), session, partitioningHandle).orElseThrow(); assertEquals(connectorBucketNodeMap.getBucketCount(), 32); assertFalse(connectorBucketNodeMap.hasFixedMapping()); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index dcf6fe091ef3..e059df6d5d20 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -223,6 +223,10 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_UPDATE: return true; + case SUPPORTS_MERGE: + // FIXME: Fails because only allowed with transactional tables + return false; + case SUPPORTS_MULTI_STATEMENT_WRITES: return true; @@ -234,6 +238,12 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) } } + @Override + protected String createTableForWrites(String createTable) + { + return createTable + " WITH (transactional = true)"; + } + @Override protected void verifySelectAfterInsertFailurePermissible(Throwable e) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorSmokeTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorSmokeTest.java index ebac07726f6c..f3f1f8de12ed 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorSmokeTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorSmokeTest.java @@ -51,6 +51,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_UPDATE: return true; + case SUPPORTS_MERGE: + return true; + case SUPPORTS_MULTI_STATEMENT_WRITES: return true; @@ -73,6 +76,13 @@ public void testUpdate() .hasMessage("Hive update is only supported for ACID transactional tables"); } + @Override + public void testMerge() + { + assertThatThrownBy(super::testMerge) + .hasMessage("Hive merge is only supported for transactional tables"); + } + @Test @Override public void testShowCreateTable() diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java index fb1edd768ac6..df0a3f1b1a3b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/TestSemiTransactionalHiveMetastore.java @@ -36,6 +36,7 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.trino.plugin.hive.HiveBasicStatistics.createEmptyStatistics; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; +import static io.trino.plugin.hive.acid.AcidOperation.INSERT; import static io.trino.plugin.hive.util.HiveBucketing.BucketingVersion.BUCKETING_V1; import static io.trino.testing.TestingConnectorSession.SESSION; import static java.util.concurrent.Executors.newFixedThreadPool; @@ -105,7 +106,7 @@ public void testParallelUpdateStatisticsOperations() else { semiTransactionalHiveMetastore = getSemiTransactionalHiveMetastoreWithUpdateExecutor(newFixedThreadPool(updateThreads)); } - IntStream.range(0, tablesToUpdate).forEach(i -> semiTransactionalHiveMetastore.finishInsertIntoExistingTable(SESSION, + IntStream.range(0, tablesToUpdate).forEach(i -> semiTransactionalHiveMetastore.finishChangingExistingTable(INSERT, SESSION, "database", "table_" + i, new Path("location"), diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java index 55801fa8aa49..07b803c8c7cb 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergColumnHandle.java @@ -36,8 +36,13 @@ public class IcebergColumnHandle implements ColumnHandle { // Iceberg reserved row ids begin at INTEGER.MAX_VALUE and count down. Starting with MIN_VALUE here to avoid conflicts. - public static final int TRINO_UPDATE_ROW_ID_COLUMN_ID = Integer.MIN_VALUE; - public static final String TRINO_UPDATE_ROW_ID_COLUMN_NAME = "$row_id"; + public static final int TRINO_UPDATE_ROW_ID = Integer.MIN_VALUE; + public static final int TRINO_MERGE_ROW_ID = Integer.MIN_VALUE + 1; + public static final String TRINO_ROW_ID_NAME = "$row_id"; + + public static final int TRINO_MERGE_FILE_RECORD_COUNT = Integer.MIN_VALUE + 2; + public static final int TRINO_MERGE_PARTITION_SPEC_ID = Integer.MIN_VALUE + 3; + public static final int TRINO_MERGE_PARTITION_DATA = Integer.MIN_VALUE + 4; private final ColumnIdentity baseColumnIdentity; private final Type baseType; @@ -157,7 +162,13 @@ public boolean isRowPositionColumn() @JsonIgnore public boolean isUpdateRowIdColumn() { - return id == TRINO_UPDATE_ROW_ID_COLUMN_ID; + return id == TRINO_UPDATE_ROW_ID; + } + + @JsonIgnore + public boolean isMergeRowIdColumn() + { + return id == TRINO_MERGE_ROW_ID; } /** diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeSink.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeSink.java new file mode 100644 index 000000000000..bad2040ce338 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeSink.java @@ -0,0 +1,243 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.json.JsonCodec; +import io.airlift.slice.Slice; +import io.trino.plugin.hive.HdfsEnvironment; +import io.trino.plugin.hive.HdfsEnvironment.HdfsContext; +import io.trino.plugin.iceberg.delete.IcebergPositionDeletePageSink; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.MergePage; +import io.trino.spi.type.VarcharType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.LocationProvider; +import org.apache.iceberg.types.Type; +import org.roaringbitmap.longlong.ImmutableLongBitmapDataProvider; +import org.roaringbitmap.longlong.LongBitmapDataProvider; +import org.roaringbitmap.longlong.Roaring64Bitmap; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static io.trino.plugin.base.util.Closables.closeAllSuppress; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.connector.MergePage.createDeleteAndInsertPages; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class IcebergMergeSink + implements ConnectorMergeSink +{ + private final LocationProvider locationProvider; + private final IcebergFileWriterFactory fileWriterFactory; + private final HdfsEnvironment hdfsEnvironment; + private final FileIoProvider fileIoProvider; + private final JsonCodec jsonCodec; + private final ConnectorSession session; + private final IcebergFileFormat fileFormat; + private final Map storageProperties; + private final Schema schema; + private final Map partitionsSpecs; + private final ConnectorPageSink insertPageSink; + private final int columnCount; + private final Map fileDeletions = new HashMap<>(); + + public IcebergMergeSink( + LocationProvider locationProvider, + IcebergFileWriterFactory fileWriterFactory, + HdfsEnvironment hdfsEnvironment, + FileIoProvider fileIoProvider, + JsonCodec jsonCodec, + ConnectorSession session, + IcebergFileFormat fileFormat, + Map storageProperties, + Schema schema, + Map partitionsSpecs, + ConnectorPageSink insertPageSink, + int columnCount) + { + this.locationProvider = requireNonNull(locationProvider, "locationProvider is null"); + this.fileWriterFactory = requireNonNull(fileWriterFactory, "fileWriterFactory is null"); + this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.fileIoProvider = requireNonNull(fileIoProvider, "fileIoProvider is null"); + this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); + this.session = requireNonNull(session, "session is null"); + this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); + this.storageProperties = ImmutableMap.copyOf(requireNonNull(storageProperties, "storageProperties is null")); + this.schema = requireNonNull(schema, "schema is null"); + this.partitionsSpecs = ImmutableMap.copyOf(requireNonNull(partitionsSpecs, "partitionsSpecs is null")); + this.insertPageSink = requireNonNull(insertPageSink, "insertPageSink is null"); + this.columnCount = columnCount; + } + + @Override + public void storeMergedRows(Page page) + { + MergePage mergePage = createDeleteAndInsertPages(page, columnCount); + + mergePage.getInsertionsPage().ifPresent(insertPageSink::appendPage); + + mergePage.getDeletionsPage().ifPresent(deletions -> { + ColumnarRow rowIdRow = toColumnarRow(deletions.getBlock(deletions.getChannelCount() - 1)); + + for (int position = 0; position < rowIdRow.getPositionCount(); position++) { + Slice filePath = VarcharType.VARCHAR.getSlice(rowIdRow.getField(0), position); + long rowPosition = BIGINT.getLong(rowIdRow.getField(1), position); + + int index = position; + FileDeletion deletion = fileDeletions.computeIfAbsent(filePath, ignored -> { + long fileRecordCount = BIGINT.getLong(rowIdRow.getField(2), index); + int partitionSpecId = toIntExact(INTEGER.getLong(rowIdRow.getField(3), index)); + String partitionData = VarcharType.VARCHAR.getSlice(rowIdRow.getField(4), index).toStringUtf8(); + return new FileDeletion(partitionSpecId, partitionData, fileRecordCount); + }); + + deletion.rowsToDelete().addLong(rowPosition); + } + }); + } + + @Override + public CompletableFuture> finish() + { + List fragments = new ArrayList<>(insertPageSink.finish().join()); + + fileDeletions.forEach((dataFilePath, deletion) -> { + ConnectorPageSink sink = createPositionDeletePageSink( + dataFilePath.toStringUtf8(), + partitionsSpecs.get(deletion.partitionSpecId()), + deletion.partitionDataJson(), + deletion.fileRecordCount()); + + fragments.addAll(writePositionDeletes(sink, deletion.rowsToDelete())); + }); + + return completedFuture(fragments); + } + + @Override + public void abort() + { + insertPageSink.abort(); + } + + private ConnectorPageSink createPositionDeletePageSink(String dataFilePath, PartitionSpec partitionSpec, String partitionDataJson, long fileRecordCount) + { + Optional partitionData = Optional.empty(); + if (partitionSpec.isPartitioned()) { + Type[] columnTypes = partitionSpec.fields().stream() + .map(field -> field.transform().getResultType(schema.findType(field.sourceId()))) + .toArray(Type[]::new); + partitionData = Optional.of(PartitionData.fromJson(partitionDataJson, columnTypes)); + } + + return new IcebergPositionDeletePageSink( + dataFilePath, + partitionSpec, + partitionData, + locationProvider, + fileWriterFactory, + hdfsEnvironment, + new HdfsContext(session), + fileIoProvider, + jsonCodec, + session, + fileFormat, + storageProperties, + fileRecordCount); + } + + private static Collection writePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + try { + return doWritePositionDeletes(sink, rowsToDelete); + } + catch (Throwable t) { + closeAllSuppress(t, sink::abort); + throw t; + } + } + + private static Collection doWritePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + + rowsToDelete.forEach(rowPosition -> { + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), rowPosition); + pageBuilder.declarePosition(); + if (pageBuilder.isFull()) { + sink.appendPage(pageBuilder.build()); + pageBuilder.reset(); + } + }); + + if (!pageBuilder.isEmpty()) { + sink.appendPage(pageBuilder.build()); + } + + return sink.finish().join(); + } + + private static class FileDeletion + { + private final int partitionSpecId; + private final String partitionDataJson; + private final long fileRecordCount; + private final LongBitmapDataProvider rowsToDelete = new Roaring64Bitmap(); + + public FileDeletion(int partitionSpecId, String partitionDataJson, long fileRecordCount) + { + this.partitionSpecId = partitionSpecId; + this.partitionDataJson = requireNonNull(partitionDataJson, "partitionDataJson is null"); + this.fileRecordCount = fileRecordCount; + } + + public int partitionSpecId() + { + return partitionSpecId; + } + + public String partitionDataJson() + { + return partitionDataJson; + } + + public long fileRecordCount() + { + return fileRecordCount; + } + + public LongBitmapDataProvider rowsToDelete() + { + return rowsToDelete; + } + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeTableHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeTableHandle.java new file mode 100644 index 000000000000..57767288444e --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMergeTableHandle.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public class IcebergMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final IcebergTableHandle tableHandle; + private final IcebergWritableTableHandle insertTableHandle; + + @JsonCreator + public IcebergMergeTableHandle(IcebergTableHandle tableHandle, IcebergWritableTableHandle insertTableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertTableHandle = requireNonNull(insertTableHandle, "insertTableHandle is null"); + } + + @Override + @JsonProperty + public IcebergTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public IcebergWritableTableHandle getInsertTableHandle() + { + return insertTableHandle; + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 813d68117673..2c9fbae6128d 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -47,9 +47,11 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorMaterializedViewDefinition; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; +import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; @@ -65,6 +67,7 @@ import io.trino.spi.connector.MaterializedViewNotFoundException; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.SystemTable; @@ -97,6 +100,7 @@ import org.apache.iceberg.FileScanTask; import org.apache.iceberg.IsolationLevel; import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.PartitionSpecParser; @@ -117,7 +121,10 @@ import org.apache.iceberg.expressions.Term; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.IntegerType; +import org.apache.iceberg.types.Types.LongType; import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StringType; import org.apache.iceberg.types.Types.StructType; import java.io.IOException; @@ -150,6 +157,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Maps.transformValues; import static com.google.common.collect.Sets.difference; import static com.google.common.collect.Sets.union; import static com.google.common.collect.Streams.concat; @@ -160,8 +168,12 @@ import static io.trino.plugin.hive.util.HiveUtil.isStructuralType; import static io.trino.plugin.iceberg.ConstraintExtractor.extractTupleDomain; import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; -import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_UPDATE_ROW_ID_COLUMN_ID; -import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_UPDATE_ROW_ID_COLUMN_NAME; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_FILE_RECORD_COUNT; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_DATA; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_SPEC_ID; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_ROW_ID; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_ROW_ID_NAME; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_UPDATE_ROW_ID; import static io.trino.plugin.iceberg.IcebergColumnHandle.fileModifiedTimeColumnHandle; import static io.trino.plugin.iceberg.IcebergColumnHandle.fileModifiedTimeColumnMetadata; import static io.trino.plugin.iceberg.IcebergColumnHandle.pathColumnHandle; @@ -205,6 +217,7 @@ import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.UuidType.UUID; @@ -214,7 +227,6 @@ import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.joining; import static org.apache.iceberg.FileContent.POSITION_DELETES; -import static org.apache.iceberg.MetadataColumns.ROW_POSITION; import static org.apache.iceberg.ReachableFileUtil.metadataFileLocations; import static org.apache.iceberg.ReachableFileUtil.versionHintLocation; import static org.apache.iceberg.SnapshotSummary.DELETED_RECORDS_PROP; @@ -669,15 +681,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con "Cannot create a table on a non-empty location: %s, set 'iceberg.unique-table-location=true' in your Iceberg catalog properties " + "to use unique table locations for every table.", location)); } - return new IcebergWritableTableHandle( - tableMetadata.getTable(), - SchemaParser.toJson(transaction.table().schema()), - PartitionSpecParser.toJson(transaction.table().spec()), - getColumns(transaction.table().schema(), typeManager), - location, - getFileFormat(transaction.table()), - transaction.table().properties(), - retryMode); + return newWritableTableHandle(tableMetadata.getTable(), transaction.table(), retryMode); } catch (IOException e) { throw new TrinoException(ICEBERG_FILESYSTEM_ERROR, "Failed checking new table's location: " + location, e); @@ -742,14 +746,20 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto beginTransaction(icebergTable); + return newWritableTableHandle(table.getSchemaTableName(), icebergTable, retryMode); + } + + private IcebergWritableTableHandle newWritableTableHandle(SchemaTableName name, Table table, RetryMode retryMode) + { return new IcebergWritableTableHandle( - table.getSchemaTableName(), - SchemaParser.toJson(icebergTable.schema()), - PartitionSpecParser.toJson(icebergTable.spec()), - getColumns(icebergTable.schema(), typeManager), - icebergTable.location(), - getFileFormat(icebergTable), - icebergTable.properties(), + name, + SchemaParser.toJson(table.schema()), + transformValues(table.specs(), PartitionSpecParser::toJson), + table.spec().specId(), + getColumns(table.schema(), typeManager), + table.location(), + getFileFormat(table), + table.properties(), retryMode); } @@ -1423,7 +1433,7 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan @Override public ColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return getColumnHandle(ROW_POSITION, typeManager); + return getColumnHandle(MetadataColumns.ROW_POSITION, typeManager); } @Override @@ -1452,7 +1462,7 @@ public void finishUpdate(ConnectorSession session, ConnectorTableHandle tableHan public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, List updatedColumns) { List unmodifiedColumns = new ArrayList<>(); - unmodifiedColumns.add(ROW_POSITION); + unmodifiedColumns.add(MetadataColumns.ROW_POSITION); // Include all the non-updated columns. These are needed when writing the new data file with updated column values. IcebergTableHandle table = (IcebergTableHandle) tableHandle; @@ -1466,8 +1476,59 @@ public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, Connect } } - NestedField icebergRowIdField = NestedField.required(TRINO_UPDATE_ROW_ID_COLUMN_ID, TRINO_UPDATE_ROW_ID_COLUMN_NAME, StructType.of(unmodifiedColumns)); - return getColumnHandle(icebergRowIdField, typeManager); + NestedField rowIdField = NestedField.required(TRINO_UPDATE_ROW_ID, TRINO_ROW_ID_NAME, StructType.of(unmodifiedColumns)); + return getColumnHandle(rowIdField, typeManager); + } + + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return DELETE_ROW_AND_INSERT_ROW; + } + + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + StructType type = StructType.of(ImmutableList.builder() + .add(MetadataColumns.FILE_PATH) + .add(MetadataColumns.ROW_POSITION) + .add(NestedField.required(TRINO_MERGE_FILE_RECORD_COUNT, "file_record_count", LongType.get())) + .add(NestedField.required(TRINO_MERGE_PARTITION_SPEC_ID, "partition_spec_id", IntegerType.get())) + .add(NestedField.required(TRINO_MERGE_PARTITION_DATA, "partition_data", StringType.get())) + .build()); + + NestedField field = NestedField.required(TRINO_MERGE_ROW_ID, TRINO_ROW_ID_NAME, type); + return getColumnHandle(field, typeManager); + } + + @Override + public Optional getUpdateLayout(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return Optional.of(IcebergUpdateHandle.INSTANCE); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + IcebergTableHandle table = (IcebergTableHandle) tableHandle; + verifyTableVersionForUpdate(table); + + Table icebergTable = catalog.loadTable(session, table.getSchemaTableName()); + validateNotModifyingOldSnapshot(table, icebergTable); + + beginTransaction(icebergTable); + + IcebergTableHandle newTableHandle = table.withRetryMode(retryMode); + IcebergWritableTableHandle insertHandle = newWritableTableHandle(table.getSchemaTableName(), icebergTable, retryMode); + + return new IcebergMergeTableHandle(newTableHandle, insertHandle); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + IcebergTableHandle handle = ((IcebergMergeTableHandle) tableHandle).getTableHandle(); + finishWrite(session, handle, fragments, true); } private static void verifyTableVersionForUpdate(IcebergTableHandle table) @@ -1969,15 +2030,7 @@ public ConnectorInsertTableHandle beginRefreshMaterializedView(ConnectorSession Table icebergTable = catalog.loadTable(session, table.getSchemaTableName()); beginTransaction(icebergTable); - return new IcebergWritableTableHandle( - table.getSchemaTableName(), - SchemaParser.toJson(icebergTable.schema()), - PartitionSpecParser.toJson(icebergTable.spec()), - getColumns(icebergTable.schema(), typeManager), - icebergTable.location(), - getFileFormat(icebergTable), - icebergTable.properties(), - retryMode); + return newWritableTableHandle(table.getSchemaTableName(), icebergTable, retryMode); } @Override diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java index 95fc18a5d953..fbae3e9c2b9e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java @@ -13,13 +13,10 @@ */ package io.trino.plugin.iceberg; -import io.trino.spi.NodeManager; import io.trino.spi.connector.BucketFunction; -import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; @@ -29,42 +26,20 @@ import javax.inject.Inject; import java.util.List; -import java.util.function.ToIntFunction; import static io.trino.plugin.iceberg.IcebergUtil.schemaFromHandles; import static io.trino.plugin.iceberg.PartitionFields.parsePartitionFields; -import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; import static java.util.Objects.requireNonNull; public class IcebergNodePartitioningProvider implements ConnectorNodePartitioningProvider { private final TypeOperators typeOperators; - private final NodeManager nodeManager; @Inject - public IcebergNodePartitioningProvider(TypeManager typeManager, NodeManager nodeManager) + public IcebergNodePartitioningProvider(TypeManager typeManager) { this.typeOperators = requireNonNull(typeManager, "typeManager is null").getTypeOperators(); - this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); - } - - @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) - { - return createBucketNodeMap(nodeManager.getRequiredWorkerNodes().size()); - } - - @Override - public ToIntFunction getSplitBucketFunction( - ConnectorTransactionHandle transactionHandle, - ConnectorSession session, - ConnectorPartitioningHandle partitioningHandle) - { - return split -> { - // Not currently used, likely because IcebergMetadata.getTableProperties currently does not expose partitioning. - throw new UnsupportedOperationException(); - }; } @Override @@ -75,6 +50,10 @@ public BucketFunction getBucketFunction( List partitionChannelTypes, int bucketCount) { + if (partitioningHandle instanceof IcebergUpdateHandle) { + return new IcebergUpdateBucketFunction(bucketCount); + } + IcebergPartitioningHandle handle = (IcebergPartitioningHandle) partitioningHandle; Schema schema = schemaFromHandles(handle.getPartitioningColumns()); return new IcebergBucketFunction( diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java index 92cd97b9d8ba..cf62fca1ca3e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSinkProvider.java @@ -20,6 +20,8 @@ import io.trino.plugin.iceberg.procedure.IcebergTableExecuteHandle; import io.trino.spi.PageIndexerFactory; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -34,6 +36,9 @@ import javax.inject.Inject; +import java.util.Map; + +import static com.google.common.collect.Maps.transformValues; import static io.trino.plugin.iceberg.IcebergUtil.getLocationProvider; import static java.util.Objects.requireNonNull; @@ -81,7 +86,8 @@ private ConnectorPageSink createPageSink(ConnectorSession session, IcebergWritab { HdfsContext hdfsContext = new HdfsContext(session); Schema schema = SchemaParser.fromJson(tableHandle.getSchemaAsJson()); - PartitionSpec partitionSpec = PartitionSpecParser.fromJson(schema, tableHandle.getPartitionSpecAsJson()); + String partitionSpecJson = tableHandle.getPartitionsSpecsAsJson().get(tableHandle.getPartitionSpecId()); + PartitionSpec partitionSpec = PartitionSpecParser.fromJson(schema, partitionSpecJson); LocationProvider locationProvider = getLocationProvider(tableHandle.getName(), tableHandle.getOutputPath(), tableHandle.getStorageProperties()); return new IcebergPageSink( schema, @@ -133,4 +139,29 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa } throw new IllegalArgumentException("Unknown procedure: " + executeHandle.getProcedureId()); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + IcebergMergeTableHandle merge = (IcebergMergeTableHandle) mergeHandle; + IcebergWritableTableHandle tableHandle = merge.getInsertTableHandle(); + LocationProvider locationProvider = getLocationProvider(tableHandle.getName(), tableHandle.getOutputPath(), tableHandle.getStorageProperties()); + Schema schema = SchemaParser.fromJson(tableHandle.getSchemaAsJson()); + Map partitionsSpecs = transformValues(tableHandle.getPartitionsSpecsAsJson(), json -> PartitionSpecParser.fromJson(schema, json)); + ConnectorPageSink pageSink = createPageSink(session, tableHandle); + + return new IcebergMergeSink( + locationProvider, + fileWriterFactory, + hdfsEnvironment, + fileIoProvider, + jsonCodec, + session, + tableHandle.getFileFormat(), + tableHandle.getStorageProperties(), + schema, + partitionsSpecs, + pageSink, + tableHandle.getInputColumns().size()); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java index c443fc8328c7..862431c74ca5 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java @@ -64,9 +64,9 @@ public class IcebergPageSource private final Supplier updatedRowPageSinkSupplier; // An array with one element per field in the $row_id column. The value in the array points to the // channel where the data can be read from. - private int[] updateRowIdChildColumnIndexes = new int[0]; + private int[] rowIdChildColumnIndexes = new int[0]; // The $row_id's index in 'expectedColumns', or -1 if there isn't one - private int updateRowIdColumnIndex = -1; + private int rowIdColumnIndex = -1; // Maps the Iceberg field ids of unmodified columns to their indexes in updateRowIdChildColumnIndexes private Map icebergIdToRowIdColumnIndex = ImmutableMap.of(); // Maps the Iceberg field ids of modified columns to their indexes in the updateColumns columnValueAndRowIdChannels array @@ -99,16 +99,16 @@ public IcebergPageSource( checkArgument(expectedColumn.equals(requiredColumns.get(i)), "Expected columns must be a prefix of required columns"); expectedColumnIndexes[i] = i; - if (expectedColumn.isUpdateRowIdColumn()) { - this.updateRowIdColumnIndex = i; + if (expectedColumn.isUpdateRowIdColumn() || expectedColumn.isMergeRowIdColumn()) { + this.rowIdColumnIndex = i; Map fieldIdToColumnIndex = mapFieldIdsToIndex(requiredColumns); List rowIdFields = expectedColumn.getColumnIdentity().getChildren(); ImmutableMap.Builder fieldIdToRowIdIndex = ImmutableMap.builder(); - this.updateRowIdChildColumnIndexes = new int[rowIdFields.size()]; + this.rowIdChildColumnIndexes = new int[rowIdFields.size()]; for (int columnIndex = 0; columnIndex < rowIdFields.size(); columnIndex++) { int fieldId = rowIdFields.get(columnIndex).getId(); - updateRowIdChildColumnIndexes[columnIndex] = requireNonNull(fieldIdToColumnIndex.get(fieldId), () -> format("Column %s not found in requiredColumns", fieldId)); + rowIdChildColumnIndexes[columnIndex] = requireNonNull(fieldIdToColumnIndex.get(fieldId), () -> format("Column %s not found in requiredColumns", fieldId)); fieldIdToRowIdIndex.put(fieldId, columnIndex); } this.icebergIdToRowIdColumnIndex = fieldIdToRowIdIndex.buildOrThrow(); @@ -167,7 +167,7 @@ public Page getNextPage() dataPage = projectionsAdapter.get().adaptPage(dataPage); } - dataPage = setUpdateRowIdBlock(dataPage); + dataPage = withRowIdBlock(dataPage); dataPage = dataPage.getColumns(expectedColumnIndexes); return dataPage; @@ -185,20 +185,20 @@ public Page getNextPage() * @param page The raw Page from the Parquet/ORC reader. * @return A Page where the $row_id channel has been populated. */ - private Page setUpdateRowIdBlock(Page page) + private Page withRowIdBlock(Page page) { - if (updateRowIdColumnIndex == -1) { + if (rowIdColumnIndex == -1) { return page; } - Block[] rowIdFields = new Block[updateRowIdChildColumnIndexes.length]; - for (int childIndex = 0; childIndex < updateRowIdChildColumnIndexes.length; childIndex++) { - rowIdFields[childIndex] = page.getBlock(updateRowIdChildColumnIndexes[childIndex]); + Block[] rowIdFields = new Block[rowIdChildColumnIndexes.length]; + for (int childIndex = 0; childIndex < rowIdChildColumnIndexes.length; childIndex++) { + rowIdFields[childIndex] = page.getBlock(rowIdChildColumnIndexes[childIndex]); } Block[] fullPage = new Block[page.getChannelCount()]; for (int channel = 0; channel < page.getChannelCount(); channel++) { - if (channel == updateRowIdColumnIndex) { + if (channel == rowIdColumnIndex) { fullPage[channel] = RowBlock.fromFieldBlocks(page.getPositionCount(), Optional.empty(), rowIdFields); continue; } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index f3cfaeacd6c3..ef50e8698970 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -89,6 +89,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hdfs.BlockMissingException; +import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.PartitionSpecParser; import org.apache.iceberg.Schema; @@ -143,6 +144,9 @@ import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.predicateMatches; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_FILE_RECORD_COUNT; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_DATA; +import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_MERGE_PARTITION_SPEC_ID; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_CANNOT_OPEN_SPLIT; import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_CURSOR_ERROR; @@ -177,6 +181,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -277,19 +282,32 @@ public ConnectorPageSource createPageSource( .forEach(requiredColumns::add); icebergColumns.stream() - .filter(IcebergColumnHandle::isUpdateRowIdColumn) - .findFirst().ifPresent(updateRowIdColumn -> { + .filter(column -> column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) + .findFirst().ifPresent(rowIdColumn -> { Set alreadyRequiredColumnIds = requiredColumns.stream() .map(IcebergColumnHandle::getId) .collect(toImmutableSet()); - for (ColumnIdentity requiredColumnIdentity : updateRowIdColumn.getColumnIdentity().getChildren()) { - if (!alreadyRequiredColumnIds.contains(requiredColumnIdentity.getId())) { - if (requiredColumnIdentity.getId() == ROW_POSITION.fieldId()) { - requiredColumns.add(new IcebergColumnHandle(requiredColumnIdentity, BIGINT, ImmutableList.of(), BIGINT, Optional.empty())); - } - else { - requiredColumns.add(getColumnHandle(tableSchema.findField(requiredColumnIdentity.getId()), typeManager)); - } + for (ColumnIdentity identity : rowIdColumn.getColumnIdentity().getChildren()) { + if (alreadyRequiredColumnIds.contains(identity.getId())) { + // ignore + } + else if (identity.getId() == MetadataColumns.FILE_PATH.fieldId()) { + requiredColumns.add(new IcebergColumnHandle(identity, VARCHAR, ImmutableList.of(), VARCHAR, Optional.empty())); + } + else if (identity.getId() == ROW_POSITION.fieldId()) { + requiredColumns.add(new IcebergColumnHandle(identity, BIGINT, ImmutableList.of(), BIGINT, Optional.empty())); + } + else if (identity.getId() == TRINO_MERGE_FILE_RECORD_COUNT) { + requiredColumns.add(new IcebergColumnHandle(identity, BIGINT, ImmutableList.of(), BIGINT, Optional.empty())); + } + else if (identity.getId() == TRINO_MERGE_PARTITION_SPEC_ID) { + requiredColumns.add(new IcebergColumnHandle(identity, INTEGER, ImmutableList.of(), INTEGER, Optional.empty())); + } + else if (identity.getId() == TRINO_MERGE_PARTITION_DATA) { + requiredColumns.add(new IcebergColumnHandle(identity, VARCHAR, ImmutableList.of(), VARCHAR, Optional.empty())); + } + else { + requiredColumns.add(getColumnHandle(tableSchema.findField(identity.getId()), typeManager)); } } }); @@ -328,6 +346,9 @@ public ConnectorPageSource createPageSource( split.getLength(), fileSize, fileModifiedTime, + split.getFileRecordCount(), + partitionSpec.specId(), + split.getPartitionDataJson(), split.getFileFormat(), split.getSchemaAsJson().map(SchemaParser::fromJson), requiredColumns, @@ -460,6 +481,9 @@ private ConnectorPageSource openDeletes( delete.fileSizeInBytes(), delete.fileSizeInBytes(), OptionalLong.empty(), + delete.recordCount(), + 0, + "", IcebergFileFormat.fromIceberg(delete.format()), Optional.of(schemaFromHandles(columns)), columns, @@ -477,6 +501,9 @@ public ReaderPageSource createDataPageSource( long length, long fileSize, OptionalLong fileModifiedTime, + long fileRecordCount, + int partitionSpecId, + String partitionData, IcebergFileFormat fileFormat, Optional fileSchema, List dataColumns, @@ -497,6 +524,9 @@ public ReaderPageSource createDataPageSource( length, fileSize, fileModifiedTime, + fileRecordCount, + partitionSpecId, + partitionData, dataColumns, predicate, orcReaderOptions @@ -523,6 +553,9 @@ public ReaderPageSource createDataPageSource( length, fileSize, fileModifiedTime, + fileRecordCount, + partitionSpecId, + partitionData, dataColumns, parquetReaderOptions .withMaxReadBlockSize(getParquetMaxReadBlockSize(session)), @@ -538,6 +571,9 @@ public ReaderPageSource createDataPageSource( start, length, fileModifiedTime, + fileRecordCount, + partitionSpecId, + partitionData, fileSchema.orElseThrow(), nameMapping, dataColumns); @@ -556,6 +592,9 @@ private static ReaderPageSource createOrcPageSource( long length, long fileSize, OptionalLong fileModifiedTime, + long fileRecordCount, + int partitionSpecId, + String partitionData, List columns, TupleDomain effectivePredicate, OrcReaderOptions options, @@ -625,13 +664,22 @@ else if (column.isPathColumn()) { else if (column.isFileModifiedTimeColumn()) { columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(fileModifiedTime.orElseThrow(), UTC_KEY)))); } - else if (column.isUpdateRowIdColumn()) { + else if (column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) { // $row_id is a composite of multiple physical columns. It is assembled by the IcebergPageSource columnAdaptations.add(ColumnAdaptation.nullColumn(column.getType())); } else if (column.isRowPositionColumn()) { columnAdaptations.add(ColumnAdaptation.positionColumn()); } + else if (column.getId() == TRINO_MERGE_FILE_RECORD_COUNT) { + columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(column.getType(), fileRecordCount))); + } + else if (column.getId() == TRINO_MERGE_PARTITION_SPEC_ID) { + columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(column.getType(), (long) partitionSpecId))); + } + else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { + columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(column.getType(), utf8Slice(partitionData)))); + } else if (orcColumn != null) { Type readType = getOrcReadType(column.getType(), typeManager); @@ -832,7 +880,7 @@ public IdBasedFieldMapperFactory(List columns) ImmutableMap.Builder> mapping = ImmutableMap.builder(); for (IcebergColumnHandle column : columns) { - if (column.isUpdateRowIdColumn()) { + if (column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) { // The update $row_id column contains fields which should not be accounted for in the mapping. continue; } @@ -904,6 +952,9 @@ private static ReaderPageSource createParquetPageSource( long length, long fileSize, OptionalLong fileModifiedTime, + long fileRecordCount, + int partitionSpecId, + String partitionData, List regularColumns, ParquetReaderOptions options, TupleDomain effectivePredicate, @@ -993,7 +1044,7 @@ else if (column.isPathColumn()) { else if (column.isFileModifiedTimeColumn()) { constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(fileModifiedTime.orElseThrow(), UTC_KEY))); } - else if (column.isUpdateRowIdColumn()) { + else if (column.isUpdateRowIdColumn() || column.isMergeRowIdColumn()) { // $row_id is a composite of multiple physical columns, it is assembled by the IcebergPageSource trinoTypes.add(column.getType()); internalFields.add(Optional.empty()); @@ -1008,6 +1059,15 @@ else if (column.isRowPositionColumn()) { constantPopulatingPageSourceBuilder.addDelegateColumn(parquetSourceChannel); parquetSourceChannel++; } + else if (column.getId() == TRINO_MERGE_FILE_RECORD_COUNT) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), fileRecordCount)); + } + else if (column.getId() == TRINO_MERGE_PARTITION_SPEC_ID) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), (long) partitionSpecId)); + } + else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), utf8Slice(partitionData))); + } else { rowIndexChannels.add(false); org.apache.parquet.schema.Type parquetField = parquetFields.get(columnIndex); @@ -1066,6 +1126,9 @@ private ReaderPageSource createAvroPageSource( long start, long length, OptionalLong fileModifiedTime, + long fileRecordCount, + int partitionSpecId, + String partitionData, Schema fileSchema, Optional nameMapping, List columns) @@ -1113,6 +1176,15 @@ else if (column.isRowPositionColumn()) { constantPopulatingPageSourceBuilder.addDelegateColumn(avroSourceChannel); avroSourceChannel++; } + else if (column.getId() == TRINO_MERGE_FILE_RECORD_COUNT) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), fileRecordCount)); + } + else if (column.getId() == TRINO_MERGE_PARTITION_SPEC_ID) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), (long) partitionSpecId)); + } + else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { + constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), utf8Slice(partitionData))); + } else if (field == null) { constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), null)); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateBucketFunction.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateBucketFunction.java new file mode 100644 index 000000000000..8a72729cc3e2 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateBucketFunction.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.BucketFunction; + +import static io.trino.spi.type.VarcharType.VARCHAR; + +public class IcebergUpdateBucketFunction + implements BucketFunction +{ + private final int bucketCount; + + public IcebergUpdateBucketFunction(int bucketCount) + { + this.bucketCount = bucketCount; + } + + @Override + public int getBucket(Page page, int position) + { + Block row = page.getBlock(0).getObject(position, Block.class); + Slice value = VARCHAR.getSlice(row, 0); // file path field of row ID + return (value.hashCode() & Integer.MAX_VALUE) % bucketCount; + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateHandle.java new file mode 100644 index 000000000000..9b08d04b8d9a --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUpdateHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import io.trino.spi.connector.ConnectorPartitioningHandle; + +public enum IcebergUpdateHandle + implements ConnectorPartitioningHandle +{ + INSTANCE +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergWritableTableHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergWritableTableHandle.java index 8cab0fbbcf09..b3ada32bac4d 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergWritableTableHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergWritableTableHandle.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.RetryMode; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Map; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class IcebergWritableTableHandle @@ -31,7 +33,8 @@ public class IcebergWritableTableHandle { private final SchemaTableName name; private final String schemaAsJson; - private final String partitionSpecAsJson; + private final Map partitionsSpecsAsJson; + private final int partitionSpecId; private final List inputColumns; private final String outputPath; private final IcebergFileFormat fileFormat; @@ -42,7 +45,8 @@ public class IcebergWritableTableHandle public IcebergWritableTableHandle( @JsonProperty("name") SchemaTableName name, @JsonProperty("schemaAsJson") String schemaAsJson, - @JsonProperty("partitionSpecAsJson") String partitionSpecAsJson, + @JsonProperty("partitionSpecsAsJson") Map partitionsSpecsAsJson, + @JsonProperty("partitionSpecId") int partitionSpecId, @JsonProperty("inputColumns") List inputColumns, @JsonProperty("outputPath") String outputPath, @JsonProperty("fileFormat") IcebergFileFormat fileFormat, @@ -51,12 +55,14 @@ public IcebergWritableTableHandle( { this.name = requireNonNull(name, "name is null"); this.schemaAsJson = requireNonNull(schemaAsJson, "schemaAsJson is null"); - this.partitionSpecAsJson = requireNonNull(partitionSpecAsJson, "partitionSpecAsJson is null"); + this.partitionsSpecsAsJson = ImmutableMap.copyOf(requireNonNull(partitionsSpecsAsJson, "partitionsSpecsAsJson is null")); + this.partitionSpecId = partitionSpecId; this.inputColumns = ImmutableList.copyOf(requireNonNull(inputColumns, "inputColumns is null")); this.outputPath = requireNonNull(outputPath, "outputPath is null"); this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); - this.storageProperties = requireNonNull(storageProperties, "storageProperties is null"); + this.storageProperties = ImmutableMap.copyOf(requireNonNull(storageProperties, "storageProperties is null")); this.retryMode = requireNonNull(retryMode, "retryMode is null"); + checkArgument(partitionsSpecsAsJson.containsKey(partitionSpecId), "partitionSpecId missing from partitionSpecs"); } @JsonProperty @@ -72,9 +78,15 @@ public String getSchemaAsJson() } @JsonProperty - public String getPartitionSpecAsJson() + public Map getPartitionsSpecsAsJson() { - return partitionSpecAsJson; + return partitionsSpecsAsJson; + } + + @JsonProperty + public int getPartitionSpecId() + { + return partitionSpecId; } @JsonProperty diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorSmokeTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorSmokeTest.java index 0dd50effeaf1..b5f0abcf3562 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorSmokeTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorSmokeTest.java @@ -60,6 +60,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_DELETE: case SUPPORTS_UPDATE: + case SUPPORTS_MERGE: return true; default: return super.hasBehavior(connectorBehavior); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 10d78c69e3c1..8be691db8d6c 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -22,7 +22,6 @@ import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.TableHandle; import io.trino.operator.OperatorStats; -import io.trino.plugin.hive.HdfsEnvironment; import io.trino.spi.QueryId; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.Constraint; @@ -76,7 +75,6 @@ import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; -import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -132,7 +130,6 @@ import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNull; -import static org.testng.Assert.assertTrue; public abstract class BaseIcebergConnectorTest extends BaseConnectorTest @@ -179,6 +176,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_DELETE: case SUPPORTS_UPDATE: + case SUPPORTS_MERGE: return true; case SUPPORTS_COMMENT_ON_VIEW: @@ -1144,7 +1142,7 @@ public void testLargeInOnPartitionedColumns() assertUpdate("INSERT INTO test_in_predicate_large_set VALUES (1, 10)", 1L); assertUpdate("INSERT INTO test_in_predicate_large_set VALUES (2, 20)", 1L); - List predicates = IntStream.range(0, 25_000).boxed() + List predicates = range(0, 25_000).boxed() .map(Object::toString) .collect(toImmutableList()); String filter = format("col2 IN (%s)", join(",", predicates)); @@ -3142,7 +3140,7 @@ private void assertFilterPushdown( Optional> result = metadata.applyFilter(session, table, new Constraint(domains)); - assertTrue(result.isEmpty() == (expectedUnenforcedPredicate == null && expectedEnforcedPredicate == null)); + assertEquals((expectedUnenforcedPredicate == null && expectedEnforcedPredicate == null), result.isEmpty()); if (result.isPresent()) { IcebergTableHandle newTable = (IcebergTableHandle) result.get().getHandle().getConnectorHandle(); @@ -3417,11 +3415,11 @@ public void testIncorrectIcebergFileSizes() // Alter data file entry to store incorrect file size GenericData.Record dataFile = (GenericData.Record) entry.get("data_file"); long alteredValue = 50L; - assertNotEquals((long) dataFile.get("file_size_in_bytes"), alteredValue); + assertNotEquals(dataFile.get("file_size_in_bytes"), alteredValue); dataFile.put("file_size_in_bytes", alteredValue); // Replace the file through HDFS client. This is required for correct checksums. - HdfsEnvironment.HdfsContext context = new HdfsContext(getSession().toConnectorSession()); + HdfsContext context = new HdfsContext(getSession().toConnectorSession()); org.apache.hadoop.fs.Path manifestFilePath = new org.apache.hadoop.fs.Path(manifestFile); FileSystem fs = HDFS_ENVIRONMENT.getFileSystem(context, manifestFilePath); @@ -3888,7 +3886,7 @@ public void testSplitPruningFromDataFileStatistics(DataMappingTestSetup testSetu nCopies(100, testSetup.getSampleValueLiteral()).stream(), nCopies(100, testSetup.getHighValueLiteral()).stream()) .map(value -> "(" + value + ", rand())") - .collect(Collectors.joining(", ")); + .collect(joining(", ")); assertUpdate(withSmallRowGroups(getSession()), "INSERT INTO " + tableName + " VALUES " + values, 200); String query = "SELECT * FROM " + tableName + " WHERE col = " + testSetup.getSampleValueLiteral(); @@ -5221,6 +5219,254 @@ public void testReadFromVersionedTableWithExpiredHistory() assertQueryFails("SELECT * FROM " + tableName + " FOR TIMESTAMP AS OF " + timestampLiteral(v1EpochMillis, 9), "No version history table .* at or before .*"); } + @Test + public void testMergeSimpleSelectPartitioned() + { + String targetTable = "merge_simple_target_" + randomTableSuffix(); + String sourceTable = "merge_simple_source_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address'])", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sql, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMergeUpdateWithVariousLayouts(String partitionPhase) + { + String targetTable = "merge_formats_target_" + randomTableSuffix(); + String sourceTable = "merge_formats_source_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partitionPhase)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')"); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + assertUpdate(sql, 3); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')"); + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] partitionedAndBucketedProvider() + { + return new Object[][] { + {"WITH (partitioning = ARRAY['customer'])"}, + {"WITH (partitioning = ARRAY['purchase'])"}, + {"WITH (partitioning = ARRAY['bucket(customer, 3)'])"}, + {"WITH (partitioning = ARRAY['bucket(purchase, 4)'])"}, + }; + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMergeMultipleOperations(String partitioning) + { + int targetCustomerCount = 32; + String targetTable = "merge_multiple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (purchase INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) %s", targetTable, partitioning)); + String originalInsertFirstHalf = range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(joining(", ")); + String originalInsertSecondHalf = range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchase, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + String firstMergeSource = range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchase = s.purchase, zipcode = s.zipcode, spouse = s.spouse, address = s.address", + targetCustomerCount / 2); + + assertQuery( + "SELECT customer, purchase, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s", originalInsertFirstHalf, firstMergeSource)); + + String nextInsert = range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchase, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + String secondMergeSource = range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase, zipcode, spouse, address) VALUES(s.customer, s.purchase, s.zipcode, s.spouse, s.address)", + targetCustomerCount * 3 / 2 - 1); + + String updatedBeginning = range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(joining(", ")); + String updatedMiddle = range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(joining(", ")); + String updatedEnd = range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(joining(", ")); + + assertQuery( + "SELECT customer, purchase, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s, %s", updatedBeginning, updatedMiddle, updatedEnd)); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeSimpleQueryPartitioned() + { + String targetTable = "merge_simple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address'])", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String query = format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + assertUpdate(query, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test(dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchFails(String createTableSql) + { + String targetTable = "merge_multiple_target_" + randomTableSuffix(); + String sourceTable = "merge_multiple_source_" + randomTableSuffix(); + assertUpdate(format(createTableSql, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable), 2); + + assertThatThrownBy(() -> computeActual(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address")) + .hasMessage("One MERGE target table row matched more than one source row"); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address", + 1); + assertQuery("SELECT customer, purchases, address FROM " + targetTable, "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['bucket(customer, 3)'])"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])"}, + {"CREATE TABLE %s (customer VARCHAR, address VARCHAR, purchases INT) WITH (partitioning = ARRAY['address'])"}, + {"CREATE TABLE %s (purchases INT, customer VARCHAR, address VARCHAR) WITH (partitioning = ARRAY['address', 'customer'])"} + }; + } + + @Test(dataProvider = "targetAndSourceWithDifferentPartitioning") + public void testMergeWithDifferentPartitioning(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + String targetTable = format("%s_target_%s", testDescription, randomTableSuffix()); + String sourceTable = format("%s_source_%s", testDescription, randomTableSuffix()); + + assertUpdate(format(createTargetTableSql, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format(createSourceTableSql, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + assertUpdate(sql, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentPartitioning() + { + return new Object[][] { + { + "target_partitioned_source_and_target_partitioned_and_bucketed", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + }, + { + "target_flat_source_partitioned_by_customer", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", + "CREATE TABLE %s (purchases INT, address VARCHAR, customer VARCHAR) WITH (partitioning = ARRAY['customer'])" + }, + { + "target_partitioned_by_customer_source_flat", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", + }, + { + "target_bucketed_by_customer_source_flat", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['bucket(customer, 3)'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", + }, + { + "target_partitioned_source_partitioned_and_bucketed", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + }, + { + "target_partitioned_target_partitioned_and_bucketed", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])", + } + }; + } + @Override protected OptionalInt maxTableNameLength() { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java new file mode 100644 index 000000000000..d4798e8fa5bc --- /dev/null +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMergeTableHandle.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kudu; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; +import io.trino.spi.type.Type; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class KuduMergeTableHandle + implements ConnectorMergeTableHandle, KuduTableMapping +{ + private final KuduTableHandle tableHandle; + private final KuduOutputTableHandle outputTableHandle; + + @JsonCreator + public KuduMergeTableHandle( + @JsonProperty("tableHandle") KuduTableHandle tableHandle, + @JsonProperty("outputTableHandle") KuduOutputTableHandle outputTableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.outputTableHandle = requireNonNull(outputTableHandle, "outputTableHandle is null"); + } + + @JsonProperty + @Override + public KuduTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public KuduOutputTableHandle getOutputTableHandle() + { + return outputTableHandle; + } + + @Override + public boolean isGenerateUUID() + { + return outputTableHandle.isGenerateUUID(); + } + + @Override + public List getColumnTypes() + { + return outputTableHandle.getColumnTypes(); + } + + @Override + public List getOriginalColumnTypes() + { + return outputTableHandle.getOriginalColumnTypes(); + } +} diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java index 45b2be25c817..a8eb69d5ef65 100755 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduMetadata.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -39,6 +40,7 @@ import io.trino.spi.connector.NotFoundException; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.expression.ConnectorExpression; @@ -65,11 +67,13 @@ import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; +import java.util.function.Consumer; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.kudu.KuduColumnHandle.ROW_ID; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; import static java.util.Objects.requireNonNull; public class KuduMetadata @@ -173,18 +177,21 @@ private ConnectorTableMetadata getTableMetadata(KuduTableHandle tableHandle) public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle connectorTableHandle) { KuduTableHandle tableHandle = (KuduTableHandle) connectorTableHandle; + ImmutableMap.Builder columnHandles = ImmutableMap.builder(); Schema schema = clientSession.getTableSchema(tableHandle); + forAllColumnHandles(schema, column -> columnHandles.put(column.getName(), column)); + return columnHandles.buildOrThrow(); + } - ImmutableMap.Builder columnHandles = ImmutableMap.builder(); + private void forAllColumnHandles(Schema schema, Consumer handleEater) + { for (int ordinal = 0; ordinal < schema.getColumnCount(); ordinal++) { ColumnSchema col = schema.getColumnByIndex(ordinal); String name = col.getName(); Type type = TypeHelper.fromKuduColumn(col); KuduColumnHandle columnHandle = new KuduColumnHandle(name, ordinal, type); - columnHandles.put(name, columnHandle); + handleEater.accept(columnHandle); } - - return columnHandles.buildOrThrow(); } @Override @@ -413,6 +420,45 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan { } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return KuduColumnHandle.ROW_ID_HANDLE; + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + KuduTableHandle kuduTableHandle = (KuduTableHandle) tableHandle; + KuduTable table = kuduTableHandle.getTable(clientSession); + Schema schema = table.getSchema(); + List columns = schema.getColumns(); + List columnTypes = columns.stream() + .map(TypeHelper::fromKuduColumn) + .collect(toImmutableList()); + ConnectorTableMetadata tableMetadata = getTableMetadata(kuduTableHandle); + List columnOriginalTypes = tableMetadata.getColumns().stream() + .map(ColumnMetadata::getType) + .collect(toImmutableList()); + PartitionDesign design = KuduTableProperties.getPartitionDesign(tableMetadata.getProperties()); + boolean generateUUID = !design.hasPartitions(); + return new KuduMergeTableHandle( + kuduTableHandle, + new KuduOutputTableHandle(tableMetadata.getTable(), columnOriginalTypes, columnTypes, generateUUID, table)); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + // For Kudu, nothing needs to be done finish the merge. + } + @Override public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduNodePartitioningProvider.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduNodePartitioningProvider.java index d12b809192d3..1839ac9f16a7 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduNodePartitioningProvider.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduNodePartitioningProvider.java @@ -26,6 +26,7 @@ import javax.inject.Inject; import java.util.List; +import java.util.Optional; import java.util.function.ToIntFunction; import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; @@ -43,13 +44,13 @@ public KuduNodePartitioningProvider(KuduClientSession clientSession) } @Override - public ConnectorBucketNodeMap getBucketNodeMap( + public Optional getBucketNodeMapping( ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { KuduPartitioningHandle handle = (KuduPartitioningHandle) partitioningHandle; - return createBucketNodeMap(handle.getBucketCount()); + return Optional.of(createBucketNodeMap(handle.getBucketCount())); } @Override diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java index 961f61dfde93..71db5611ee4a 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSink.java @@ -19,6 +19,7 @@ import io.airlift.slice.Slice; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.DecimalType; @@ -26,6 +27,10 @@ import io.trino.spi.type.SqlDecimal; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import org.apache.kudu.Schema; +import org.apache.kudu.client.Delete; +import org.apache.kudu.client.Insert; +import org.apache.kudu.client.KeyEncoderAccessor; import org.apache.kudu.client.KuduException; import org.apache.kudu.client.KuduOperationApplier; import org.apache.kudu.client.KuduTable; @@ -42,6 +47,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; @@ -60,7 +66,7 @@ import static java.util.concurrent.CompletableFuture.completedFuture; public class KuduPageSink - implements ConnectorPageSink + implements ConnectorPageSink, ConnectorMergeSink { private final ConnectorSession connectorSession; private final KuduClientSession session; @@ -88,6 +94,14 @@ public KuduPageSink( this(connectorSession, clientSession, tableHandle.getTable(clientSession), tableHandle); } + public KuduPageSink( + ConnectorSession connectorSession, + KuduClientSession clientSession, + KuduMergeTableHandle tableHandle) + { + this(connectorSession, clientSession, tableHandle.getOutputTableHandle().getTable(clientSession), tableHandle); + } + private KuduPageSink( ConnectorSession connectorSession, KuduClientSession clientSession, @@ -187,6 +201,58 @@ else if (type instanceof DecimalType) { } } + @Override + public void storeMergedRows(Page page) + { + // The last channel in the page is the rowId block, the next-to-last is the operation block + int columnCount = columnTypes.size(); + checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + Block operationBlock = page.getBlock(columnCount); + Block rowIds = page.getBlock(columnCount + 1); + + Schema schema = table.getSchema(); + try (KuduOperationApplier operationApplier = KuduOperationApplier.fromKuduClientSession(session)) { + for (int position = 0; position < page.getPositionCount(); position++) { + long operation = TINYINT.getLong(operationBlock, position); + + if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + Delete delete = table.newDelete(); + Slice deleteRowId = VARBINARY.getSlice(rowIds, position); + RowHelper.copyPrimaryKey(schema, KeyEncoderAccessor.decodePrimaryKey(schema, deleteRowId.getBytes()), delete.getRow()); + try { + operationApplier.applyOperationAsync(delete); + } + catch (KuduException e) { + throw new RuntimeException(e); + } + } + + if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + Insert insert = table.newInsert(); + PartialRow insertRow = insert.getRow(); + int insertStart = 0; + if (generateUUID) { + String id = format("%s-%08x", uuid, nextSubId++); + insertRow.addString(0, id); + insertStart = 1; + } + for (int channel = 0; channel < columnCount; channel++) { + appendColumn(insertRow, page, position, channel, channel + insertStart); + } + try { + operationApplier.applyOperationAsync(insert); + } + catch (KuduException e) { + throw new RuntimeException(e); + } + } + } + } + catch (KuduException e) { + throw new RuntimeException(e); + } + } + @Override public CompletableFuture> finish() { diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java index eff964c67a78..86d4c1796219 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/KuduPageSinkProvider.java @@ -14,6 +14,8 @@ package io.trino.plugin.kudu; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -55,4 +57,10 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new KuduPageSink(session, clientSession, handle); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + return new KuduPageSink(session, clientSession, (KuduMergeTableHandle) mergeHandle); + } } diff --git a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java index de7d2864401b..1603a233bc3d 100644 --- a/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java +++ b/plugin/trino-kudu/src/main/java/io/trino/plugin/kudu/properties/KuduTableProperties.java @@ -171,7 +171,7 @@ public static PartitionDesign getPartitionDesign(Map tableProper @SuppressWarnings("unchecked") List hashColumns = (List) tableProperties.get(PARTITION_BY_HASH_COLUMNS); @SuppressWarnings("unchecked") - List hashColumns2 = (List) tableProperties.get(PARTITION_BY_HASH_COLUMNS_2); + List hashColumns2 = (List) tableProperties.getOrDefault(PARTITION_BY_HASH_COLUMNS_2, ImmutableList.of()); PartitionDesign design = new PartitionDesign(); if (!hashColumns.isEmpty()) { diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/BaseKuduConnectorSmokeTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/BaseKuduConnectorSmokeTest.java index ae8c3ac2fbd5..6642dcef6cc9 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/BaseKuduConnectorSmokeTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/BaseKuduConnectorSmokeTest.java @@ -44,6 +44,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_DELETE: + case SUPPORTS_MERGE: return true; case SUPPORTS_RENAME_SCHEMA: case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java index 6f50dc7c4394..a2a85f5ba118 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java @@ -24,6 +24,7 @@ import org.testng.annotations.Test; import java.util.Optional; +import java.util.regex.Matcher; import java.util.regex.Pattern; import static io.trino.plugin.kudu.KuduQueryRunnerFactory.createKuduQueryRunnerTpch; @@ -66,6 +67,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_DELETE: + case SUPPORTS_MERGE: return true; case SUPPORTS_RENAME_SCHEMA: case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: @@ -85,6 +87,18 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) } } + @Override + protected String createTableForWrites(String createTable) + { + // assume primary key column is the first column and there are multiple columns + Matcher matcher = Pattern.compile("CREATE TABLE .* \\((\\w+) .*").matcher(createTable); + assertThat(matcher.matches()).as(createTable).isTrue(); + String column = matcher.group(1); + + return createTable.replaceFirst(",", " WITH (primary_key=true),") + + format("WITH (partition_by_hash_columns = ARRAY['%s'], partition_by_hash_buckets = 2)", column); + } + @Test @Override public void testCreateSchema() diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java index 2aaa204bc2ea..a3166aee0922 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestingKuduServer.java @@ -42,7 +42,7 @@ public class TestingKuduServer private static final Integer KUDU_TSERVER_PORT = 7050; private static final Integer NUMBER_OF_REPLICA = 3; - private static final String TOXIPROXY_IMAGE = "shopify/toxiproxy:2.1.0"; + private static final String TOXIPROXY_IMAGE = "shopify/toxiproxy:2.1.4"; private static final String TOXIPROXY_NETWORK_ALIAS = "toxiproxy"; private final Network network; diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotNodePartitioningProvider.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotNodePartitioningProvider.java index 31905dd60e5e..26f02f4bf23b 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotNodePartitioningProvider.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotNodePartitioningProvider.java @@ -23,18 +23,21 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.function.ToIntFunction; +import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; + public class PinotNodePartitioningProvider implements ConnectorNodePartitioningProvider { @Override - public ConnectorBucketNodeMap getBucketNodeMap( + public Optional getBucketNodeMapping( ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { - return ConnectorBucketNodeMap.createBucketNodeMap(1); + return Optional.of(createBucketNodeMap(1)); } @Override diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateFunction.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateFunction.java new file mode 100644 index 000000000000..f5d56cb72599 --- /dev/null +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateFunction.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.raptor.legacy; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.BucketFunction; + +import static io.trino.spi.type.IntegerType.INTEGER; +import static java.lang.Math.toIntExact; + +public class RaptorBucketedUpdateFunction + implements BucketFunction +{ + @Override + public int getBucket(Page page, int position) + { + Block row = page.getBlock(0).getObject(position, Block.class); + return toIntExact(INTEGER.getLong(row, 0)); // bucket field of row ID + } +} diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/util/Types.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateHandle.java similarity index 60% rename from plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/util/Types.java rename to plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateHandle.java index 35771a149e1b..2c0d7ff6a465 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/util/Types.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorBucketedUpdateHandle.java @@ -11,23 +11,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.raptor.legacy.util; +package io.trino.plugin.raptor.legacy; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.MapType; -import io.trino.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; -public final class Types -{ - private Types() {} +import java.util.List; - public static boolean isArrayType(Type type) - { - return type instanceof ArrayType; - } - - public static boolean isMapType(Type type) +public class RaptorBucketedUpdateHandle + extends RaptorPartitioningHandle +{ + @JsonCreator + public RaptorBucketedUpdateHandle(long distributionId, List bucketToNode) { - return type instanceof MapType; + super(distributionId, bucketToNode); } } diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorColumnHandle.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorColumnHandle.java index 47cbaf4c58d6..44442e880edf 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorColumnHandle.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorColumnHandle.java @@ -16,12 +16,16 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; import java.util.Objects; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RowType.field; +import static io.trino.spi.type.RowType.rowType; +import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarcharType.createVarcharType; import static java.util.Objects.requireNonNull; @@ -31,6 +35,7 @@ public final class RaptorColumnHandle // Generated rowId column for updates private static final long SHARD_ROW_ID_COLUMN_ID = -1; private static final String SHARD_ROW_ID_COLUMN_NAME = "$shard_row_id"; + private static final BigintType SHARD_ROW_ID_COLUMN_TYPE = BIGINT; public static final long SHARD_UUID_COLUMN_ID = -2; public static final String SHARD_UUID_COLUMN_NAME = "$shard_uuid"; @@ -39,6 +44,13 @@ public final class RaptorColumnHandle public static final long BUCKET_NUMBER_COLUMN_ID = -3; public static final String BUCKET_NUMBER_COLUMN_NAME = "$bucket_number"; + private static final long MERGE_ROW_ID_COLUMN_ID = -4; + private static final String MERGE_ROW_ID_COLUMN_NAME = "$merge_row_id"; + private static final Type MERGE_ROW_ID_COLUMN_TYPE = rowType( + field("bucket", INTEGER), + field("uuid", UUID), + field("row_id", BIGINT)); + private final String columnName; private final long columnId; private final Type columnType; @@ -119,7 +131,7 @@ public static boolean isShardRowIdColumn(long columnId) public static RaptorColumnHandle shardRowIdHandle() { - return new RaptorColumnHandle(SHARD_ROW_ID_COLUMN_NAME, SHARD_ROW_ID_COLUMN_ID, BIGINT); + return new RaptorColumnHandle(SHARD_ROW_ID_COLUMN_NAME, SHARD_ROW_ID_COLUMN_ID, SHARD_ROW_ID_COLUMN_TYPE); } public static boolean isShardUuidColumn(long columnId) @@ -142,6 +154,16 @@ public static RaptorColumnHandle bucketNumberColumnHandle() return new RaptorColumnHandle(BUCKET_NUMBER_COLUMN_NAME, BUCKET_NUMBER_COLUMN_ID, INTEGER); } + public static RaptorColumnHandle mergeRowIdHandle() + { + return new RaptorColumnHandle(MERGE_ROW_ID_COLUMN_NAME, MERGE_ROW_ID_COLUMN_ID, MERGE_ROW_ID_COLUMN_TYPE); + } + + public static boolean isMergeRowIdColumn(long columnId) + { + return columnId == MERGE_ROW_ID_COLUMN_ID; + } + public static boolean isHiddenColumn(long columnId) { return columnId < 0; diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeSink.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeSink.java new file mode 100644 index 000000000000..8fc0369434ae --- /dev/null +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeSink.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.raptor.legacy; + +import com.google.common.collect.ImmutableList; +import io.airlift.json.JsonCodec; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.plugin.raptor.legacy.metadata.ShardDelta; +import io.trino.plugin.raptor.legacy.metadata.ShardInfo; +import io.trino.plugin.raptor.legacy.storage.ShardRewriter; +import io.trino.plugin.raptor.legacy.storage.StorageManager; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.ColumnarRow; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.MergePage; +import io.trino.spi.type.UuidType; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.OptionalInt; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.trino.spi.block.ColumnarRow.toColumnarRow; +import static io.trino.spi.connector.MergePage.createDeleteAndInsertPages; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.allOf; +import static java.util.stream.Collectors.toUnmodifiableList; + +public class RaptorMergeSink + implements ConnectorMergeSink +{ + private static final JsonCodec SHARD_INFO_CODEC = jsonCodec(ShardInfo.class); + private static final JsonCodec SHARD_DELTA_CODEC = jsonCodec(ShardDelta.class); + + private final ConnectorPageSink pageSink; + private final StorageManager storageManager; + private final long transactionId; + private final int columnCount; + private final Map> rowsToDelete = new HashMap<>(); + + public RaptorMergeSink(ConnectorPageSink pageSink, StorageManager storageManager, long transactionId, int columnCount) + { + this.pageSink = requireNonNull(pageSink, "pageSink is null"); + this.storageManager = requireNonNull(storageManager, "storageManager is null"); + this.transactionId = transactionId; + this.columnCount = columnCount; + } + + @Override + public void storeMergedRows(Page page) + { + MergePage mergePage = createDeleteAndInsertPages(page, columnCount); + + mergePage.getInsertionsPage().ifPresent(pageSink::appendPage); + + mergePage.getDeletionsPage().ifPresent(deletions -> { + ColumnarRow rowIdRow = toColumnarRow(deletions.getBlock(deletions.getChannelCount() - 1)); + Block shardBucketBlock = rowIdRow.getField(0); + Block shardUuidBlock = rowIdRow.getField(1); + Block shardRowIdBlock = rowIdRow.getField(2); + + for (int position = 0; position < rowIdRow.getPositionCount(); position++) { + OptionalInt bucketNumber = shardBucketBlock.isNull(position) + ? OptionalInt.empty() + : OptionalInt.of(toIntExact(INTEGER.getLong(shardBucketBlock, position))); + UUID uuid = trinoUuidToJavaUuid(UuidType.UUID.getSlice(shardUuidBlock, position)); + int rowId = toIntExact(BIGINT.getLong(shardRowIdBlock, position)); + Entry entry = rowsToDelete.computeIfAbsent(uuid, ignored -> Map.entry(bucketNumber, new BitSet())); + verify(entry.getKey().equals(bucketNumber), "multiple bucket numbers for same shard"); + entry.getValue().set(rowId); + } + }); + } + + @Override + public CompletableFuture> finish() + { + List>> futures = new ArrayList<>(); + + rowsToDelete.forEach((uuid, entry) -> { + OptionalInt bucketNumber = entry.getKey(); + BitSet rowIds = entry.getValue(); + ShardRewriter rewriter = storageManager.createShardRewriter(transactionId, bucketNumber, uuid); + futures.add(rewriter.rewrite(rowIds)); + }); + + futures.add(pageSink.finish().thenApply(slices -> { + List newShards = slices.stream() + .map(slice -> SHARD_INFO_CODEC.fromJson(slice.getBytes())) + .collect(toImmutableList()); + ShardDelta delta = new ShardDelta(ImmutableList.of(), newShards); + return ImmutableList.of(Slices.wrappedBuffer(SHARD_DELTA_CODEC.toJsonBytes(delta))); + })); + + return allOf(futures.toArray(CompletableFuture[]::new)) + .thenApply(ignored -> futures.stream() + .map(CompletableFuture::join) + .flatMap(Collection::stream) + .collect(toUnmodifiableList())); + } + + @Override + public void abort() + { + pageSink.abort(); + } +} diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeTableHandle.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeTableHandle.java new file mode 100644 index 000000000000..4240bd5c0e67 --- /dev/null +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMergeTableHandle.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.raptor.legacy; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.ConnectorMergeTableHandle; + +import static java.util.Objects.requireNonNull; + +public class RaptorMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final RaptorTableHandle tableHandle; + private final RaptorInsertTableHandle insertTableHandle; + + @JsonCreator + public RaptorMergeTableHandle( + @JsonProperty RaptorTableHandle tableHandle, + @JsonProperty RaptorInsertTableHandle insertTableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertTableHandle = requireNonNull(insertTableHandle, "insertTableHandle is null"); + } + + @Override + @JsonProperty + public RaptorTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + public RaptorInsertTableHandle getInsertTableHandle() + { + return insertTableHandle; + } +} diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMetadata.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMetadata.java index c5cda14f7bb9..5ca8151e63ce 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMetadata.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorMetadata.java @@ -38,6 +38,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -52,6 +53,7 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.SystemTable; @@ -64,6 +66,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -78,6 +81,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.json.JsonCodec.jsonCodec; @@ -87,6 +91,7 @@ import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.SHARD_UUID_COLUMN_TYPE; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.bucketNumberColumnHandle; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isHiddenColumn; +import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.mergeRowIdHandle; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.shardRowIdHandle; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.shardUuidColumnHandle; import static io.trino.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_ERROR; @@ -114,6 +119,7 @@ import static io.trino.spi.StandardErrorCode.NOT_FOUND; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; +import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.connector.SortOrder.ASC_NULLS_FIRST; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.IntegerType.INTEGER; @@ -373,13 +379,11 @@ public Optional getNewTableLayout(ConnectorSession session .map(RaptorColumnHandle::getColumnName) .collect(toList()); - ConnectorPartitioningHandle partitioning = getPartitioningHandle(distribution.get().getDistributionId()); - return Optional.of(new ConnectorTableLayout(partitioning, partitionColumns)); - } + long distributionId = distribution.get().getDistributionId(); + List bucketAssignments = shardManager.getBucketAssignments(distributionId); + ConnectorPartitioningHandle partitioning = new RaptorPartitioningHandle(distributionId, bucketAssignments); - private RaptorPartitioningHandle getPartitioningHandle(long distributionId) - { - return new RaptorPartitioningHandle(distributionId, shardManager.getBucketAssignments(distributionId)); + return Optional.of(new ConnectorTableLayout(partitioning, partitionColumns)); } private Optional getOrCreateDistribution(Map columnHandleMap, Map properties) @@ -717,7 +721,7 @@ public Optional finishCreateTable(ConnectorSession sess } @Override - public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List columns, RetryMode retryMode) + public RaptorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List columns, RetryMode retryMode) { if (retryMode != NO_RETRIES) { throw new TrinoException(NOT_SUPPORTED, "This connector does not support query retries"); @@ -832,33 +836,76 @@ public ConnectorTableHandle beginDelete(ConnectorSession session, ConnectorTable public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHandle, Collection fragments) { RaptorTableHandle table = (RaptorTableHandle) tableHandle; - long transactionId = table.getTransactionId().getAsLong(); - long tableId = table.getTableId(); + finishDelete(session, table, table.getTransactionId().orElseThrow(), fragments); + } + + private void finishDelete(ConnectorSession session, RaptorTableHandle tableHandle, long transactionId, Collection fragments) + { + long tableId = tableHandle.getTableId(); List columns = getColumnHandles(session, tableHandle).values().stream() .map(RaptorColumnHandle.class::cast) .map(ColumnInfo::fromHandle).collect(toList()); - ImmutableSet.Builder oldShardUuidsBuilder = ImmutableSet.builder(); - ImmutableList.Builder newShardsBuilder = ImmutableList.builder(); + Set oldShardUuids = new HashSet<>(); + List newShards = new ArrayList<>(); - fragments.stream() - .map(fragment -> SHARD_DELTA_CODEC.fromJson(fragment.getBytes())) - .forEach(delta -> { - oldShardUuidsBuilder.addAll(delta.getOldShardUuids()); - newShardsBuilder.addAll(delta.getNewShards()); - }); + for (Slice fragment : fragments) { + ShardDelta delta = SHARD_DELTA_CODEC.fromJson(fragment.getBytes()); + for (UUID uuid : delta.getOldShardUuids()) { + verify(oldShardUuids.add(uuid), "duplicate old shard: %s", uuid); + } + newShards.addAll(delta.getNewShards()); + } - Set oldShardUuids = oldShardUuidsBuilder.build(); - List newShards = newShardsBuilder.build(); OptionalLong updateTime = OptionalLong.of(session.getStart().toEpochMilli()); - log.info("Finishing delete for tableId %s (removed: %s, rewritten: %s)", tableId, oldShardUuids.size() - newShards.size(), newShards.size()); + log.info("Finishing update for tableId %s (removed: %s, new: %s)", tableId, oldShardUuids.size(), newShards.size()); shardManager.replaceShardUuids(transactionId, tableId, columns, oldShardUuids, newShards, updateTime); clearRollback(); } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return DELETE_ROW_AND_INSERT_ROW; + } + + @Override + public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return mergeRowIdHandle(); + } + + @Override + public Optional getUpdateLayout(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return ((RaptorTableHandle) tableHandle).getDistributionId().map(distributionId -> + new RaptorBucketedUpdateHandle(distributionId, shardManager.getBucketAssignments(distributionId))) + .or(() -> Optional.of(RaptorUnbucketedUpdateHandle.INSTANCE)); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + RaptorTableHandle handle = (RaptorTableHandle) tableHandle; + + beginDeleteForTableId.accept(handle.getTableId()); + + RaptorInsertTableHandle insertHandle = beginInsert(session, handle, ImmutableList.of(), retryMode); + + return new RaptorMergeTableHandle(handle, insertHandle); + } + + @Override + public void finishMerge(ConnectorSession session, ConnectorMergeTableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + RaptorMergeTableHandle handle = (RaptorMergeTableHandle) tableHandle; + long transactionId = handle.getInsertTableHandle().getTransactionId(); + finishDelete(session, handle.getTableHandle(), transactionId, fragments); + } + @Override public void createView(ConnectorSession session, SchemaTableName viewName, ConnectorViewDefinition definition, boolean replace) { diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorNodePartitioningProvider.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorNodePartitioningProvider.java index e2f73d88cb7b..58d129a52eaa 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorNodePartitioningProvider.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorNodePartitioningProvider.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.ToIntFunction; import static com.google.common.collect.Maps.uniqueIndex; @@ -48,8 +49,12 @@ public RaptorNodePartitioningProvider(NodeSupplier nodeSupplier) } @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorPartitioningHandle partitioning) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioning) { + if (partitioning instanceof RaptorUnbucketedUpdateHandle) { + return Optional.empty(); + } + RaptorPartitioningHandle handle = (RaptorPartitioningHandle) partitioning; Map nodesById = uniqueIndex(nodeSupplier.getWorkerNodes(), Node::getNodeIdentifier); @@ -62,7 +67,7 @@ public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transa } bucketToNode.add(node); } - return createBucketNodeMap(bucketToNode.build()); + return Optional.of(createBucketNodeMap(bucketToNode.build())); } @Override @@ -74,6 +79,12 @@ public ToIntFunction getSplitBucketFunction(ConnectorTransaction @Override public BucketFunction getBucketFunction(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorPartitioningHandle partitioning, List partitionChannelTypes, int bucketCount) { + if (partitioning instanceof RaptorUnbucketedUpdateHandle) { + return new RaptorUnbucketedUpdateFunction(bucketCount); + } + if (partitioning instanceof RaptorBucketedUpdateHandle) { + return new RaptorBucketedUpdateFunction(); + } return new RaptorBucketFunction(bucketCount, partitionChannelTypes); } } diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorPageSinkProvider.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorPageSinkProvider.java index 2202ce5e76e2..8b2ab276855d 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorPageSinkProvider.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorPageSinkProvider.java @@ -18,6 +18,8 @@ import io.trino.plugin.raptor.legacy.storage.StorageManagerConfig; import io.trino.spi.PageSorter; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -82,6 +84,16 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa maxBufferSize); } + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + RaptorMergeTableHandle merge = (RaptorMergeTableHandle) mergeHandle; + ConnectorPageSink pageSink = createPageSink(transactionHandle, session, merge.getInsertTableHandle()); + long transactionId = merge.getInsertTableHandle().getTransactionId(); + int columnCount = merge.getInsertTableHandle().getColumnHandles().size(); + return new RaptorMergeSink(pageSink, storageManager, transactionId, columnCount); + } + private static List toColumnIds(List columnHandles) { return columnHandles.stream().map(RaptorColumnHandle::getColumnId).collect(toList()); diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateFunction.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateFunction.java new file mode 100644 index 000000000000..f967f325b1ec --- /dev/null +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateFunction.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.raptor.legacy; + +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.BucketFunction; +import io.trino.spi.type.UuidType; + +public class RaptorUnbucketedUpdateFunction + implements BucketFunction +{ + private final int bucketCount; + + public RaptorUnbucketedUpdateFunction(int bucketCount) + { + this.bucketCount = bucketCount; + } + + @Override + public int getBucket(Page page, int position) + { + Block row = page.getBlock(0).getObject(position, Block.class); + Slice uuid = UuidType.UUID.getSlice(row, 1); // uuid field of row ID + return (uuid.hashCode() & Integer.MAX_VALUE) % bucketCount; + } +} diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateHandle.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateHandle.java new file mode 100644 index 000000000000..83e361b67fc1 --- /dev/null +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/RaptorUnbucketedUpdateHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.raptor.legacy; + +import io.trino.spi.connector.ConnectorPartitioningHandle; + +public enum RaptorUnbucketedUpdateHandle + implements ConnectorPartitioningHandle +{ + INSTANCE +} diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/OrcFileWriter.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/OrcFileWriter.java index 517d712e8bf9..45fda7448415 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/OrcFileWriter.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/OrcFileWriter.java @@ -18,13 +18,23 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.json.JsonCodec; -import io.airlift.slice.Slice; import io.trino.hive.orc.NullMemoryManager; import io.trino.plugin.raptor.legacy.util.SyncingFileSystem; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DateType; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeId; import io.trino.spi.type.VarbinaryType; @@ -66,8 +76,6 @@ import static io.trino.plugin.raptor.legacy.storage.Row.extractRow; import static io.trino.plugin.raptor.legacy.storage.StorageType.arrayOf; import static io.trino.plugin.raptor.legacy.storage.StorageType.mapOf; -import static io.trino.plugin.raptor.legacy.util.Types.isArrayType; -import static io.trino.plugin.raptor.legacy.util.Types.isMapType; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.util.Objects.requireNonNull; @@ -306,28 +314,35 @@ private static StorageType toStorageType(Type type) DecimalType decimalType = (DecimalType) type; return StorageType.decimal(decimalType.getPrecision(), decimalType.getScale()); } - Class javaType = type.getJavaType(); - if (javaType == boolean.class) { + if (type == BooleanType.BOOLEAN) { return StorageType.BOOLEAN; } - if (javaType == long.class) { + if ((type == SmallintType.SMALLINT) || + (type == IntegerType.INTEGER) || + (type == BigintType.BIGINT) || + (type == DateType.DATE) || + type.equals(TimestampType.TIMESTAMP_MILLIS)) { return StorageType.LONG; } - if (javaType == double.class) { + if (type == TinyintType.TINYINT) { + return StorageType.BYTE; + } + if (type == RealType.REAL) { + return StorageType.FLOAT; + } + if (type == DoubleType.DOUBLE) { return StorageType.DOUBLE; } - if (javaType == Slice.class) { - if (type instanceof VarcharType) { - return StorageType.STRING; - } - if (type.equals(VarbinaryType.VARBINARY)) { - return StorageType.BYTES; - } + if (type instanceof VarcharType) { + return StorageType.STRING; + } + if (type == VarbinaryType.VARBINARY) { + return StorageType.BYTES; } - if (isArrayType(type)) { + if (type instanceof ArrayType) { return arrayOf(toStorageType(type.getTypeParameters().get(0))); } - if (isMapType(type)) { + if (type instanceof MapType) { return mapOf(toStorageType(type.getTypeParameters().get(0)), toStorageType(type.getTypeParameters().get(1))); } throw new TrinoException(NOT_SUPPORTED, "Unsupported type: " + type); diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorPageSource.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorPageSource.java index 6decc7e2eff1..1c0131167f6a 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorPageSource.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorPageSource.java @@ -22,9 +22,11 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.UpdatablePageSource; import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; import java.io.IOException; import java.util.BitSet; @@ -44,6 +46,7 @@ import static io.trino.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_ERROR; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -200,7 +203,12 @@ static ColumnAdaptation bucketNumberColumn(OptionalInt bucketNumber) static ColumnAdaptation rowIdColumn() { - return new RowIdColumn(); + return RowIdColumn.INSTANCE; + } + + static ColumnAdaptation mergeRowIdColumn(OptionalInt bucketNumber, UUID shardUuid) + { + return new MergeRowIdColumn(bucketNumber, shardUuid); } static ColumnAdaptation sourceColumn(int index) @@ -239,6 +247,8 @@ public String toString() private static class RowIdColumn implements ColumnAdaptation { + public static final RowIdColumn INSTANCE = new RowIdColumn(); + @Override public Block block(Page sourcePage, long filePosition) { @@ -258,6 +268,36 @@ public String toString() } } + private static class MergeRowIdColumn + implements ColumnAdaptation + { + private final Block bucketNumberValue; + private final Block shardUuidValue; + + public MergeRowIdColumn(OptionalInt bucketNumber, UUID shardUuid) + { + BlockBuilder blockBuilder = INTEGER.createFixedSizeBlockBuilder(1); + bucketNumber.ifPresentOrElse(value -> INTEGER.writeLong(blockBuilder, value), blockBuilder::appendNull); + bucketNumberValue = blockBuilder.build(); + + BlockBuilder builder = UuidType.UUID.createFixedSizeBlockBuilder(1); + UuidType.UUID.writeSlice(builder, javaUuidToTrinoUuid(shardUuid)); + shardUuidValue = builder.build(); + } + + @Override + public Block block(Page sourcePage, long filePosition) + { + Block bucketNumberBlock = new RunLengthEncodedBlock(bucketNumberValue, sourcePage.getPositionCount()); + Block shardUuidBlock = new RunLengthEncodedBlock(shardUuidValue, sourcePage.getPositionCount()); + Block rowIdBlock = RowIdColumn.INSTANCE.block(sourcePage, filePosition); + return RowBlock.fromFieldBlocks( + sourcePage.getPositionCount(), + Optional.empty(), + new Block[] {bucketNumberBlock, shardUuidBlock, rowIdBlock}); + } + } + private static class NullColumn implements ColumnAdaptation { diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorStorageManager.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorStorageManager.java index 31c09acdd44c..efe625b89e48 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorStorageManager.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/RaptorStorageManager.java @@ -106,6 +106,7 @@ import static io.trino.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isBucketNumberColumn; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isHiddenColumn; +import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isMergeRowIdColumn; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isShardRowIdColumn; import static io.trino.plugin.raptor.legacy.RaptorColumnHandle.isShardUuidColumn; import static io.trino.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_ERROR; @@ -310,6 +311,9 @@ private static ColumnAdaptation specialColumnAdaptation(long columnId, UUID shar if (isBucketNumberColumn(columnId)) { return ColumnAdaptation.bucketNumberColumn(bucketNumber); } + if (isMergeRowIdColumn(columnId)) { + return ColumnAdaptation.mergeRowIdColumn(bucketNumber, shardUuid); + } throw new TrinoException(RAPTOR_ERROR, "Invalid column ID: " + columnId); } @@ -322,7 +326,8 @@ public StoragePageSink createStoragePageSink(long transactionId, OptionalInt buc return new RaptorStoragePageSink(transactionId, columnIds, columnTypes, bucketNumber); } - private ShardRewriter createShardRewriter(long transactionId, OptionalInt bucketNumber, UUID shardUuid) + @Override + public ShardRewriter createShardRewriter(long transactionId, OptionalInt bucketNumber, UUID shardUuid) { return rowsToDelete -> { if (rowsToDelete.isEmpty()) { diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java index 1a7182f76c93..a264fed2dfc1 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/Row.java @@ -17,9 +17,21 @@ import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DateType; import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; import io.trino.spi.type.Int128; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -34,8 +46,6 @@ import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; -import static io.trino.plugin.raptor.legacy.util.Types.isArrayType; -import static io.trino.plugin.raptor.legacy.util.Types.isMapType; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.util.Objects.requireNonNull; @@ -147,20 +157,28 @@ private static Object nativeContainerToOrcValue(Type type, Object nativeValue) } return HiveDecimal.create(unscaledValue, decimalType.getScale()); } - if (type.getJavaType() == boolean.class) { + if ((type == BooleanType.BOOLEAN) || + (type == SmallintType.SMALLINT) || + (type == IntegerType.INTEGER) || + (type == BigintType.BIGINT) || + (type == DoubleType.DOUBLE) || + (type == DateType.DATE) || + type.equals(TimestampType.TIMESTAMP_MILLIS)) { return nativeValue; } - if (type.getJavaType() == long.class) { - return nativeValue; + if (type == TinyintType.TINYINT) { + return ((Number) nativeValue).byteValue(); } - if (type.getJavaType() == double.class) { - return nativeValue; + if (type == RealType.REAL) { + return Float.intBitsToFloat(((Number) nativeValue).intValue()); + } + if (type instanceof VarcharType) { + return ((Slice) nativeValue).toStringUtf8(); } - if (type.getJavaType() == Slice.class) { - Slice slice = (Slice) nativeValue; - return type instanceof VarcharType ? slice.toStringUtf8() : slice.getBytes(); + if (type == VarbinaryType.VARBINARY) { + return ((Slice) nativeValue).getBytes(); } - if (isArrayType(type)) { + if (type instanceof ArrayType) { Block arrayBlock = (Block) nativeValue; Type elementType = type.getTypeParameters().get(0); List list = new ArrayList<>(); @@ -169,7 +187,7 @@ private static Object nativeContainerToOrcValue(Type type, Object nativeValue) } return list; } - if (isMapType(type)) { + if (type instanceof MapType) { Block mapBlock = (Block) nativeValue; Type keyType = type.getTypeParameters().get(0); Type valueType = type.getTypeParameters().get(1); diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageManager.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageManager.java index 97d43cb18024..d819c6cc6d48 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageManager.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageManager.java @@ -52,4 +52,9 @@ StoragePageSink createStoragePageSink( List columnIds, List columnTypes, boolean checkSpace); + + ShardRewriter createShardRewriter( + long transactionId, + OptionalInt bucketNumber, + UUID shardUuid); } diff --git a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageType.java b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageType.java index 5e50ca44f3b4..bd733187b2c0 100644 --- a/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageType.java +++ b/plugin/trino-raptor-legacy/src/main/java/io/trino/plugin/raptor/legacy/storage/StorageType.java @@ -20,14 +20,18 @@ import static org.apache.hadoop.hive.serde.serdeConstants.BOOLEAN_TYPE_NAME; import static org.apache.hadoop.hive.serde.serdeConstants.DECIMAL_TYPE_NAME; import static org.apache.hadoop.hive.serde.serdeConstants.DOUBLE_TYPE_NAME; +import static org.apache.hadoop.hive.serde.serdeConstants.FLOAT_TYPE_NAME; import static org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME; import static org.apache.hadoop.hive.serde.serdeConstants.MAP_TYPE_NAME; import static org.apache.hadoop.hive.serde.serdeConstants.STRING_TYPE_NAME; +import static org.apache.hadoop.hive.serde.serdeConstants.TINYINT_TYPE_NAME; public final class StorageType { public static final StorageType BOOLEAN = new StorageType(BOOLEAN_TYPE_NAME); + public static final StorageType BYTE = new StorageType(TINYINT_TYPE_NAME); public static final StorageType LONG = new StorageType(BIGINT_TYPE_NAME); + public static final StorageType FLOAT = new StorageType(FLOAT_TYPE_NAME); public static final StorageType DOUBLE = new StorageType(DOUBLE_TYPE_NAME); public static final StorageType STRING = new StorageType(STRING_TYPE_NAME); public static final StorageType BYTES = new StorageType(BINARY_TYPE_NAME); diff --git a/plugin/trino-raptor-legacy/src/test/java/io/trino/plugin/raptor/legacy/BaseRaptorConnectorTest.java b/plugin/trino-raptor-legacy/src/test/java/io/trino/plugin/raptor/legacy/BaseRaptorConnectorTest.java index f5a654476a03..c982b5d46fb4 100644 --- a/plugin/trino-raptor-legacy/src/test/java/io/trino/plugin/raptor/legacy/BaseRaptorConnectorTest.java +++ b/plugin/trino-raptor-legacy/src/test/java/io/trino/plugin/raptor/legacy/BaseRaptorConnectorTest.java @@ -26,6 +26,7 @@ import io.trino.testng.services.Flaky; import org.intellij.lang.annotations.Language; import org.testng.SkipException; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.time.LocalDate; @@ -51,14 +52,15 @@ import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.testing.assertions.Assert.assertEquals; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.function.Function.identity; +import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toSet; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotNull; @@ -71,6 +73,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_DELETE: + case SUPPORTS_MERGE: case SUPPORTS_CREATE_VIEW: return true; case SUPPORTS_CREATE_SCHEMA: @@ -920,4 +923,211 @@ protected void verifyTableNameLengthFailurePermissible(Throwable e) { assertThat(e).hasMessage("Failed to perform metadata operation"); } + + @Test + public void testMergeMultipleOperationsUnbucketed() + { + String targetTable = "merge_multiple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR)", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + } + + @Test + public void testMergeMultipleOperationsBucketed() + { + String targetTable = "merge_multiple_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR)" + + " WITH (bucket_count=4, bucketed_on=ARRAY['customer'])", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + } + + private void testMergeMultipleOperationsInternal(String targetTable, int targetCustomerCount) + { + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(joining(", ")); + + @Language("SQL") String sql = format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address"; + assertUpdate(sql, targetCustomerCount / 2); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("SELECT * FROM (VALUES %s, %s) AS v(customer, purchases, zipcode, spouse, address)", originalInsertFirstHalf, firstMergeSource)); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(joining(", ")); + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)", + targetCustomerCount * 3 / 2 - 1); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(joining(", ")); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("SELECT * FROM (VALUES %s, %s, %s) AS v(customer, purchases, zipcode, spouse, address)", updatedBeginning, updatedMiddle, updatedEnd)); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeSimpleQueryBucketed() + { + String targetTable = "merge_simple_target_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count=7, bucketed_on=ARRAY['address'])", targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String query = format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + assertUpdate(query, 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + + @Test(dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchFails(String createTableSql) + { + String targetTable = "merge_all_matches_deleted_target_" + randomTableSuffix(); + assertUpdate(format(createTableSql, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + + String sourceTable = "merge_all_matches_deleted_source_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable), 2); + + assertThatThrownBy(() -> computeActual(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address")) + .hasMessage("One MERGE target table row matched more than one source row"); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address", + 1); + assertQuery("SELECT customer, purchases, address FROM " + targetTable, "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['customer'])"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 4, bucketed_on = ARRAY['address'])"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 4, bucketed_on = ARRAY['address', 'purchases', 'customer'])"}}; + } + + @Test(dataProvider = "targetAndSourceWithDifferentBucketing") + public void testMergeWithDifferentBucketing(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + testMergeWithDifferentBucketingInternal(testDescription, createTargetTableSql, createSourceTableSql); + } + + private void testMergeWithDifferentBucketingInternal(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + String targetTable = format("%s_target_%s", testDescription, randomTableSuffix()); + assertUpdate(format(createTargetTableSql, targetTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + String sourceTable = format("%s_source_%s", testDescription, randomTableSuffix()); + assertUpdate(format(createSourceTableSql, sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sql, 4); + + assertQuery("SELECT customer, purchases, address FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentBucketing() + { + return new Object[][] { + { + "target_and_source_with_different_bucketing_counts", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 5, bucketed_on = ARRAY['customer'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['purchases', 'address'])", + }, + { + "target_and_source_with_different_bucketing_columns", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['address'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['customer'])", + }, + { + "target_flat_source_bucketed_by_customer", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['customer'])", + }, + { + "target_bucketed_by_customer_source_flat", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (bucket_count = 3, bucketed_on = ARRAY['customer'])", + "CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", + }, + }; + } + + @Test + public void testMergeOverManySplits() + { + String targetTable = "merge_delete_select_" + randomTableSuffix(); + assertUpdate(format("CREATE TABLE %s (orderkey bigint, custkey bigint, orderstatus varchar(1), totalprice double, orderdate date, orderpriority varchar(15), clerk varchar(15), shippriority integer, comment varchar(79))", targetTable)); + + assertUpdate(format("INSERT INTO %s SELECT * FROM tpch.\"sf0.1\".orders", targetTable), 150000); + + @Language("SQL") String sql = format("MERGE INTO %s t USING (SELECT * FROM tpch.\"sf0.1\".orders) s ON (t.orderkey = s.orderkey)", targetTable) + + " WHEN MATCHED AND mod(s.orderkey, 3) = 0 THEN UPDATE SET totalprice = t.totalprice + s.totalprice" + + " WHEN MATCHED AND mod(s.orderkey, 3) = 1 THEN DELETE"; + + assertUpdate(sql, 100_000); + + assertQuery(format("SELECT count(*) FROM %s t WHERE mod(t.orderkey, 3) = 1", targetTable), "VALUES (0)"); + + assertUpdate("DROP TABLE " + targetTable); + } } diff --git a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsNodePartitioningProvider.java b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsNodePartitioningProvider.java index 82d566f09274..83f30a47ec97 100644 --- a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsNodePartitioningProvider.java +++ b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsNodePartitioningProvider.java @@ -26,6 +26,7 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.ToIntFunction; @@ -52,7 +53,7 @@ public TpcdsNodePartitioningProvider(NodeManager nodeManager, int splitsPerNode) } @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { Set nodes = nodeManager.getRequiredWorkerNodes(); checkState(!nodes.isEmpty(), "No TPCDS nodes available"); @@ -66,7 +67,7 @@ public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transa bucketToNode.add(node); } } - return createBucketNodeMap(bucketToNode.build()); + return Optional.of(createBucketNodeMap(bucketToNode.build())); } @Override diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchNodePartitioningProvider.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchNodePartitioningProvider.java index 722037be0c9a..5fa6ed2492d4 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchNodePartitioningProvider.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchNodePartitioningProvider.java @@ -26,6 +26,7 @@ import io.trino.spi.type.Type; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.ToIntFunction; @@ -50,7 +51,7 @@ public TpchNodePartitioningProvider(NodeManager nodeManager, int splitsPerNode) } @Override - public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + public Optional getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) { Set nodes = nodeManager.getRequiredWorkerNodes(); checkState(!nodes.isEmpty(), "No TPCH nodes available"); @@ -64,7 +65,7 @@ public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transa bucketToNode.add(node); } } - return createBucketNodeMap(bucketToNode.build()); + return Optional.of(createBucketNodeMap(bucketToNode.build())); } @Override diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveMerge.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveMerge.java new file mode 100644 index 000000000000..6206ba756f31 --- /dev/null +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveMerge.java @@ -0,0 +1,740 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.product.hive; + +import io.trino.tempto.assertions.QueryAssert; +import io.trino.tempto.query.QueryExecutor; +import io.trino.tempto.query.QueryResult; +import io.trino.tests.product.hive.util.TemporaryHiveTable; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.tempto.assertions.QueryAssert.Row.row; +import static io.trino.tempto.assertions.QueryAssert.assertThat; +import static io.trino.tests.product.TestGroups.HIVE_TRANSACTIONAL; +import static io.trino.tests.product.hive.BucketingType.BUCKETED_V2; +import static io.trino.tests.product.hive.BucketingType.NONE; +import static io.trino.tests.product.hive.TestHiveTransactionalTable.TEST_TIMEOUT; +import static io.trino.tests.product.hive.TestHiveTransactionalTable.tableName; +import static io.trino.tests.product.hive.TestHiveTransactionalTable.verifySelectForTrinoAndHive; +import static io.trino.tests.product.utils.QueryExecutors.onHive; +import static io.trino.tests.product.utils.QueryExecutors.onTrino; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class TestHiveMerge + extends HiveProductTest +{ + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleSelect() + { + withTemporaryTable("merge_simple_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_simple_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleSelectPartitioned() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address'])", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_simple_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT, dataProvider = "partitionedAndBucketedProvider") + public void testMergeUpdateWithVariousLayouts(boolean partitioned, String bucketing) + { + BucketingType bucketingType = bucketing.isEmpty() ? NONE : BUCKETED_V2; + withTemporaryTable("merge_with_various_formats", true, partitioned, bucketingType, targetTable -> { + StringBuilder builder = new StringBuilder(); + builder.append("CREATE TABLE ") + .append(targetTable) + .append("(customer STRING"); + builder.append(partitioned ? ") PARTITIONED BY (" : ", ") + .append("purchase STRING) "); + if (!bucketing.isEmpty()) { + builder.append(bucketing); + } + builder.append(" STORED AS ORC TBLPROPERTIES ('transactional' = 'true')"); + onHive().executeQuery(builder.toString()); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable)); + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Lou", "limes"), row("Carol", "candles")); + + withTemporaryTable("merge_simple_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Carol_Craig", "candles"), row("Joe", "jellybeans")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT) + public void testMergeUnBucketedUnPartitionedFailure() + { + withTemporaryTable("merge_with_various_formats", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable)); + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Lou", "limes"), row("Carol", "candles")); + + withTemporaryTable("merge_simple_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", "dates"), row("Carol_Craig", "candles"), row("Joe", "jellybeans")); + }); + }); + } + + @DataProvider + public Object[][] partitionedAndBucketedProvider() + { + return new Object[][] { + {false, "CLUSTERED BY (customer) INTO 3 BUCKETS"}, + {false, "CLUSTERED BY (purchase) INTO 4 BUCKETS"}, + {true, ""}, + {true, "CLUSTERED BY (customer) INTO 3 BUCKETS"}, + }; + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsUnbucketedUnpartitioned() + { + withTemporaryTable("merge_multiple", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR) WITH (transactional = true)", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsUnbucketedPartitioned() + { + withTemporaryTable("merge_multiple", true, true, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address', 'customer'])", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeMultipleOperationsBucketedUnpartitioned() + { + withTemporaryTable("merge_multiple", true, false, BUCKETED_V2, targetTable -> { + onHive().executeQuery(format("CREATE TABLE %s (customer STRING, purchases INT, zipcode INT, spouse STRING, address STRING)" + + " CLUSTERED BY(customer, zipcode, address) INTO 4 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", targetTable)); + testMergeMultipleOperationsInternal(targetTable, 32); + }); + } + + private void testMergeMultipleOperationsInternal(String targetTable, int targetCustomerCount) + { + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf)); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + onTrino().executeQuery(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address"); + + QueryResult expectedResult = onTrino().executeQuery(format("SELECT * FROM (VALUES %s, %s) AS v(customer, purchases, zipcode, spouse, address)", originalInsertFirstHalf, firstMergeSource)); + verifyOnTrinoAndHiveFromQueryResults("SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, expectedResult); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert)); + + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + onTrino().executeQuery(format("MERGE INTO %s t USING (SELECT * FROM (VALUES %s)) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)"); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + expectedResult = onTrino().executeQuery(format("SELECT * FROM (VALUES %s, %s, %s) AS v(customer, purchases, zipcode, spouse, address)", updatedBeginning, updatedMiddle, updatedEnd)); + verifyOnTrinoAndHiveFromQueryResults("SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, expectedResult); + } + + private void verifyOnTrinoAndHiveFromQueryResults(String sql, QueryResult expectedResult) + { + QueryResult trinoResult = onTrino().executeQuery(sql); + assertThat(trinoResult).contains(getRowsFromQueryResult(expectedResult)); + QueryResult hiveResult = onHive().executeQuery(sql); + assertThat(hiveResult).contains(getRowsFromQueryResult(expectedResult)); + } + + private List getRowsFromQueryResult(QueryResult result) + { + return result.rows().stream().map(QueryAssert.Row::new).collect(toImmutableList()); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleQuery() + { + withTemporaryTable("merge_simple_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllInserts() + { + withTemporaryTable("merge_all_inserts", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Antioch"), row("Bill", 7, "Buena"), row("Carol", 9, "Centreville"), row("Dave", 22, "Darbyshire")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSimpleQueryPartitioned() + { + withTemporaryTable("merge_simple_target", true, true, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true, partitioned_by = ARRAY['address'])", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + String query = format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + " " + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + onTrino().executeQuery(query); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllColumnsUpdated() + { + withTemporaryTable("merge_all_columns_updated_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable)); + + withTemporaryTable("merge_all_columns_updated_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave_updated", 22, "Darbyshire"), row("Aaron_updated", 11, "Arches"), row("Bill", 7, "Buena"), row("Carol_updated", 12, "Centreville")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeAllMatchesDeleted() + { + withTemporaryTable("merge_all_matches_deleted_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("merge_all_matches_deleted_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN DELETE"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Bill", 7, "Buena")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000, dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchFails(String createTableSql) + { + withTemporaryTable("merge_all_matches_deleted_target", true, true, NONE, targetTable -> { + onHive().executeQuery(format(createTableSql, targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable)); + + withTemporaryTable("merge_all_matches_deleted_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable)); + + assertThat(() -> onTrino().executeQuery(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address")) + .failsWithMessage("One MERGE target table row matched more than one source row"); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address"); + verifySelectForTrinoAndHive("SELECT customer, purchases, address FROM " + targetTable, "TRUE", row("Aaron", 5, "Adelphi"), row("Bill", 7, "Antioch")); + }); + }); + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (customer STRING, purchases INT, address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"}, + {"CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) CLUSTERED BY (address) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"} + }; + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeFailingPartitioning() + { + String testDescription = "failing_merge"; + withTemporaryTable(format("%s_target", testDescription), true, true, NONE, targetTable -> { + onHive().executeQuery(format("CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable(format("%s_source", testDescription), true, true, NONE, sourceTable -> { + onHive().executeQuery(format("CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT customer, purchases, address FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeFailureWithDifferentPartitioning() + { + testMergeWithDifferentPartitioningInternal( + "target_partitioned_source_partitioned_and_bucketed", + "CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')"); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000, dataProvider = "targetAndSourceWithDifferentPartitioning") + public void testMergeWithDifferentPartitioning(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + testMergeWithDifferentPartitioningInternal(testDescription, createTargetTableSql, createSourceTableSql); + } + + private void testMergeWithDifferentPartitioningInternal(String testDescription, String createTargetTableSql, String createSourceTableSql) + { + withTemporaryTable(format("%s_target", testDescription), true, true, NONE, targetTable -> { + onHive().executeQuery(format(createTargetTableSql, targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable(format("%s_source", testDescription), true, true, NONE, sourceTable -> { + onHive().executeQuery(format(createSourceTableSql, sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + String sql = format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive("SELECT customer, purchases, address FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Ed", 7, "Etherville"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire")); + }); + }); + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentPartitioning() + { + return new Object[][] { + { + "target_partitioned_source_and_target_partitioned_and_bucketed", + "CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", + }, + { + "target_flat_source_partitioned_by_customer", + "CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')" + }, + { + "target_partitioned_by_customer_source_flat", + "CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + }, + { + "target_bucketed_by_customer_source_flat", + "CREATE TABLE %s (customer STRING, purchases INT, address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (customer STRING, purchases INT, address STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + }, + { + "target_partitioned_source_partitioned_and_bucketed", + "CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", + }, + { + "target_partitioned_target_partitioned_and_bucketed", + "CREATE TABLE %s (customer STRING, purchases INT) PARTITIONED BY (address STRING) CLUSTERED BY (customer) INTO 3 BUCKETS STORED AS ORC TBLPROPERTIES ('transactional'='true')", + "CREATE TABLE %s (purchases INT, address STRING) PARTITIONED BY (customer STRING) STORED AS ORC TBLPROPERTIES ('transactional'='true')", + } + }; + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeQueryWithStrangeCapitalization() + { + withTemporaryTable("test_without_aliases_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING ", targetTable.toUpperCase(ENGLISH)) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purCHases = s.PurchaseS + t.pUrchases, aDDress = s.addrESs" + + " WHEN NOT MATCHED THEN INSERT (CUSTOMER, purchases, addRESS) VALUES(s.custoMer, s.Purchases, s.ADDress)"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeWithoutTablesAliases() + { + withTemporaryTable("test_without_aliases_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (cusTomer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("test_without_aliases_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s USING %s", targetTable, sourceTable) + + format(" ON (%s.customer = %s.customer)", targetTable, sourceTable) + + format(" WHEN MATCHED AND %s.address = 'Centreville' THEN DELETE", sourceTable) + + format(" WHEN MATCHED THEN UPDATE SET purchases = %s.pURCHases + %s.pUrchases, aDDress = %s.addrESs", sourceTable, targetTable, sourceTable) + + format(" WHEN NOT MATCHED THEN INSERT (cusTomer, purchases, addRESS) VALUES(%s.custoMer, %s.Purchases, %s.ADDress)", sourceTable, sourceTable, sourceTable)); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 22, "Darbyshire"), row("Ed", 7, "Etherville")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeWithUnpredictablePredicates() + { + withTemporaryTable("test_without_aliases_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (cusTomer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable)); + + withTemporaryTable("test_without_aliases_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer AND s.purchases < 10.2" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", + row("Aaron", 11, "Arches"), row("Bill", 7, "Buena"), row("Dave", 11, "Darbyshire"), row("Dave", 11, "Devon"), row("Ed", 7, "Etherville")); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + " THEN DELETE" + + " WHEN MATCHED" + + " THEN UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address)" + + " WHEN NOT MATCHED" + + " THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", + row("Aaron", 17, "Arches/Arches"), row("Bill", 7, "Buena"), row("Carol", 9, "Centreville"), row("Dave", 22, "Darbyshire/Darbyshire"), row("Ed", 14, "Etherville/Etherville")); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES('Fred', 30, 'Franklin')", targetTable)); + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", + row("Aaron", 17, "Arches/Arches"), row("Bill", 7, "Buena"), row("Carol", 9, "Centreville"), row("Dave", 22, "Darbyshire/Darbyshire"), row("Ed", 14, "Etherville/Etherville"), row("Fred", 30, "Franklin")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeWithSimplifiedUnpredictablePredicates() + { + withTemporaryTable("test_without_aliases_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address)" + + " VALUES ('Dave', 11, 'Devon'), ('Dave', 11, 'Darbyshire')", targetTable)); + + withTemporaryTable("test_without_aliases_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + " THEN DELETE"); + + // BUG: The actual row are [Dave, 11, Devon]. Why did the wrong one get deleted? + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("Dave", 11, "Darbyshire")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeCasts() + { + withTemporaryTable("merge_cast_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (col1 TINYINT, col2 SMALLINT, col3 INT, col4 BIGINT, col5 REAL, col6 DOUBLE) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (1, 2, 3, 4, 5, 6)", targetTable)); + + withTemporaryTable("test_without_aliases_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (col1 DOUBLE, col2 REAL, col3 BIGINT, col4 INT, col5 SMALLINT, col6 TINYINT) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s VALUES (2, 3, 4, 5, 6, 7)", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON (t.col1 + 1 = s.col1)" + + " WHEN MATCHED THEN UPDATE SET col1 = s.col1, col2 = s.col2, col3 = s.col3, col4 = s.col4, col5 = s.col5, col6 = s.col6"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row(2, 3, 4, 5, 6.0, 7.0)); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeSubqueries() + { + withTemporaryTable("merge_nation_target", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable)); + + withTemporaryTable("merge_nation_source", true, false, NONE, sourceTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR) WITH (transactional = true)", sourceTable)); + + onTrino().executeQuery(format("INSERT INTO %s VALUES ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE'), ('EGYPT', 'MIDDLE EAST'), ('RUSSIA', 'EUROPE')", sourceTable)); + + onTrino().executeQuery(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON (t.nation_name = s.nation_name)" + + " WHEN MATCHED AND t.nation_name > (SELECT name FROM tpch.tiny.region WHERE name = t.region_name AND name LIKE ('A%'))" + + " THEN DELETE" + + " WHEN NOT MATCHED AND s.region_name = 'EUROPE'" + + " THEN INSERT VALUES(s.nation_name, (SELECT 'EUROPE'))"); + + verifySelectForTrinoAndHive("SELECT * FROM " + targetTable, "TRUE", row("FRANCE", "EUROPE"), row("GERMANY", "EUROPE"), row("RUSSIA", "EUROPE")); + }); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = 60 * 60 * 1000) + public void testMergeOriginalFilesTarget() + { + withTemporaryTable("region", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s WITH (transactional=true) AS TABLE tpch.tiny.region", targetTable)); + + // This merge is illegal, because many nations have the same region + assertThat(() -> onTrino().executeQuery(format("MERGE INTO %s r USING tpch.tiny.nation n", targetTable) + + " ON r.regionkey = n.regionkey" + + " WHEN MATCHED" + + " THEN UPDATE SET comment = n.comment")) + .failsWithMessage("One MERGE target table row matched more than one source row"); + + onTrino().executeQuery(format("MERGE INTO %s r USING tpch.tiny.nation n", targetTable) + + " ON r.regionkey = n.regionkey AND n.name = 'FRANCE'" + + " WHEN MATCHED" + + " THEN UPDATE SET name = 'EUROPEAN'"); + + verifySelectForTrinoAndHive("SELECT name FROM " + targetTable, "name LIKE('EU%')", row("EUROPEAN")); + }); + } + + @Test(groups = HIVE_TRANSACTIONAL, timeOut = TEST_TIMEOUT) + public void testMergeOverManySplits() + { + withTemporaryTable("delete_select", true, false, NONE, targetTable -> { + onTrino().executeQuery(format("CREATE TABLE %s (orderkey bigint, custkey bigint, orderstatus varchar(1), totalprice double, orderdate date, orderpriority varchar(15), clerk varchar(15), shippriority integer, comment varchar(79)) WITH (transactional = true)", targetTable)); + + onTrino().executeQuery(format("INSERT INTO %s SELECT * FROM tpch.\"sf0.1\".orders", targetTable)); + + String sql = format("MERGE INTO %s t USING (SELECT * FROM tpch.\"sf0.1\".orders) s ON (t.orderkey = s.orderkey)", targetTable) + + " WHEN MATCHED AND mod(s.orderkey, 3) = 0 THEN UPDATE SET totalprice = t.totalprice + s.totalprice" + + " WHEN MATCHED AND mod(s.orderkey, 3) = 1 THEN DELETE"; + + onTrino().executeQuery(sql); + + verifySelectForTrinoAndHive(format("SELECT count(*) FROM %s t", targetTable), "mod(t.orderkey, 3) = 1", row(0)); + }); + } + + @DataProvider + public Object[][] insertersProvider() + { + return new Object[][] { + {false, Engine.HIVE, Engine.TRINO}, + {false, Engine.TRINO, Engine.TRINO}, + {true, Engine.HIVE, Engine.TRINO}, + {true, Engine.TRINO, Engine.TRINO}, + }; + } + + private static QueryResult execute(Engine engine, String sql, QueryExecutor.QueryParam... params) + { + return engine.queryExecutor().executeQuery(sql, params); + } + + @DataProvider + public Object[][] inserterAndDeleterProvider() + { + return new Object[][] { + {Engine.HIVE, Engine.TRINO}, + {Engine.TRINO, Engine.TRINO}, + {Engine.TRINO, Engine.HIVE} + }; + } + + void withTemporaryTable(String rootName, boolean transactional, boolean isPartitioned, BucketingType bucketingType, Consumer testRunner) + { + if (transactional) { + ensureTransactionalHive(); + } + try (TemporaryHiveTable table = TemporaryHiveTable.temporaryHiveTable(tableName(rootName, isPartitioned, bucketingType))) { + testRunner.accept(table.getName()); + } + } + + private void ensureTransactionalHive() + { + if (getHiveVersionMajor() < 3) { + throw new SkipException("Hive transactional tables are supported with Hive version 3 or above"); + } + } +} diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java index b8c727b497b7..128af14bf6d8 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveTransactionalTable.java @@ -19,7 +19,7 @@ import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.plugin.hive.metastore.thrift.ThriftHiveMetastoreClient; -import io.trino.tempto.assertions.QueryAssert; +import io.trino.tempto.assertions.QueryAssert.Row; import io.trino.tempto.hadoop.hdfs.HdfsClient; import io.trino.tempto.query.QueryExecutor; import io.trino.tempto.query.QueryResult; @@ -80,7 +80,7 @@ public class TestHiveTransactionalTable { private static final Logger log = Logger.get(TestHiveTransactionalTable.class); - private static final int TEST_TIMEOUT = 15 * 60 * 1000; + public static final int TEST_TIMEOUT = 15 * 60 * 1000; // Hive original file path end looks like /000000_0 // New Trino original file path end looks like /000000_132574635756428963553891918669625313402 @@ -2172,7 +2172,7 @@ private static Stream> mapRows(QueryResult result) return rows.build().stream(); } - private static String tableName(String testName, boolean isPartitioned, BucketingType bucketingType) + public static String tableName(String testName, boolean isPartitioned, BucketingType bucketingType) { return format("test_%s_%b_%s_%s", testName, isPartitioned, bucketingType.name(), randomTableSuffix()); } @@ -2210,13 +2210,13 @@ private void ensureSchemaEvolutionSupported() } } - private static void verifySelectForTrinoAndHive(String select, String whereClause, QueryAssert.Row... rows) + public static void verifySelectForTrinoAndHive(String select, String whereClause, Row... rows) { verifySelect("onTrino", onTrino(), select, whereClause, rows); verifySelect("onHive", onHive(), select, whereClause, rows); } - private static void verifySelect(String name, QueryExecutor executor, String select, String whereClause, QueryAssert.Row... rows) + public static void verifySelect(String name, QueryExecutor executor, String select, String whereClause, Row... rows) { String fullQuery = format("%s WHERE %s", select, whereClause); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorSmokeTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorSmokeTest.java index 519199fb8c84..ff5f98d210f7 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorSmokeTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorSmokeTest.java @@ -29,6 +29,7 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_CREATE_VIEW; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_DELETE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_INSERT; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_MERGE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_RENAME_SCHEMA; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_RENAME_TABLE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS; @@ -144,6 +145,11 @@ protected String getCreateTableDefaultDefinition() return "(a bigint, b double)"; } + protected String expectedValues(String values) + { + return format("SELECT CAST(a AS bigint), CAST(b AS double) FROM (VALUES %s) AS t (a, b)", values); + } + @Test public void testCreateTableAsSelect() { @@ -172,9 +178,9 @@ public void testInsert() } try (TestTable table = new TestTable(getQueryRunner()::execute, "test_insert_", getCreateTableDefaultDefinition())) { - assertUpdate("INSERT INTO " + table.getName() + " (a, b) VALUES (42, -38.5)", 1); + assertUpdate("INSERT INTO " + table.getName() + " (a, b) VALUES (42, -38.5), (13, 99.9)", 2); assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName())) - .matches("VALUES (BIGINT '42', -385e-1)"); + .matches(expectedValues("(42, -38.5), (13, 99.9)")); } } @@ -241,12 +247,50 @@ public void testUpdate() return; } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update", "AS TABLE tpch.tiny.nation")) { - String tableName = table.getName(); - assertUpdate("UPDATE " + tableName + " SET nationkey = 100 + nationkey WHERE regionkey = 2", 5); - assertThat(query("SELECT * FROM " + tableName)) - .skippingTypesCheck() - .matches("SELECT IF(regionkey=2, nationkey + 100, nationkey) nationkey, name, regionkey, comment FROM tpch.tiny.nation"); + if (!hasBehavior(SUPPORTS_INSERT)) { + throw new AssertionError("Cannot test UPDATE without INSERT"); + } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_", getCreateTableDefaultDefinition())) { + assertUpdate("INSERT INTO " + table.getName() + " (a, b) SELECT regionkey, regionkey * 2.5 FROM region", "SELECT count(*) FROM region"); + assertThat(query("SELECT a, b FROM " + table.getName())) + .matches(expectedValues("(0, 0.0), (1, 2.5), (2, 5.0), (3, 7.5), (4, 10.0)")); + + assertUpdate("UPDATE " + table.getName() + " SET b = b + 1.2 WHERE a % 2 = 0", 3); + assertThat(query("SELECT a, b FROM " + table.getName())) + .matches(expectedValues("(0, 1.2), (1, 2.5), (2, 6.2), (3, 7.5), (4, 11.2)")); + } + } + + @Test + public void testMerge() + { + if (!hasBehavior(SUPPORTS_MERGE)) { + // Note this change is a no-op, if actually run + assertQueryFails("MERGE INTO nation n USING nation s ON (n.nationkey = s.nationkey) " + + "WHEN MATCHED AND n.regionkey < 1 THEN UPDATE SET nationkey = 5", + "This connector does not support merges"); + return; + } + + if (!hasBehavior(SUPPORTS_INSERT)) { + throw new AssertionError("Cannot test MERGE without INSERT"); + } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_merge_", getCreateTableDefaultDefinition())) { + assertUpdate("INSERT INTO " + table.getName() + " (a, b) SELECT regionkey, regionkey * 2.5 FROM region", "SELECT count(*) FROM region"); + assertThat(query("SELECT a, b FROM " + table.getName())) + .matches(expectedValues("(0, 0.0), (1, 2.5), (2, 5.0), (3, 7.5), (4, 10.0)")); + + assertUpdate("MERGE INTO " + table.getName() + " t " + + "USING (VALUES (0, 1.3), (2, 2.9), (3, 0.0), (4, -5.0), (5, 5.7)) AS s (a, b) " + + "ON (t.a = s.a) " + + "WHEN MATCHED AND s.b > 0 THEN UPDATE SET b = t.b + s.b " + + "WHEN MATCHED AND s.b = 0 THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT VALUES (s.a, s.b)", + 4); + assertThat(query("SELECT a, b FROM " + table.getName())) + .matches(expectedValues("(0, 1.3), (1, 2.5), (2, 7.9), (4, 10.0), (5, 5.7)")); } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 6af8d2c03852..2bec3419d10d 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -57,6 +57,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import java.util.regex.Pattern; +import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -92,6 +93,7 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_DELETE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_DROP_COLUMN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_INSERT; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_MERGE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_MULTI_STATEMENT_WRITES; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NEGATIVE_DATE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NOT_NULL_CONSTRAINT; @@ -3740,6 +3742,472 @@ public void testPotentialDuplicateDereferencePushdown() } } + protected String createTableForWrites(String createTable) + { + return createTable; + } + + @Test + public void testMergeLarge() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE) && hasBehavior(SUPPORTS_INSERT)); + + String tableName = "test_merge_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (orderkey BIGINT, custkey BIGINT, totalprice DOUBLE)", tableName))); + + assertUpdate( + format("INSERT INTO %s SELECT orderkey, custkey, totalprice FROM tpch.sf1.orders", tableName), + (long) computeActual("SELECT count(*) FROM tpch.sf1.orders").getOnlyValue()); + + @Language("SQL") String mergeSql = "" + + "MERGE INTO " + tableName + " t USING (SELECT * FROM tpch.sf1.orders) s ON (t.orderkey = s.orderkey)\n" + + "WHEN MATCHED AND mod(s.orderkey, 3) = 0 THEN UPDATE SET totalprice = t.totalprice + s.totalprice\n" + + "WHEN MATCHED AND mod(s.orderkey, 3) = 1 THEN DELETE"; + + assertUpdate(mergeSql, 1_000_000); + + // verify deleted rows + assertQuery("SELECT count(*) FROM " + tableName + " WHERE mod(orderkey, 3) = 1", "SELECT 0"); + + // verify untouched rows + assertEquals( + computeActual("SELECT count(*), cast(sum(totalprice) AS decimal(18,2)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2"), + computeActual("SELECT count(*), cast(sum(totalprice) AS decimal(18,2)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2")); + + // verify updated rows + assertEquals( + computeActual("SELECT count(*), cast(sum(totalprice) AS decimal(18,2)) FROM " + tableName + " WHERE mod(orderkey, 3) = 0"), + computeActual("SELECT count(*), cast(sum(totalprice * 2) AS decimal(18,2)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0")); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testMergeSimpleSelect() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_simple_target_" + randomTableSuffix(); + String sourceTable = "merge_simple_source_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)", 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeFruits() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_various_target_" + randomTableSuffix(); + String sourceTable = "merge_various_source_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.purchase = s.purchase)", targetTable, sourceTable) + + " WHEN MATCHED AND s.purchase = 'limes' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)", 3); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeMultipleOperations() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + int targetCustomerCount = 32; + String targetTable = "merge_multiple_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR)", targetTable))); + + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address", + targetCustomerCount / 2); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s", originalInsertFirstHalf, firstMergeSource)); + + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" + + " WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" + + " WHEN MATCHED THEN UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)", + targetCustomerCount * 3 / 2 - 1); + + String updatedBeginning = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 60000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedMiddle = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + String updatedEnd = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertQuery( + "SELECT customer, purchases, zipcode, spouse, address FROM " + targetTable, + format("VALUES %s, %s, %s", updatedBeginning, updatedMiddle, updatedEnd)); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeSimpleQuery() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_query_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address)" + + " ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)", + 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeAllInserts() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_inserts_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); + + assertUpdate(format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + + " ON (t.customer = s.customer)" + + " WHEN NOT MATCHED THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)", + 2); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeAllColumnsUpdated() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); + String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address", + 3); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave_updated', 22, 'Darbyshire'), ('Aaron_updated', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol_updated', 12, 'Centreville')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeAllMatchesDeleted() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_all_matches_deleted_target_" + randomTableSuffix(); + String sourceTable = "merge_all_matches_deleted_source_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN DELETE", + 3); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Bill', 7, 'Buena')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeMultipleRowsMatchFails() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_multiple_fail_target_" + randomTableSuffix(); + String sourceTable = "merge_multiple_fail_source_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Aaron', 6, 'Adelphi'), (2, 'Aaron', 8, 'Ashland')", sourceTable), 2); + + assertQueryFails(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED THEN UPDATE SET address = s.address", + "One MERGE target table row matched more than one source row"); + + assertUpdate(format("MERGE INTO %s t USING %s s ON (t.customer = s.customer)", targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Adelphi' THEN UPDATE SET address = s.address", + 1); + assertQuery("SELECT customer, purchases, address FROM " + targetTable, "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeQueryWithStrangeCapitalization() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_strange_capitalization_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(format("MERGE INTO %s t USING ", targetTable.toUpperCase(ENGLISH)) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purCHases = s.PurchaseS + t.pUrchases, aDDress = s.addrESs" + + " WHEN NOT MATCHED THEN INSERT (CUSTOMER, purchases, addRESS) VALUES(s.custoMer, s.Purchases, s.ADDress)", + 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeWithoutTablesAliases() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "test_without_aliases_target_" + randomTableSuffix(); + String sourceTable = "test_without_aliases_source_" + randomTableSuffix(); + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s USING %s", targetTable, sourceTable) + + format(" ON (%s.customer = %s.customer)", targetTable, sourceTable) + + format(" WHEN MATCHED AND %s.address = 'Centreville' THEN DELETE", sourceTable) + + format(" WHEN MATCHED THEN UPDATE SET purchases = %s.pURCHases + %s.pUrchases, aDDress = %s.addrESs", sourceTable, targetTable, sourceTable) + + format(" WHEN NOT MATCHED THEN INSERT (cusTomer, purchases, addRESS) VALUES(%s.custoMer, %s.Purchases, %s.ADDress)", sourceTable, sourceTable, sourceTable), + 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeWithUnpredictablePredicates() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_predicates_target_" + randomTableSuffix(); + String sourceTable = "merge_predicates_source_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Aaron', 5, 'Antioch'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 3, 'Cambridge'), (4, 'Dave', 11, 'Devon')", targetTable), 4); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (5, 'Aaron', 6, 'Arches'), (6, 'Carol', 9, 'Centreville'), (7, 'Dave', 11, 'Darbyshire'), (8, 'Ed', 7, 'Etherville')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer AND s.purchases < 10.2" + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (id, customer, purchases, address) VALUES (s.id, s.customer, s.purchases, s.address)", + 4); + + assertQuery("SELECT * FROM " + targetTable, "VALUES (1, 'Aaron', 11, 'Arches'), (2, 'Bill', 7, 'Buena'), (7, 'Dave', 11, 'Darbyshire'), (4, 'Dave', 11, 'Devon'), (8, 'Ed', 7, 'Etherville')"); + + assertUpdate(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + " THEN DELETE" + + " WHEN MATCHED" + + " THEN UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address)" + + " WHEN NOT MATCHED" + + " THEN INSERT (id, customer, purchases, address) VALUES (s.id, s.customer, s.purchases, s.address)", + 5); + + assertQuery( + "SELECT * FROM " + targetTable, + "VALUES (1, 'Aaron', 17, 'Arches/Arches'), (2, 'Bill', 7, 'Buena'), (6, 'Carol', 9, 'Centreville'), (7, 'Dave', 22, 'Darbyshire/Darbyshire'), (8, 'Ed', 14, 'Etherville/Etherville')"); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (9, 'Fred', 30, 'Franklin')", targetTable), 1); + assertQuery( + "SELECT * FROM " + targetTable, + "VALUES (1, 'Aaron', 17, 'Arches/Arches'), (2, 'Bill', 7, 'Buena'), (6, 'Carol', 9, 'Centreville'), (7, 'Dave', 22, 'Darbyshire/Darbyshire'), (8, 'Ed', 14, 'Etherville/Etherville'), (9, 'Fred', 30, 'Franklin')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeWithSimplifiedUnpredictablePredicates() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_predicates_target_" + randomTableSuffix(); + String sourceTable = "merge_predicates_source_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Dave', 11, 'Devon'), (2, 'Dave', 11, 'Darbyshire')", targetTable), 2); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire')", sourceTable), 1); + + assertUpdate(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON t.customer = s.customer" + + " WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + " THEN DELETE", + 1); + + assertQuery("SELECT * FROM " + targetTable, "VALUES (2, 'Dave', 11, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeCasts() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_cast_target_" + randomTableSuffix(); + String sourceTable = "merge_cast_source_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (col1 INT, col2 DOUBLE, col3 INT, col4 BIGINT, col5 REAL, col6 DOUBLE)", targetTable))); + + assertUpdate(format("INSERT INTO %s VALUES (1, 2, 3, 4, 5, 6)", targetTable), 1); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (col1 BIGINT, col2 REAL, col3 DOUBLE, col4 INT, col5 INT, col6 REAL)", sourceTable))); + + assertUpdate(format("INSERT INTO %s VALUES (2, 3, 4, 5, 6, 7)", sourceTable), 1); + + assertUpdate(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON (t.col1 + 1 = s.col1)" + + " WHEN MATCHED THEN UPDATE SET col1 = s.col1, col2 = s.col2, col3 = s.col3, col4 = s.col4, col5 = s.col5, col6 = s.col6", + 1); + + assertQuery("SELECT * FROM " + targetTable, "VALUES (2, 3.0, 4, 5, 6.0, 7.0)"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + + @Test + public void testMergeSubqueries() + { + skipTestUnless(hasBehavior(SUPPORTS_MERGE)); + + String targetTable = "merge_nation_target_" + randomTableSuffix(); + String sourceTable = "merge_nation_source_" + randomTableSuffix(); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable))); + + assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable), 3); + + assertUpdate(createTableForWrites(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable))); + + assertUpdate(format("INSERT INTO %s VALUES ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE'), ('EGYPT', 'MIDDLE EAST'), ('RUSSIA', 'EUROPE')", sourceTable), 4); + + assertUpdate(format("MERGE INTO %s t USING %s s", targetTable, sourceTable) + + " ON (t.nation_name = s.nation_name)" + + " WHEN MATCHED AND t.nation_name > (SELECT name FROM tpch.tiny.region WHERE name = t.region_name AND name LIKE ('A%'))" + + " THEN DELETE" + + " WHEN NOT MATCHED AND s.region_name = 'EUROPE'" + + " THEN INSERT VALUES(s.nation_name, (SELECT 'EUROPE'))", + 2); + + assertQuery("SELECT * FROM " + targetTable, "VALUES ('FRANCE', 'EUROPE'), ('GERMANY', 'EUROPE'), ('RUSSIA', 'EUROPE')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + private void verifyUnsupportedTypeException(Throwable exception, String trinoTypeName) { String typeNameBase = trinoTypeName.replaceFirst("\\(.*", ""); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java index 0684c84cc375..9a082b625770 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java @@ -84,6 +84,8 @@ public enum TestingConnectorBehavior SUPPORTS_UPDATE(false), + SUPPORTS_MERGE(false), + SUPPORTS_TRUNCATE(false), SUPPORTS_ARRAY,