Skip to content

Commit

Permalink
[Bugfix] fix rewrite bug after insert new partition data (StarRocks#2…
Browse files Browse the repository at this point in the history
…0157)

Fix rewrite failure after inserting new partition data. 
The bug reason is that the mv plan is cached, but the plan may change after ingestion with new partitions data.
And the logic of partition predicate calculation for mv rewrite depends on scan node's selected partition id,
which may change after ingestion. So it leads to invalid compensation predicate in mv rewrite,
which caused an invalid rewritten plan. Fix it by:
1. disable partition prune rules during compile mv plan, which will keep partition predicates in mv plan, to avoid the problem of caching plan
2. remove mv partition predicate compensation logic because it is not necessary after step 1. and it also fixed the problem of redundant partition predicates after partition prune of mv.

Signed-off-by: ABingHuang <[email protected]>
  • Loading branch information
ABingHuang authored and abc982627271 committed Jun 5, 2023
1 parent 1a15c07 commit 47b2d2f
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.starrocks.sql.optimizer;

import com.starrocks.catalog.Table;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.PredicateSplit;

Expand All @@ -32,6 +33,10 @@ public class MvRewriteContext {
private final ReplaceColumnRefRewriter queryColumnRefRewriter;
private final PredicateSplit queryPredicateSplit;

// mv's partition and distribution related conjunct predicate,
// used to prune partitions and buckets of scan mv operator after rewrite
private ScalarOperator mvPruneConjunct;

public MvRewriteContext(
MaterializationContext materializationContext,
List<Table> queryTables,
Expand Down Expand Up @@ -64,4 +69,12 @@ public ReplaceColumnRefRewriter getQueryColumnRefRewriter() {
public PredicateSplit getQueryPredicateSplit() {
return queryPredicateSplit;
}

public ScalarOperator getMvPruneConjunct() {
return mvPruneConjunct;
}

public void setMvPruneConjunct(ScalarOperator mvPruneConjunct) {
this.mvPruneConjunct = mvPruneConjunct;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,11 @@
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.ast.PartitionNames;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.base.DistributionSpec;
import com.starrocks.sql.optimizer.base.HashDistributionDesc;
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -189,7 +186,6 @@ private LogicalOlapScanOperator createScanMvOperator(MaterializationContext mvCo
final Map<Column, ColumnRefOperator> columnMetaToColRefMap = columnMetaToColRefMapBuilder.build();

// construct distribution
final Set<Integer> mvPartitionDistributionColumnRef = Sets.newHashSet();
DistributionInfo distributionInfo = mv.getDefaultDistributionInfo();
// only hash distribution is supported
Preconditions.checkState(distributionInfo instanceof HashDistributionInfo);
Expand Down Expand Up @@ -217,54 +213,12 @@ private LogicalOlapScanOperator createScanMvOperator(MaterializationContext mvCo
}
final PartitionNames partitionNames = new PartitionNames(false, selectedPartitionNames);

// NOTE:
// - To partition/distribution prune, need filter predicates that belong to MV.
// - Those predicates are only used for partition/distribution pruning and don't affect the real
// query compute.
// - after partition/distribution pruning, those predicates should be removed from mv rewrite result.
final OptExpression mvExpression = mvContext.getMvExpression();
final List<ScalarOperator> conjuncts = MvUtils.getAllPredicates(mvExpression);
final ColumnRefSet mvOutputColumnRefSet = mvExpression.getOutputColumns();
final List<ScalarOperator> mvConjuncts = Lists.newArrayList();

// Construct partition/distribution key column refs to filter conjunctions which need to retain.
Set<String> mvPruneKeyColNames = Sets.newHashSet();
distributedColumns.stream().forEach(distKey -> mvPruneKeyColNames.add(distKey.getName()));
mv.getPartitionNames().stream().forEach(partName -> mvPruneKeyColNames.add(partName));
final Set<Integer> mvPruneColumnIdSet = mvOutputColumnRefSet.getStream().map(
id -> mvContext.getMvColumnRefFactory().getColumnRef(id))
.filter(colRef -> mvPruneKeyColNames.contains(colRef.getName()))
.map(colRef -> colRef.getId())
.collect(Collectors.toSet());
// Case1: keeps original predicates which belong to MV table(which are not pruned after mv's partition pruning)
for (ScalarOperator conj : conjuncts) {
// ignore binary predicates which cannot be used for pruning.
if (conj instanceof BinaryPredicateOperator) {
BinaryPredicateOperator conjOp = (BinaryPredicateOperator) conj;
if (conjOp.getChild(0).isColumnRef() && conjOp.getChild(1).isColumnRef()) {
continue;
}
}
final List<Integer> conjColumnRefOperators =
Utils.extractColumnRef(conj).stream().map(ref -> ref.getId()).collect(Collectors.toList());
if (mvPruneColumnIdSet.containsAll(conjColumnRefOperators)) {
mvConjuncts.add(conj);
}
}
// Case2: compensated partition predicates which are pruned after mv's partition pruning.
// Compensate partition predicates and add them into mv predicate.
final ScalarOperator mvPartitionPredicate =
MvUtils.compensatePartitionPredicate(mvExpression, mvContext.getMvColumnRefFactory());
if (!ConstantOperator.TRUE.equals(mvPartitionPredicate)) {
mvConjuncts.add(mvPartitionPredicate);
}

return new LogicalOlapScanOperator(mv,
colRefToColumnMetaMapBuilder.build(),
columnMetaToColRefMap,
DistributionSpec.createHashDistributionSpec(hashDistributionDesc),
Operator.DEFAULT_LIMIT,
Utils.compoundAnd(mvConjuncts),
null,
mv.getBaseIndexId(),
selectPartitionIds,
partitionNames,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package com.starrocks.sql.optimizer.rule.transformation.materialization;

import com.google.common.collect.Lists;
import com.starrocks.sql.optimizer.MvRewriteContext;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptExpressionVisitor;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalDeltaLakeScanOperator;
Expand All @@ -28,21 +30,29 @@
import com.starrocks.sql.optimizer.operator.logical.LogicalIcebergScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.OptDistributionPruner;
import com.starrocks.sql.optimizer.rewrite.OptExternalPartitionPruner;
import com.starrocks.sql.optimizer.rewrite.OptOlapPartitionPruner;

import java.util.List;

public class MVPartitionPruner {
private final OptimizerContext optimizerContext;
private final MvRewriteContext mvRewriteContext;

public OptExpression prunePartition(OptimizerContext context, OptExpression queryExpression) {
return queryExpression.getOp().accept(new MVPartitionPrunerVisitor(), queryExpression, context);
public MVPartitionPruner(OptimizerContext optimizerContext, MvRewriteContext mvRewriteContext) {
this.optimizerContext = optimizerContext;
this.mvRewriteContext = mvRewriteContext;
}

private class MVPartitionPrunerVisitor extends OptExpressionVisitor<OptExpression, OptimizerContext> {
public OptExpression prunePartition(OptExpression queryExpression) {
return queryExpression.getOp().accept(new MVPartitionPrunerVisitor(), queryExpression, null);
}

private class MVPartitionPrunerVisitor extends OptExpressionVisitor<OptExpression, Void> {
@Override
public OptExpression visitLogicalTableScan(OptExpression optExpression, OptimizerContext context) {
public OptExpression visitLogicalTableScan(OptExpression optExpression, Void context) {
LogicalScanOperator scanOperator = optExpression.getOp().cast();

if (scanOperator instanceof LogicalOlapScanOperator) {
Expand All @@ -56,6 +66,18 @@ public OptExpression visitLogicalTableScan(OptExpression optExpression, Optimize
.setPrunedPartitionPredicates(Lists.newArrayList())
.setSelectedPartitionId(Lists.newArrayList())
.setSelectedTabletId(Lists.newArrayList());

// for mv: select c1, c3, c2 from test_base_part where c3 < 2000 and c1 = 1,
// which c3 is partition column and c1 is distribution column.
// we should add predicate c3 < 2000 and c1 = 1 into scan operator to do pruning
boolean isAddMvPrunePredicate = scanOperator.getTable().isMaterializedView()
&& scanOperator.getTable().getId() == mvRewriteContext.getMaterializationContext().getMv().getId()
&& mvRewriteContext.getMvPruneConjunct() != null;
if (isAddMvPrunePredicate) {
ScalarOperator originPredicate = scanOperator.getPredicate();
ScalarOperator newPredicate = Utils.compoundAnd(originPredicate, mvRewriteContext.getMvPruneConjunct());
builder.setPredicate(newPredicate);
}
LogicalOlapScanOperator copiedOlapScanOperator = builder.build();

// prune partition
Expand All @@ -67,9 +89,19 @@ public OptExpression visitLogicalTableScan(OptExpression optExpression, Optimize
List<Long> selectedTabletIds = OptDistributionPruner.pruneTabletIds(copiedOlapScanOperator,
prunedOlapScanOperator.getSelectedPartitionId());

ScalarOperator scanPredicate = prunedOlapScanOperator.getPredicate();
if (isAddMvPrunePredicate) {
List<ScalarOperator> originConjuncts = Utils.extractConjuncts(scanOperator.getPredicate());
List<ScalarOperator> pruneConjuncts = Utils.extractConjuncts(mvRewriteContext.getMvPruneConjunct());
pruneConjuncts.removeAll(originConjuncts);
List<ScalarOperator> currentConjuncts = Utils.extractConjuncts(prunedOlapScanOperator.getPredicate());
currentConjuncts.removeAll(pruneConjuncts);
scanPredicate = Utils.compoundAnd(currentConjuncts);
}

LogicalOlapScanOperator.Builder rewrittenBuilder = new LogicalOlapScanOperator.Builder();
scanOperator = rewrittenBuilder.withOperator(prunedOlapScanOperator)
.setPredicate(MvUtils.canonizePredicate(prunedOlapScanOperator.getPredicate()))
.setPredicate(MvUtils.canonizePredicate(scanPredicate))
.setSelectedTabletId(selectedTabletIds)
.build();
} else if (scanOperator instanceof LogicalHiveScanOperator ||
Expand All @@ -81,16 +113,16 @@ public OptExpression visitLogicalTableScan(OptExpression optExpression, Optimize
Operator.Builder builder = OperatorBuilderFactory.build(scanOperator);
LogicalScanOperator copiedScanOperator =
(LogicalScanOperator) builder.withOperator(scanOperator).build();
scanOperator = OptExternalPartitionPruner.prunePartitions(context,
scanOperator = OptExternalPartitionPruner.prunePartitions(optimizerContext,
copiedScanOperator);
}
return OptExpression.create(scanOperator);
}

public OptExpression visit(OptExpression optExpression, OptimizerContext context) {
public OptExpression visit(OptExpression optExpression, Void context) {
List<OptExpression> children = Lists.newArrayList();
for (int i = 0; i < optExpression.arity(); ++i) {
children.add(optExpression.inputAt(i).getOp().accept(this, optExpression.inputAt(i), context));
children.add(optExpression.inputAt(i).getOp().accept(this, optExpression.inputAt(i), null));
}
return OptExpression.create(optExpression.getOp(), children);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
import com.google.common.graph.MutableGraph;
import com.starrocks.analysis.JoinOperator;
import com.starrocks.catalog.Column;
import com.starrocks.catalog.DistributionInfo;
import com.starrocks.catalog.ForeignKeyConstraint;
import com.starrocks.catalog.HashDistributionInfo;
import com.starrocks.catalog.KeysType;
import com.starrocks.catalog.MaterializedView;
import com.starrocks.catalog.OlapTable;
import com.starrocks.catalog.Table;
import com.starrocks.catalog.UniqueConstraint;
Expand Down Expand Up @@ -155,22 +158,9 @@ public OptExpression rewrite() {
final ColumnRefFactory mvColumnRefFactory = materializationContext.getMvColumnRefFactory();
final ReplaceColumnRefRewriter mvColumnRefRewriter =
MvUtils.getReplaceColumnRefWriter(mvExpression, mvColumnRefFactory);
// Compensate partition predicates and add them into mv predicate,
// eg: c3 is partition column
// MV : select c1, c3, c2 from test_base_part where c3 < 2000
// Query : select c1, c3, c2 from test_base_part
// `c3 < 2000` is missed after partition pruning, so `mvPredicate` must add `mvPartitionPredicate`,
// otherwise query above may be rewritten by mv.
final ScalarOperator mvPartitionPredicate =
MvUtils.compensatePartitionPredicate(mvExpression, mvColumnRefFactory);
if (mvPartitionPredicate == null) {
return null;
}

ScalarOperator mvPredicate = MvUtils.rewriteOptExprCompoundPredicate(mvExpression, mvColumnRefRewriter);

if (!ConstantOperator.TRUE.equals(mvPartitionPredicate)) {
mvPredicate = MvUtils.canonizePredicate(Utils.compoundAnd(mvPredicate, mvPartitionPredicate));
}
if (materializationContext.getMvPartialPartitionPredicate() != null) {
// add latest partition predicate to mv predicate
ScalarOperator rewritten = mvColumnRefRewriter.rewrite(materializationContext.getMvPartialPartitionPredicate());
Expand Down Expand Up @@ -222,7 +212,12 @@ public OptExpression rewrite() {
materializationContext.getMvColumnRefFactory(), mvColumnRefRewriter,
materializationContext.getOutputMapping(), queryColumnSet);

// collect partition and distribution related predicates in mv
// used to prune partition and buckets after mv rewrite
ScalarOperator mvPrunePredicate = collectMvPrunePredicate(materializationContext);

for (BiMap<Integer, Integer> relationIdMapping : relationIdMappings) {
mvRewriteContext.setMvPruneConjunct(mvPrunePredicate);
rewriteContext.setQueryToMvRelationIdMapping(relationIdMapping);

// for view delta, should add compensation join columns to query ec
Expand Down Expand Up @@ -500,6 +495,45 @@ private boolean isJoinMatch(OptExpression queryExpression,
}
}

private ScalarOperator collectMvPrunePredicate(MaterializationContext mvContext) {
final OptExpression mvExpression = mvContext.getMvExpression();
final List<ScalarOperator> conjuncts = MvUtils.getAllPredicates(mvExpression);
final ColumnRefSet mvOutputColumnRefSet = mvExpression.getOutputColumns();
// conjuncts related to partition and distribution
final List<ScalarOperator> mvPrunePredicates = Lists.newArrayList();

// Construct partition/distribution key column refs to filter conjunctions which need to retain.
Set<String> mvPruneKeyColNames = Sets.newHashSet();
MaterializedView mv = mvContext.getMv();
DistributionInfo distributionInfo = mv.getDefaultDistributionInfo();
// only hash distribution is supported
Preconditions.checkState(distributionInfo instanceof HashDistributionInfo);
HashDistributionInfo hashDistributionInfo = (HashDistributionInfo) distributionInfo;
List<Column> distributedColumns = hashDistributionInfo.getDistributionColumns();
distributedColumns.stream().forEach(distKey -> mvPruneKeyColNames.add(distKey.getName()));
mv.getPartitionColumnNames().stream().forEach(partName -> mvPruneKeyColNames.add(partName));
final Set<Integer> mvPruneColumnIdSet = mvOutputColumnRefSet.getStream().map(
id -> mvContext.getMvColumnRefFactory().getColumnRef(id))
.filter(colRef -> mvPruneKeyColNames.contains(colRef.getName()))
.map(colRef -> colRef.getId())
.collect(Collectors.toSet());
for (ScalarOperator conj : conjuncts) {
// ignore binary predicates which cannot be used for pruning.
if (conj instanceof BinaryPredicateOperator) {
BinaryPredicateOperator conjOp = (BinaryPredicateOperator) conj;
if (conjOp.getChild(0).isVariable() && conjOp.getChild(1).isVariable()) {
continue;
}
}
final List<Integer> conjColumnRefOperators =
Utils.extractColumnRef(conj).stream().map(ref -> ref.getId()).collect(Collectors.toList());
if (mvPruneColumnIdSet.containsAll(conjColumnRefOperators)) {
mvPrunePredicates.add(conj);
}
}
return Utils.compoundAnd(mvPrunePredicates);
}

private OptExpression tryRewriteForRelationMapping(RewriteContext rewriteContext) {
// the rewritten expression to replace query
// should copy the op because the op will be modified and reused
Expand All @@ -510,13 +544,11 @@ private OptExpression tryRewriteForRelationMapping(RewriteContext rewriteContext
// Rewrite original mv's predicates into query if needed.
final ColumnRewriter columnRewriter = new ColumnRewriter(rewriteContext);
final Map<ColumnRefOperator, ScalarOperator> mvColumnRefToScalarOp = rewriteContext.getMVColumnRefToScalarOp();
ScalarOperator mvOriginalPredicates = mvScanOperator.getPredicate();
if (mvOriginalPredicates != null && !ConstantOperator.TRUE.equals(mvOriginalPredicates)) {
mvOriginalPredicates = rewriteMVCompensationExpression(rewriteContext, columnRewriter,
mvColumnRefToScalarOp, mvOriginalPredicates, false);
if (!ConstantOperator.TRUE.equals(mvOriginalPredicates)) {
mvScanBuilder.setPredicate(mvOriginalPredicates);
}
if (mvRewriteContext.getMvPruneConjunct() != null
&& !ConstantOperator.TRUE.equals(mvRewriteContext.getMvPruneConjunct())) {
ScalarOperator rewrittenPrunePredicate = rewriteMVCompensationExpression(rewriteContext, columnRewriter,
mvColumnRefToScalarOp, mvRewriteContext.getMvPruneConjunct(), false);
mvRewriteContext.setMvPruneConjunct(MvUtils.canonizePredicate(rewrittenPrunePredicate));
}
OptExpression mvScanOptExpression = OptExpression.create(mvScanBuilder.build());
deriveLogicalProperty(mvScanOptExpression);
Expand Down
Loading

0 comments on commit 47b2d2f

Please sign in to comment.