Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static io.trino.spi.session.PropertyMetadata.enumProperty;
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
import static io.trino.spi.session.PropertyMetadata.integerProperty;
import static io.trino.spi.session.PropertyMetadata.stringProperty;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
import static java.lang.Math.min;
Expand All @@ -56,6 +57,7 @@ public final class SystemSessionProperties
public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation";
public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type";
public static final String JOIN_MAX_BROADCAST_TABLE_SIZE = "join_max_broadcast_table_size";
Comment thread
sopel39 marked this conversation as resolved.
Outdated
public static final String JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR = "join_multi_clause_independence_factor";
public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join";
public static final String HASH_PARTITION_COUNT = "hash_partition_count";
public static final String GROUPED_EXECUTION = "grouped_execution";
Expand Down Expand Up @@ -121,6 +123,7 @@ public final class SystemSessionProperties
public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures";
public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task";
public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled";
public static final String FILTER_CONJUNCTION_INDEPENDENCE_FACTOR = "filter_conjunction_independence_factor";
public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort";
public static final String ALLOW_PUSHDOWN_INTO_CONNECTORS = "allow_pushdown_into_connectors";
public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown";
Expand Down Expand Up @@ -203,6 +206,15 @@ public SystemSessionProperties(
"Maximum estimated size of a table that can be broadcast when using automatic join type selection",
optimizerConfig.getJoinMaxBroadcastTableSize(),
false),
new PropertyMetadata<>(
JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR,
"Scales the strength of independence assumption for selectivity estimates of multi-clause joins",
DOUBLE,
Double.class,
optimizerConfig.getJoinMultiClauseIndependenceFactor(),
false,
value -> validateDoubleRange(value, JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, 0.0, 1.0),
value -> value),
booleanProperty(
DISTRIBUTED_INDEX_JOIN,
"Distribute index joins on join keys instead of executing inline",
Expand Down Expand Up @@ -551,6 +563,15 @@ public SystemSessionProperties(
"use a default filter factor for unknown filters in a filter node",
optimizerConfig.isDefaultFilterFactorEnabled(),
false),
new PropertyMetadata<>(
FILTER_CONJUNCTION_INDEPENDENCE_FACTOR,
"Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters",
DOUBLE,
Double.class,
optimizerConfig.getFilterConjunctionIndependenceFactor(),
false,
value -> validateDoubleRange(value, FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, 0.0, 1.0),
value -> value),
booleanProperty(
SKIP_REDUNDANT_SORT,
"Skip redundant sort operations",
Expand Down Expand Up @@ -756,6 +777,11 @@ public static DataSize getJoinMaxBroadcastTableSize(Session session)
return session.getSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, DataSize.class);
}

public static double getJoinMultiClauseIndependenceFactor(Session session)
{
return session.getSystemProperty(JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, Double.class);
}

public static boolean isDistributedIndexJoinEnabled(Session session)
{
return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class);
Expand Down Expand Up @@ -1103,6 +1129,17 @@ private static Integer validateIntegerValue(Object value, String property, int l
return intValue;
}

private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded)
{
double doubleValue = (double) value;
if (doubleValue < lowerBoundIncluded || doubleValue > upperBoundIncluded) {
throw new TrinoException(
INVALID_SESSION_PROPERTY,
format("%s must be in the range [%.2f, %.2f]: %.2f", property, lowerBoundIncluded, upperBoundIncluded, doubleValue));
}
return doubleValue;
}

public static boolean isStatisticsCpuTimerEnabled(Session session)
{
return session.getSystemProperty(STATISTICS_CPU_TIMER_ENABLED, Boolean.class);
Expand Down Expand Up @@ -1133,6 +1170,11 @@ public static boolean isDefaultFilterFactorEnabled(Session session)
return session.getSystemProperty(DEFAULT_FILTER_FACTOR_ENABLED, Boolean.class);
}

public static double getFilterConjunctionIndependenceFactor(Session session)
{
return session.getSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, Double.class);
}

public static boolean isSkipRedundantSort(Session session)
{
return session.getSystemProperty(SKIP_REDUNDANT_SORT, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.OptionalDouble;

import static io.trino.cost.SymbolStatsEstimate.buildFrom;
import static io.trino.util.MoreMath.firstNonNaN;
import static io.trino.util.MoreMath.averageExcludingNaNs;
import static io.trino.util.MoreMath.max;
import static io.trino.util.MoreMath.min;
import static java.lang.Double.NEGATIVE_INFINITY;
Expand Down Expand Up @@ -239,15 +239,4 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered));
return result.build();
}

private static double averageExcludingNaNs(double first, double second)
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
{
if (isNaN(first) && isNaN(second)) {
return NaN;
}
if (!isNaN(first) && !isNaN(second)) {
return (first + second) / 2;
}
return firstNonNaN(first, second);
}
}
157 changes: 118 additions & 39 deletions core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.cost;

import com.google.common.base.VerifyException;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
Expand All @@ -39,6 +40,7 @@
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
Expand All @@ -48,7 +50,6 @@
import javax.annotation.Nullable;
import javax.inject.Inject;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -57,10 +58,13 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison;
import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues;
import static io.trino.cost.PlanNodeStatsEstimateMath.capStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats;
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
Expand Down Expand Up @@ -137,7 +141,14 @@ private class FilterExpressionStatsCalculatingVisitor
@Override
public PlanNodeStatsEstimate process(Node node, @Nullable Void context)
{
return normalizer.normalize(super.process(node, context), types);
PlanNodeStatsEstimate output;
if (input.getOutputRowCount() == 0 || input.isOutputRowCountUnknown()) {
output = input;
}
else {
output = super.process(node, context);
}
return normalizer.normalize(output, types);
}

@Override
Expand Down Expand Up @@ -169,35 +180,93 @@ protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, V

private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms)
{
// first try to estimate in the fair way
PlanNodeStatsEstimate estimate = process(terms.get(0));
if (!estimate.isOutputRowCountUnknown()) {
for (int i = 1; i < terms.size(); i++) {
estimate = new FilterExpressionStatsCalculatingVisitor(estimate, session, types).process(terms.get(i));
double filterConjunctionIndependenceFactor = getFilterConjunctionIndependenceFactor(session);
List<PlanNodeStatsEstimate> estimates = estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor);
double outputRowCount = estimateCorrelatedConjunctionRowCount(
input,
estimates,
filterConjunctionIndependenceFactor);
if (isNaN(outputRowCount)) {
return PlanNodeStatsEstimate.unknown();
}
return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedStats(estimates)), types);
}

Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
if (estimate.isOutputRowCountUnknown()) {
break;
/**
* There can be multiple predicate expressions for the same symbol, e.g. x > 0 AND x <= 1, x BETWEEN 1 AND 10.
* We attempt to detect such cases in extractUncorrelatedGroups and calculate a combined estimate for each
* such group of expressions. This is done so that we don't apply the above scaling factors when combining estimates
* from conjunction of multiple predicates on the same symbol and underestimate the output.
**/
private List<PlanNodeStatsEstimate> estimateCorrelatedExpressions(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
ImmutableList.Builder<PlanNodeStatsEstimate> estimatesBuilder = ImmutableList.builder();
boolean hasUnestimatedTerm = false;
for (List<Expression> uncorrelatedExpressions : extractUncorrelatedGroups(terms, filterConjunctionIndependenceFactor)) {
PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown();
Comment thread
sopel39 marked this conversation as resolved.
Outdated
for (Expression expression : uncorrelatedExpressions) {
PlanNodeStatsEstimate estimate;
// combinedEstimate is unknown until the 1st known estimated term
if (combinedEstimate.isOutputRowCountUnknown()) {
estimate = process(expression);
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
}
else {
estimate = new FilterExpressionStatsCalculatingVisitor(combinedEstimate, session, types)
.process(expression);
}
}

if (!estimate.isOutputRowCountUnknown()) {
return estimate;
if (estimate.isOutputRowCountUnknown()) {
hasUnestimatedTerm = true;
}
else {
// update combinedEstimate only when the term estimate is known so that all the known estimates
// can be applied progressively through FilterExpressionStatsCalculatingVisitor calls.
combinedEstimate = estimate;
}
}
estimatesBuilder.add(combinedEstimate);
Comment thread
sopel39 marked this conversation as resolved.
Outdated
}

// If some of the filters cannot be estimated, take the smallest estimate.
// Apply 0.9 filter factor as "unknown filter" factor.
Optional<PlanNodeStatsEstimate> smallest = terms.stream()
.map(this::process)
.filter(termEstimate -> !termEstimate.isOutputRowCountUnknown())
.sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount))
.findFirst();

if (smallest.isEmpty()) {
return PlanNodeStatsEstimate.unknown();
if (hasUnestimatedTerm) {
estimatesBuilder.add(PlanNodeStatsEstimate.unknown());
}
return estimatesBuilder.build();
}

return smallest.get().mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT);
private List<List<Expression>> extractUncorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
if (filterConjunctionIndependenceFactor == 1) {
// Allows the filters to be estimated as if there is no correlation between any of the terms
return ImmutableList.of(terms);
}
ArrayListMultimap<Expression, Expression> groupedExpressions = ArrayListMultimap.create();
Comment thread
sopel39 marked this conversation as resolved.
Outdated
for (Expression expression : terms) {
// group expressions by common terms on LHS
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
if (expression instanceof ComparisonExpression) {
ComparisonExpression normalized = normalize((ComparisonExpression) expression);
groupedExpressions.put(normalized.getLeft(), normalized);
}
else if (expression instanceof InPredicate) {
groupedExpressions.put(((InPredicate) expression).getValue(), expression);
}
else if (expression instanceof NotExpression) {
groupedExpressions.put(((NotExpression) expression).getValue(), expression);
}
else if (expression instanceof LikePredicate) {
groupedExpressions.put(((LikePredicate) expression).getValue(), expression);
}
else if (expression instanceof IsNotNullPredicate) {
groupedExpressions.put(((IsNotNullPredicate) expression).getValue(), expression);
}
else if (expression instanceof IsNullPredicate) {
groupedExpressions.put(((IsNullPredicate) expression).getValue(), expression);
}
else {
groupedExpressions.put(expression, expression);
}
}
return groupedExpressions.keySet().stream()
.map(groupedExpressions::get)
.collect(toImmutableList());
}

private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms)
Expand Down Expand Up @@ -349,22 +418,12 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context)
@Override
protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context)
{
ComparisonExpression.Operator operator = node.getOperator();
Expression left = node.getLeft();
Expression right = node.getRight();

checkArgument(!(isEffectivelyLiteral(left) && isEffectivelyLiteral(right)), "Literal-to-literal not supported here, should be eliminated earlier");

if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
// normalize so that symbol is on the left
return process(new ComparisonExpression(operator.flip(), right, left));
}
checkArgument(!(isEffectivelyLiteral(node.getLeft()) && isEffectivelyLiteral(node.getRight())), "Literal-to-literal not supported here, should be eliminated earlier");

if (isEffectivelyLiteral(left)) {
verify(!isEffectivelyLiteral(right));
// normalize so that literal is on the right
return process(new ComparisonExpression(operator.flip(), right, left));
}
ComparisonExpression normalized = normalize(node);
ComparisonExpression.Operator operator = normalized.getOperator();
Expression left = normalized.getLeft();
Expression right = normalized.getRight();

if (left instanceof SymbolReference && left.equals(right)) {
return process(new IsNotNullPredicate(left));
Expand Down Expand Up @@ -441,5 +500,25 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal)
ImmutableMap.of());
return toStatsRepresentation(type, literalValue);
}

private ComparisonExpression normalize(ComparisonExpression node)
{
ComparisonExpression.Operator operator = node.getOperator();
Expression left = node.getLeft();
Expression right = node.getRight();

if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
// normalize so that symbol is on the left
return new ComparisonExpression(operator.flip(), right, left);
}

if (isEffectivelyLiteral(left)) {
verify(!isEffectivelyLiteral(right));
// normalize so that literal is on the right
return new ComparisonExpression(operator.flip(), right, left);
}

return node;
}
}
}
Loading