Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static io.trino.spi.session.PropertyMetadata.enumProperty;
import static io.trino.spi.session.PropertyMetadata.integerProperty;
import static io.trino.spi.session.PropertyMetadata.stringProperty;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
import static java.lang.Math.min;
Expand All @@ -56,6 +57,7 @@ public final class SystemSessionProperties
public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation";
public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type";
public static final String JOIN_MAX_BROADCAST_TABLE_SIZE = "join_max_broadcast_table_size";
public static final String JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR = "join_multi_clause_independence_factor";
public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join";
public static final String HASH_PARTITION_COUNT = "hash_partition_count";
public static final String GROUPED_EXECUTION = "grouped_execution";
Expand Down Expand Up @@ -121,6 +123,7 @@ public final class SystemSessionProperties
public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures";
public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task";
public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled";
public static final String FILTER_CONJUNCTION_INDEPENDENCE_FACTOR = "filter_conjunction_independence_factor";
public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort";
public static final String ALLOW_PUSHDOWN_INTO_CONNECTORS = "allow_pushdown_into_connectors";
public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown";
Expand Down Expand Up @@ -203,6 +206,15 @@ public SystemSessionProperties(
"Maximum estimated size of a table that can be broadcast when using automatic join type selection",
optimizerConfig.getJoinMaxBroadcastTableSize(),
false),
new PropertyMetadata<>(
JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR,
"Scales the strength of independence assumption for selectivity estimates of multi-clause joins",
DOUBLE,
Double.class,
optimizerConfig.getJoinMultiClauseIndependenceFactor(),
false,
value -> validateDoubleRange(value, JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, 0.0, 1.0),
value -> value),
booleanProperty(
DISTRIBUTED_INDEX_JOIN,
"Distribute index joins on join keys instead of executing inline",
Expand Down Expand Up @@ -551,6 +563,15 @@ public SystemSessionProperties(
"use a default filter factor for unknown filters in a filter node",
optimizerConfig.isDefaultFilterFactorEnabled(),
false),
new PropertyMetadata<>(
FILTER_CONJUNCTION_INDEPENDENCE_FACTOR,
"Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters",
DOUBLE,
Double.class,
optimizerConfig.getFilterConjunctionIndependenceFactor(),
false,
value -> validateDoubleRange(value, FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, 0.0, 1.0),
value -> value),
booleanProperty(
SKIP_REDUNDANT_SORT,
"Skip redundant sort operations",
Expand Down Expand Up @@ -756,6 +777,11 @@ public static DataSize getJoinMaxBroadcastTableSize(Session session)
return session.getSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, DataSize.class);
}

public static double getJoinMultiClauseIndependenceFactor(Session session)
{
return session.getSystemProperty(JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, Double.class);
}

public static boolean isDistributedIndexJoinEnabled(Session session)
{
return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class);
Expand Down Expand Up @@ -1103,6 +1129,17 @@ private static Integer validateIntegerValue(Object value, String property, int l
return intValue;
}

private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded)
{
double doubleValue = (double) value;
if (doubleValue < lowerBoundIncluded || doubleValue > upperBoundIncluded) {
throw new TrinoException(
INVALID_SESSION_PROPERTY,
format("%s must be in the range [%.2f, %.2f]: %.2f", property, lowerBoundIncluded, upperBoundIncluded, doubleValue));
}
return doubleValue;
}

public static boolean isStatisticsCpuTimerEnabled(Session session)
{
return session.getSystemProperty(STATISTICS_CPU_TIMER_ENABLED, Boolean.class);
Expand Down Expand Up @@ -1133,6 +1170,11 @@ public static boolean isDefaultFilterFactorEnabled(Session session)
return session.getSystemProperty(DEFAULT_FILTER_FACTOR_ENABLED, Boolean.class);
}

public static double getFilterConjunctionIndependenceFactor(Session session)
{
return session.getSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, Double.class);
}

public static boolean isSkipRedundantSort(Session session)
{
return session.getSystemProperty(SKIP_REDUNDANT_SORT, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.OptionalDouble;

import static io.trino.cost.SymbolStatsEstimate.buildFrom;
import static io.trino.util.MoreMath.firstNonNaN;
import static io.trino.util.MoreMath.averageExcludingNaNs;
import static io.trino.util.MoreMath.max;
import static io.trino.util.MoreMath.min;
import static java.lang.Double.NEGATIVE_INFINITY;
Expand Down Expand Up @@ -239,15 +239,4 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(
rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered));
return result.build();
}

private static double averageExcludingNaNs(double first, double second)
{
if (isNaN(first) && isNaN(second)) {
return NaN;
}
if (!isNaN(first) && !isNaN(second)) {
return (first + second) / 2;
}
return firstNonNaN(first, second);
}
}
136 changes: 111 additions & 25 deletions core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
package io.trino.cost;

import com.google.common.base.VerifyException;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AllowAllAccessControl;
Expand Down Expand Up @@ -44,30 +46,37 @@
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.util.DisjointSet;

import javax.annotation.Nullable;
import javax.inject.Inject;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison;
import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues;
import static io.trino.cost.PlanNodeStatsEstimateMath.capStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats;
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.ExpressionUtils.and;
import static io.trino.sql.ExpressionUtils.getExpressionTypes;
import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
Expand Down Expand Up @@ -137,7 +146,14 @@ private class FilterExpressionStatsCalculatingVisitor
@Override
public PlanNodeStatsEstimate process(Node node, @Nullable Void context)
{
return normalizer.normalize(super.process(node, context), types);
PlanNodeStatsEstimate output;
if (input.getOutputRowCount() == 0 || input.isOutputRowCountUnknown()) {
output = input;
}
else {
output = super.process(node, context);
}
return normalizer.normalize(output, types);
}

@Override
Expand Down Expand Up @@ -169,35 +185,56 @@ protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, V

private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms)
{
// first try to estimate in the fair way
PlanNodeStatsEstimate estimate = process(terms.get(0));
if (!estimate.isOutputRowCountUnknown()) {
for (int i = 1; i < terms.size(); i++) {
estimate = new FilterExpressionStatsCalculatingVisitor(estimate, session, types).process(terms.get(i));
double filterConjunctionIndependenceFactor = getFilterConjunctionIndependenceFactor(session);
List<PlanNodeStatsEstimate> estimates = estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor);
double outputRowCount = estimateCorrelatedConjunctionRowCount(
input,
estimates,
filterConjunctionIndependenceFactor);
if (isNaN(outputRowCount)) {
return PlanNodeStatsEstimate.unknown();
}
return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedStats(estimates)), types);
}

if (estimate.isOutputRowCountUnknown()) {
break;
/**
* There can be multiple predicate expressions for the same symbol, e.g. x > 0 AND x <= 1, x BETWEEN 1 AND 10.
* We attempt to detect such cases in extractCorrelatedGroups and calculate a combined estimate for each
* such group of expressions. This is done so that we don't apply the above scaling factors when combining estimates
* from conjunction of multiple predicates on the same symbol and underestimate the output.
**/
private List<PlanNodeStatsEstimate> estimateCorrelatedExpressions(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
ImmutableList.Builder<PlanNodeStatsEstimate> estimatesBuilder = ImmutableList.builder();
boolean hasUnestimatedTerm = false;
for (List<Expression> correlatedExpressions : extractCorrelatedGroups(terms, filterConjunctionIndependenceFactor)) {
PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown();
for (Expression expression : correlatedExpressions) {
PlanNodeStatsEstimate estimate;
// combinedEstimate is unknown until the 1st known estimated term
if (combinedEstimate.isOutputRowCountUnknown()) {
estimate = process(expression);
}
else {
estimate = new FilterExpressionStatsCalculatingVisitor(combinedEstimate, session, types)
.process(expression);
}
}

if (!estimate.isOutputRowCountUnknown()) {
return estimate;
if (estimate.isOutputRowCountUnknown()) {
hasUnestimatedTerm = true;
}
else {
// update combinedEstimate only when the term estimate is known so that all the known estimates
// can be applied progressively through FilterExpressionStatsCalculatingVisitor calls.
combinedEstimate = estimate;
}
}
estimatesBuilder.add(combinedEstimate);
}

// If some of the filters cannot be estimated, take the smallest estimate.
// Apply 0.9 filter factor as "unknown filter" factor.
Optional<PlanNodeStatsEstimate> smallest = terms.stream()
.map(this::process)
.filter(termEstimate -> !termEstimate.isOutputRowCountUnknown())
.sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount))
.findFirst();

if (smallest.isEmpty()) {
return PlanNodeStatsEstimate.unknown();
if (hasUnestimatedTerm) {
estimatesBuilder.add(PlanNodeStatsEstimate.unknown());
}

return smallest.get().mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT);
return estimatesBuilder.build();
}

private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms)
Expand Down Expand Up @@ -442,4 +479,53 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal)
return toStatsRepresentation(type, literalValue);
}
}

private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
if (filterConjunctionIndependenceFactor == 1) {
// Allows the filters to be estimated as if there is no correlation between any of the terms
return ImmutableList.of(terms);
}

ListMultimap<Expression, Symbol> expressionUniqueSymbols = ArrayListMultimap.create();
terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression)));
// Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets
// do not appear together in any expression.
DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<>();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
if (expressionSymbols.isEmpty()) {
continue;
}
// Ensure that symbol is added to DisjointSet when there is only one symbol in the list
symbolsPartitioner.find(expressionSymbols.get(0));
for (int i = 1; i < expressionSymbols.size(); i++) {
symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i));
}
}

// Use disjoint sets of symbols to partition the given list of expressions
List<Set<Symbol>> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions");
ListMultimap<Integer, Expression> expressionPartitions = ArrayListMultimap.create();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
int expressionPartitionId;
if (expressionSymbols.isEmpty()) {
expressionPartitionId = symbolPartitions.size(); // For expressions with no symbols
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: -1 instead?

}
else {
Symbol symbol = expressionSymbols.get(0); // Lookup any symbol to find the partition id
expressionPartitionId = IntStream.range(0, symbolPartitions.size())
.filter(partition -> symbolPartitions.get(partition).contains(symbol))
.findFirst()
.orElseThrow();
}
expressionPartitions.put(expressionPartitionId, term);
}

return expressionPartitions.keySet().stream()
.map(expressionPartitions::get)
.collect(toImmutableList());
}
}
Loading