diff --git a/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java index 91b8dbac60ba..edcab355271b 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ComparisonStatsCalculator.java @@ -22,7 +22,9 @@ import static io.trino.cost.SymbolStatsEstimate.buildFrom; import static io.trino.util.MoreMath.averageExcludingNaNs; import static io.trino.util.MoreMath.max; +import static io.trino.util.MoreMath.maxExcludeNaN; import static io.trino.util.MoreMath.min; +import static io.trino.util.MoreMath.minExcludeNaN; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; import static java.lang.Double.POSITIVE_INFINITY; @@ -31,6 +33,11 @@ public final class ComparisonStatsCalculator { + // We assume uniform distribution of values within each range. + // Within the overlapping range, we assume that all pairs of distinct values from both ranges exist. + // Based on the above, we estimate that half of the pairs of values will match inequality predicate on average. + public static final double OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT = 0.5; + private ComparisonStatsCalculator() {} public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( @@ -164,6 +171,13 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( case LESS_THAN_OR_EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: + return estimateExpressionToExpressionInequality( + operator, + inputStatistics, + leftExpressionStatistics, + leftExpressionSymbol, + rightExpressionStatistics, + rightExpressionSymbol); case IS_DISTINCT_FROM: return PlanNodeStatsEstimate.unknown(); } @@ -239,4 +253,128 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered)); return result.build(); } + + private static PlanNodeStatsEstimate estimateExpressionToExpressionInequality( + ComparisonExpression.Operator operator, + PlanNodeStatsEstimate inputStatistics, + SymbolStatsEstimate leftExpressionStatistics, + Optional leftExpressionSymbol, + SymbolStatsEstimate rightExpressionStatistics, + Optional rightExpressionSymbol) + { + if (leftExpressionStatistics.isUnknown() || rightExpressionStatistics.isUnknown()) { + return PlanNodeStatsEstimate.unknown(); + } + if (isNaN(leftExpressionStatistics.getNullsFraction()) && isNaN(rightExpressionStatistics.getNullsFraction())) { + return PlanNodeStatsEstimate.unknown(); + } + if (leftExpressionStatistics.statisticRange().isEmpty() || rightExpressionStatistics.statisticRange().isEmpty()) { + return inputStatistics.mapOutputRowCount(rowCount -> 0.0); + } + + // We don't know the correlation between NULLs, so we take the max nullsFraction from the expression statistics + // to make a conservative estimate (nulls are fully correlated) for the NULLs filter factor + double nullsFilterFactor = 1 - maxExcludeNaN(leftExpressionStatistics.getNullsFraction(), rightExpressionStatistics.getNullsFraction()); + switch (operator) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return estimateExpressionLessThanOrEqualToExpression( + inputStatistics, + leftExpressionStatistics, + leftExpressionSymbol, + rightExpressionStatistics, + rightExpressionSymbol, + nullsFilterFactor); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return estimateExpressionLessThanOrEqualToExpression( + inputStatistics, + rightExpressionStatistics, + rightExpressionSymbol, + leftExpressionStatistics, + leftExpressionSymbol, + nullsFilterFactor); + default: + throw new IllegalArgumentException("Unsupported inequality operator " + operator); + } + } + + private static PlanNodeStatsEstimate estimateExpressionLessThanOrEqualToExpression( + PlanNodeStatsEstimate inputStatistics, + SymbolStatsEstimate leftExpressionStatistics, + Optional leftExpressionSymbol, + SymbolStatsEstimate rightExpressionStatistics, + Optional rightExpressionSymbol, + double nullsFilterFactor) + { + StatisticRange leftRange = StatisticRange.from(leftExpressionStatistics); + StatisticRange rightRange = StatisticRange.from(rightExpressionStatistics); + // left is always greater than right, no overlap + if (leftRange.getLow() > rightRange.getHigh()) { + return inputStatistics.mapOutputRowCount(rowCount -> 0.0); + } + // left is always lesser than right + if (leftRange.getHigh() < rightRange.getLow()) { + PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics); + leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics( + symbol, + leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0))); + rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics( + symbol, + rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0))); + return estimate.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor) + .build(); + } + + PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics); + double leftOverlappingRangeFraction = leftRange.overlapPercentWith(rightRange); + double leftAlwaysLessRangeFraction; + if (leftRange.getLow() < rightRange.getLow()) { + leftAlwaysLessRangeFraction = min( + leftRange.overlapPercentWith(new StatisticRange(leftRange.getLow(), rightRange.getLow(), NaN)), + // Prevents expanding NDVs in case range fractions addition goes beyond 1 for infinite ranges + 1 - leftOverlappingRangeFraction); + } + else { + leftAlwaysLessRangeFraction = 0; + } + leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics( + symbol, + SymbolStatsEstimate.builder() + .setLowValue(leftRange.getLow()) + .setHighValue(minExcludeNaN(leftRange.getHigh(), rightRange.getHigh())) + .setAverageRowSize(leftExpressionStatistics.getAverageRowSize()) + .setDistinctValuesCount(leftExpressionStatistics.getDistinctValuesCount() * (leftAlwaysLessRangeFraction + leftOverlappingRangeFraction)) + .setNullsFraction(0) + .build())); + + double rightOverlappingRangeFraction = rightRange.overlapPercentWith(leftRange); + double rightAlwaysGreaterRangeFraction; + if (leftRange.getHigh() < rightRange.getHigh()) { + rightAlwaysGreaterRangeFraction = min( + rightRange.overlapPercentWith(new StatisticRange(leftRange.getHigh(), rightRange.getHigh(), NaN)), + // Prevents expanding NDVs in case range fractions addition goes beyond 1 for infinite ranges + 1 - rightOverlappingRangeFraction); + } + else { + rightAlwaysGreaterRangeFraction = 0; + } + rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics( + symbol, + SymbolStatsEstimate.builder() + .setLowValue(maxExcludeNaN(leftRange.getLow(), rightRange.getLow())) + .setHighValue(rightRange.getHigh()) + .setAverageRowSize(rightExpressionStatistics.getAverageRowSize()) + .setDistinctValuesCount(rightExpressionStatistics.getDistinctValuesCount() * (rightOverlappingRangeFraction + rightAlwaysGreaterRangeFraction)) + .setNullsFraction(0) + .build())); + double filterFactor = + // all left range values which are below right range are selected + leftAlwaysLessRangeFraction + + // for pairs in overlapping range, only half of pairs are selected + leftOverlappingRangeFraction * rightOverlappingRangeFraction * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT + + // all pairs where left value is in overlapping range and right value is above left range are selected + leftOverlappingRangeFraction * rightAlwaysGreaterRangeFraction; + return estimate.setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor).build(); + } } diff --git a/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java index 833be89b7b3c..8ac103ff9cb5 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestComparisonStatsCalculator.java @@ -36,6 +36,7 @@ import java.util.Objects; import java.util.function.Consumer; +import static io.trino.cost.ComparisonStatsCalculator.OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; @@ -43,7 +44,9 @@ import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.lang.Double.NEGATIVE_INFINITY; @@ -69,6 +72,7 @@ public class TestComparisonStatsCalculator private SymbolStatsEstimate rightOpenStats; private SymbolStatsEstimate unknownRangeStats; private SymbolStatsEstimate emptyRangeStats; + private SymbolStatsEstimate unknownNdvRangeStats; private SymbolStatsEstimate varcharStats; @BeforeClass @@ -140,6 +144,13 @@ public void setUp() .setHighValue(NaN) .setNullsFraction(1.0) .build(); + unknownNdvRangeStats = SymbolStatsEstimate.builder() + .setAverageRowSize(4.0) + .setDistinctValuesCount(NaN) + .setLowValue(0) + .setHighValue(10) + .setNullsFraction(0.1) + .build(); varcharStats = SymbolStatsEstimate.builder() .setAverageRowSize(4.0) .setDistinctValuesCount(50.0) @@ -157,6 +168,7 @@ public void setUp() .addSymbolStatistics(new Symbol("rightOpen"), rightOpenStats) .addSymbolStatistics(new Symbol("unknownRange"), unknownRangeStats) .addSymbolStatistics(new Symbol("emptyRange"), emptyRangeStats) + .addSymbolStatistics(new Symbol("unknownNdvRange"), unknownNdvRangeStats) .addSymbolStatistics(new Symbol("varchar"), varcharStats) .setOutputRowCount(1000.0) .build(); @@ -171,6 +183,7 @@ public void setUp() .put(new Symbol("rightOpen"), DoubleType.DOUBLE) .put(new Symbol("unknownRange"), DoubleType.DOUBLE) .put(new Symbol("emptyRange"), DoubleType.DOUBLE) + .put(new Symbol("unknownNdvRange"), DoubleType.DOUBLE) .put(new Symbol("varchar"), VarcharType.createVarcharType(10)) .buildOrThrow()); } @@ -695,6 +708,237 @@ public void symbolToCastExpressionNotEqual() .symbolStats("z", equalTo(capNDV(zStats, rowCount))); } + @Test + public void symbolToSymbolInequalityStats() + { + double inputRowCount = standardInputStatistics.getOutputRowCount(); + // z's stats should be unchanged when not involved, except NDV capping to row count + + double nullsFractionX = 0.25; + double rowCount = inputRowCount * (1 - nullsFractionX); + // Same symbol on both sides of inequality, gets simplified to x IS NOT NULL + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new SymbolReference("x"))) + .outputRowsCount(rowCount) + .symbolStats("x", equalTo(capNDV(zeroNullsFraction(xStats), rowCount))); + + double nullsFractionU = 0.1; + double nonNullRowCount = inputRowCount * (1 - nullsFractionU); + rowCount = nonNullRowCount * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT; + // Equal ranges + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("u"), new SymbolReference("w"))) + .outputRowsCount(rowCount) + .symbolStats("u", equalTo(capNDV(zeroNullsFraction(uStats), rowCount))) + .symbolStats("w", equalTo(capNDV(zeroNullsFraction(wStats), rowCount))) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + double overlappingFractionX = 0.25; + double alwaysLesserFractionX = 0.5; + double nullsFractionY = 0.5; + nonNullRowCount = inputRowCount * (1 - nullsFractionY); + rowCount = nonNullRowCount * (alwaysLesserFractionX + (overlappingFractionX * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT)); + // One symbol's range is within the other's + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new SymbolReference("y"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-10) + .highValue(5) + .distinctValuesCount(30) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + assertCalculate(new ComparisonExpression(LESS_THAN_OR_EQUAL, new SymbolReference("x"), new SymbolReference("y"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-10) + .highValue(5) + .distinctValuesCount(30) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + // Flip symbols to be on opposite sides + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("y"), new SymbolReference("x"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-10) + .highValue(5) + .distinctValuesCount(30) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + double alwaysGreaterFractionX = 0.25; + rowCount = nonNullRowCount * (alwaysGreaterFractionX + overlappingFractionX * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new SymbolReference("y"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + assertCalculate(new ComparisonExpression(GREATER_THAN_OR_EQUAL, new SymbolReference("x"), new SymbolReference("y"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + // Flip symbols to be on opposite sides + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("y"), new SymbolReference("x"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(5) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + // Partially overlapping ranges + overlappingFractionX = 0.5; + nonNullRowCount = inputRowCount * (1 - nullsFractionX); + double overlappingFractionW = 0.5; + double alwaysGreaterFractionW = 0.5; + rowCount = nonNullRowCount * (alwaysLesserFractionX + + overlappingFractionX * (overlappingFractionW * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT + alwaysGreaterFractionW)); + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("x"), new SymbolReference("w"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-10) + .highValue(10) + .distinctValuesCount(40) + .nullsFraction(0)) + .symbolStats("w", symbolAssert -> symbolAssert.averageRowSize(8) + .lowValue(0) + .highValue(20) + .distinctValuesCount(30) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + // Flip symbols to be on opposite sides + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("w"), new SymbolReference("x"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-10) + .highValue(10) + .distinctValuesCount(40) + .nullsFraction(0)) + .symbolStats("w", symbolAssert -> symbolAssert.averageRowSize(8) + .lowValue(0) + .highValue(20) + .distinctValuesCount(30) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + rowCount = nonNullRowCount * (overlappingFractionX * overlappingFractionW * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); + assertCalculate(new ComparisonExpression(GREATER_THAN, new SymbolReference("x"), new SymbolReference("w"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("w", symbolAssert -> symbolAssert.averageRowSize(8) + .lowValue(0) + .highValue(10) + .distinctValuesCount(15) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + // Flip symbols to be on opposite sides + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("w"), new SymbolReference("x"))) + .outputRowsCount(rowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(20) + .nullsFraction(0)) + .symbolStats("w", symbolAssert -> symbolAssert.averageRowSize(8) + .lowValue(0) + .highValue(10) + .distinctValuesCount(15) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + // Open ranges + double nullsFractionLeft = 0.1; + nonNullRowCount = inputRowCount * (1 - nullsFractionLeft); + double overlappingFractionLeft = 0.25; + double alwaysLesserFractionLeft = 0.5; + double overlappingFractionRight = 0.25; + double alwaysGreaterFractionRight = 0.5; + rowCount = nonNullRowCount * (alwaysLesserFractionLeft + overlappingFractionLeft + * (overlappingFractionRight * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT + alwaysGreaterFractionRight)); + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("leftOpen"), new SymbolReference("rightOpen"))) + .outputRowsCount(rowCount) + .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(NEGATIVE_INFINITY) + .highValue(15) + .distinctValuesCount(37.5) + .nullsFraction(0)) + .symbolStats("rightOpen", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(-15) + .highValue(POSITIVE_INFINITY) + .distinctValuesCount(37.5) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + rowCount = nonNullRowCount * (alwaysLesserFractionLeft + overlappingFractionLeft * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT); + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("leftOpen"), new SymbolReference("unknownNdvRange"))) + .outputRowsCount(rowCount) + .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(NEGATIVE_INFINITY) + .highValue(10) + .distinctValuesCount(37.5) + .nullsFraction(0)) + .symbolStats("unknownNdvRange", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(0) + .highValue(10) + .distinctValuesCount(NaN) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + + rowCount = nonNullRowCount * OVERLAPPING_RANGE_INEQUALITY_FILTER_COEFFICIENT; + assertCalculate(new ComparisonExpression(LESS_THAN, new SymbolReference("leftOpen"), new SymbolReference("unknownRange"))) + .outputRowsCount(rowCount) + .symbolStats("leftOpen", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(NEGATIVE_INFINITY) + .highValue(15) + .distinctValuesCount(50) + .nullsFraction(0)) + .symbolStats("unknownRange", symbolAssert -> symbolAssert.averageRowSize(4) + .lowValue(NEGATIVE_INFINITY) + .highValue(POSITIVE_INFINITY) + .distinctValuesCount(50) + .nullsFraction(0)) + .symbolStats("z", equalTo(capNDV(zStats, rowCount))); + } + private static void checkConsistent(StatsNormalizer normalizer, String source, PlanNodeStatsEstimate stats, Collection outputSymbols, TypeProvider types) { PlanNodeStatsEstimate normalized = normalizer.normalize(stats, outputSymbols, types); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index b814aa82a3b7..679f8e2e775a 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -205,6 +205,34 @@ public void testComparison() } } + @Test + public void testInequalityComparisonApproximation() + { + assertExpression("x > emptyRange").outputRowsCount(0); + + assertExpression("x > y + 20").outputRowsCount(0); + assertExpression("x >= y + 20").outputRowsCount(0); + assertExpression("x < y - 25").outputRowsCount(0); + assertExpression("x <= y - 25").outputRowsCount(0); + + double nullsFractionY = 0.5; + double inputRowCount = standardInputStatistics.getOutputRowCount(); + double nonNullRowCount = inputRowCount * (1 - nullsFractionY); + SymbolStatsEstimate nonNullStatsX = xStats.mapNullsFraction(nullsFraction -> 0.0); + assertExpression("x > y - 25") + .outputRowsCount(nonNullRowCount) + .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); + assertExpression("x >= y - 25") + .outputRowsCount(nonNullRowCount) + .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); + assertExpression("x < y + 20") + .outputRowsCount(nonNullRowCount) + .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); + assertExpression("x <= y + 20") + .outputRowsCount(nonNullRowCount) + .symbolStats("x", symbolAssert -> symbolAssert.isEqualTo(nonNullStatsX)); + } + @Test public void testOrStats() { diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q72.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q72.plan.txt index e867ccf1c95b..92fc29d6d713 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q72.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q72.plan.txt @@ -7,40 +7,40 @@ local exchange (GATHER, SINGLE, []) join (LEFT, PARTITIONED): remote exchange (REPARTITION, HASH, ["cs_order_number", "inv_item_sk"]) join (LEFT, REPLICATED): - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["d_week_seq_4", "inv_item_sk"]) - join (INNER, REPLICATED): - join (INNER, REPLICATED): - scan inventory - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan warehouse - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["cs_item_sk", "d_week_seq"]) - join (INNER, REPLICATED): + join (INNER, REPLICATED): + join (INNER, REPLICATED): + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, ["d_week_seq_4", "inv_item_sk"]) join (INNER, REPLICATED): + scan inventory + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["cs_item_sk", "d_week_seq"]) join (INNER, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): - scan catalog_sales + join (INNER, REPLICATED): + scan catalog_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan customer_demographics local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan customer_demographics + scan household_demographics local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan household_demographics + scan date_dim local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan item + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan warehouse + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan item local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan promotion diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q04.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q04.plan.txt index 1fc783a9926b..90bd90b44639 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q04.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q04.plan.txt @@ -6,10 +6,11 @@ remote exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["orderpriority"]) partial aggregation over (orderpriority) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["orderkey"]) - scan orders final aggregation over (orderkey_1) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["orderkey_1"]) partial aggregation over (orderkey_1) scan lineitem + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["orderkey"]) + scan orders diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q21.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q21.plan.txt index db8906cd1930..01e14a57f294 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q21.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpch/q21.plan.txt @@ -4,30 +4,29 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["name"]) partial aggregation over (name) - single aggregation over (commitdate, exists, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey, unique) + single aggregation over (commitdate, exists, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey_1, unique) join (LEFT, PARTITIONED): - final aggregation over (commitdate, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey, unique_49) + final aggregation over (commitdate, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey_1, unique_49) local exchange (GATHER, SINGLE, []) - partial aggregation over (commitdate, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey, unique_49) - join (LEFT, PARTITIONED): - join (INNER, REPLICATED): + partial aggregation over (commitdate, name, name_8, nationkey, orderkey, orderstatus, receiptdate, suppkey_1, unique_49) + join (RIGHT, PARTITIONED): + remote exchange (REPARTITION, HASH, ["orderkey_11"]) + scan lineitem + local exchange (GATHER, SINGLE, []) join (INNER, PARTITIONED): remote exchange (REPARTITION, HASH, ["orderkey"]) - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["suppkey"]) - scan supplier + join (INNER, REPLICATED): + scan lineitem local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["suppkey_1"]) - scan lineitem + remote exchange (REPLICATE, BROADCAST, []) + join (INNER, REPLICATED): + scan supplier + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan nation local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["orderkey_4"]) scan orders - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan nation - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["orderkey_11"]) - scan lineitem local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["orderkey_29"]) scan lineitem diff --git a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java index 78b9c85717a0..b4f207cb2ec9 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java @@ -78,6 +78,18 @@ public void testFilter() statisticsAssertion.check("SELECT l_orderkey FROM lineitem GROUP BY l_orderkey HAVING sum(l_quantity) > 30", checks -> checks.estimate(OUTPUT_ROW_COUNT, defaultTolerance())); + + statisticsAssertion.check("SELECT * FROM lineitem WHERE l_receiptdate > l_commitdate", + checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(-0.2, -0.18))); + + statisticsAssertion.check("SELECT * FROM lineitem WHERE l_receiptdate >= l_commitdate", + checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(-0.23, -0.2))); + + statisticsAssertion.check("SELECT * FROM lineitem WHERE l_receiptdate < l_commitdate", + checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(0.35, 0.38))); + + statisticsAssertion.check("SELECT * FROM lineitem WHERE l_receiptdate <= l_commitdate", + checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(0.3, 0.35))); } @Test diff --git a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java index 88d19e477ad6..2fd31e373b70 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java @@ -238,12 +238,12 @@ public void testLeftJoinStats() // simple non-equi join statisticsAssertion.check("SELECT * FROM partsupp LEFT JOIN lineitem ON ps_partkey = l_partkey AND ps_suppkey < l_suppkey", checks -> checks - .estimate(OUTPUT_ROW_COUNT, relativeError(4.0)) + .estimate(OUTPUT_ROW_COUNT, relativeError(0.3, 0.4)) .verifyExactColumnStatistics("ps_partkey") - .verifyColumnStatistics("l_partkey", relativeError(0.10)) + .verifyColumnStatistics("l_partkey", relativeError(0.7)) .verifyExactColumnStatistics("ps_suppkey") - .verifyColumnStatistics("l_suppkey", relativeError(1.0)) - .verifyColumnStatistics("l_orderkey", relativeError(0.10))); + .verifyColumnStatistics("l_suppkey", relativeError(0.7)) + .verifyColumnStatistics("l_orderkey", relativeError(0.7))); } @Test @@ -293,12 +293,12 @@ public void testRightJoinStats() // simple non-equi join statisticsAssertion.check("SELECT * FROM lineitem RIGHT JOIN partsupp ON ps_partkey = l_partkey AND ps_suppkey < l_suppkey", checks -> checks - .estimate(OUTPUT_ROW_COUNT, relativeError(4.0)) + .estimate(OUTPUT_ROW_COUNT, relativeError(0.3, 0.4)) .verifyExactColumnStatistics("ps_partkey") - .verifyColumnStatistics("l_partkey", relativeError(0.10)) + .verifyColumnStatistics("l_partkey", relativeError(0.7)) .verifyExactColumnStatistics("ps_suppkey") - .verifyColumnStatistics("l_suppkey", relativeError(1.0)) - .verifyColumnStatistics("l_orderkey", relativeError(0.10))); + .verifyColumnStatistics("l_suppkey", relativeError(0.7)) + .verifyColumnStatistics("l_orderkey", relativeError(0.7))); } @Test @@ -343,12 +343,12 @@ public void testFullJoinStats() // simple non-equi join statisticsAssertion.check("SELECT * FROM lineitem FULL JOIN partsupp ON ps_partkey = l_partkey AND ps_suppkey < l_suppkey", checks -> checks - .estimate(OUTPUT_ROW_COUNT, relativeError(4.0)) - .verifyColumnStatistics("ps_partkey", relativeError(0.10)) - .verifyColumnStatistics("l_partkey", relativeError(0.10)) - .verifyColumnStatistics("ps_suppkey", relativeError(0.10)) - .verifyColumnStatistics("l_suppkey", relativeError(1.0)) - .verifyColumnStatistics("l_orderkey", relativeError(0.10))); + .estimate(OUTPUT_ROW_COUNT, relativeError(0.4, 0.5)) + .verifyColumnStatistics("ps_partkey", relativeError(0.6)) + .verifyColumnStatistics("l_partkey", relativeError(0.6)) + .verifyColumnStatistics("ps_suppkey", relativeError(0.6)) + .verifyColumnStatistics("l_suppkey", relativeError(0.6)) + .verifyColumnStatistics("l_orderkey", relativeError(0.6))); } @Test