Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.cost;

import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Comparison;
import io.trino.sql.planner.Symbol;

import java.util.Optional;
Expand Down Expand Up @@ -45,7 +45,7 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison(
SymbolStatsEstimate expressionStatistics,
Optional<Symbol> expressionSymbol,
OptionalDouble literalValue,
ComparisonExpression.Operator operator)
Comparison.Operator operator)
{
switch (operator) {
case EQUAL:
Expand Down Expand Up @@ -160,7 +160,7 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison(
Optional<Symbol> leftExpressionSymbol,
SymbolStatsEstimate rightExpressionStatistics,
Optional<Symbol> rightExpressionSymbol,
ComparisonExpression.Operator operator)
Comparison.Operator operator)
{
switch (operator) {
case EQUAL:
Expand Down Expand Up @@ -255,7 +255,7 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
}

private static PlanNodeStatsEstimate estimateExpressionToExpressionInequality(
ComparisonExpression.Operator operator,
Comparison.Operator operator,
PlanNodeStatsEstimate inputStatistics,
SymbolStatsEstimate leftExpressionStatistics,
Optional<Symbol> leftExpressionSymbol,
Expand Down
106 changes: 53 additions & 53 deletions core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
import io.trino.Session;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
Expand All @@ -58,9 +58,9 @@
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.ir.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.IrUtils.and;
import static io.trino.sql.planner.IrExpressionInterpreter.evaluateConstantExpression;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
Expand Down Expand Up @@ -148,11 +148,11 @@ protected PlanNodeStatsEstimate visitExpression(Expression node, Void context)
}

@Override
protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context)
protected PlanNodeStatsEstimate visitNot(Not node, Void context)
{
if (node.getValue() instanceof IsNullPredicate inner) {
if (inner.getValue() instanceof SymbolReference) {
Symbol symbol = Symbol.from(inner.getValue());
if (node.value() instanceof IsNull inner) {
if (inner.value() instanceof Reference) {
Symbol symbol = Symbol.from(inner.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(input.getOutputRowCount() * (1 - symbolStats.getNullsFraction()));
Expand All @@ -161,19 +161,19 @@ protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void cont
}
return PlanNodeStatsEstimate.unknown();
}
return subtractSubsetStats(input, process(node.getValue()));
return subtractSubsetStats(input, process(node.value()));
}

@Override
protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, Void context)
protected PlanNodeStatsEstimate visitLogical(Logical node, Void context)
{
switch (node.getOperator()) {
switch (node.operator()) {
case AND:
return estimateLogicalAnd(node.getTerms());
return estimateLogicalAnd(node.terms());
case OR:
return estimateLogicalOr(node.getTerms());
return estimateLogicalOr(node.terms());
}
throw new IllegalArgumentException("Unexpected binary operator: " + node.getOperator());
throw new IllegalArgumentException("Unexpected binary operator: " + node.operator());
}

private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms)
Expand Down Expand Up @@ -262,8 +262,8 @@ private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms)
@Override
protected PlanNodeStatsEstimate visitConstant(Constant node, Void context)
{
if (node.getType().equals(BOOLEAN) && node.getValue() != null) {
if ((boolean) node.getValue()) {
if (node.type().equals(BOOLEAN) && node.value() != null) {
if ((boolean) node.value()) {
return input;
}

Expand All @@ -277,10 +277,10 @@ protected PlanNodeStatsEstimate visitConstant(Constant node, Void context)
}

@Override
protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context)
protected PlanNodeStatsEstimate visitIsNull(IsNull node, Void context)
{
if (node.getValue() instanceof SymbolReference) {
Symbol symbol = Symbol.from(node.getValue());
if (node.value() instanceof Reference) {
Symbol symbol = Symbol.from(node.value());
SymbolStatsEstimate symbolStats = input.getSymbolStatistics(symbol);
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(input.getOutputRowCount() * symbolStats.getNullsFraction());
Expand All @@ -296,21 +296,21 @@ protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void
}

@Override
protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context)
protected PlanNodeStatsEstimate visitBetween(Between node, Void context)
{
SymbolStatsEstimate valueStats = getExpressionStats(node.getValue());
SymbolStatsEstimate valueStats = getExpressionStats(node.value());
if (valueStats.isUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
if (!getExpressionStats(node.getMin()).isSingleValue()) {
if (!getExpressionStats(node.min()).isSingleValue()) {
return PlanNodeStatsEstimate.unknown();
}
if (!getExpressionStats(node.getMax()).isSingleValue()) {
if (!getExpressionStats(node.max()).isSingleValue()) {
return PlanNodeStatsEstimate.unknown();
}

Expression lowerBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
Expression upperBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
Expression lowerBound = new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min());
Expression upperBound = new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max());

Expression transformed;
if (isInfinite(valueStats.getLowValue())) {
Expand All @@ -325,10 +325,10 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi
}

@Override
protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
protected PlanNodeStatsEstimate visitIn(In node, Void context)
{
ImmutableList<PlanNodeStatsEstimate> equalityEstimates = node.getValueList().stream()
.map(inValue -> process(new ComparisonExpression(EQUAL, node.getValue(), inValue)))
ImmutableList<PlanNodeStatsEstimate> equalityEstimates = node.valueList().stream()
.map(inValue -> process(new Comparison(EQUAL, node.value(), inValue)))
.collect(toImmutableList());

if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) {
Expand All @@ -343,7 +343,7 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
return PlanNodeStatsEstimate.unknown();
}

SymbolStatsEstimate valueStats = getExpressionStats(node.getValue());
SymbolStatsEstimate valueStats = getExpressionStats(node.value());
if (valueStats.isUnknown()) {
return PlanNodeStatsEstimate.unknown();
}
Expand All @@ -353,8 +353,8 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(input);
result.setOutputRowCount(min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn));

if (node.getValue() instanceof SymbolReference) {
Symbol valueSymbol = Symbol.from(node.getValue());
if (node.value() instanceof Reference) {
Symbol valueSymbol = Symbol.from(node.value());
SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol)
.mapDistinctValuesCount(newDistinctValuesCount -> min(newDistinctValuesCount, valueStats.getDistinctValuesCount()));
result.addSymbolStatistics(valueSymbol, newSymbolStats);
Expand All @@ -364,30 +364,30 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)

@SuppressWarnings("ArgumentSelectionDefectChecker")
@Override
protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context)
protected PlanNodeStatsEstimate visitComparison(Comparison node, Void context)
{
ComparisonExpression.Operator operator = node.getOperator();
Expression left = node.getLeft();
Expression right = node.getRight();
Comparison.Operator operator = node.operator();
Expression left = node.left();
Expression right = node.right();

checkArgument(!(left instanceof Constant && right instanceof Constant), "Literal-to-literal not supported here, should be eliminated earlier");

if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
if (!(left instanceof Reference) && right instanceof Reference) {
// normalize so that symbol is on the left
return process(new ComparisonExpression(operator.flip(), right, left));
return process(new Comparison(operator.flip(), right, left));
}

if (left instanceof Constant) {
// normalize so that literal is on the right
return process(new ComparisonExpression(operator.flip(), right, left));
return process(new Comparison(operator.flip(), right, left));
}

if (left instanceof SymbolReference && left.equals(right)) {
return process(new NotExpression(new IsNullPredicate(left)));
if (left instanceof Reference && left.equals(right)) {
return process(new Not(new IsNull(left)));
}

SymbolStatsEstimate leftStats = getExpressionStats(left);
Optional<Symbol> leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
Optional<Symbol> leftSymbol = left instanceof Reference ? Optional.of(Symbol.from(left)) : Optional.empty();
if (right instanceof Constant) {
Type type = left.type();
Object literalValue = evaluateConstantExpression(right, plannerContext, session);
Expand All @@ -405,22 +405,22 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n
return estimateExpressionToLiteralComparison(input, leftStats, leftSymbol, value, operator);
}

Optional<Symbol> rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty();
Optional<Symbol> rightSymbol = right instanceof Reference ? Optional.of(Symbol.from(right)) : Optional.empty();
return estimateExpressionToExpressionComparison(input, leftStats, leftSymbol, rightStats, rightSymbol, operator);
}

@Override
protected PlanNodeStatsEstimate visitFunctionCall(FunctionCall node, Void context)
protected PlanNodeStatsEstimate visitCall(Call node, Void context)
{
if (isDynamicFilter(node)) {
return process(BooleanLiteral.TRUE_LITERAL, context);
return process(Booleans.TRUE, context);
}
return PlanNodeStatsEstimate.unknown();
}

private SymbolStatsEstimate getExpressionStats(Expression expression)
{
if (expression instanceof SymbolReference) {
if (expression instanceof Reference) {
Symbol symbol = Symbol.from(expression);
return requireNonNull(input.getSymbolStatistics(symbol), () -> format("No statistics for symbol %s", symbol));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.trino.Session;
import io.trino.cost.StatsCalculator.Context;
import io.trino.matching.Pattern;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.JoinNode;
Expand All @@ -36,7 +36,7 @@
import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT;
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.SymbolStatsEstimate.buildFrom;
import static io.trino.sql.ir.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.IrUtils.extractConjuncts;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.util.MoreMath.firstNonNaN;
Expand Down Expand Up @@ -183,7 +183,7 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses(
// clause separately because stats estimates would be way off.
List<PlanNodeStatsEstimateWithClause> knownEstimates = clauses.stream()
.map(clause -> {
ComparisonExpression predicate = new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference());
Comparison predicate = new Comparison(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference());
return new PlanNodeStatsEstimateWithClause(filterStatsCalculator.filterStats(stats, predicate, session), clause);
})
.collect(toImmutableList());
Expand Down
Loading