diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java index 2f39dbbd64ab..254cd158bc73 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/InternalDeltaLakeConnectorFactory.java @@ -38,6 +38,7 @@ import io.trino.plugin.base.jmx.MBeanServerModule; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.deltalake.metastore.DeltaLakeMetastoreModule; +import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.NodeVersion; import io.trino.spi.NodeManager; import io.trino.spi.PageIndexerFactory; @@ -61,6 +62,7 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.inject.multibindings.Multibinder.newSetBinder; @@ -135,6 +137,8 @@ public static Connector createConnector( Set connectorTableFunctions = injector.getInstance(Key.get(new TypeLiteral>() {})); FunctionProvider functionProvider = injector.getInstance(FunctionProvider.class); + checkState(!injector.getBindings().containsKey(Key.get(HiveConfig.class)), "HiveConfig should not be bound"); + return new DeltaLakeConnector( injector, lifeCycleManager, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index eaa423967fbb..db6cb8640eb4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -1585,6 +1585,8 @@ public ConnectorTableHandle beginStatisticsCollection(ConnectorSession session, @Override public void finishStatisticsCollection(ConnectorSession session, ConnectorTableHandle tableHandle, Collection computedStatistics) { + verify(isStatisticsEnabled(session), "statistics not enabled"); + HiveTableHandle handle = (HiveTableHandle) tableHandle; SchemaTableName tableName = handle.getSchemaTableName(); Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) @@ -3453,6 +3455,9 @@ public TableStatisticsMetadata getStatisticsCollectionMetadataForWrite(Connector if (!isCollectColumnStatisticsOnWrite(session)) { return TableStatisticsMetadata.empty(); } + if (!isStatisticsEnabled(session)) { + throw new TrinoException(NOT_SUPPORTED, "Table statistics must be enabled when column statistics collection on write is enabled"); + } if (isTransactional(tableMetadata.getProperties()).orElse(false)) { // TODO(https://github.com/trinodb/trino/issues/1956) updating table statistics for transactional not supported right now. return TableStatisticsMetadata.empty(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProvider.java deleted file mode 100644 index a0b4024fcb3a..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProvider.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive.metastore.glue; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import io.trino.plugin.hive.HiveColumnStatisticType; -import io.trino.plugin.hive.metastore.HiveColumnStatistics; -import io.trino.plugin.hive.metastore.Partition; -import io.trino.plugin.hive.metastore.Table; -import io.trino.spi.TrinoException; -import io.trino.spi.type.Type; - -import java.util.Collection; -import java.util.Map; -import java.util.Set; - -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; -import static java.util.function.UnaryOperator.identity; - -public class DisabledGlueColumnStatisticsProvider - implements GlueColumnStatisticsProvider -{ - @Override - public Set getSupportedColumnStatistics(Type type) - { - return ImmutableSet.of(); - } - - @Override - public Map getTableColumnStatistics(Table table) - { - return ImmutableMap.of(); - } - - @Override - public Map> getPartitionColumnStatistics(Collection partitions) - { - return partitions.stream().collect(toImmutableMap(identity(), partition -> ImmutableMap.of())); - } - - @Override - public void updateTableColumnStatistics(Table table, Map columnStatistics) - { - if (!columnStatistics.isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, "Glue metastore column level statistics are disabled"); - } - } - - @Override - public void updatePartitionStatistics(Set partitionStatisticsUpdates) - { - if (partitionStatisticsUpdates.stream().anyMatch(update -> !update.getColumnStatistics().isEmpty())) { - throw new TrinoException(NOT_SUPPORTED, "Glue metastore column level statistics are disabled"); - } - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProviderFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProviderFactory.java deleted file mode 100644 index 6a06aa5bc33f..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/DisabledGlueColumnStatisticsProviderFactory.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive.metastore.glue; - -import com.amazonaws.services.glue.AWSGlueAsync; - -public class DisabledGlueColumnStatisticsProviderFactory - implements GlueColumnStatisticsProviderFactory -{ - @Override - public GlueColumnStatisticsProvider createGlueColumnStatisticsProvider(AWSGlueAsync glueClient, GlueMetastoreStats stats) - { - return new DisabledGlueColumnStatisticsProvider(); - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java index 77fcf7043200..ea4e4cb1e6da 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/metastore/glue/GlueMetastoreModule.java @@ -19,7 +19,6 @@ import com.amazonaws.services.glue.model.Table; import com.google.inject.Binder; import com.google.inject.Key; -import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.Singleton; @@ -31,7 +30,6 @@ import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.instrumentation.awssdk.v1_11.AwsSdkTelemetry; import io.trino.plugin.hive.AllowHiveTableRename; -import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.metastore.HiveMetastoreFactory; import io.trino.plugin.hive.metastore.RawHiveMetastoreFactory; @@ -42,8 +40,6 @@ import static com.google.inject.multibindings.Multibinder.newSetBinder; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.configuration.ConditionalModule.conditionalModule; -import static io.airlift.configuration.ConfigBinder.configBinder; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.weakref.jmx.guice.ExportBinder.newExporter; @@ -58,7 +54,6 @@ protected void setup(Binder binder) glueConfig.getCatalogId().ifPresent(catalogId -> requestHandlers.addBinding().toInstance(new GlueCatalogIdRequestHandler(catalogId))); glueConfig.getGlueProxyApiId().ifPresent(glueProxyApiId -> requestHandlers.addBinding() .toInstance(new ProxyApiRequestHandler(glueProxyApiId))); - configBinder(binder).bindConfig(HiveConfig.class); binder.bind(AWSCredentialsProvider.class).toProvider(GlueCredentialsProvider.class).in(Scopes.SINGLETON); newOptionalBinder(binder, Key.get(new TypeLiteral>() {}, ForGlueHiveMetastore.class)) @@ -78,19 +73,8 @@ protected void setup(Binder binder) binder.bind(Key.get(boolean.class, AllowHiveTableRename.class)).toInstance(false); - install(conditionalModule( - HiveConfig.class, - HiveConfig::isTableStatisticsEnabled, - getGlueStatisticsModule(DefaultGlueColumnStatisticsProviderFactory.class), - getGlueStatisticsModule(DisabledGlueColumnStatisticsProviderFactory.class))); - } - - private Module getGlueStatisticsModule(Class statisticsPrividerFactoryClass) - { - return internalBinder -> newOptionalBinder(internalBinder, GlueColumnStatisticsProviderFactory.class) - .setDefault() - .to(statisticsPrividerFactoryClass) - .in(Scopes.SINGLETON); + newOptionalBinder(binder, GlueColumnStatisticsProviderFactory.class) + .setDefault().to(DefaultGlueColumnStatisticsProviderFactory.class).in(Scopes.SINGLETON); } @ProvidesIntoSet diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/AbstractHiveStatisticsProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/AbstractHiveStatisticsProvider.java new file mode 100644 index 000000000000..bfaa9129f620 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/AbstractHiveStatisticsProvider.java @@ -0,0 +1,891 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.plugin.hive.statistics; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.hash.HashFunction; +import com.google.common.primitives.Ints; +import com.google.common.primitives.Shorts; +import com.google.common.primitives.SignedBytes; +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.trino.plugin.hive.HiveBasicStatistics; +import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HivePartition; +import io.trino.plugin.hive.PartitionStatistics; +import io.trino.plugin.hive.metastore.DateStatistics; +import io.trino.plugin.hive.metastore.DecimalStatistics; +import io.trino.plugin.hive.metastore.DoubleStatistics; +import io.trino.plugin.hive.metastore.HiveColumnStatistics; +import io.trino.plugin.hive.metastore.IntegerStatistics; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.NullableValue; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.DoubleRange; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.type.CharType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; + +import java.math.BigDecimal; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.DoubleStream; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.immutableEntry; +import static com.google.common.hash.Hashing.murmur3_128; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_CORRUPTED_COLUMN_STATISTICS; +import static io.trino.plugin.hive.HiveSessionProperties.getPartitionStatisticsSampleSize; +import static io.trino.plugin.hive.HiveSessionProperties.isIgnoreCorruptedStatistics; +import static io.trino.plugin.hive.HiveSessionProperties.isStatisticsEnabled; +import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; +import static java.lang.String.format; +import static java.util.Collections.unmodifiableList; + +public abstract class AbstractHiveStatisticsProvider + implements HiveStatisticsProvider +{ + private static final Logger log = Logger.get(AbstractHiveStatisticsProvider.class); + + @Override + public TableStatistics getTableStatistics( + ConnectorSession session, + SchemaTableName table, + Map columns, + Map columnTypes, + List partitions) + { + if (!isStatisticsEnabled(session)) { + return TableStatistics.empty(); + } + if (partitions.isEmpty()) { + return createZeroStatistics(columns, columnTypes); + } + int sampleSize = getPartitionStatisticsSampleSize(session); + List partitionsSample = getPartitionsSample(partitions, sampleSize); + try { + Map statisticsSample = getPartitionsStatistics(session, table, partitionsSample, columns.keySet()); + validatePartitionStatistics(table, statisticsSample); + return getTableStatistics(columns, columnTypes, partitions, statisticsSample); + } + catch (TrinoException e) { + if (e.getErrorCode().equals(HIVE_CORRUPTED_COLUMN_STATISTICS.toErrorCode()) && isIgnoreCorruptedStatistics(session)) { + log.error(e); + return TableStatistics.empty(); + } + throw e; + } + } + + protected abstract Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns); + + private TableStatistics createZeroStatistics(Map columns, Map columnTypes) + { + TableStatistics.Builder result = TableStatistics.builder(); + result.setRowCount(Estimate.of(0)); + columns.forEach((columnName, columnHandle) -> { + Type columnType = columnTypes.get(columnName); + verifyNotNull(columnType, "columnType is missing for column: %s", columnName); + ColumnStatistics.Builder columnStatistics = ColumnStatistics.builder(); + columnStatistics.setNullsFraction(Estimate.of(0)); + columnStatistics.setDistinctValuesCount(Estimate.of(0)); + if (hasDataSize(columnType)) { + columnStatistics.setDataSize(Estimate.of(0)); + } + result.setColumnStatistics(columnHandle, columnStatistics.build()); + }); + return result.build(); + } + + @VisibleForTesting + static List getPartitionsSample(List partitions, int sampleSize) + { + checkArgument(sampleSize > 0, "sampleSize is expected to be greater than zero"); + + if (partitions.size() <= sampleSize) { + return partitions; + } + + List result = new ArrayList<>(); + + int samplesLeft = sampleSize; + + HivePartition min = partitions.get(0); + HivePartition max = partitions.get(0); + for (HivePartition partition : partitions) { + if (partition.getPartitionId().compareTo(min.getPartitionId()) < 0) { + min = partition; + } + else if (partition.getPartitionId().compareTo(max.getPartitionId()) > 0) { + max = partition; + } + } + + result.add(min); + samplesLeft--; + if (samplesLeft > 0) { + result.add(max); + samplesLeft--; + } + + if (samplesLeft > 0) { + HashFunction hashFunction = murmur3_128(); + Comparator> hashComparator = Comparator + ., Long>comparing(Map.Entry::getValue) + .thenComparing(entry -> entry.getKey().getPartitionId()); + partitions.stream() + .filter(partition -> !result.contains(partition)) + .map(partition -> immutableEntry(partition, hashFunction.hashUnencodedChars(partition.getPartitionId()).asLong())) + .sorted(hashComparator) + .limit(samplesLeft) + .forEachOrdered(entry -> result.add(entry.getKey())); + } + + return unmodifiableList(result); + } + + @VisibleForTesting + static void validatePartitionStatistics(SchemaTableName table, Map partitionStatistics) + { + partitionStatistics.forEach((partition, statistics) -> { + HiveBasicStatistics basicStatistics = statistics.getBasicStatistics(); + OptionalLong rowCount = basicStatistics.getRowCount(); + rowCount.ifPresent(count -> checkStatistics(count >= 0, table, partition, "rowCount must be greater than or equal to zero: %s", count)); + basicStatistics.getFileCount().ifPresent(count -> checkStatistics(count >= 0, table, partition, "fileCount must be greater than or equal to zero: %s", count)); + basicStatistics.getInMemoryDataSizeInBytes().ifPresent(size -> checkStatistics(size >= 0, table, partition, "inMemoryDataSizeInBytes must be greater than or equal to zero: %s", size)); + basicStatistics.getOnDiskDataSizeInBytes().ifPresent(size -> checkStatistics(size >= 0, table, partition, "onDiskDataSizeInBytes must be greater than or equal to zero: %s", size)); + statistics.getColumnStatistics().forEach((column, columnStatistics) -> validateColumnStatistics(table, partition, column, rowCount, columnStatistics)); + }); + } + + private static void validateColumnStatistics(SchemaTableName table, String partition, String column, OptionalLong rowCount, HiveColumnStatistics columnStatistics) + { + columnStatistics.getMaxValueSizeInBytes().ifPresent(maxValueSizeInBytes -> + checkStatistics(maxValueSizeInBytes >= 0, table, partition, column, "maxValueSizeInBytes must be greater than or equal to zero: %s", maxValueSizeInBytes)); + columnStatistics.getTotalSizeInBytes().ifPresent(totalSizeInBytes -> + checkStatistics(totalSizeInBytes >= 0, table, partition, column, "totalSizeInBytes must be greater than or equal to zero: %s", totalSizeInBytes)); + columnStatistics.getNullsCount().ifPresent(nullsCount -> { + checkStatistics(nullsCount >= 0, table, partition, column, "nullsCount must be greater than or equal to zero: %s", nullsCount); + if (rowCount.isPresent()) { + checkStatistics( + nullsCount <= rowCount.getAsLong(), + table, + partition, + column, + "nullsCount must be less than or equal to rowCount. nullsCount: %s. rowCount: %s.", + nullsCount, + rowCount.getAsLong()); + } + }); + columnStatistics.getDistinctValuesCount().ifPresent(distinctValuesCount -> { + checkStatistics(distinctValuesCount >= 0, table, partition, column, "distinctValuesCount must be greater than or equal to zero: %s", distinctValuesCount); + if (rowCount.isPresent()) { + checkStatistics( + distinctValuesCount <= rowCount.getAsLong(), + table, + partition, + column, + "distinctValuesCount must be less than or equal to rowCount. distinctValuesCount: %s. rowCount: %s.", + distinctValuesCount, + rowCount.getAsLong()); + } + if (rowCount.isPresent() && columnStatistics.getNullsCount().isPresent()) { + long nonNullsCount = rowCount.getAsLong() - columnStatistics.getNullsCount().getAsLong(); + checkStatistics( + distinctValuesCount <= nonNullsCount, + table, + partition, + column, + "distinctValuesCount must be less than or equal to nonNullsCount. distinctValuesCount: %s. nonNullsCount: %s.", + distinctValuesCount, + nonNullsCount); + } + }); + + columnStatistics.getIntegerStatistics().ifPresent(integerStatistics -> { + OptionalLong min = integerStatistics.getMin(); + OptionalLong max = integerStatistics.getMax(); + if (min.isPresent() && max.isPresent()) { + checkStatistics( + min.getAsLong() <= max.getAsLong(), + table, + partition, + column, + "integerStatistics.min must be less than or equal to integerStatistics.max. integerStatistics.min: %s. integerStatistics.max: %s.", + min.getAsLong(), + max.getAsLong()); + } + }); + columnStatistics.getDoubleStatistics().ifPresent(doubleStatistics -> { + OptionalDouble min = doubleStatistics.getMin(); + OptionalDouble max = doubleStatistics.getMax(); + if (min.isPresent() && max.isPresent() && !isNaN(min.getAsDouble()) && !isNaN(max.getAsDouble())) { + checkStatistics( + min.getAsDouble() <= max.getAsDouble(), + table, + partition, + column, + "doubleStatistics.min must be less than or equal to doubleStatistics.max. doubleStatistics.min: %s. doubleStatistics.max: %s.", + min.getAsDouble(), + max.getAsDouble()); + } + }); + columnStatistics.getDecimalStatistics().ifPresent(decimalStatistics -> { + Optional min = decimalStatistics.getMin(); + Optional max = decimalStatistics.getMax(); + if (min.isPresent() && max.isPresent()) { + checkStatistics( + min.get().compareTo(max.get()) <= 0, + table, + partition, + column, + "decimalStatistics.min must be less than or equal to decimalStatistics.max. decimalStatistics.min: %s. decimalStatistics.max: %s.", + min.get(), + max.get()); + } + }); + columnStatistics.getDateStatistics().ifPresent(dateStatistics -> { + Optional min = dateStatistics.getMin(); + Optional max = dateStatistics.getMax(); + if (min.isPresent() && max.isPresent()) { + checkStatistics( + min.get().compareTo(max.get()) <= 0, + table, + partition, + column, + "dateStatistics.min must be less than or equal to dateStatistics.max. dateStatistics.min: %s. dateStatistics.max: %s.", + min.get(), + max.get()); + } + }); + columnStatistics.getBooleanStatistics().ifPresent(booleanStatistics -> { + OptionalLong falseCount = booleanStatistics.getFalseCount(); + OptionalLong trueCount = booleanStatistics.getTrueCount(); + falseCount.ifPresent(count -> + checkStatistics(count >= 0, table, partition, column, "falseCount must be greater than or equal to zero: %s", count)); + trueCount.ifPresent(count -> + checkStatistics(count >= 0, table, partition, column, "trueCount must be greater than or equal to zero: %s", count)); + if (rowCount.isPresent() && falseCount.isPresent()) { + checkStatistics( + falseCount.getAsLong() <= rowCount.getAsLong(), + table, + partition, + column, + "booleanStatistics.falseCount must be less than or equal to rowCount. booleanStatistics.falseCount: %s. rowCount: %s.", + falseCount.getAsLong(), + rowCount.getAsLong()); + } + if (rowCount.isPresent() && trueCount.isPresent()) { + checkStatistics( + trueCount.getAsLong() <= rowCount.getAsLong(), + table, + partition, + column, + "booleanStatistics.trueCount must be less than or equal to rowCount. booleanStatistics.trueCount: %s. rowCount: %s.", + trueCount.getAsLong(), + rowCount.getAsLong()); + } + }); + } + + private static void checkStatistics(boolean expression, SchemaTableName table, String partition, String column, String message, Object... args) + { + if (!expression) { + throw new TrinoException( + HIVE_CORRUPTED_COLUMN_STATISTICS, + format("Corrupted partition statistics (Table: %s Partition: [%s] Column: %s): %s", table, partition, column, format(message, args))); + } + } + + private static void checkStatistics(boolean expression, SchemaTableName table, String partition, String message, Object... args) + { + if (!expression) { + throw new TrinoException( + HIVE_CORRUPTED_COLUMN_STATISTICS, + format("Corrupted partition statistics (Table: %s Partition: [%s]): %s", table, partition, format(message, args))); + } + } + + private static TableStatistics getTableStatistics( + Map columns, + Map columnTypes, + List partitions, + Map statistics) + { + if (statistics.isEmpty()) { + return createEmptyTableStatisticsWithPartitionColumnStatistics(columns, columnTypes, partitions); + } + + checkArgument(!partitions.isEmpty(), "partitions is empty"); + + Optional optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size()); + if (optionalRowCount.isEmpty()) { + return createEmptyTableStatisticsWithPartitionColumnStatistics(columns, columnTypes, partitions); + } + double rowCount = optionalRowCount.get().getRowCount(); + + TableStatistics.Builder result = TableStatistics.builder(); + result.setRowCount(Estimate.of(rowCount)); + for (Map.Entry column : columns.entrySet()) { + String columnName = column.getKey(); + HiveColumnHandle columnHandle = (HiveColumnHandle) column.getValue(); + Type columnType = columnTypes.get(columnName); + ColumnStatistics columnStatistics; + if (columnHandle.isPartitionKey()) { + double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition(); + columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount); + } + else { + columnStatistics = createDataColumnStatistics(columnName, columnType, rowCount, statistics.values()); + } + result.setColumnStatistics(columnHandle, columnStatistics); + } + return result.build(); + } + + private static TableStatistics createEmptyTableStatisticsWithPartitionColumnStatistics( + Map columns, + Map columnTypes, + List partitions) + { + TableStatistics.Builder result = TableStatistics.builder(); + // Estimate stats for partitioned columns even when row count is unavailable. This will help us use + // ndv stats in rules like "ApplyPreferredTableWriterPartitioning". + for (Map.Entry column : columns.entrySet()) { + HiveColumnHandle columnHandle = (HiveColumnHandle) column.getValue(); + if (columnHandle.isPartitionKey()) { + result.setColumnStatistics( + columnHandle, + createPartitionColumnStatisticsWithoutRowCount(columnHandle, columnTypes.get(column.getKey()), partitions)); + } + } + return result.build(); + } + + @VisibleForTesting + static Optional calculatePartitionsRowCount(Collection statistics, int queriedPartitionsCount) + { + long[] rowCounts = statistics.stream() + .map(PartitionStatistics::getBasicStatistics) + .map(HiveBasicStatistics::getRowCount) + .filter(OptionalLong::isPresent) + .mapToLong(OptionalLong::getAsLong) + .peek(count -> verify(count >= 0, "count must be greater than or equal to zero")) + .toArray(); + int sampleSize = statistics.size(); + // Sample contains all the queried partitions, estimate avg normally + if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) { + OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average(); + if (averageRowsPerPartitionOptional.isEmpty()) { + return Optional.empty(); + } + double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble(); + return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount)); + } + + // Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count. + // Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the + // possibility of errors in the extrapolated rowCount due to a couple of outliers. + int minIndex = 0; + int maxIndex = 0; + long rowCountSum = rowCounts[0]; + for (int index = 1; index < rowCounts.length; index++) { + if (rowCounts[index] < rowCounts[minIndex]) { + minIndex = index; + } + else if (rowCounts[index] > rowCounts[maxIndex]) { + maxIndex = index; + } + rowCountSum += rowCounts[index]; + } + double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2); + double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex]; + return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount)); + } + + @VisibleForTesting + static class PartitionsRowCount + { + private final double averageRowsPerPartition; + private final double rowCount; + + PartitionsRowCount(double averageRowsPerPartition, double rowCount) + { + verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); + verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); + this.averageRowsPerPartition = averageRowsPerPartition; + this.rowCount = rowCount; + } + + private double getAverageRowsPerPartition() + { + return averageRowsPerPartition; + } + + private double getRowCount() + { + return rowCount; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartitionsRowCount that = (PartitionsRowCount) o; + return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0 + && Double.compare(that.rowCount, rowCount) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(averageRowsPerPartition, rowCount); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("averageRowsPerPartition", averageRowsPerPartition) + .add("rowCount", rowCount) + .toString(); + } + } + + private static ColumnStatistics createPartitionColumnStatistics( + HiveColumnHandle column, + Type type, + List partitions, + Map statistics, + double averageRowsPerPartition, + double rowCount) + { + List nonEmptyPartitions = partitions.stream() + .filter(partition -> getPartitionRowCount(partition.getPartitionId(), statistics).orElse(averageRowsPerPartition) != 0) + .collect(toImmutableList()); + + return ColumnStatistics.builder() + .setDistinctValuesCount(Estimate.of(calculateDistinctPartitionKeys(column, nonEmptyPartitions))) + .setNullsFraction(Estimate.of(calculateNullsFractionForPartitioningKey(column, partitions, statistics, averageRowsPerPartition, rowCount))) + .setRange(calculateRangeForPartitioningKey(column, type, nonEmptyPartitions)) + .setDataSize(calculateDataSizeForPartitioningKey(column, type, partitions, statistics, averageRowsPerPartition)) + .build(); + } + + private static ColumnStatistics createPartitionColumnStatisticsWithoutRowCount(HiveColumnHandle column, Type type, List partitions) + { + if (partitions.isEmpty()) { + return ColumnStatistics.empty(); + } + + // Since we don't know the row count for each partition, we are taking an assumption here that all partitions + // are non-empty and contains exactly same amount of data. This will help us estimate ndv stats for partitioned + // columns which can be useful for certain optimizer rules. + double estimatedNullsCount = partitions.stream() + .filter(partition -> partition.getKeys().get(column).isNull()) + .count(); + + return ColumnStatistics.builder() + .setDistinctValuesCount(Estimate.of(calculateDistinctPartitionKeys(column, partitions))) + .setNullsFraction(Estimate.of(normalizeFraction(estimatedNullsCount / partitions.size()))) + .setRange(calculateRangeForPartitioningKey(column, type, partitions)) + .build(); + } + + @VisibleForTesting + static long calculateDistinctPartitionKeys( + HiveColumnHandle column, + List partitions) + { + return partitions.stream() + .map(partition -> partition.getKeys().get(column)) + .filter(value -> !value.isNull()) + .distinct() + .count(); + } + + @VisibleForTesting + static double calculateNullsFractionForPartitioningKey( + HiveColumnHandle column, + List partitions, + Map statistics, + double averageRowsPerPartition, + double rowCount) + { + if (rowCount == 0) { + return 0; + } + double estimatedNullsCount = partitions.stream() + .filter(partition -> partition.getKeys().get(column).isNull()) + .map(HivePartition::getPartitionId) + .mapToDouble(partitionName -> getPartitionRowCount(partitionName, statistics).orElse(averageRowsPerPartition)) + .sum(); + return normalizeFraction(estimatedNullsCount / rowCount); + } + + private static double normalizeFraction(double fraction) + { + checkArgument(!isNaN(fraction), "fraction is NaN"); + checkArgument(isFinite(fraction), "fraction must be finite"); + if (fraction < 0) { + return 0; + } + if (fraction > 1) { + return 1; + } + return fraction; + } + + @VisibleForTesting + static Estimate calculateDataSizeForPartitioningKey( + HiveColumnHandle column, + Type type, + List partitions, + Map statistics, + double averageRowsPerPartition) + { + if (!hasDataSize(type)) { + return Estimate.unknown(); + } + double dataSize = 0; + for (HivePartition partition : partitions) { + int length = getSize(partition.getKeys().get(column)); + double rowCount = getPartitionRowCount(partition.getPartitionId(), statistics).orElse(averageRowsPerPartition); + dataSize += length * rowCount; + } + return Estimate.of(dataSize); + } + + private static boolean hasDataSize(Type type) + { + return type instanceof VarcharType || type instanceof CharType; + } + + private static int getSize(NullableValue nullableValue) + { + if (nullableValue.isNull()) { + return 0; + } + Object value = nullableValue.getValue(); + checkArgument(value instanceof Slice, "value is expected to be of Slice type"); + return ((Slice) value).length(); + } + + private static OptionalDouble getPartitionRowCount(String partitionName, Map statistics) + { + PartitionStatistics partitionStatistics = statistics.get(partitionName); + if (partitionStatistics == null) { + return OptionalDouble.empty(); + } + OptionalLong rowCount = partitionStatistics.getBasicStatistics().getRowCount(); + if (rowCount.isPresent()) { + verify(rowCount.getAsLong() >= 0, "rowCount must be greater than or equal to zero"); + return OptionalDouble.of(rowCount.getAsLong()); + } + return OptionalDouble.empty(); + } + + @VisibleForTesting + static Optional calculateRangeForPartitioningKey(HiveColumnHandle column, Type type, List partitions) + { + List convertedValues = partitions.stream() + .map(HivePartition::getKeys) + .map(keys -> keys.get(column)) + .filter(value -> !value.isNull()) + .map(NullableValue::getValue) + .map(value -> convertPartitionValueToDouble(type, value)) + .collect(toImmutableList()); + + if (convertedValues.stream().noneMatch(OptionalDouble::isPresent)) { + return Optional.empty(); + } + double[] values = convertedValues.stream() + .peek(convertedValue -> checkState(convertedValue.isPresent(), "convertedValue is missing")) + .mapToDouble(OptionalDouble::getAsDouble) + .toArray(); + verify(values.length != 0, "No values"); + + if (DoubleStream.of(values).anyMatch(Double::isNaN)) { + return Optional.empty(); + } + + double min = DoubleStream.of(values).min().orElseThrow(); + double max = DoubleStream.of(values).max().orElseThrow(); + return Optional.of(new DoubleRange(min, max)); + } + + @VisibleForTesting + static OptionalDouble convertPartitionValueToDouble(Type type, Object value) + { + return toStatsRepresentation(type, value); + } + + @VisibleForTesting + static ColumnStatistics createDataColumnStatistics(String column, Type type, double rowsCount, Collection partitionStatistics) + { + List columnStatistics = partitionStatistics.stream() + .map(PartitionStatistics::getColumnStatistics) + .map(statistics -> statistics.get(column)) + .filter(Objects::nonNull) + .collect(toImmutableList()); + + if (columnStatistics.isEmpty()) { + return ColumnStatistics.empty(); + } + + return ColumnStatistics.builder() + .setDistinctValuesCount(calculateDistinctValuesCount(columnStatistics)) + .setNullsFraction(calculateNullsFraction(column, partitionStatistics)) + .setDataSize(calculateDataSize(column, partitionStatistics, rowsCount)) + .setRange(calculateRange(type, columnStatistics)) + .build(); + } + + @VisibleForTesting + static Estimate calculateDistinctValuesCount(List columnStatistics) + { + return columnStatistics.stream() + .map(AbstractHiveStatisticsProvider::getDistinctValuesCount) + .filter(OptionalLong::isPresent) + .map(OptionalLong::getAsLong) + .peek(distinctValuesCount -> verify(distinctValuesCount >= 0, "distinctValuesCount must be greater than or equal to zero")) + .max(Long::compare) + .map(Estimate::of) + .orElse(Estimate.unknown()); + } + + private static OptionalLong getDistinctValuesCount(HiveColumnStatistics statistics) + { + if (statistics.getBooleanStatistics().isPresent() && + statistics.getBooleanStatistics().get().getFalseCount().isPresent() && + statistics.getBooleanStatistics().get().getTrueCount().isPresent()) { + long falseCount = statistics.getBooleanStatistics().get().getFalseCount().getAsLong(); + long trueCount = statistics.getBooleanStatistics().get().getTrueCount().getAsLong(); + return OptionalLong.of((falseCount > 0 ? 1 : 0) + (trueCount > 0 ? 1 : 0)); + } + if (statistics.getDistinctValuesCount().isPresent()) { + return statistics.getDistinctValuesCount(); + } + return OptionalLong.empty(); + } + + @VisibleForTesting + static Estimate calculateNullsFraction(String column, Collection partitionStatistics) + { + List statisticsWithKnownRowCountAndNullsCount = partitionStatistics.stream() + .filter(statistics -> { + if (statistics.getBasicStatistics().getRowCount().isEmpty()) { + return false; + } + HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); + if (columnStatistics == null) { + return false; + } + return columnStatistics.getNullsCount().isPresent(); + }) + .collect(toImmutableList()); + + if (statisticsWithKnownRowCountAndNullsCount.isEmpty()) { + return Estimate.unknown(); + } + + long totalNullsCount = 0; + long totalRowCount = 0; + for (PartitionStatistics statistics : statisticsWithKnownRowCountAndNullsCount) { + long rowCount = statistics.getBasicStatistics().getRowCount().orElseThrow(() -> new VerifyException("rowCount is not present")); + verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); + HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); + verifyNotNull(columnStatistics, "columnStatistics is null"); + long nullsCount = columnStatistics.getNullsCount().orElseThrow(() -> new VerifyException("nullsCount is not present")); + verify(nullsCount >= 0, "nullsCount must be greater than or equal to zero"); + verify(nullsCount <= rowCount, "nullsCount must be less than or equal to rowCount. nullsCount: %s. rowCount: %s.", nullsCount, rowCount); + totalNullsCount += nullsCount; + totalRowCount += rowCount; + } + + if (totalRowCount == 0) { + return Estimate.zero(); + } + + verify( + totalNullsCount <= totalRowCount, + "totalNullsCount must be less than or equal to totalRowCount. totalNullsCount: %s. totalRowCount: %s.", + totalNullsCount, + totalRowCount); + return Estimate.of(((double) totalNullsCount) / totalRowCount); + } + + @VisibleForTesting + static Estimate calculateDataSize(String column, Collection partitionStatistics, double totalRowCount) + { + List statisticsWithKnownRowCountAndDataSize = partitionStatistics.stream() + .filter(statistics -> { + if (statistics.getBasicStatistics().getRowCount().isEmpty()) { + return false; + } + HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); + if (columnStatistics == null) { + return false; + } + return columnStatistics.getTotalSizeInBytes().isPresent(); + }) + .collect(toImmutableList()); + + if (statisticsWithKnownRowCountAndDataSize.isEmpty()) { + return Estimate.unknown(); + } + + long knownRowCount = 0; + long knownDataSize = 0; + for (PartitionStatistics statistics : statisticsWithKnownRowCountAndDataSize) { + long rowCount = statistics.getBasicStatistics().getRowCount().orElseThrow(() -> new VerifyException("rowCount is not present")); + verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); + HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); + verifyNotNull(columnStatistics, "columnStatistics is null"); + long dataSize = columnStatistics.getTotalSizeInBytes().orElseThrow(() -> new VerifyException("totalSizeInBytes is not present")); + verify(dataSize >= 0, "dataSize must be greater than or equal to zero"); + knownRowCount += rowCount; + knownDataSize += dataSize; + } + + if (totalRowCount == 0) { + return Estimate.zero(); + } + + if (knownRowCount == 0) { + return Estimate.unknown(); + } + + double averageValueDataSizeInBytes = ((double) knownDataSize) / knownRowCount; + return Estimate.of(averageValueDataSizeInBytes * totalRowCount); + } + + @VisibleForTesting + static Optional calculateRange(Type type, List columnStatistics) + { + return columnStatistics.stream() + .map(statistics -> createRange(type, statistics)) + .filter(Optional::isPresent) + .map(Optional::get) + .reduce(DoubleRange::union); + } + + private static Optional createRange(Type type, HiveColumnStatistics statistics) + { + if (type.equals(BIGINT) || type.equals(INTEGER) || type.equals(SMALLINT) || type.equals(TINYINT)) { + return statistics.getIntegerStatistics().flatMap(integerStatistics -> createIntegerRange(type, integerStatistics)); + } + if (type.equals(DOUBLE) || type.equals(REAL)) { + return statistics.getDoubleStatistics().flatMap(AbstractHiveStatisticsProvider::createDoubleRange); + } + if (type.equals(DATE)) { + return statistics.getDateStatistics().flatMap(AbstractHiveStatisticsProvider::createDateRange); + } + if (type instanceof DecimalType) { + return statistics.getDecimalStatistics().flatMap(AbstractHiveStatisticsProvider::createDecimalRange); + } + return Optional.empty(); + } + + private static Optional createIntegerRange(Type type, IntegerStatistics statistics) + { + if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { + return Optional.of(createIntegerRange(type, statistics.getMin().getAsLong(), statistics.getMax().getAsLong())); + } + return Optional.empty(); + } + + private static DoubleRange createIntegerRange(Type type, long min, long max) + { + return new DoubleRange(normalizeIntegerValue(type, min), normalizeIntegerValue(type, max)); + } + + private static long normalizeIntegerValue(Type type, long value) + { + if (type.equals(BIGINT)) { + return value; + } + if (type.equals(INTEGER)) { + return Ints.saturatedCast(value); + } + if (type.equals(SMALLINT)) { + return Shorts.saturatedCast(value); + } + if (type.equals(TINYINT)) { + return SignedBytes.saturatedCast(value); + } + throw new IllegalArgumentException("Unexpected type: " + type); + } + + private static Optional createDoubleRange(DoubleStatistics statistics) + { + if (statistics.getMin().isPresent() && statistics.getMax().isPresent() && !isNaN(statistics.getMin().getAsDouble()) && !isNaN(statistics.getMax().getAsDouble())) { + return Optional.of(new DoubleRange(statistics.getMin().getAsDouble(), statistics.getMax().getAsDouble())); + } + return Optional.empty(); + } + + private static Optional createDateRange(DateStatistics statistics) + { + if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { + return Optional.of(new DoubleRange(statistics.getMin().get().toEpochDay(), statistics.getMax().get().toEpochDay())); + } + return Optional.empty(); + } + + private static Optional createDecimalRange(DecimalStatistics statistics) + { + if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { + return Optional.of(new DoubleRange(statistics.getMin().get().doubleValue(), statistics.getMax().get().doubleValue())); + } + return Optional.empty(); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java index ded2ad1a766c..3d79eed2a4d8 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java @@ -14,103 +14,40 @@ package io.trino.plugin.hive.statistics; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableMap; -import com.google.common.hash.HashFunction; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Shorts; -import com.google.common.primitives.SignedBytes; -import io.airlift.log.Logger; -import io.airlift.slice.Slice; -import io.trino.plugin.hive.HiveBasicStatistics; -import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HivePartition; import io.trino.plugin.hive.PartitionStatistics; -import io.trino.plugin.hive.metastore.DateStatistics; -import io.trino.plugin.hive.metastore.DecimalStatistics; -import io.trino.plugin.hive.metastore.DoubleStatistics; -import io.trino.plugin.hive.metastore.HiveColumnStatistics; -import io.trino.plugin.hive.metastore.IntegerStatistics; import io.trino.plugin.hive.metastore.SemiTransactionalHiveMetastore; -import io.trino.spi.TrinoException; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SchemaTableName; -import io.trino.spi.predicate.NullableValue; -import io.trino.spi.statistics.ColumnStatistics; -import io.trino.spi.statistics.DoubleRange; -import io.trino.spi.statistics.Estimate; -import io.trino.spi.statistics.TableStatistics; -import io.trino.spi.type.CharType; -import io.trino.spi.type.DecimalType; -import io.trino.spi.type.Type; -import io.trino.spi.type.VarcharType; -import java.math.BigDecimal; -import java.time.LocalDate; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Comparator; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; -import java.util.OptionalDouble; -import java.util.OptionalLong; import java.util.Set; -import java.util.stream.DoubleStream; -import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static com.google.common.base.Verify.verifyNotNull; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Maps.immutableEntry; -import static com.google.common.hash.Hashing.murmur3_128; -import static io.trino.plugin.hive.HiveErrorCode.HIVE_CORRUPTED_COLUMN_STATISTICS; import static io.trino.plugin.hive.HivePartition.UNPARTITIONED_ID; -import static io.trino.plugin.hive.HiveSessionProperties.getPartitionStatisticsSampleSize; -import static io.trino.plugin.hive.HiveSessionProperties.isIgnoreCorruptedStatistics; import static io.trino.plugin.hive.HiveSessionProperties.isStatisticsEnabled; -import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DateType.DATE; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.RealType.REAL; -import static io.trino.spi.type.SmallintType.SMALLINT; -import static io.trino.spi.type.TinyintType.TINYINT; -import static java.lang.Double.isFinite; -import static java.lang.Double.isNaN; -import static java.lang.String.format; -import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; public class MetastoreHiveStatisticsProvider - implements HiveStatisticsProvider + extends AbstractHiveStatisticsProvider { - private static final Logger log = Logger.get(MetastoreHiveStatisticsProvider.class); - - private final PartitionsStatisticsProvider statisticsProvider; + private final SemiTransactionalHiveMetastore metastore; public MetastoreHiveStatisticsProvider(SemiTransactionalHiveMetastore metastore) { - requireNonNull(metastore, "metastore is null"); - this.statisticsProvider = (session, table, hivePartitions, columns) -> getPartitionsStatistics(metastore, table, columns, hivePartitions); - } - - @VisibleForTesting - MetastoreHiveStatisticsProvider(PartitionsStatisticsProvider statisticsProvider) - { - this.statisticsProvider = requireNonNull(statisticsProvider, "statisticsProvider is null"); + this.metastore = requireNonNull(metastore, "metastore is null"); } - private static Map getPartitionsStatistics(SemiTransactionalHiveMetastore metastore, SchemaTableName table, Set columns, List hivePartitions) + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) { + if (!isStatisticsEnabled(session)) { + return ImmutableMap.of(); + } if (hivePartitions.isEmpty()) { return ImmutableMap.of(); } @@ -124,807 +61,4 @@ private static Map getPartitionsStatistics(SemiTran .collect(toImmutableSet()); return metastore.getPartitionStatistics(table.getSchemaName(), table.getTableName(), columns, partitionNames); } - - @Override - public TableStatistics getTableStatistics( - ConnectorSession session, - SchemaTableName table, - Map columns, - Map columnTypes, - List partitions) - { - if (!isStatisticsEnabled(session)) { - return TableStatistics.empty(); - } - if (partitions.isEmpty()) { - return createZeroStatistics(columns, columnTypes); - } - int sampleSize = getPartitionStatisticsSampleSize(session); - List partitionsSample = getPartitionsSample(partitions, sampleSize); - try { - Map statisticsSample = statisticsProvider.getPartitionsStatistics(session, table, partitionsSample, columns.keySet()); - validatePartitionStatistics(table, statisticsSample); - return getTableStatistics(columns, columnTypes, partitions, statisticsSample); - } - catch (TrinoException e) { - if (e.getErrorCode().equals(HIVE_CORRUPTED_COLUMN_STATISTICS.toErrorCode()) && isIgnoreCorruptedStatistics(session)) { - log.error(e); - return TableStatistics.empty(); - } - throw e; - } - } - - private TableStatistics createZeroStatistics(Map columns, Map columnTypes) - { - TableStatistics.Builder result = TableStatistics.builder(); - result.setRowCount(Estimate.of(0)); - columns.forEach((columnName, columnHandle) -> { - Type columnType = columnTypes.get(columnName); - verifyNotNull(columnType, "columnType is missing for column: %s", columnName); - ColumnStatistics.Builder columnStatistics = ColumnStatistics.builder(); - columnStatistics.setNullsFraction(Estimate.of(0)); - columnStatistics.setDistinctValuesCount(Estimate.of(0)); - if (hasDataSize(columnType)) { - columnStatistics.setDataSize(Estimate.of(0)); - } - result.setColumnStatistics(columnHandle, columnStatistics.build()); - }); - return result.build(); - } - - @VisibleForTesting - static List getPartitionsSample(List partitions, int sampleSize) - { - checkArgument(sampleSize > 0, "sampleSize is expected to be greater than zero"); - - if (partitions.size() <= sampleSize) { - return partitions; - } - - List result = new ArrayList<>(); - - int samplesLeft = sampleSize; - - HivePartition min = partitions.get(0); - HivePartition max = partitions.get(0); - for (HivePartition partition : partitions) { - if (partition.getPartitionId().compareTo(min.getPartitionId()) < 0) { - min = partition; - } - else if (partition.getPartitionId().compareTo(max.getPartitionId()) > 0) { - max = partition; - } - } - - result.add(min); - samplesLeft--; - if (samplesLeft > 0) { - result.add(max); - samplesLeft--; - } - - if (samplesLeft > 0) { - HashFunction hashFunction = murmur3_128(); - Comparator> hashComparator = Comparator - ., Long>comparing(Map.Entry::getValue) - .thenComparing(entry -> entry.getKey().getPartitionId()); - partitions.stream() - .filter(partition -> !result.contains(partition)) - .map(partition -> immutableEntry(partition, hashFunction.hashUnencodedChars(partition.getPartitionId()).asLong())) - .sorted(hashComparator) - .limit(samplesLeft) - .forEachOrdered(entry -> result.add(entry.getKey())); - } - - return unmodifiableList(result); - } - - @VisibleForTesting - static void validatePartitionStatistics(SchemaTableName table, Map partitionStatistics) - { - partitionStatistics.forEach((partition, statistics) -> { - HiveBasicStatistics basicStatistics = statistics.getBasicStatistics(); - OptionalLong rowCount = basicStatistics.getRowCount(); - rowCount.ifPresent(count -> checkStatistics(count >= 0, table, partition, "rowCount must be greater than or equal to zero: %s", count)); - basicStatistics.getFileCount().ifPresent(count -> checkStatistics(count >= 0, table, partition, "fileCount must be greater than or equal to zero: %s", count)); - basicStatistics.getInMemoryDataSizeInBytes().ifPresent(size -> checkStatistics(size >= 0, table, partition, "inMemoryDataSizeInBytes must be greater than or equal to zero: %s", size)); - basicStatistics.getOnDiskDataSizeInBytes().ifPresent(size -> checkStatistics(size >= 0, table, partition, "onDiskDataSizeInBytes must be greater than or equal to zero: %s", size)); - statistics.getColumnStatistics().forEach((column, columnStatistics) -> validateColumnStatistics(table, partition, column, rowCount, columnStatistics)); - }); - } - - private static void validateColumnStatistics(SchemaTableName table, String partition, String column, OptionalLong rowCount, HiveColumnStatistics columnStatistics) - { - columnStatistics.getMaxValueSizeInBytes().ifPresent(maxValueSizeInBytes -> - checkStatistics(maxValueSizeInBytes >= 0, table, partition, column, "maxValueSizeInBytes must be greater than or equal to zero: %s", maxValueSizeInBytes)); - columnStatistics.getTotalSizeInBytes().ifPresent(totalSizeInBytes -> - checkStatistics(totalSizeInBytes >= 0, table, partition, column, "totalSizeInBytes must be greater than or equal to zero: %s", totalSizeInBytes)); - columnStatistics.getNullsCount().ifPresent(nullsCount -> { - checkStatistics(nullsCount >= 0, table, partition, column, "nullsCount must be greater than or equal to zero: %s", nullsCount); - if (rowCount.isPresent()) { - checkStatistics( - nullsCount <= rowCount.getAsLong(), - table, - partition, - column, - "nullsCount must be less than or equal to rowCount. nullsCount: %s. rowCount: %s.", - nullsCount, - rowCount.getAsLong()); - } - }); - columnStatistics.getDistinctValuesCount().ifPresent(distinctValuesCount -> { - checkStatistics(distinctValuesCount >= 0, table, partition, column, "distinctValuesCount must be greater than or equal to zero: %s", distinctValuesCount); - if (rowCount.isPresent()) { - checkStatistics( - distinctValuesCount <= rowCount.getAsLong(), - table, - partition, - column, - "distinctValuesCount must be less than or equal to rowCount. distinctValuesCount: %s. rowCount: %s.", - distinctValuesCount, - rowCount.getAsLong()); - } - if (rowCount.isPresent() && columnStatistics.getNullsCount().isPresent()) { - long nonNullsCount = rowCount.getAsLong() - columnStatistics.getNullsCount().getAsLong(); - checkStatistics( - distinctValuesCount <= nonNullsCount, - table, - partition, - column, - "distinctValuesCount must be less than or equal to nonNullsCount. distinctValuesCount: %s. nonNullsCount: %s.", - distinctValuesCount, - nonNullsCount); - } - }); - - columnStatistics.getIntegerStatistics().ifPresent(integerStatistics -> { - OptionalLong min = integerStatistics.getMin(); - OptionalLong max = integerStatistics.getMax(); - if (min.isPresent() && max.isPresent()) { - checkStatistics( - min.getAsLong() <= max.getAsLong(), - table, - partition, - column, - "integerStatistics.min must be less than or equal to integerStatistics.max. integerStatistics.min: %s. integerStatistics.max: %s.", - min.getAsLong(), - max.getAsLong()); - } - }); - columnStatistics.getDoubleStatistics().ifPresent(doubleStatistics -> { - OptionalDouble min = doubleStatistics.getMin(); - OptionalDouble max = doubleStatistics.getMax(); - if (min.isPresent() && max.isPresent() && !isNaN(min.getAsDouble()) && !isNaN(max.getAsDouble())) { - checkStatistics( - min.getAsDouble() <= max.getAsDouble(), - table, - partition, - column, - "doubleStatistics.min must be less than or equal to doubleStatistics.max. doubleStatistics.min: %s. doubleStatistics.max: %s.", - min.getAsDouble(), - max.getAsDouble()); - } - }); - columnStatistics.getDecimalStatistics().ifPresent(decimalStatistics -> { - Optional min = decimalStatistics.getMin(); - Optional max = decimalStatistics.getMax(); - if (min.isPresent() && max.isPresent()) { - checkStatistics( - min.get().compareTo(max.get()) <= 0, - table, - partition, - column, - "decimalStatistics.min must be less than or equal to decimalStatistics.max. decimalStatistics.min: %s. decimalStatistics.max: %s.", - min.get(), - max.get()); - } - }); - columnStatistics.getDateStatistics().ifPresent(dateStatistics -> { - Optional min = dateStatistics.getMin(); - Optional max = dateStatistics.getMax(); - if (min.isPresent() && max.isPresent()) { - checkStatistics( - min.get().compareTo(max.get()) <= 0, - table, - partition, - column, - "dateStatistics.min must be less than or equal to dateStatistics.max. dateStatistics.min: %s. dateStatistics.max: %s.", - min.get(), - max.get()); - } - }); - columnStatistics.getBooleanStatistics().ifPresent(booleanStatistics -> { - OptionalLong falseCount = booleanStatistics.getFalseCount(); - OptionalLong trueCount = booleanStatistics.getTrueCount(); - falseCount.ifPresent(count -> - checkStatistics(count >= 0, table, partition, column, "falseCount must be greater than or equal to zero: %s", count)); - trueCount.ifPresent(count -> - checkStatistics(count >= 0, table, partition, column, "trueCount must be greater than or equal to zero: %s", count)); - if (rowCount.isPresent() && falseCount.isPresent()) { - checkStatistics( - falseCount.getAsLong() <= rowCount.getAsLong(), - table, - partition, - column, - "booleanStatistics.falseCount must be less than or equal to rowCount. booleanStatistics.falseCount: %s. rowCount: %s.", - falseCount.getAsLong(), - rowCount.getAsLong()); - } - if (rowCount.isPresent() && trueCount.isPresent()) { - checkStatistics( - trueCount.getAsLong() <= rowCount.getAsLong(), - table, - partition, - column, - "booleanStatistics.trueCount must be less than or equal to rowCount. booleanStatistics.trueCount: %s. rowCount: %s.", - trueCount.getAsLong(), - rowCount.getAsLong()); - } - }); - } - - private static void checkStatistics(boolean expression, SchemaTableName table, String partition, String column, String message, Object... args) - { - if (!expression) { - throw new TrinoException( - HIVE_CORRUPTED_COLUMN_STATISTICS, - format("Corrupted partition statistics (Table: %s Partition: [%s] Column: %s): %s", table, partition, column, format(message, args))); - } - } - - private static void checkStatistics(boolean expression, SchemaTableName table, String partition, String message, Object... args) - { - if (!expression) { - throw new TrinoException( - HIVE_CORRUPTED_COLUMN_STATISTICS, - format("Corrupted partition statistics (Table: %s Partition: [%s]): %s", table, partition, format(message, args))); - } - } - - private static TableStatistics getTableStatistics( - Map columns, - Map columnTypes, - List partitions, - Map statistics) - { - if (statistics.isEmpty()) { - return createEmptyTableStatisticsWithPartitionColumnStatistics(columns, columnTypes, partitions); - } - - checkArgument(!partitions.isEmpty(), "partitions is empty"); - - Optional optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size()); - if (optionalRowCount.isEmpty()) { - return createEmptyTableStatisticsWithPartitionColumnStatistics(columns, columnTypes, partitions); - } - double rowCount = optionalRowCount.get().getRowCount(); - - TableStatistics.Builder result = TableStatistics.builder(); - result.setRowCount(Estimate.of(rowCount)); - for (Map.Entry column : columns.entrySet()) { - String columnName = column.getKey(); - HiveColumnHandle columnHandle = (HiveColumnHandle) column.getValue(); - Type columnType = columnTypes.get(columnName); - ColumnStatistics columnStatistics; - if (columnHandle.isPartitionKey()) { - double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition(); - columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount); - } - else { - columnStatistics = createDataColumnStatistics(columnName, columnType, rowCount, statistics.values()); - } - result.setColumnStatistics(columnHandle, columnStatistics); - } - return result.build(); - } - - private static TableStatistics createEmptyTableStatisticsWithPartitionColumnStatistics( - Map columns, - Map columnTypes, - List partitions) - { - TableStatistics.Builder result = TableStatistics.builder(); - // Estimate stats for partitioned columns even when row count is unavailable. This will help us use - // ndv stats in rules like "ApplyPreferredTableWriterPartitioning". - for (Map.Entry column : columns.entrySet()) { - HiveColumnHandle columnHandle = (HiveColumnHandle) column.getValue(); - if (columnHandle.isPartitionKey()) { - result.setColumnStatistics( - columnHandle, - createPartitionColumnStatisticsWithoutRowCount(columnHandle, columnTypes.get(column.getKey()), partitions)); - } - } - return result.build(); - } - - @VisibleForTesting - static Optional calculatePartitionsRowCount(Collection statistics, int queriedPartitionsCount) - { - long[] rowCounts = statistics.stream() - .map(PartitionStatistics::getBasicStatistics) - .map(HiveBasicStatistics::getRowCount) - .filter(OptionalLong::isPresent) - .mapToLong(OptionalLong::getAsLong) - .peek(count -> verify(count >= 0, "count must be greater than or equal to zero")) - .toArray(); - int sampleSize = statistics.size(); - // Sample contains all the queried partitions, estimate avg normally - if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) { - OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average(); - if (averageRowsPerPartitionOptional.isEmpty()) { - return Optional.empty(); - } - double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble(); - return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount)); - } - - // Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count. - // Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the - // possibility of errors in the extrapolated rowCount due to a couple of outliers. - int minIndex = 0; - int maxIndex = 0; - long rowCountSum = rowCounts[0]; - for (int index = 1; index < rowCounts.length; index++) { - if (rowCounts[index] < rowCounts[minIndex]) { - minIndex = index; - } - else if (rowCounts[index] > rowCounts[maxIndex]) { - maxIndex = index; - } - rowCountSum += rowCounts[index]; - } - double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2); - double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex]; - return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount)); - } - - @VisibleForTesting - static class PartitionsRowCount - { - private final double averageRowsPerPartition; - private final double rowCount; - - PartitionsRowCount(double averageRowsPerPartition, double rowCount) - { - verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); - verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); - this.averageRowsPerPartition = averageRowsPerPartition; - this.rowCount = rowCount; - } - - private double getAverageRowsPerPartition() - { - return averageRowsPerPartition; - } - - private double getRowCount() - { - return rowCount; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - PartitionsRowCount that = (PartitionsRowCount) o; - return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0 - && Double.compare(that.rowCount, rowCount) == 0; - } - - @Override - public int hashCode() - { - return Objects.hash(averageRowsPerPartition, rowCount); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("averageRowsPerPartition", averageRowsPerPartition) - .add("rowCount", rowCount) - .toString(); - } - } - - private static ColumnStatistics createPartitionColumnStatistics( - HiveColumnHandle column, - Type type, - List partitions, - Map statistics, - double averageRowsPerPartition, - double rowCount) - { - List nonEmptyPartitions = partitions.stream() - .filter(partition -> getPartitionRowCount(partition.getPartitionId(), statistics).orElse(averageRowsPerPartition) != 0) - .collect(toImmutableList()); - - return ColumnStatistics.builder() - .setDistinctValuesCount(Estimate.of(calculateDistinctPartitionKeys(column, nonEmptyPartitions))) - .setNullsFraction(Estimate.of(calculateNullsFractionForPartitioningKey(column, partitions, statistics, averageRowsPerPartition, rowCount))) - .setRange(calculateRangeForPartitioningKey(column, type, nonEmptyPartitions)) - .setDataSize(calculateDataSizeForPartitioningKey(column, type, partitions, statistics, averageRowsPerPartition)) - .build(); - } - - private static ColumnStatistics createPartitionColumnStatisticsWithoutRowCount(HiveColumnHandle column, Type type, List partitions) - { - if (partitions.isEmpty()) { - return ColumnStatistics.empty(); - } - - // Since we don't know the row count for each partition, we are taking an assumption here that all partitions - // are non-empty and contains exactly same amount of data. This will help us estimate ndv stats for partitioned - // columns which can be useful for certain optimizer rules. - double estimatedNullsCount = partitions.stream() - .filter(partition -> partition.getKeys().get(column).isNull()) - .count(); - - return ColumnStatistics.builder() - .setDistinctValuesCount(Estimate.of(calculateDistinctPartitionKeys(column, partitions))) - .setNullsFraction(Estimate.of(normalizeFraction(estimatedNullsCount / partitions.size()))) - .setRange(calculateRangeForPartitioningKey(column, type, partitions)) - .build(); - } - - @VisibleForTesting - static long calculateDistinctPartitionKeys( - HiveColumnHandle column, - List partitions) - { - return partitions.stream() - .map(partition -> partition.getKeys().get(column)) - .filter(value -> !value.isNull()) - .distinct() - .count(); - } - - @VisibleForTesting - static double calculateNullsFractionForPartitioningKey( - HiveColumnHandle column, - List partitions, - Map statistics, - double averageRowsPerPartition, - double rowCount) - { - if (rowCount == 0) { - return 0; - } - double estimatedNullsCount = partitions.stream() - .filter(partition -> partition.getKeys().get(column).isNull()) - .map(HivePartition::getPartitionId) - .mapToDouble(partitionName -> getPartitionRowCount(partitionName, statistics).orElse(averageRowsPerPartition)) - .sum(); - return normalizeFraction(estimatedNullsCount / rowCount); - } - - private static double normalizeFraction(double fraction) - { - checkArgument(!isNaN(fraction), "fraction is NaN"); - checkArgument(isFinite(fraction), "fraction must be finite"); - if (fraction < 0) { - return 0; - } - if (fraction > 1) { - return 1; - } - return fraction; - } - - @VisibleForTesting - static Estimate calculateDataSizeForPartitioningKey( - HiveColumnHandle column, - Type type, - List partitions, - Map statistics, - double averageRowsPerPartition) - { - if (!hasDataSize(type)) { - return Estimate.unknown(); - } - double dataSize = 0; - for (HivePartition partition : partitions) { - int length = getSize(partition.getKeys().get(column)); - double rowCount = getPartitionRowCount(partition.getPartitionId(), statistics).orElse(averageRowsPerPartition); - dataSize += length * rowCount; - } - return Estimate.of(dataSize); - } - - private static boolean hasDataSize(Type type) - { - return type instanceof VarcharType || type instanceof CharType; - } - - private static int getSize(NullableValue nullableValue) - { - if (nullableValue.isNull()) { - return 0; - } - Object value = nullableValue.getValue(); - checkArgument(value instanceof Slice, "value is expected to be of Slice type"); - return ((Slice) value).length(); - } - - private static OptionalDouble getPartitionRowCount(String partitionName, Map statistics) - { - PartitionStatistics partitionStatistics = statistics.get(partitionName); - if (partitionStatistics == null) { - return OptionalDouble.empty(); - } - OptionalLong rowCount = partitionStatistics.getBasicStatistics().getRowCount(); - if (rowCount.isPresent()) { - verify(rowCount.getAsLong() >= 0, "rowCount must be greater than or equal to zero"); - return OptionalDouble.of(rowCount.getAsLong()); - } - return OptionalDouble.empty(); - } - - @VisibleForTesting - static Optional calculateRangeForPartitioningKey(HiveColumnHandle column, Type type, List partitions) - { - List convertedValues = partitions.stream() - .map(HivePartition::getKeys) - .map(keys -> keys.get(column)) - .filter(value -> !value.isNull()) - .map(NullableValue::getValue) - .map(value -> convertPartitionValueToDouble(type, value)) - .collect(toImmutableList()); - - if (convertedValues.stream().noneMatch(OptionalDouble::isPresent)) { - return Optional.empty(); - } - double[] values = convertedValues.stream() - .peek(convertedValue -> checkState(convertedValue.isPresent(), "convertedValue is missing")) - .mapToDouble(OptionalDouble::getAsDouble) - .toArray(); - verify(values.length != 0, "No values"); - - if (DoubleStream.of(values).anyMatch(Double::isNaN)) { - return Optional.empty(); - } - - double min = DoubleStream.of(values).min().orElseThrow(); - double max = DoubleStream.of(values).max().orElseThrow(); - return Optional.of(new DoubleRange(min, max)); - } - - @VisibleForTesting - static OptionalDouble convertPartitionValueToDouble(Type type, Object value) - { - return toStatsRepresentation(type, value); - } - - @VisibleForTesting - static ColumnStatistics createDataColumnStatistics(String column, Type type, double rowsCount, Collection partitionStatistics) - { - List columnStatistics = partitionStatistics.stream() - .map(PartitionStatistics::getColumnStatistics) - .map(statistics -> statistics.get(column)) - .filter(Objects::nonNull) - .collect(toImmutableList()); - - if (columnStatistics.isEmpty()) { - return ColumnStatistics.empty(); - } - - return ColumnStatistics.builder() - .setDistinctValuesCount(calculateDistinctValuesCount(columnStatistics)) - .setNullsFraction(calculateNullsFraction(column, partitionStatistics)) - .setDataSize(calculateDataSize(column, partitionStatistics, rowsCount)) - .setRange(calculateRange(type, columnStatistics)) - .build(); - } - - @VisibleForTesting - static Estimate calculateDistinctValuesCount(List columnStatistics) - { - return columnStatistics.stream() - .map(MetastoreHiveStatisticsProvider::getDistinctValuesCount) - .filter(OptionalLong::isPresent) - .map(OptionalLong::getAsLong) - .peek(distinctValuesCount -> verify(distinctValuesCount >= 0, "distinctValuesCount must be greater than or equal to zero")) - .max(Long::compare) - .map(Estimate::of) - .orElse(Estimate.unknown()); - } - - private static OptionalLong getDistinctValuesCount(HiveColumnStatistics statistics) - { - if (statistics.getBooleanStatistics().isPresent() && - statistics.getBooleanStatistics().get().getFalseCount().isPresent() && - statistics.getBooleanStatistics().get().getTrueCount().isPresent()) { - long falseCount = statistics.getBooleanStatistics().get().getFalseCount().getAsLong(); - long trueCount = statistics.getBooleanStatistics().get().getTrueCount().getAsLong(); - return OptionalLong.of((falseCount > 0 ? 1 : 0) + (trueCount > 0 ? 1 : 0)); - } - if (statistics.getDistinctValuesCount().isPresent()) { - return statistics.getDistinctValuesCount(); - } - return OptionalLong.empty(); - } - - @VisibleForTesting - static Estimate calculateNullsFraction(String column, Collection partitionStatistics) - { - List statisticsWithKnownRowCountAndNullsCount = partitionStatistics.stream() - .filter(statistics -> { - if (statistics.getBasicStatistics().getRowCount().isEmpty()) { - return false; - } - HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); - if (columnStatistics == null) { - return false; - } - return columnStatistics.getNullsCount().isPresent(); - }) - .collect(toImmutableList()); - - if (statisticsWithKnownRowCountAndNullsCount.isEmpty()) { - return Estimate.unknown(); - } - - long totalNullsCount = 0; - long totalRowCount = 0; - for (PartitionStatistics statistics : statisticsWithKnownRowCountAndNullsCount) { - long rowCount = statistics.getBasicStatistics().getRowCount().orElseThrow(() -> new VerifyException("rowCount is not present")); - verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); - HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); - verifyNotNull(columnStatistics, "columnStatistics is null"); - long nullsCount = columnStatistics.getNullsCount().orElseThrow(() -> new VerifyException("nullsCount is not present")); - verify(nullsCount >= 0, "nullsCount must be greater than or equal to zero"); - verify(nullsCount <= rowCount, "nullsCount must be less than or equal to rowCount. nullsCount: %s. rowCount: %s.", nullsCount, rowCount); - totalNullsCount += nullsCount; - totalRowCount += rowCount; - } - - if (totalRowCount == 0) { - return Estimate.zero(); - } - - verify( - totalNullsCount <= totalRowCount, - "totalNullsCount must be less than or equal to totalRowCount. totalNullsCount: %s. totalRowCount: %s.", - totalNullsCount, - totalRowCount); - return Estimate.of(((double) totalNullsCount) / totalRowCount); - } - - @VisibleForTesting - static Estimate calculateDataSize(String column, Collection partitionStatistics, double totalRowCount) - { - List statisticsWithKnownRowCountAndDataSize = partitionStatistics.stream() - .filter(statistics -> { - if (statistics.getBasicStatistics().getRowCount().isEmpty()) { - return false; - } - HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); - if (columnStatistics == null) { - return false; - } - return columnStatistics.getTotalSizeInBytes().isPresent(); - }) - .collect(toImmutableList()); - - if (statisticsWithKnownRowCountAndDataSize.isEmpty()) { - return Estimate.unknown(); - } - - long knownRowCount = 0; - long knownDataSize = 0; - for (PartitionStatistics statistics : statisticsWithKnownRowCountAndDataSize) { - long rowCount = statistics.getBasicStatistics().getRowCount().orElseThrow(() -> new VerifyException("rowCount is not present")); - verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); - HiveColumnStatistics columnStatistics = statistics.getColumnStatistics().get(column); - verifyNotNull(columnStatistics, "columnStatistics is null"); - long dataSize = columnStatistics.getTotalSizeInBytes().orElseThrow(() -> new VerifyException("totalSizeInBytes is not present")); - verify(dataSize >= 0, "dataSize must be greater than or equal to zero"); - knownRowCount += rowCount; - knownDataSize += dataSize; - } - - if (totalRowCount == 0) { - return Estimate.zero(); - } - - if (knownRowCount == 0) { - return Estimate.unknown(); - } - - double averageValueDataSizeInBytes = ((double) knownDataSize) / knownRowCount; - return Estimate.of(averageValueDataSizeInBytes * totalRowCount); - } - - @VisibleForTesting - static Optional calculateRange(Type type, List columnStatistics) - { - return columnStatistics.stream() - .map(statistics -> createRange(type, statistics)) - .filter(Optional::isPresent) - .map(Optional::get) - .reduce(DoubleRange::union); - } - - private static Optional createRange(Type type, HiveColumnStatistics statistics) - { - if (type.equals(BIGINT) || type.equals(INTEGER) || type.equals(SMALLINT) || type.equals(TINYINT)) { - return statistics.getIntegerStatistics().flatMap(integerStatistics -> createIntegerRange(type, integerStatistics)); - } - if (type.equals(DOUBLE) || type.equals(REAL)) { - return statistics.getDoubleStatistics().flatMap(MetastoreHiveStatisticsProvider::createDoubleRange); - } - if (type.equals(DATE)) { - return statistics.getDateStatistics().flatMap(MetastoreHiveStatisticsProvider::createDateRange); - } - if (type instanceof DecimalType) { - return statistics.getDecimalStatistics().flatMap(MetastoreHiveStatisticsProvider::createDecimalRange); - } - return Optional.empty(); - } - - private static Optional createIntegerRange(Type type, IntegerStatistics statistics) - { - if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { - return Optional.of(createIntegerRange(type, statistics.getMin().getAsLong(), statistics.getMax().getAsLong())); - } - return Optional.empty(); - } - - private static DoubleRange createIntegerRange(Type type, long min, long max) - { - return new DoubleRange(normalizeIntegerValue(type, min), normalizeIntegerValue(type, max)); - } - - private static long normalizeIntegerValue(Type type, long value) - { - if (type.equals(BIGINT)) { - return value; - } - if (type.equals(INTEGER)) { - return Ints.saturatedCast(value); - } - if (type.equals(SMALLINT)) { - return Shorts.saturatedCast(value); - } - if (type.equals(TINYINT)) { - return SignedBytes.saturatedCast(value); - } - throw new IllegalArgumentException("Unexpected type: " + type); - } - - private static Optional createDoubleRange(DoubleStatistics statistics) - { - if (statistics.getMin().isPresent() && statistics.getMax().isPresent() && !isNaN(statistics.getMin().getAsDouble()) && !isNaN(statistics.getMax().getAsDouble())) { - return Optional.of(new DoubleRange(statistics.getMin().getAsDouble(), statistics.getMax().getAsDouble())); - } - return Optional.empty(); - } - - private static Optional createDateRange(DateStatistics statistics) - { - if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { - return Optional.of(new DoubleRange(statistics.getMin().get().toEpochDay(), statistics.getMax().get().toEpochDay())); - } - return Optional.empty(); - } - - private static Optional createDecimalRange(DecimalStatistics statistics) - { - if (statistics.getMin().isPresent() && statistics.getMax().isPresent()) { - return Optional.of(new DoubleRange(statistics.getMin().get().doubleValue(), statistics.getMax().get().doubleValue())); - } - return Optional.empty(); - } - - @VisibleForTesting - interface PartitionsStatisticsProvider - { - Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns); - } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java index 521ba7c5eed9..5896b6e7d61d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -38,9 +38,12 @@ import java.math.BigDecimal; import java.time.LocalDate; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalDouble; import java.util.OptionalLong; +import java.util.Set; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; @@ -58,20 +61,20 @@ import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDecimalColumnStatistics; import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDoubleColumnStatistics; import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createIntegerColumnStatistics; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctPartitionKeys; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.createDataColumnStatistics; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.getPartitionsSample; -import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.validatePartitionStatistics; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.PartitionsRowCount; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateDataSize; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateDataSizeForPartitioningKey; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateDistinctPartitionKeys; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateDistinctValuesCount; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateNullsFraction; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateNullsFractionForPartitioningKey; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculatePartitionsRowCount; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateRange; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.calculateRangeForPartitioningKey; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.convertPartitionValueToDouble; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.createDataColumnStatistics; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.getPartitionsSample; +import static io.trino.plugin.hive.statistics.AbstractHiveStatisticsProvider.validatePartitionStatistics; import static io.trino.plugin.hive.util.HiveUtil.parsePartitionValue; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateType.DATE; @@ -630,7 +633,14 @@ public void testGetTableStatistics() .setBasicStatistics(new HiveBasicStatistics(OptionalLong.empty(), OptionalLong.of(1000), OptionalLong.empty(), OptionalLong.empty())) .setColumnStatistics(ImmutableMap.of(COLUMN, createIntegerColumnStatistics(OptionalLong.of(-100), OptionalLong.of(100), OptionalLong.of(500), OptionalLong.of(300)))) .build(); - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((session, table, hivePartitions, columns) -> ImmutableMap.of(partitionName, statistics)); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of(partitionName, statistics); + } + }; HiveColumnHandle columnHandle = createBaseColumn(COLUMN, 2, HIVE_LONG, BIGINT, REGULAR, Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) @@ -679,7 +689,14 @@ public void testGetTableStatisticsUnpartitioned() .setBasicStatistics(new HiveBasicStatistics(OptionalLong.empty(), OptionalLong.of(1000), OptionalLong.empty(), OptionalLong.empty())) .setColumnStatistics(ImmutableMap.of(COLUMN, createIntegerColumnStatistics(OptionalLong.of(-100), OptionalLong.of(100), OptionalLong.of(500), OptionalLong.of(300)))) .build(); - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((session, table, hivePartitions, columns) -> ImmutableMap.of(UNPARTITIONED_ID, statistics)); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of(UNPARTITIONED_ID, statistics); + } + }; HiveColumnHandle columnHandle = createBaseColumn(COLUMN, 2, HIVE_LONG, BIGINT, REGULAR, Optional.empty()); @@ -707,7 +724,14 @@ public void testGetTableStatisticsUnpartitioned() public void testGetTableStatisticsEmpty() { String partitionName = "p1=string1/p2=1234"; - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((session, table, hivePartitions, columns) -> ImmutableMap.of(partitionName, PartitionStatistics.empty())); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of(partitionName, PartitionStatistics.empty()); + } + }; assertEquals( statisticsProvider.getTableStatistics( SESSION, @@ -721,11 +745,15 @@ public void testGetTableStatisticsEmpty() @Test public void testGetTableStatisticsSampling() { - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((session, table, hivePartitions, columns) -> { - assertEquals(table, TABLE); - assertEquals(hivePartitions.size(), 1); - return ImmutableMap.of(); - }); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + assertEquals(table, TABLE); + assertEquals(hivePartitions.size(), 1); + return ImmutableMap.of(); + } + }; ConnectorSession session = getHiveSession(new HiveConfig() .setPartitionStatisticsSampleSize(1)); statisticsProvider.getTableStatistics( @@ -743,7 +771,14 @@ public void testGetTableStatisticsValidationFailure() .setBasicStatistics(new HiveBasicStatistics(-1, 0, 0, 0)) .build(); String partitionName = "p1=string1/p2=1234"; - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((session, table, hivePartitions, columns) -> ImmutableMap.of(partitionName, corruptedStatistics)); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of(partitionName, corruptedStatistics); + } + }; assertThatThrownBy(() -> statisticsProvider.getTableStatistics( getHiveSession(new HiveConfig().setIgnoreCorruptedStatistics(false)), TABLE, @@ -765,20 +800,32 @@ public void testGetTableStatisticsValidationFailure() @Test public void testEmptyTableStatisticsForPartitionColumnsWhenStatsAreEmpty() { - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider( - (session, table, hivePartitions, columns) -> ImmutableMap.of("p1=string1/p2=1234", PartitionStatistics.empty())); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of("p1=string1/p2=1234", PartitionStatistics.empty()); + } + }; testEmptyTableStatisticsForPartitionColumns(statisticsProvider); } @Test public void testEmptyTableStatisticsForPartitionColumnsWhenStatsAreMissing() { - MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider( - (session, table, hivePartitions, columns) -> ImmutableMap.of()); + HiveStatisticsProvider statisticsProvider = new AbstractHiveStatisticsProvider() + { + @Override + protected Map getPartitionsStatistics(ConnectorSession session, SchemaTableName table, List hivePartitions, Set columns) + { + return ImmutableMap.of(); + } + }; testEmptyTableStatisticsForPartitionColumns(statisticsProvider); } - private void testEmptyTableStatisticsForPartitionColumns(MetastoreHiveStatisticsProvider statisticsProvider) + private void testEmptyTableStatisticsForPartitionColumns(HiveStatisticsProvider statisticsProvider) { String partitionName1 = "p1=string1/p2=1234"; String partitionName2 = "p1=string2/p2=1235"; diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java index 0ebc1bd4378b..6e05341bc096 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/InternalIcebergConnectorFactory.java @@ -33,6 +33,7 @@ import io.trino.plugin.base.jmx.ConnectorObjectNameGeneratorModule; import io.trino.plugin.base.jmx.MBeanServerModule; import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.NodeVersion; import io.trino.plugin.iceberg.catalog.IcebergCatalogModule; import io.trino.spi.NodeManager; @@ -62,6 +63,7 @@ import java.util.Set; import java.util.stream.Stream; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; public final class InternalIcebergConnectorFactory @@ -128,6 +130,8 @@ public static Connector createConnector( .flatMap(Collection::stream) .collect(toImmutableList()); + checkState(!injector.getBindings().containsKey(Key.get(HiveConfig.class)), "HiveConfig should not be bound"); + return new IcebergConnector( injector, lifeCycleManager,