diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java index 1a248924a3b9b..d98a3562e2bf1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java @@ -50,6 +50,7 @@ import java.math.BigDecimal; import java.time.LocalDate; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.List; @@ -78,6 +79,7 @@ import static com.facebook.presto.hive.HiveSessionProperties.isStatisticsEnabled; import static com.facebook.presto.hive.metastore.MetastoreUtil.getMetastoreHeaders; import static com.facebook.presto.hive.metastore.MetastoreUtil.isUserDefinedTypeEncodingEnabled; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -397,15 +399,11 @@ private static TableStatistics getTableStatistics( } checkArgument(!partitions.isEmpty(), "partitions is empty"); - - OptionalDouble optionalAverageRowsPerPartition = calculateAverageRowsPerPartition(statistics.values()); - if (!optionalAverageRowsPerPartition.isPresent()) { + Optional optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size()); + if (!optionalRowCount.isPresent()) { return TableStatistics.empty(); } - double averageRowsPerPartition = optionalAverageRowsPerPartition.getAsDouble(); - verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); - int queriedPartitionsCount = partitions.size(); - double rowCount = averageRowsPerPartition * queriedPartitionsCount; + double rowCount = optionalRowCount.get().getRowCount(); TableStatistics.Builder result = TableStatistics.builder(); result.setRowCount(Estimate.of(rowCount)); @@ -414,7 +412,7 @@ private static TableStatistics getTableStatistics( if (optionalAverageSizePerPartition.isPresent()) { double averageSizePerPartition = optionalAverageSizePerPartition.getAsDouble(); verify(averageSizePerPartition >= 0, "averageSizePerPartition must be greater than or equal to zero: %s", averageSizePerPartition); - double totalSize = averageSizePerPartition * queriedPartitionsCount; + double totalSize = averageSizePerPartition * partitions.size(); result.setTotalSize(Estimate.of(totalSize)); } @@ -424,6 +422,7 @@ private static TableStatistics getTableStatistics( Type columnType = columnTypes.get(columnName); ColumnStatistics columnStatistics; if (columnHandle.isPartitionKey()) { + double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition(); columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount); } else { @@ -435,15 +434,98 @@ private static TableStatistics getTableStatistics( } @VisibleForTesting - static OptionalDouble calculateAverageRowsPerPartition(Collection statistics) + static Optional calculatePartitionsRowCount(Collection statistics, int queriedPartitionsCount) { - return statistics.stream() + long[] rowCounts = statistics.stream() .map(PartitionStatistics::getBasicStatistics) .map(HiveBasicStatistics::getRowCount) .filter(OptionalLong::isPresent) .mapToLong(OptionalLong::getAsLong) .peek(count -> verify(count >= 0, "count must be greater than or equal to zero")) - .average(); + .toArray(); + int sampleSize = statistics.size(); + // Sample contains all the queried partitions, estimate avg normally + if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) { + OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average(); + if (!averageRowsPerPartitionOptional.isPresent()) { + return Optional.empty(); + } + double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble(); + return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount)); + } + + // Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count. + // Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the + // possibility of errors in the extrapolated rowCount due to a couple of outliers. + int minIndex = 0; + int maxIndex = 0; + long rowCountSum = rowCounts[0]; + for (int index = 1; index < rowCounts.length; index++) { + if (rowCounts[index] < rowCounts[minIndex]) { + minIndex = index; + } + else if (rowCounts[index] > rowCounts[maxIndex]) { + maxIndex = index; + } + rowCountSum += rowCounts[index]; + } + double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2); + double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex]; + return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount)); + } + + @VisibleForTesting + static class PartitionsRowCount + { + private final double averageRowsPerPartition; + private final double rowCount; + + PartitionsRowCount(double averageRowsPerPartition, double rowCount) + { + verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); + verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); + this.averageRowsPerPartition = averageRowsPerPartition; + this.rowCount = rowCount; + } + + private double getAverageRowsPerPartition() + { + return averageRowsPerPartition; + } + + private double getRowCount() + { + return rowCount; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartitionsRowCount that = (PartitionsRowCount) o; + return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0 + && Double.compare(that.rowCount, rowCount) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(averageRowsPerPartition, rowCount); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("averageRowsPerPartition", averageRowsPerPartition) + .add("rowCount", rowCount) + .toString(); + } } @VisibleForTesting diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java index 44944c82d02d8..944ce503ae96d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -71,7 +71,7 @@ import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createDecimalColumnStatistics; import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createDoubleColumnStatistics; import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createIntegerColumnStatistics; -import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageRowsPerPartition; +import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageSizePerPartition; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey; @@ -79,6 +79,7 @@ import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey; +import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey; import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble; @@ -87,6 +88,7 @@ import static com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider.validatePartitionStatistics; import static java.lang.Double.NaN; import static java.lang.String.format; +import static java.util.Collections.nCopies; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; @@ -255,15 +257,34 @@ public void testValidatePartitionStatistics() } @Test - public void testCalculateAverageRowsPerPartition() - { - assertThat(calculateAverageRowsPerPartition(ImmutableList.of())).isEmpty(); - assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty()))).isEmpty(); - assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()))).isEmpty(); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10))), OptionalDouble.of(10)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), PartitionStatistics.empty())), OptionalDouble.of(10)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20))), OptionalDouble.of(15)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty())), OptionalDouble.of(15)); + public void testCalculatePartitionsRowCount() + { + assertThat(calculatePartitionsRowCount(ImmutableList.of(), 0)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty()), 1)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()), 2)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 1)) + .isEqualTo(Optional.of(new PartitionsRowCount(10, 10))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(10, 20))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), PartitionStatistics.empty()), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(10, 20))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 30))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 45))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty()), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 45))); + + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(100), rowsCount(1000)), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount((10 + 100 + 1000) / 3.0, 10 + 100 + 1000))); + // Exclude outliers from average row count + assertThat(calculatePartitionsRowCount(ImmutableList.builder() + .addAll(nCopies(10, rowsCount(100))) + .add(rowsCount(1)) + .add(rowsCount(1000)) + .build(), + 50)) + .isEqualTo(Optional.of(new PartitionsRowCount(100, (100 * 48) + 1 + 1000))); } @Test