diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java index b02661233898..05396a4f1b3c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -14,20 +14,15 @@ package io.trino.operator; import io.trino.spi.Page; -import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.ColumnarRow; import io.trino.spi.block.RunLengthEncodedBlock; import java.util.ArrayList; import java.util.List; -import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION; import static io.trino.spi.block.ColumnarRow.toColumnarRow; -import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; -import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static io.trino.spi.type.TinyintType.TINYINT; import static java.util.Objects.requireNonNull; @@ -47,24 +42,18 @@ public class ChangeOnlyUpdatedColumnsMergeProcessor private final int rowIdChannel; private final int mergeRowChannel; private final List dataColumnChannels; - private final List dataColumnNames; private final int writeRedistributionColumnCount; - private final Set nonNullColumnChannels; public ChangeOnlyUpdatedColumnsMergeProcessor( int rowIdChannel, int mergeRowChannel, List dataColumnChannels, - List dataColumnNames, - List redistributionColumnChannels, - Set nonNullColumnChannels) + List redistributionColumnChannels) { this.rowIdChannel = rowIdChannel; this.mergeRowChannel = mergeRowChannel; this.dataColumnChannels = requireNonNull(dataColumnChannels, "dataColumnChannels is null"); - this.dataColumnNames = requireNonNull(dataColumnNames, "dataColumnNames is null"); this.writeRedistributionColumnCount = redistributionColumnChannels.size(); - this.nonNullColumnChannels = requireNonNull(nonNullColumnChannels, "nonNullColumnChannels is null"); } @Override @@ -94,20 +83,6 @@ public Page transformPage(Page inputPage) Page result = new Page(builder.toArray(Block[]::new)); - for (int nonNullColumnChannel : nonNullColumnChannels) { - Block block = result.getBlock(nonNullColumnChannel); - if (block.mayHaveNull()) { - for (int position = 0; position < positionCount; position++) { - long operation = TINYINT.getLong(operationChannelBlock, position); - if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { - if (block.isNull(position)) { - throw new TrinoException(CONSTRAINT_VIOLATION, "Assigning NULL to non-null MERGE target table column " + dataColumnNames.get(nonNullColumnChannel)); - } - } - } - } - } - int defaultCaseCount = 0; for (int position = 0; position < positionCount; position++) { if (TINYINT.getLong(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) { diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java index ceecf06a7b66..0d3b8dea0294 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java @@ -15,21 +15,17 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import io.trino.spi.Page; import io.trino.spi.PageBuilder; -import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.ColumnarRow; import io.trino.spi.type.Type; import java.util.List; -import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; -import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION; import static io.trino.spi.block.ColumnarRow.toColumnarRow; import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; @@ -42,27 +38,22 @@ public class DeleteAndInsertMergeProcessor implements MergeRowChangeProcessor { private final List dataColumnTypes; - private final List dataColumnNames; private final Type rowIdType; private final int rowIdChannel; private final int mergeRowChannel; private final List dataColumnChannels; private final int redistributionColumnCount; private final List redistributionChannelNumbers; - private final Set nonNullColumnChannels; public DeleteAndInsertMergeProcessor( List dataColumnTypes, - List dataColumnNames, Type rowIdType, int rowIdChannel, int mergeRowChannel, List redistributionChannelNumbers, - List dataColumnChannels, - Set nonNullColumnChannels) + List dataColumnChannels) { this.dataColumnTypes = requireNonNull(dataColumnTypes, "dataColumnTypes is null"); - this.dataColumnNames = requireNonNull(dataColumnNames, "dataColumnNames is null"); this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); this.rowIdChannel = rowIdChannel; this.mergeRowChannel = mergeRowChannel; @@ -80,7 +71,6 @@ public DeleteAndInsertMergeProcessor( } } this.redistributionChannelNumbers = redistributionChannelNumbersBuilder.build(); - this.nonNullColumnChannels = ImmutableSet.copyOf(requireNonNull(nonNullColumnChannels, "nonNullColumnChannels is null")); } @JsonProperty @@ -150,18 +140,6 @@ public Page transformPage(Page inputPage) } Page page = pageBuilder.build(); - int positionCount = page.getPositionCount(); - for (int nonNullColumnChannel : nonNullColumnChannels) { - Block nonNullBlock = page.getBlock(nonNullColumnChannel); - Block operationBlock = page.getBlock(dataColumnChannels.size()); - if (nonNullBlock.mayHaveNull()) { - for (int position = 0; position < positionCount; position++) { - if (TINYINT.getLong(operationBlock, position) == INSERT_OPERATION_NUMBER && nonNullBlock.isNull(position)) { - throw new TrinoException(CONSTRAINT_VIOLATION, "Assigning NULL to non-null MERGE target table column " + dataColumnNames.get(nonNullColumnChannel)); - } - } - } - } verify(page.getPositionCount() == totalPositions, "page positions (%s) is not equal to (%s)", page.getPositionCount(), totalPositions); return page; } diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java b/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java index 3b3bde9bb872..69843e4d12ed 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/MergeProcessorOperator.java @@ -18,7 +18,6 @@ import io.trino.sql.planner.plan.TableWriterNode.MergeParadigmAndTypes; import java.util.List; -import java.util.Set; import static com.google.common.base.Preconditions.checkState; import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory; @@ -40,29 +39,29 @@ public static OperatorFactory createOperatorFactory( MergeParadigmAndTypes merge, int rowIdChannel, int mergeRowChannel, - List redistributionColumnChannels, - List dataColumnChannels, - Set nonNullColumnChannels) + List redistributionColumns, + List dataColumnChannels) { - MergeRowChangeProcessor rowChangeProcessor = switch (merge.getParadigm()) { + MergeRowChangeProcessor rowChangeProcessor = createRowChangeProcessor(merge, rowIdChannel, mergeRowChannel, redistributionColumns, dataColumnChannels); + return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor)); + } + + private static MergeRowChangeProcessor createRowChangeProcessor(MergeParadigmAndTypes merge, int rowIdChannel, int mergeRowChannel, List redistributionColumnChannels, List dataColumnChannels) + { + return switch (merge.getParadigm()) { case DELETE_ROW_AND_INSERT_ROW -> new DeleteAndInsertMergeProcessor( merge.getColumnTypes(), - merge.getColumnNames(), merge.getRowIdType(), rowIdChannel, mergeRowChannel, redistributionColumnChannels, - dataColumnChannels, - nonNullColumnChannels); + dataColumnChannels); case CHANGE_ONLY_UPDATED_COLUMNS -> new ChangeOnlyUpdatedColumnsMergeProcessor( rowIdChannel, mergeRowChannel, dataColumnChannels, - merge.getColumnNames(), - redistributionColumnChannels, - nonNullColumnChannels); + redistributionColumnChannels); }; - return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor)); } public static class Factory diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index ff87111d023d..a902da9fafa6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -3444,9 +3444,6 @@ public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecu List dataColumnChannels = node.getDataColumnSymbols().stream() .map(nodeLayout::get) .collect(toImmutableList()); - Set nonNullColumnChannels = node.getNonNullColumnSymbols().stream() - .map(nodeLayout::get) - .collect(toImmutableSet()); OperatorFactory operatorFactory = MergeProcessorOperator.createOperatorFactory( context.getNextOperatorId(), @@ -3455,8 +3452,7 @@ public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecu rowIdChannel, mergeRowChannel, redistributionColumns, - dataColumnChannels, - nonNullColumnChannels); + dataColumnChannels); return new PhysicalOperation(operatorFactory, nodeLayout, context, source); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 7cfafd485c61..d8be2a47d85b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -677,8 +677,10 @@ public MergeWriterNode plan(Merge merge) PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); // Build the SearchedCaseExpression that creates the project merge_row - + Metadata metadata = plannerContext.getMetadata(); + List dataColumnSchemas = mergeAnalysis.getDataColumnSchemas(); ImmutableList.Builder whenClauses = ImmutableList.builder(); + Set nonNullableColumnHandles = mergeAnalysis.getNonNullableColumnHandles(); for (int caseNumber = 0; caseNumber < merge.getMergeCases().size(); caseNumber++) { MergeCase mergeCase = merge.getMergeCases().get(caseNumber); @@ -698,7 +700,14 @@ public MergeWriterNode plan(Merge merge) Expression original = mergeCase.getSetExpressions().get(index); Expression setExpression = coerceIfNecessary(analysis, original, original); subPlan = subqueryPlanner.handleSubqueries(subPlan, setExpression, analysis.getSubqueries(merge)); - rowBuilder.add(subPlan.rewrite(setExpression)); + Expression rewritten = subPlan.rewrite(setExpression); + if (nonNullableColumnHandles.contains(dataColumnHandle)) { + int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Could not find fieldIndex for non nullable column"); + ColumnSchema columnSchema = dataColumnSchemas.get(fieldIndex); + String columnName = columnSchema.getName(); + rewritten = new CoalesceExpression(rewritten, new Cast(failFunction(metadata, session, INVALID_ARGUMENTS, "Assigning NULL to non-null MERGE target table column " + columnName), toSqlType(columnSchema.getType()))); + } + rowBuilder.add(rewritten); } else { Integer fieldNumber = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Field number for ColumnHandle is null"); @@ -732,7 +741,7 @@ public MergeWriterNode plan(Merge merge) // Build the "else" clause for the SearchedCaseExpression ImmutableList.Builder rowBuilder = ImmutableList.builder(); - mergeAnalysis.getDataColumnSchemas().forEach(columnSchema -> + dataColumnSchemas.forEach(columnSchema -> rowBuilder.add(new Cast(new NullLiteral(), toSqlType(columnSchema.getType())))); rowBuilder.add(new IsNotNullPredicate(presentColumn.toSymbolReference())); // The operation number @@ -741,7 +750,7 @@ public MergeWriterNode plan(Merge merge) rowBuilder.add(new GenericLiteral("INTEGER", "-1")); SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(rowBuilder.build()))); - RowType rowType = createMergeRowType(mergeAnalysis.getDataColumnSchemas()); + RowType rowType = createMergeRowType(dataColumnSchemas); FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable()); Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex()); @@ -778,7 +787,6 @@ public MergeWriterNode plan(Merge merge) MarkDistinctNode markDistinctNode = new MarkDistinctNode(idAllocator.getNextId(), project, isDistinctSymbol, ImmutableList.of(uniqueIdSymbol, caseNumberSymbol), Optional.empty()); // Raise an error if unique_id symbol is non-null and the unique_id/case_number combination was not distinct - Metadata metadata = plannerContext.getMetadata(); Expression filter = new IfExpression( LogicalExpression.and( new NotExpression(isDistinctSymbol.toSymbolReference()), @@ -796,26 +804,18 @@ public MergeWriterNode plan(Merge merge) RowChangeParadigm paradigm = metadata.getRowChangeParadigm(session, handle); Type rowIdType = analysis.getType(analysis.getRowIdField(table)); - ImmutableList.Builder typeBuilder = ImmutableList.builder(); - ImmutableList.Builder columnNamesBuilder = ImmutableList.builder(); - tableMetadata.getMetadata().getColumns().stream() + List dataColumnTypes = tableMetadata.getMetadata().getColumns().stream() .filter(column -> !column.isHidden()) - .forEach(columnMetadata -> { - typeBuilder.add(columnMetadata.getType()); - columnNamesBuilder.add(columnMetadata.getName()); - }); - MergeParadigmAndTypes mergeParadigmAndTypes = new MergeParadigmAndTypes(paradigm, typeBuilder.build(), columnNamesBuilder.build(), rowIdType); + .map(ColumnMetadata::getType) + .collect(toImmutableList()); + + MergeParadigmAndTypes mergeParadigmAndTypes = new MergeParadigmAndTypes(paradigm, dataColumnTypes, rowIdType); MergeTarget mergeTarget = new MergeTarget(handle, Optional.empty(), tableMetadata.getTable(), mergeParadigmAndTypes); ImmutableList.Builder columnSymbolsBuilder = ImmutableList.builder(); - ImmutableList.Builder nonNullColumnSymbolsBuilder = ImmutableList.builder(); for (ColumnHandle columnHandle : mergeAnalysis.getDataColumnHandles()) { int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(columnHandle), "Could not find field number for column handle"); - Symbol symbol = planWithPresentColumn.getFieldMappings().get(fieldIndex); - columnSymbolsBuilder.add(symbol); - if (mergeAnalysis.getNonNullableColumnHandles().contains(columnHandle)) { - nonNullColumnSymbolsBuilder.add(symbol); - } + columnSymbolsBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldIndex)); } List columnSymbols = columnSymbolsBuilder.build(); ImmutableList.Builder redistributionSymbolsBuilder = ImmutableList.builder(); @@ -842,7 +842,6 @@ public MergeWriterNode plan(Merge merge) mergeRowSymbol, columnSymbols, redistributionSymbolsBuilder.build(), - nonNullColumnSymbolsBuilder.build(), projectedSymbols); Optional partitioningScheme = createMergePartitioningScheme( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 2431da63b85f..395217bef8a4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -471,7 +471,6 @@ public MergeProcessorNode map(MergeProcessorNode node, PlanNode source) map(node.getMergeRowSymbol()), map(node.getDataColumnSymbols()), map(node.getRedistributionColumnSymbols()), - map(node.getNonNullColumnSymbols()), newOutputs); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java index 7382a3d25faa..4860632ae388 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/MergeProcessorNode.java @@ -37,7 +37,6 @@ public class MergeProcessorNode private final Symbol mergeRowSymbol; private final List dataColumnSymbols; private final List redistributionColumnSymbols; - private final List nonNullColumnSymbols; private final List outputs; @JsonCreator @@ -49,7 +48,6 @@ public MergeProcessorNode( @JsonProperty("mergeRowSymbol") Symbol mergeRowSymbol, @JsonProperty("dataColumnSymbols") List dataColumnSymbols, @JsonProperty("redistributionColumnSymbols") List redistributionColumnSymbols, - @JsonProperty("nonNullColumnSymbols") List nonNullColumnSymbols, @JsonProperty("outputs") List outputs) { super(id); @@ -60,7 +58,6 @@ public MergeProcessorNode( this.rowIdSymbol = requireNonNull(rowIdSymbol, "rowIdSymbol is null"); this.dataColumnSymbols = requireNonNull(dataColumnSymbols, "dataColumnSymbols is null"); this.redistributionColumnSymbols = requireNonNull(redistributionColumnSymbols, "redistributionColumnSymbols is null"); - this.nonNullColumnSymbols = ImmutableList.copyOf(requireNonNull(nonNullColumnSymbols, "nonNullColumnSymbols is null")); this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); } @@ -100,12 +97,6 @@ public List getRedistributionColumnSymbols() return redistributionColumnSymbols; } - @JsonProperty - public List getNonNullColumnSymbols() - { - return nonNullColumnSymbols; - } - @JsonProperty("outputs") @Override public List getOutputSymbols() @@ -128,6 +119,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new MergeProcessorNode(getId(), Iterables.getOnlyElement(newChildren), target, rowIdSymbol, mergeRowSymbol, dataColumnSymbols, redistributionColumnSymbols, nonNullColumnSymbols, outputs); + return new MergeProcessorNode(getId(), Iterables.getOnlyElement(newChildren), target, rowIdSymbol, mergeRowSymbol, dataColumnSymbols, redistributionColumnSymbols, outputs); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java index 957e762e75fc..9aaa921b56a5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java @@ -726,19 +726,16 @@ public static class MergeParadigmAndTypes { private final RowChangeParadigm paradigm; private final List columnTypes; - private final List columnNames; private final Type rowIdType; @JsonCreator public MergeParadigmAndTypes( @JsonProperty("paradigm") RowChangeParadigm paradigm, @JsonProperty("columnTypes") List columnTypes, - @JsonProperty("columnNames") List columnNames, @JsonProperty("rowIdType") Type rowIdType) { this.paradigm = requireNonNull(paradigm, "paradigm is null"); this.columnTypes = requireNonNull(columnTypes, "columnTypes is null"); - this.columnNames = requireNonNull(columnNames, "columnNames is null"); this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); } @@ -754,12 +751,6 @@ public List getColumnTypes() return columnTypes; } - @JsonProperty - public List getColumnNames() - { - return columnNames; - } - @JsonProperty public Type getRowIdType() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java index dd5a6c7c891f..31ae14eb6c2e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -14,7 +14,6 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slices; import io.trino.operator.DeleteAndInsertMergeProcessor; import io.trino.spi.Page; @@ -184,7 +183,7 @@ private DeleteAndInsertMergeProcessor makeMergeProcessor() List types = ImmutableList.of(VARCHAR, INTEGER, VARCHAR); RowType rowIdType = RowType.anonymous(ImmutableList.of(BIGINT, BIGINT, INTEGER)); - return new DeleteAndInsertMergeProcessor(types, ImmutableList.of(), rowIdType, 0, 1, ImmutableList.of(), ImmutableList.of(0, 1, 2), ImmutableSet.of()); + return new DeleteAndInsertMergeProcessor(types, rowIdType, 0, 1, ImmutableList.of(), ImmutableList.of(0, 1, 2)); } private String getString(Block block, int position) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index b0552e9b973c..bdb1e2a1642c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -835,7 +835,7 @@ private MergeTarget mergeTarget(SchemaTableName schemaTableName) TestingTransactionHandle.create()), Optional.empty(), schemaTableName, - new MergeParadigmAndTypes(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW, ImmutableList.of(), ImmutableList.of(), INTEGER)); + new MergeParadigmAndTypes(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW, ImmutableList.of(), INTEGER)); } public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child)