Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@
package io.trino.operator;

import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.ColumnarRow;
import io.trino.spi.block.RunLengthEncodedBlock;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
import static io.trino.spi.block.ColumnarRow.toColumnarRow;
import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER;
import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER;
import static io.trino.spi.predicate.Utils.nativeValueToBlock;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.util.Objects.requireNonNull;
Expand All @@ -47,24 +42,18 @@ public class ChangeOnlyUpdatedColumnsMergeProcessor
private final int rowIdChannel;
private final int mergeRowChannel;
private final List<Integer> dataColumnChannels;
private final List<String> dataColumnNames;
private final int writeRedistributionColumnCount;
private final Set<Integer> nonNullColumnChannels;

public ChangeOnlyUpdatedColumnsMergeProcessor(
int rowIdChannel,
int mergeRowChannel,
List<Integer> dataColumnChannels,
List<String> dataColumnNames,
List<Integer> redistributionColumnChannels,
Set<Integer> nonNullColumnChannels)
List<Integer> redistributionColumnChannels)
{
this.rowIdChannel = rowIdChannel;
this.mergeRowChannel = mergeRowChannel;
this.dataColumnChannels = requireNonNull(dataColumnChannels, "dataColumnChannels is null");
this.dataColumnNames = requireNonNull(dataColumnNames, "dataColumnNames is null");
this.writeRedistributionColumnCount = redistributionColumnChannels.size();
this.nonNullColumnChannels = requireNonNull(nonNullColumnChannels, "nonNullColumnChannels is null");
}

@Override
Expand Down Expand Up @@ -94,20 +83,6 @@ public Page transformPage(Page inputPage)

Page result = new Page(builder.toArray(Block[]::new));

for (int nonNullColumnChannel : nonNullColumnChannels) {
Block block = result.getBlock(nonNullColumnChannel);
if (block.mayHaveNull()) {
for (int position = 0; position < positionCount; position++) {
long operation = TINYINT.getLong(operationChannelBlock, position);
if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) {
if (block.isNull(position)) {
throw new TrinoException(CONSTRAINT_VIOLATION, "Assigning NULL to non-null MERGE target table column " + dataColumnNames.get(nonNullColumnChannel));
}
}
}
}
}

int defaultCaseCount = 0;
for (int position = 0; position < positionCount; position++) {
if (TINYINT.getLong(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ColumnarRow;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
import static io.trino.spi.block.ColumnarRow.toColumnarRow;
import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER;
import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER;
Expand All @@ -42,27 +38,22 @@ public class DeleteAndInsertMergeProcessor
implements MergeRowChangeProcessor
{
private final List<Type> dataColumnTypes;
private final List<String> dataColumnNames;
private final Type rowIdType;
private final int rowIdChannel;
private final int mergeRowChannel;
private final List<Integer> dataColumnChannels;
private final int redistributionColumnCount;
private final List<Integer> redistributionChannelNumbers;
private final Set<Integer> nonNullColumnChannels;

public DeleteAndInsertMergeProcessor(
List<Type> dataColumnTypes,
List<String> dataColumnNames,
Type rowIdType,
int rowIdChannel,
int mergeRowChannel,
List<Integer> redistributionChannelNumbers,
List<Integer> dataColumnChannels,
Set<Integer> nonNullColumnChannels)
List<Integer> dataColumnChannels)
{
this.dataColumnTypes = requireNonNull(dataColumnTypes, "dataColumnTypes is null");
this.dataColumnNames = requireNonNull(dataColumnNames, "dataColumnNames is null");
this.rowIdType = requireNonNull(rowIdType, "rowIdType is null");
this.rowIdChannel = rowIdChannel;
this.mergeRowChannel = mergeRowChannel;
Expand All @@ -80,7 +71,6 @@ public DeleteAndInsertMergeProcessor(
}
}
this.redistributionChannelNumbers = redistributionChannelNumbersBuilder.build();
this.nonNullColumnChannels = ImmutableSet.copyOf(requireNonNull(nonNullColumnChannels, "nonNullColumnChannels is null"));
}

@JsonProperty
Expand Down Expand Up @@ -150,18 +140,6 @@ public Page transformPage(Page inputPage)
}

Page page = pageBuilder.build();
int positionCount = page.getPositionCount();
for (int nonNullColumnChannel : nonNullColumnChannels) {
Block nonNullBlock = page.getBlock(nonNullColumnChannel);
Block operationBlock = page.getBlock(dataColumnChannels.size());
if (nonNullBlock.mayHaveNull()) {
for (int position = 0; position < positionCount; position++) {
if (TINYINT.getLong(operationBlock, position) == INSERT_OPERATION_NUMBER && nonNullBlock.isNull(position)) {
throw new TrinoException(CONSTRAINT_VIOLATION, "Assigning NULL to non-null MERGE target table column " + dataColumnNames.get(nonNullColumnChannel));
}
}
}
}
verify(page.getPositionCount() == totalPositions, "page positions (%s) is not equal to (%s)", page.getPositionCount(), totalPositions);
return page;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.trino.sql.planner.plan.TableWriterNode.MergeParadigmAndTypes;

import java.util.List;
import java.util.Set;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.operator.BasicWorkProcessorOperatorAdapter.createAdapterOperatorFactory;
Expand All @@ -40,29 +39,29 @@ public static OperatorFactory createOperatorFactory(
MergeParadigmAndTypes merge,
int rowIdChannel,
int mergeRowChannel,
List<Integer> redistributionColumnChannels,
List<Integer> dataColumnChannels,
Set<Integer> nonNullColumnChannels)
List<Integer> redistributionColumns,
List<Integer> dataColumnChannels)
{
MergeRowChangeProcessor rowChangeProcessor = switch (merge.getParadigm()) {
MergeRowChangeProcessor rowChangeProcessor = createRowChangeProcessor(merge, rowIdChannel, mergeRowChannel, redistributionColumns, dataColumnChannels);
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor));
}

private static MergeRowChangeProcessor createRowChangeProcessor(MergeParadigmAndTypes merge, int rowIdChannel, int mergeRowChannel, List<Integer> redistributionColumnChannels, List<Integer> dataColumnChannels)
{
return switch (merge.getParadigm()) {
case DELETE_ROW_AND_INSERT_ROW -> new DeleteAndInsertMergeProcessor(
merge.getColumnTypes(),
merge.getColumnNames(),
merge.getRowIdType(),
rowIdChannel,
mergeRowChannel,
redistributionColumnChannels,
dataColumnChannels,
nonNullColumnChannels);
dataColumnChannels);
case CHANGE_ONLY_UPDATED_COLUMNS -> new ChangeOnlyUpdatedColumnsMergeProcessor(
rowIdChannel,
mergeRowChannel,
dataColumnChannels,
merge.getColumnNames(),
redistributionColumnChannels,
nonNullColumnChannels);
redistributionColumnChannels);
};
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, rowChangeProcessor));
}

public static class Factory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3444,9 +3444,6 @@ public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecu
List<Integer> dataColumnChannels = node.getDataColumnSymbols().stream()
.map(nodeLayout::get)
.collect(toImmutableList());
Set<Integer> nonNullColumnChannels = node.getNonNullColumnSymbols().stream()
.map(nodeLayout::get)
.collect(toImmutableSet());

OperatorFactory operatorFactory = MergeProcessorOperator.createOperatorFactory(
context.getNextOperatorId(),
Expand All @@ -3455,8 +3452,7 @@ public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecu
rowIdChannel,
mergeRowChannel,
redistributionColumns,
dataColumnChannels,
nonNullColumnChannels);
dataColumnChannels);
return new PhysicalOperation(operatorFactory, nodeLayout, context, source);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,10 @@ public MergeWriterNode plan(Merge merge)
PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext);

// Build the SearchedCaseExpression that creates the project merge_row

Metadata metadata = plannerContext.getMetadata();
List<ColumnSchema> dataColumnSchemas = mergeAnalysis.getDataColumnSchemas();
ImmutableList.Builder<WhenClause> whenClauses = ImmutableList.builder();
Set<ColumnHandle> nonNullableColumnHandles = mergeAnalysis.getNonNullableColumnHandles();
for (int caseNumber = 0; caseNumber < merge.getMergeCases().size(); caseNumber++) {
MergeCase mergeCase = merge.getMergeCases().get(caseNumber);

Expand All @@ -698,7 +700,14 @@ public MergeWriterNode plan(Merge merge)
Expression original = mergeCase.getSetExpressions().get(index);
Expression setExpression = coerceIfNecessary(analysis, original, original);
subPlan = subqueryPlanner.handleSubqueries(subPlan, setExpression, analysis.getSubqueries(merge));
rowBuilder.add(subPlan.rewrite(setExpression));
Expression rewritten = subPlan.rewrite(setExpression);
if (nonNullableColumnHandles.contains(dataColumnHandle)) {
int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Could not find fieldIndex for non nullable column");
ColumnSchema columnSchema = dataColumnSchemas.get(fieldIndex);
String columnName = columnSchema.getName();
rewritten = new CoalesceExpression(rewritten, new Cast(failFunction(metadata, session, INVALID_ARGUMENTS, "Assigning NULL to non-null MERGE target table column " + columnName), toSqlType(columnSchema.getType())));
}
rowBuilder.add(rewritten);
}
else {
Integer fieldNumber = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Field number for ColumnHandle is null");
Expand Down Expand Up @@ -732,7 +741,7 @@ public MergeWriterNode plan(Merge merge)

// Build the "else" clause for the SearchedCaseExpression
ImmutableList.Builder<Expression> rowBuilder = ImmutableList.builder();
mergeAnalysis.getDataColumnSchemas().forEach(columnSchema ->
dataColumnSchemas.forEach(columnSchema ->
rowBuilder.add(new Cast(new NullLiteral(), toSqlType(columnSchema.getType()))));
rowBuilder.add(new IsNotNullPredicate(presentColumn.toSymbolReference()));
// The operation number
Expand All @@ -741,7 +750,7 @@ public MergeWriterNode plan(Merge merge)
rowBuilder.add(new GenericLiteral("INTEGER", "-1"));

SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(rowBuilder.build())));
RowType rowType = createMergeRowType(mergeAnalysis.getDataColumnSchemas());
RowType rowType = createMergeRowType(dataColumnSchemas);

FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex());
Expand Down Expand Up @@ -778,7 +787,6 @@ public MergeWriterNode plan(Merge merge)
MarkDistinctNode markDistinctNode = new MarkDistinctNode(idAllocator.getNextId(), project, isDistinctSymbol, ImmutableList.of(uniqueIdSymbol, caseNumberSymbol), Optional.empty());

// Raise an error if unique_id symbol is non-null and the unique_id/case_number combination was not distinct
Metadata metadata = plannerContext.getMetadata();
Expression filter = new IfExpression(
LogicalExpression.and(
new NotExpression(isDistinctSymbol.toSymbolReference()),
Expand All @@ -796,26 +804,18 @@ public MergeWriterNode plan(Merge merge)

RowChangeParadigm paradigm = metadata.getRowChangeParadigm(session, handle);
Type rowIdType = analysis.getType(analysis.getRowIdField(table));
ImmutableList.Builder<Type> typeBuilder = ImmutableList.builder();
ImmutableList.Builder<String> columnNamesBuilder = ImmutableList.builder();
tableMetadata.getMetadata().getColumns().stream()
List<Type> dataColumnTypes = tableMetadata.getMetadata().getColumns().stream()
.filter(column -> !column.isHidden())
.forEach(columnMetadata -> {
typeBuilder.add(columnMetadata.getType());
columnNamesBuilder.add(columnMetadata.getName());
});
MergeParadigmAndTypes mergeParadigmAndTypes = new MergeParadigmAndTypes(paradigm, typeBuilder.build(), columnNamesBuilder.build(), rowIdType);
.map(ColumnMetadata::getType)
.collect(toImmutableList());

MergeParadigmAndTypes mergeParadigmAndTypes = new MergeParadigmAndTypes(paradigm, dataColumnTypes, rowIdType);
MergeTarget mergeTarget = new MergeTarget(handle, Optional.empty(), tableMetadata.getTable(), mergeParadigmAndTypes);

ImmutableList.Builder<Symbol> columnSymbolsBuilder = ImmutableList.builder();
ImmutableList.Builder<Symbol> nonNullColumnSymbolsBuilder = ImmutableList.builder();
for (ColumnHandle columnHandle : mergeAnalysis.getDataColumnHandles()) {
int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(columnHandle), "Could not find field number for column handle");
Symbol symbol = planWithPresentColumn.getFieldMappings().get(fieldIndex);
columnSymbolsBuilder.add(symbol);
if (mergeAnalysis.getNonNullableColumnHandles().contains(columnHandle)) {
nonNullColumnSymbolsBuilder.add(symbol);
}
columnSymbolsBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldIndex));
}
List<Symbol> columnSymbols = columnSymbolsBuilder.build();
ImmutableList.Builder<Symbol> redistributionSymbolsBuilder = ImmutableList.builder();
Expand All @@ -842,7 +842,6 @@ public MergeWriterNode plan(Merge merge)
mergeRowSymbol,
columnSymbols,
redistributionSymbolsBuilder.build(),
nonNullColumnSymbolsBuilder.build(),
projectedSymbols);

Optional<PartitioningScheme> partitioningScheme = createMergePartitioningScheme(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ public MergeProcessorNode map(MergeProcessorNode node, PlanNode source)
map(node.getMergeRowSymbol()),
map(node.getDataColumnSymbols()),
map(node.getRedistributionColumnSymbols()),
map(node.getNonNullColumnSymbols()),
newOutputs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ public class MergeProcessorNode
private final Symbol mergeRowSymbol;
private final List<Symbol> dataColumnSymbols;
private final List<Symbol> redistributionColumnSymbols;
private final List<Symbol> nonNullColumnSymbols;
private final List<Symbol> outputs;

@JsonCreator
Expand All @@ -49,7 +48,6 @@ public MergeProcessorNode(
@JsonProperty("mergeRowSymbol") Symbol mergeRowSymbol,
@JsonProperty("dataColumnSymbols") List<Symbol> dataColumnSymbols,
@JsonProperty("redistributionColumnSymbols") List<Symbol> redistributionColumnSymbols,
@JsonProperty("nonNullColumnSymbols") List<Symbol> nonNullColumnSymbols,
@JsonProperty("outputs") List<Symbol> outputs)
{
super(id);
Expand All @@ -60,7 +58,6 @@ public MergeProcessorNode(
this.rowIdSymbol = requireNonNull(rowIdSymbol, "rowIdSymbol is null");
this.dataColumnSymbols = requireNonNull(dataColumnSymbols, "dataColumnSymbols is null");
this.redistributionColumnSymbols = requireNonNull(redistributionColumnSymbols, "redistributionColumnSymbols is null");
this.nonNullColumnSymbols = ImmutableList.copyOf(requireNonNull(nonNullColumnSymbols, "nonNullColumnSymbols is null"));
this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null"));
}

Expand Down Expand Up @@ -100,12 +97,6 @@ public List<Symbol> getRedistributionColumnSymbols()
return redistributionColumnSymbols;
}

@JsonProperty
public List<Symbol> getNonNullColumnSymbols()
{
return nonNullColumnSymbols;
}

@JsonProperty("outputs")
@Override
public List<Symbol> getOutputSymbols()
Expand All @@ -128,6 +119,6 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new MergeProcessorNode(getId(), Iterables.getOnlyElement(newChildren), target, rowIdSymbol, mergeRowSymbol, dataColumnSymbols, redistributionColumnSymbols, nonNullColumnSymbols, outputs);
return new MergeProcessorNode(getId(), Iterables.getOnlyElement(newChildren), target, rowIdSymbol, mergeRowSymbol, dataColumnSymbols, redistributionColumnSymbols, outputs);
}
}
Loading