From b9d2af6332de252ebaf46dcad1cd405c9183c49e Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 13 Sep 2016 20:03:05 -0700 Subject: [PATCH 01/22] support column-level stats --- .../catalyst/plans/logical/Statistics.scala | 55 +++- .../spark/sql/execution/SparkSqlParser.scala | 6 +- .../command/AnalyzeColumnCommand.scala | 209 ++++++++++++++ .../command/AnalyzeTableCommand.scala | 98 ++++--- .../apache/spark/sql/StatisticsSuite.scala | 45 ++- .../spark/sql/hive/HiveExternalCatalog.scala | 25 +- .../spark/sql/hive/StatisticsSuite.scala | 273 ++++++++++++++++-- 7 files changed, 629 insertions(+), 82 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 3cf20385dd71..e2c1be1f8bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.types.DataType + /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override @@ -32,19 +34,70 @@ package org.apache.spark.sql.catalyst.plans.logical * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. + * @param basicColStats Basic column-level statistics. * @param isBroadcastable If true, output is small enough to be used in a broadcast join. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, + basicColStats: Map[String, BasicColStats] = Map.empty, isBroadcastable: Boolean = false) { + override def toString: String = "Statistics(" + simpleString + ")" /** Readable string representation for the Statistics. */ def simpleString: String = { Seq(s"sizeInBytes=$sizeInBytes", if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", + if (basicColStats.nonEmpty) s"basicColStats=$basicColStats" else "", s"isBroadcastable=$isBroadcastable" - ).filter(_.nonEmpty).mkString("", ", ", "") + ).filter(_.nonEmpty).mkString(", ") + } +} + +case class BasicColStats( + dataType: DataType, + numNulls: Long, + max: Option[Any] = None, + min: Option[Any] = None, + ndv: Option[Long] = None, + avgColLen: Option[Double] = None, + maxColLen: Option[Long] = None, + numTrues: Option[Long] = None, + numFalses: Option[Long] = None) { + + override def toString: String = "BasicColStats(" + simpleString + ")" + + def simpleString: String = { + Seq(s"numNulls=$numNulls", + if (max.isDefined) s"max=${max.get}" else "", + if (min.isDefined) s"min=${min.get}" else "", + if (ndv.isDefined) s"ndv=${ndv.get}" else "", + if (avgColLen.isDefined) s"avgColLen=${avgColLen.get}" else "", + if (maxColLen.isDefined) s"maxColLen=${maxColLen.get}" else "", + if (numTrues.isDefined) s"numTrues=${numTrues.get}" else "", + if (numFalses.isDefined) s"numFalses=${numFalses.get}" else "" + ).filter(_.nonEmpty).mkString(", ") + } +} + +object BasicColStats { + def fromString(str: String, dataType: DataType): BasicColStats = { + val suffix = ",\\s|\\)" + BasicColStats( + dataType = dataType, + numNulls = findItem(source = str, prefix = "numNulls=", suffix = suffix).map(_.toLong).get, + max = findItem(source = str, prefix = "max=", suffix = suffix), + min = findItem(source = str, prefix = "min=", suffix = suffix), + ndv = findItem(source = str, prefix = "ndv=", suffix = suffix).map(_.toLong), + avgColLen = findItem(source = str, prefix = "avgColLen=", suffix = suffix).map(_.toDouble), + maxColLen = findItem(source = str, prefix = "maxColLen=", suffix = suffix).map(_.toLong), + numTrues = findItem(source = str, prefix = "numTrues=", suffix = suffix).map(_.toLong), + numFalses = findItem(source = str, prefix = "numFalses=", suffix = suffix).map(_.toLong)) + } + + private def findItem(source: String, prefix: String, suffix: String): Option[String] = { + val pattern = s"(?<=$prefix)(.+?)(?=$suffix)".r + pattern.findFirstIn(source) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 5359cedc8097..3d60af2a5d76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -98,8 +98,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) - } else { + } else if (ctx.identifierSeq() == null) { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) + } else { + AnalyzeColumnCommand( + visitTableIdentifier(ctx.tableIdentifier).toString, + visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala new file mode 100644 index 000000000000..9d3c51be8c84 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * Analyzes the given columns of the given table in the current database to generate statistics, + * which will be used in query optimizations. + */ +case class AnalyzeColumnCommand( + tableName: String, + columnNames: Seq[String]) extends RunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + + // check correctness for column names + val attributeNames = relation.output.map(_.name.toLowerCase) + val invalidColumns = columnNames.filterNot { col => attributeNames.contains(col.toLowerCase)} + if (invalidColumns.nonEmpty) { + throw new AnalysisException(s"Invalid columns for table $tableName: $invalidColumns.") + } + + relation match { + case catalogRel: CatalogRelation => + updateStats(catalogRel.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sparkSession, catalogRel.catalogTable)) + + case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => + updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) + + case otherRelation => + throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + s"${otherRelation.nodeName}.") + } + + def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { + val lowerCaseNames = columnNames.map(_.toLowerCase) + val attributes = + relation.output.filter(attr => lowerCaseNames.contains(attr.name.toLowerCase)) + + // collect column statistics + val aggColumns = mutable.ArrayBuffer[Column](count(Column("*"))) + attributes.foreach(entry => aggColumns ++= statsAgg(entry.name, entry.dataType)) + val statsRow: InternalRow = Dataset.ofRows(sparkSession, relation).select(aggColumns: _*) + .queryExecution.toRdd.collect().head + + // We also update table-level stats to prevent inconsistency in case of table modification + // between the two ANALYZE commands for collecting table-level stats and column-level stats. + val rowCount = statsRow.getLong(0) + var newStats: Statistics = if (catalogTable.stats.isDefined) { + catalogTable.stats.get.copy(sizeInBytes = newTotalSize, rowCount = Some(rowCount)) + } else { + Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount)) + } + + var pos = 1 + val colStats = mutable.HashMap[String, BasicColStats]() + attributes.foreach { attr => + attr.dataType match { + case n: NumericType => + colStats += attr.name -> BasicColStats( + dataType = attr.dataType, + numNulls = rowCount - statsRow.getLong(pos + NumericStatsAgg.numNotNullsIndex), + max = Option(statsRow.get(pos + NumericStatsAgg.maxIndex, attr.dataType)), + min = Option(statsRow.get(pos + NumericStatsAgg.minIndex, attr.dataType)), + ndv = Some(statsRow.getLong(pos + NumericStatsAgg.ndvIndex))) + pos += NumericStatsAgg.statsSeq.length + case TimestampType | DateType => + colStats += attr.name -> BasicColStats( + dataType = attr.dataType, + numNulls = rowCount - statsRow.getLong(pos + NumericStatsAgg.numNotNullsIndex), + max = Option(statsRow.get(pos + NumericStatsAgg.maxIndex, attr.dataType)), + min = Option(statsRow.get(pos + NumericStatsAgg.minIndex, attr.dataType)), + ndv = Some(statsRow.getLong(pos + NumericStatsAgg.ndvIndex))) + pos += NumericStatsAgg.statsSeq.length + case StringType => + colStats += attr.name -> BasicColStats( + dataType = attr.dataType, + numNulls = rowCount - statsRow.getLong(pos + StringStatsAgg.numNotNullsIndex), + maxColLen = Some(statsRow.getLong(pos + StringStatsAgg.maxLenIndex)), + avgColLen = + Some(statsRow.getLong(pos + StringStatsAgg.sumLenIndex) / (1.0 * rowCount)), + ndv = Some(statsRow.getLong(pos + StringStatsAgg.ndvIndex))) + pos += StringStatsAgg.statsSeq.length + case BinaryType => + colStats += attr.name -> BasicColStats( + dataType = attr.dataType, + numNulls = rowCount - statsRow.getLong(pos + BinaryStatsAgg.numNotNullsIndex), + maxColLen = Some(statsRow.getLong(pos + BinaryStatsAgg.maxLenIndex)), + avgColLen = + Some(statsRow.getLong(pos + BinaryStatsAgg.sumLenIndex) / (1.0 * rowCount))) + pos += BinaryStatsAgg.statsSeq.length + case BooleanType => + val numOfNotNulls = statsRow.getLong(pos + BooleanStatsAgg.numNotNullsIndex) + val numOfTrues = Some(statsRow.getLong(pos + BooleanStatsAgg.numTruesIndex)) + colStats += attr.name -> BasicColStats( + dataType = attr.dataType, + numNulls = rowCount - numOfNotNulls, + numTrues = numOfTrues, + numFalses = numOfTrues.map(i => numOfNotNulls - i), + ndv = Some(2)) + pos += BooleanStatsAgg.statsSeq.length + } + } + newStats = newStats.copy(basicColStats = colStats.toMap) + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(newStats))) + // Refresh the cached data source table in the catalog. + sessionState.catalog.refreshTable(tableIdent) + } + + Seq.empty[Row] + } + + private def statsAgg(name: String, dataType: DataType): Seq[Column] = dataType match { + // Currently we only support stats generation for atomic types + case n: NumericType => NumericStatsAgg(name) + case TimestampType | DateType => NumericStatsAgg(name) + case StringType => StringStatsAgg(name) + case BinaryType => BinaryStatsAgg(name) + case BooleanType => BooleanStatsAgg(name) + case otherType => + throw new AnalysisException(s"Analyzing column $name of $otherType is not supported.") + } +} + +object ColumnStats extends Enumeration { + val MAX, MIN, NDV, NUM_NOT_NULLS, MAX_LENGTH, SUM_LENGTH, NUM_TRUES = Value +} + +trait StatsAggFunc { + // This sequence is used to track the order of stats results when collecting. + val statsSeq: Seq[ColumnStats.Value] + + def apply(name: String): Seq[Column] = { + val col = Column(name) + statsSeq.map { + case ColumnStats.MAX => max(col) + case ColumnStats.MIN => min(col) + // count(distinct col) will have a shuffle, so we use an approximate ndv for efficiency + case ColumnStats.NDV => approxCountDistinct(col) + case ColumnStats.NUM_NOT_NULLS => count(col) + case ColumnStats.MAX_LENGTH => max(length(col)) + case ColumnStats.SUM_LENGTH => sum(length(col)) + case ColumnStats.NUM_TRUES => sum(col.cast(IntegerType)) + } + } + + // This is used to locate the needed stat in the sequence. + def offset: Map[ColumnStats.Value, Int] = statsSeq.zipWithIndex.toMap + + def numNotNullsIndex: Int = offset(ColumnStats.NUM_NOT_NULLS) +} + +object NumericStatsAgg extends StatsAggFunc { + override val statsSeq = Seq(ColumnStats.MAX, ColumnStats.MIN, ColumnStats.NDV, + ColumnStats.NUM_NOT_NULLS) + def maxIndex: Int = offset(ColumnStats.MAX) + def minIndex: Int = offset(ColumnStats.MIN) + def ndvIndex: Int = offset(ColumnStats.NDV) +} + +object StringStatsAgg extends StatsAggFunc { + override val statsSeq = Seq(ColumnStats.MAX_LENGTH, ColumnStats.SUM_LENGTH, ColumnStats.NDV, + ColumnStats.NUM_NOT_NULLS) + def maxLenIndex: Int = offset(ColumnStats.MAX_LENGTH) + def sumLenIndex: Int = offset(ColumnStats.SUM_LENGTH) + def ndvIndex: Int = offset(ColumnStats.NDV) +} + +object BinaryStatsAgg extends StatsAggFunc { + override val statsSeq = Seq(ColumnStats.MAX_LENGTH, ColumnStats.SUM_LENGTH, + ColumnStats.NUM_NOT_NULLS) + def maxLenIndex: Int = offset(ColumnStats.MAX_LENGTH) + def sumLenIndex: Int = offset(ColumnStats.SUM_LENGTH) +} + +object BooleanStatsAgg extends StatsAggFunc { + override val statsSeq = Seq(ColumnStats.NUM_TRUES, ColumnStats.NUM_NOT_NULLS) + def numTruesIndex: Int = offset(ColumnStats.NUM_TRUES) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 40aecafecf5b..0d59c6a8846a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} - +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases @@ -44,51 +43,8 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend relation match { case relation: CatalogRelation => - val catalogTable: CatalogTable = relation.catalogTable - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateTableSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - val newTotalSize = - catalogTable.storage.locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - calculateTableSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - - updateTableStats(catalogTable, newTotalSize) + updateTableStats(relation.catalogTable, + AnalyzeTableCommand.calculateTotalSize(sparkSession, relation.catalogTable)) // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => @@ -132,3 +88,51 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend Seq.empty[Row] } } + +object AnalyzeTableCommand extends Logging { + + def calculateTotalSize(sparkSession: SparkSession, catalogTable: CatalogTable): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = + sparkSession.sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + catalogTable.storage.locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) + calculateTableSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${catalogTable.identifier.table} in the " + + s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 264a2ffbebeb..ac7cdecc6dbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} +import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, GlobalLimit, Join, LocalLimit} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -101,4 +101,47 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { checkTableStats(tableName, expectedRowCount = Some(2)) } } + + test("test column-level statistics for data source table created in InMemoryCatalog") { + def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { + assert(colStats.dataType == expectedColStats.dataType) + assert(colStats.numNulls == expectedColStats.numNulls) + assert(colStats.max == expectedColStats.max) + assert(colStats.min == expectedColStats.min) + if (expectedColStats.ndv.isDefined) { + // ndv is an approximate value, so we just make sure we have the value + assert(colStats.ndv.get >= 0) + } + assert(colStats.avgColLen == expectedColStats.avgColLen) + assert(colStats.maxColLen == expectedColStats.maxColLen) + assert(colStats.numTrues == expectedColStats.numTrues) + assert(colStats.numFalses == expectedColStats.numFalses) + } + + val tableName = "tbl" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS i, j") + val df = sql(s"SELECT * FROM $tableName") + val expectedRowCount = Some(2) + val expectedColStatsSeq: Seq[(String, BasicColStats)] = Seq( + ("i", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), + ndv = Some(2))), + ("j", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(1), + avgColLen = Some(1), ndv = Some(2)))) + val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + val stats = rel.catalogTable.get.stats.get + assert(stats.rowCount == expectedRowCount) + expectedColStatsSeq.foreach { case (column, expectedColStats) => + assert(stats.basicColStats.contains(column)) + checkColStats(colStats = stats.basicColStats(column), expectedColStats = expectedColStats) + } + rel + } + assert(relations.size == 1) + } + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index d35a681b67e3..242a7b0de781 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient @@ -401,7 +401,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var statsProperties: Map[String, String] = Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) if (stats.rowCount.isDefined) { - statsProperties += (STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()) + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + stats.basicColStats.foreach { case (colName, colStats) => + statsProperties += (STATISTICS_BASIC_COL_STATS_PREFIX + colName) -> colStats.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -473,15 +476,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } // construct Spark's statistics from information in Hive metastore - if (catalogTable.properties.contains(STATISTICS_TOTAL_SIZE)) { - val totalSize = BigInt(catalogTable.properties.get(STATISTICS_TOTAL_SIZE).get) - // TODO: we will compute "estimatedSize" when we have column stats: - // average size of row * number of rows + if (catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)).nonEmpty) { + val colStatsProps = catalogTable.properties + .filterKeys(_.startsWith(STATISTICS_BASIC_COL_STATS_PREFIX)) + .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v)} + val colStats: Map[String, BasicColStats] = catalogTable.schema.collect { + case field if colStatsProps.contains(field.name) => + (field.name, BasicColStats.fromString(colStatsProps(field.name), field.dataType)) + }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), stats = Some(Statistics( - sizeInBytes = totalSize, - rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_))))) + sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), + rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + basicColStats = colStats))) } else { catalogTable } @@ -693,6 +701,7 @@ object HiveExternalCatalog { val STATISTICS_PREFIX = "spark.sql.statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" + val STATISTICS_BASIC_COL_STATS_PREFIX = STATISTICS_PREFIX + "basicColStats." def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9956706929cd..9378af6f6413 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -18,19 +18,21 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} +import java.sql.{Date, Timestamp} import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructType, _} class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -171,7 +173,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } - private def checkStats( + private def checkTableStats( stats: Option[Statistics], hasSizeInBytes: Boolean, expectedRowCounts: Option[Int]): Unit = { @@ -184,7 +186,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - private def checkStats( + private def checkTableStats( tableName: String, isDataSourceTable: Boolean, hasSizeInBytes: Boolean, @@ -192,12 +194,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val df = sql(s"SELECT * FROM $tableName") val stats = df.queryExecution.analyzed.collect { case rel: MetastoreRelation => - checkStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a data source table, but got a Hive serde table") + checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") rel.catalogTable.stats case rel: LogicalRelation => - checkStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a Hive serde table, but got a data source table") + checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") rel.catalogTable.get.stats } assert(stats.size == 1) @@ -210,13 +212,13 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // Currently Spark's statistics are self-contained, we don't have statistics until we use // the `ANALYZE TABLE` command. sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, expectedRowCounts = None) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") - checkStats( + checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = false, @@ -224,12 +226,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // noscan won't count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1.get.sizeInBytes == fetchedStats2.get.sizeInBytes) } @@ -241,19 +243,19 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils sql(s"CREATE TABLE $textTable (key STRING, value STRING) STORED AS TEXTFILE") sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // when the total size is not changed, the old row count is kept - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(500)) assert(fetchedStats1 == fetchedStats2) sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") sql(s"ANALYZE TABLE $textTable COMPUTE STATISTICS noscan") // update total size and remove the old and invalid row count - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( textTable, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats3.get.sizeInBytes > fetchedStats2.get.sizeInBytes) } @@ -271,20 +273,20 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it // for robustness withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "true") { - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkStats( + checkTableStats( orcTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(500)) } } @@ -298,23 +300,23 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils assert(DDLUtils.isDatasourceTable(catalogTable)) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - checkStats( + checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = false, expectedRowCounts = None) // noscan won't count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats1 = checkStats( + val fetchedStats1 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS noscan") - val fetchedStats2 = checkStats( + val fetchedStats2 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - val fetchedStats3 = checkStats( + val fetchedStats3 = checkTableStats( parquetTable, isDataSourceTable = true, hasSizeInBytes = true, @@ -330,7 +332,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) dfNoCols.write.format("json").saveAsTable(table_no_cols) sql(s"ANALYZE TABLE $table_no_cols COMPUTE STATISTICS") - checkStats( + checkTableStats( table_no_cols, isDataSourceTable = true, hasSizeInBytes = true, @@ -338,6 +340,229 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + private def checkColStats( + rowRDD: RDD[Row], + schema: StructType, + expectedColStatsSeq: Seq[(String, BasicColStats)]): Unit = { + val table = "tbl" + withTable(table) { + var df = spark.createDataFrame(rowRDD, schema) + df.write.format("json").saveAsTable(table) + val columns = expectedColStatsSeq.map(_._1).mkString(", ") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS $columns") + df = sql(s"SELECT * FROM $table") + val stats = df.queryExecution.analyzed.collect { + case rel: LogicalRelation => + expectedColStatsSeq.foreach { expected => + assert(rel.catalogTable.get.stats.get.basicColStats.contains(expected._1)) + checkColStats(colStats = rel.catalogTable.get.stats.get.basicColStats(expected._1), + expectedColStats = expected._2) + } + } + assert(stats.size == 1) + } + } + + private def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { + assert(colStats.dataType == expectedColStats.dataType) + assert(colStats.numNulls == expectedColStats.numNulls) + colStats.dataType match { + case ByteType | ShortType | IntegerType | LongType => + assert(colStats.max.map(_.toString.toLong) == expectedColStats.max.map(_.toString.toLong)) + assert(colStats.min.map(_.toString.toLong) == expectedColStats.min.map(_.toString.toLong)) + case FloatType | DoubleType => + assert(colStats.max.map(_.toString.toDouble) == expectedColStats.max + .map(_.toString.toDouble)) + assert(colStats.min.map(_.toString.toDouble) == expectedColStats.min + .map(_.toString.toDouble)) + case DecimalType.SYSTEM_DEFAULT => + assert(colStats.max.map(i => Decimal(i.toString)) == expectedColStats.max + .map(i => Decimal(i.toString))) + assert(colStats.min.map(i => Decimal(i.toString)) == expectedColStats.min + .map(i => Decimal(i.toString))) + case DateType | TimestampType => + if (expectedColStats.max.isDefined) { + // just check the difference to exclude the influence of timezones + assert(colStats.max.get.toString.toLong - colStats.min.get.toString.toLong == + expectedColStats.max.get.toString.toLong - expectedColStats.min.get.toString.toLong) + } else { + assert(colStats.max.isEmpty && colStats.min.isEmpty) + } + case _ => // only numeric types, date type and timestamp type have max and min stats + } + colStats.dataType match { + case BinaryType => assert(colStats.ndv.isEmpty) + case BooleanType => assert(colStats.ndv.contains(2)) + case _ => + // ndv is an approximate value, so we just make sure we have the value + assert(colStats.ndv.get >= 0) + } + assert(colStats.avgColLen == expectedColStats.avgColLen) + assert(colStats.maxColLen == expectedColStats.maxColLen) + assert(colStats.numTrues == expectedColStats.numTrues) + assert(colStats.numFalses == expectedColStats.numFalses) + } + + test("basic statistics for integral type columns") { + val rdd = sparkContext.parallelize(Seq("1", null, "2", "3", null)).map { i => + if (i != null) Row(i.toByte, i.toShort, i.toInt, i.toLong) else Row(i, i, i, i) + } + val schema = StructType( + StructField(name = "c1", dataType = ByteType, nullable = true) :: + StructField(name = "c2", dataType = ShortType, nullable = true) :: + StructField(name = "c3", dataType = IntegerType, nullable = true) :: + StructField(name = "c4", dataType = LongType, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats( + dataType = ByteType, numNulls = 2, max = Some(3), min = Some(1), ndv = Some(3)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = ShortType)), + ("c3", expectedBasicStats.copy(dataType = IntegerType)), + ("c4", expectedBasicStats.copy(dataType = LongType))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for fractional type columns") { + val rdd = sparkContext.parallelize(Seq(null, "1.01", "2.02", "3.03")).map { i => + if (i != null) Row(i.toFloat, i.toDouble, Decimal(i)) else Row(i, i, i) + } + val schema = StructType( + StructField(name = "c1", dataType = FloatType, nullable = true) :: + StructField(name = "c2", dataType = DoubleType, nullable = true) :: + StructField(name = "c3", dataType = DecimalType.SYSTEM_DEFAULT, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats( + dataType = FloatType, numNulls = 1, max = Some(3.03), min = Some(1.01), ndv = Some(3)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = DoubleType)), + ("c3", expectedBasicStats.copy(dataType = DecimalType.SYSTEM_DEFAULT))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for string column") { + val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map(Row(_)) + val schema = StructType(StructField(name = "c1", dataType = StringType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = StringType, numNulls = 1, + maxColLen = Some(4), avgColLen = Some(2.25), ndv = Some(3)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for binary column") { + val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map { i => + if (i != null) Row(i.getBytes) else Row(i) + } + val schema = StructType(StructField(name = "c1", dataType = BinaryType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = BinaryType, numNulls = 1, + maxColLen = Some(4), avgColLen = Some(2.25)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for boolean column") { + val rdd = sparkContext.parallelize(Seq(null, true, false, true)).map(Row(_)) + val schema = + StructType(StructField(name = "c1", dataType = BooleanType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = BooleanType, numNulls = 1, + numTrues = Some(2), numFalses = Some(1)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for date column") { + val rdd = sparkContext.parallelize(Seq(null, "1970-01-01", "1970-02-02")).map { i => + if (i != null) Row(Date.valueOf(i)) else Row(i) + } + val schema = + StructType(StructField(name = "c1", dataType = DateType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = DateType, numNulls = 1, + max = Some(32), min = Some(0), ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for timestamp column") { + val rdd = sparkContext.parallelize(Seq(null, "1970-01-01 00:00:00", "1970-01-01 00:00:05")) + .map(i => if (i != null) Row(Timestamp.valueOf(i)) else Row(i)) + val schema = + StructType(StructField(name = "c1", dataType = TimestampType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = TimestampType, numNulls = 1, + max = Some(5000000), min = Some(0), ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for null columns") { + val rdd = sparkContext.parallelize(Seq(Row(null, null))) + val schema = StructType( + StructField(name = "c1", dataType = LongType, nullable = true) :: + StructField(name = "c2", dataType = TimestampType, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats(dataType = LongType, numNulls = 1, + max = None, min = None, ndv = Some(0)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = TimestampType))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for columns with different types") { + val rdd = sparkContext.parallelize(Seq( + Row(1, 1.01, "a", "a".getBytes, true, Date.valueOf("1970-01-01"), + Timestamp.valueOf("1970-01-01 00:00:00"), 5.toLong), + Row(2, 2.02, "bb", "bb".getBytes, false, Date.valueOf("1970-02-02"), + Timestamp.valueOf("1970-01-01 00:00:05"), 4.toLong))) + val schema = StructType(Seq( + StructField(name = "c1", dataType = IntegerType, nullable = false), + StructField(name = "c2", dataType = DoubleType, nullable = false), + StructField(name = "c3", dataType = StringType, nullable = false), + StructField(name = "c4", dataType = BinaryType, nullable = false), + StructField(name = "c5", dataType = BooleanType, nullable = false), + StructField(name = "c6", dataType = DateType, nullable = false), + StructField(name = "c7", dataType = TimestampType, nullable = false), + StructField(name = "c8", dataType = LongType, nullable = false))) + val statsSeq = Seq( + ("c1", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), + ndv = Some(2))), + ("c2", BasicColStats(dataType = DoubleType, numNulls = 0, max = Some(2.02), min = Some(1.01), + ndv = Some(2))), + ("c3", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(2), + avgColLen = Some(1.5), ndv = Some(2))), + ("c4", BasicColStats(dataType = BinaryType, numNulls = 0, maxColLen = Some(2), + avgColLen = Some(1.5))), + ("c5", BasicColStats(dataType = BooleanType, numNulls = 0, numTrues = Some(1), + numFalses = Some(1), ndv = Some(2))), + ("c6", BasicColStats(dataType = DateType, numNulls = 0, max = Some(32), min = Some(0), + ndv = Some(2))), + ("c7", BasicColStats(dataType = TimestampType, numNulls = 0, max = Some(5000000), + min = Some(0), ndv = Some(2))), + ("c8", BasicColStats(dataType = LongType, numNulls = 0, max = Some(5), min = Some(4), + ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("update table-level stats while collecting column-level stats") { + val table = "tbl" + val tmpTable = "tmp" + withTable(table, tmpTable) { + val rdd = sparkContext.parallelize(Seq(Row(1))) + val df = spark.createDataFrame(rdd, StructType(Seq( + StructField(name = "c1", dataType = IntegerType, nullable = false)))) + df.write.format("json").saveAsTable(tmpTable) + + sql(s"CREATE TABLE $table (c1 int)") + sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = true, expectedRowCounts = Some(1)) + + // update table between analyze table and analyze column commands + sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = true, expectedRowCounts = Some(2)) + + assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) + val basicColStats = fetchedStats2.get.basicColStats("c1") + checkColStats(colStats = basicColStats, expectedColStats = BasicColStats( + dataType = IntegerType, numNulls = 0, max = Some(1), min = Some(1), ndv = Some(1))) + } + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => From 18f7bfd474415e84899193aa2f7a46d5becef1d6 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 13 Sep 2016 21:02:10 -0700 Subject: [PATCH 02/22] fix style --- .../spark/sql/execution/command/AnalyzeTableCommand.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 0d59c6a8846a..f7996ee318d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal + import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -135,4 +137,4 @@ object AnalyzeTableCommand extends Logging { } }.getOrElse(0L) } -} \ No newline at end of file +} From 230f1d3c6c3b69a1b88896eeca21a8fb227545c1 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 14 Sep 2016 15:44:50 -0700 Subject: [PATCH 03/22] 1. add test cases for parsing command and columns, 2. refactor test suites --- .../spark/sql/execution/SparkSqlParser.scala | 10 +- .../command/AnalyzeColumnCommand.scala | 28 +- .../command/AnalyzeTableCommand.scala | 2 +- .../sql/hive/StatisticsColumnSuite.scala | 228 ++++++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 285 +----------------- .../spark/sql/hive/StatisticsTest.scala | 145 +++++++++ 6 files changed, 398 insertions(+), 300 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3d60af2a5d76..9aba2742187f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -87,11 +87,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN - * option (other options are passed on to Hive) e.g.: - * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; - * }}} + * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec == null && @@ -99,6 +95,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.identifier.getText.toLowerCase == "noscan") { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) } else if (ctx.identifierSeq() == null) { + if (ctx.FOR() != null || ctx.COLUMNS() != null) { + throw new ParseException("Need to specify the columns to analyze. Usage: " + + "ANALYZE TABLE tbl COMPUTE STATISTICS FOR COLUMNS key, value", ctx) + } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) } else { AnalyzeColumnCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 9d3c51be8c84..72f7d8947eda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ @@ -42,11 +43,18 @@ case class AnalyzeColumnCommand( val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) - // check correctness for column names - val attributeNames = relation.output.map(_.name.toLowerCase) - val invalidColumns = columnNames.filterNot { col => attributeNames.contains(col.toLowerCase)} - if (invalidColumns.nonEmpty) { - throw new AnalysisException(s"Invalid columns for table $tableName: $invalidColumns.") + // check correctness of column names + val validColumns = mutable.HashSet[NamedExpression]() + val resolver = sparkSession.sessionState.conf.resolver + columnNames.foreach { col => + val exprOption = relation.resolve(col.split("\\."), resolver) + if (exprOption.isEmpty) { + throw new AnalysisException(s"Invalid column name: $col") + } + if (validColumns.map(_.exprId).contains(exprOption.get.exprId)) { + throw new AnalysisException(s"Duplicate column name: $col") + } + validColumns += exprOption.get } relation match { @@ -58,18 +66,14 @@ case class AnalyzeColumnCommand( updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) case otherRelation => - throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + throw new AnalysisException("ANALYZE TABLE is not supported for " + s"${otherRelation.nodeName}.") } def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - val lowerCaseNames = columnNames.map(_.toLowerCase) - val attributes = - relation.output.filter(attr => lowerCaseNames.contains(attr.name.toLowerCase)) - // collect column statistics val aggColumns = mutable.ArrayBuffer[Column](count(Column("*"))) - attributes.foreach(entry => aggColumns ++= statsAgg(entry.name, entry.dataType)) + validColumns.foreach(entry => aggColumns ++= statsAgg(entry.name, entry.dataType)) val statsRow: InternalRow = Dataset.ofRows(sparkSession, relation).select(aggColumns: _*) .queryExecution.toRdd.collect().head @@ -84,7 +88,7 @@ case class AnalyzeColumnCommand( var pos = 1 val colStats = mutable.HashMap[String, BasicColStats]() - attributes.foreach { attr => + validColumns.foreach { attr => attr.dataType match { case n: NumericType => colStats += attr.name -> BasicColStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index f7996ee318d0..8c82d10d66c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -53,7 +53,7 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend updateTableStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) case otherRelation => - throw new AnalysisException(s"ANALYZE TABLE is not supported for " + + throw new AnalysisException("ANALYZE TABLE is not supported for " + s"${otherRelation.nodeName}.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala new file mode 100644 index 000000000000..ff9ccc2ece4f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.plans.logical.BasicColStats +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.types._ + +class StatisticsColumnSuite extends StatisticsTest { + + test("parse analyze column commands") { + val table = "table" + assertAnalyzeCommand( + s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value", + classOf[AnalyzeColumnCommand]) + + val noColumnError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS") + } + assert(noColumnError.message == "Need to specify the columns to analyze. Usage: " + + "ANALYZE TABLE tbl COMPUTE STATISTICS FOR COLUMNS key, value") + + withTable(table) { + sql(s"CREATE TABLE $table (key INT, value STRING)") + val invalidColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS k") + } + assert(invalidColError.message == s"Invalid column name: k") + + val duplicateColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value, key") + } + assert(duplicateColError.message == s"Duplicate column name: key") + + withSQLConf("spark.sql.caseSensitive" -> "true") { + val invalidErr = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS keY") + } + assert(invalidErr.message == s"Invalid column name: keY") + } + + withSQLConf("spark.sql.caseSensitive" -> "false") { + val duplicateErr = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value, vaLue") + } + assert(duplicateErr.message == s"Duplicate column name: vaLue") + } + } + } + + test("basic statistics for integral type columns") { + val rdd = sparkContext.parallelize(Seq("1", null, "2", "3", null)).map { i => + if (i != null) Row(i.toByte, i.toShort, i.toInt, i.toLong) else Row(i, i, i, i) + } + val schema = StructType( + StructField(name = "c1", dataType = ByteType, nullable = true) :: + StructField(name = "c2", dataType = ShortType, nullable = true) :: + StructField(name = "c3", dataType = IntegerType, nullable = true) :: + StructField(name = "c4", dataType = LongType, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats( + dataType = ByteType, numNulls = 2, max = Some(3), min = Some(1), ndv = Some(3)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = ShortType)), + ("c3", expectedBasicStats.copy(dataType = IntegerType)), + ("c4", expectedBasicStats.copy(dataType = LongType))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for fractional type columns") { + val rdd = sparkContext.parallelize(Seq(null, "1.01", "2.02", "3.03")).map { i => + if (i != null) Row(i.toFloat, i.toDouble, Decimal(i)) else Row(i, i, i) + } + val schema = StructType( + StructField(name = "c1", dataType = FloatType, nullable = true) :: + StructField(name = "c2", dataType = DoubleType, nullable = true) :: + StructField(name = "c3", dataType = DecimalType.SYSTEM_DEFAULT, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats( + dataType = FloatType, numNulls = 1, max = Some(3.03), min = Some(1.01), ndv = Some(3)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = DoubleType)), + ("c3", expectedBasicStats.copy(dataType = DecimalType.SYSTEM_DEFAULT))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for string column") { + val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map(Row(_)) + val schema = StructType(StructField(name = "c1", dataType = StringType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = StringType, numNulls = 1, + maxColLen = Some(4), avgColLen = Some(2.25), ndv = Some(3)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for binary column") { + val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map { i => + if (i != null) Row(i.getBytes) else Row(i) + } + val schema = StructType(StructField(name = "c1", dataType = BinaryType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = BinaryType, numNulls = 1, + maxColLen = Some(4), avgColLen = Some(2.25)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for boolean column") { + val rdd = sparkContext.parallelize(Seq(null, true, false, true)).map(Row(_)) + val schema = + StructType(StructField(name = "c1", dataType = BooleanType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = BooleanType, numNulls = 1, + numTrues = Some(2), numFalses = Some(1)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for date column") { + val rdd = sparkContext.parallelize(Seq(null, "1970-01-01", "1970-02-02")).map { i => + if (i != null) Row(Date.valueOf(i)) else Row(i) + } + val schema = + StructType(StructField(name = "c1", dataType = DateType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = DateType, numNulls = 1, + max = Some(32), min = Some(0), ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for timestamp column") { + val rdd = sparkContext.parallelize(Seq(null, "1970-01-01 00:00:00", "1970-01-01 00:00:05")) + .map(i => if (i != null) Row(Timestamp.valueOf(i)) else Row(i)) + val schema = + StructType(StructField(name = "c1", dataType = TimestampType, nullable = true) :: Nil) + val statsSeq = Seq(("c1", BasicColStats(dataType = TimestampType, numNulls = 1, + max = Some(5000000), min = Some(0), ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for null columns") { + val rdd = sparkContext.parallelize(Seq(Row(null, null))) + val schema = StructType( + StructField(name = "c1", dataType = LongType, nullable = true) :: + StructField(name = "c2", dataType = TimestampType, nullable = true) :: Nil) + val expectedBasicStats = BasicColStats(dataType = LongType, numNulls = 1, + max = None, min = None, ndv = Some(0)) + val statsSeq = Seq( + ("c1", expectedBasicStats), + ("c2", expectedBasicStats.copy(dataType = TimestampType))) + checkColStats(rdd, schema, statsSeq) + } + + test("basic statistics for columns with different types") { + val rdd = sparkContext.parallelize(Seq( + Row(1, 1.01, "a", "a".getBytes, true, Date.valueOf("1970-01-01"), + Timestamp.valueOf("1970-01-01 00:00:00"), 5.toLong), + Row(2, 2.02, "bb", "bb".getBytes, false, Date.valueOf("1970-02-02"), + Timestamp.valueOf("1970-01-01 00:00:05"), 4.toLong))) + val schema = StructType(Seq( + StructField(name = "c1", dataType = IntegerType, nullable = false), + StructField(name = "c2", dataType = DoubleType, nullable = false), + StructField(name = "c3", dataType = StringType, nullable = false), + StructField(name = "c4", dataType = BinaryType, nullable = false), + StructField(name = "c5", dataType = BooleanType, nullable = false), + StructField(name = "c6", dataType = DateType, nullable = false), + StructField(name = "c7", dataType = TimestampType, nullable = false), + StructField(name = "c8", dataType = LongType, nullable = false))) + val statsSeq = Seq( + ("c1", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), + ndv = Some(2))), + ("c2", BasicColStats(dataType = DoubleType, numNulls = 0, max = Some(2.02), min = Some(1.01), + ndv = Some(2))), + ("c3", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(2), + avgColLen = Some(1.5), ndv = Some(2))), + ("c4", BasicColStats(dataType = BinaryType, numNulls = 0, maxColLen = Some(2), + avgColLen = Some(1.5))), + ("c5", BasicColStats(dataType = BooleanType, numNulls = 0, numTrues = Some(1), + numFalses = Some(1), ndv = Some(2))), + ("c6", BasicColStats(dataType = DateType, numNulls = 0, max = Some(32), min = Some(0), + ndv = Some(2))), + ("c7", BasicColStats(dataType = TimestampType, numNulls = 0, max = Some(5000000), + min = Some(0), ndv = Some(2))), + ("c8", BasicColStats(dataType = LongType, numNulls = 0, max = Some(5), min = Some(4), + ndv = Some(2)))) + checkColStats(rdd, schema, statsSeq) + } + + test("update table-level stats while collecting column-level stats") { + val table = "tbl" + val tmpTable = "tmp" + withTable(table, tmpTable) { + val rdd = sparkContext.parallelize(Seq(Row(1))) + val df = spark.createDataFrame(rdd, StructType(Seq( + StructField(name = "c1", dataType = IntegerType, nullable = false)))) + df.write.format("json").saveAsTable(tmpTable) + + sql(s"CREATE TABLE $table (c1 int)") + sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = true, expectedRowCounts = Some(1)) + + // update table between analyze table and analyze column commands + sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) + + val basicColStats = fetchedStats2.get.basicColStats("c1") + checkColStats(colStats = basicColStats, expectedColStats = BasicColStats( + dataType = IntegerType, numNulls = 0, max = Some(1), min = Some(1), ndv = Some(1))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9378af6f6413..f83d797b3ccb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -18,42 +18,19 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} -import java.sql.{Date, Timestamp} import scala.reflect.ClassTag -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{StructType, _} +import org.apache.spark.sql.types.StructType -class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { +class StatisticsSuite extends StatisticsTest { test("parse analyze commands") { - def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeTableCommand => a - case o => o - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators(0)} - |parsed command: - |$parsed - """.stripMargin) - } - } - assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[AnalyzeTableCommand]) @@ -173,39 +150,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } - private def checkTableStats( - stats: Option[Statistics], - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Unit = { - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - } - - private def checkTableStats( - tableName: String, - isDataSourceTable: Boolean, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[Statistics] = { - val df = sql(s"SELECT * FROM $tableName") - val stats = df.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats - case rel: LogicalRelation => - checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -340,229 +284,6 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } - private def checkColStats( - rowRDD: RDD[Row], - schema: StructType, - expectedColStatsSeq: Seq[(String, BasicColStats)]): Unit = { - val table = "tbl" - withTable(table) { - var df = spark.createDataFrame(rowRDD, schema) - df.write.format("json").saveAsTable(table) - val columns = expectedColStatsSeq.map(_._1).mkString(", ") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS $columns") - df = sql(s"SELECT * FROM $table") - val stats = df.queryExecution.analyzed.collect { - case rel: LogicalRelation => - expectedColStatsSeq.foreach { expected => - assert(rel.catalogTable.get.stats.get.basicColStats.contains(expected._1)) - checkColStats(colStats = rel.catalogTable.get.stats.get.basicColStats(expected._1), - expectedColStats = expected._2) - } - } - assert(stats.size == 1) - } - } - - private def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { - assert(colStats.dataType == expectedColStats.dataType) - assert(colStats.numNulls == expectedColStats.numNulls) - colStats.dataType match { - case ByteType | ShortType | IntegerType | LongType => - assert(colStats.max.map(_.toString.toLong) == expectedColStats.max.map(_.toString.toLong)) - assert(colStats.min.map(_.toString.toLong) == expectedColStats.min.map(_.toString.toLong)) - case FloatType | DoubleType => - assert(colStats.max.map(_.toString.toDouble) == expectedColStats.max - .map(_.toString.toDouble)) - assert(colStats.min.map(_.toString.toDouble) == expectedColStats.min - .map(_.toString.toDouble)) - case DecimalType.SYSTEM_DEFAULT => - assert(colStats.max.map(i => Decimal(i.toString)) == expectedColStats.max - .map(i => Decimal(i.toString))) - assert(colStats.min.map(i => Decimal(i.toString)) == expectedColStats.min - .map(i => Decimal(i.toString))) - case DateType | TimestampType => - if (expectedColStats.max.isDefined) { - // just check the difference to exclude the influence of timezones - assert(colStats.max.get.toString.toLong - colStats.min.get.toString.toLong == - expectedColStats.max.get.toString.toLong - expectedColStats.min.get.toString.toLong) - } else { - assert(colStats.max.isEmpty && colStats.min.isEmpty) - } - case _ => // only numeric types, date type and timestamp type have max and min stats - } - colStats.dataType match { - case BinaryType => assert(colStats.ndv.isEmpty) - case BooleanType => assert(colStats.ndv.contains(2)) - case _ => - // ndv is an approximate value, so we just make sure we have the value - assert(colStats.ndv.get >= 0) - } - assert(colStats.avgColLen == expectedColStats.avgColLen) - assert(colStats.maxColLen == expectedColStats.maxColLen) - assert(colStats.numTrues == expectedColStats.numTrues) - assert(colStats.numFalses == expectedColStats.numFalses) - } - - test("basic statistics for integral type columns") { - val rdd = sparkContext.parallelize(Seq("1", null, "2", "3", null)).map { i => - if (i != null) Row(i.toByte, i.toShort, i.toInt, i.toLong) else Row(i, i, i, i) - } - val schema = StructType( - StructField(name = "c1", dataType = ByteType, nullable = true) :: - StructField(name = "c2", dataType = ShortType, nullable = true) :: - StructField(name = "c3", dataType = IntegerType, nullable = true) :: - StructField(name = "c4", dataType = LongType, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats( - dataType = ByteType, numNulls = 2, max = Some(3), min = Some(1), ndv = Some(3)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = ShortType)), - ("c3", expectedBasicStats.copy(dataType = IntegerType)), - ("c4", expectedBasicStats.copy(dataType = LongType))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for fractional type columns") { - val rdd = sparkContext.parallelize(Seq(null, "1.01", "2.02", "3.03")).map { i => - if (i != null) Row(i.toFloat, i.toDouble, Decimal(i)) else Row(i, i, i) - } - val schema = StructType( - StructField(name = "c1", dataType = FloatType, nullable = true) :: - StructField(name = "c2", dataType = DoubleType, nullable = true) :: - StructField(name = "c3", dataType = DecimalType.SYSTEM_DEFAULT, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats( - dataType = FloatType, numNulls = 1, max = Some(3.03), min = Some(1.01), ndv = Some(3)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = DoubleType)), - ("c3", expectedBasicStats.copy(dataType = DecimalType.SYSTEM_DEFAULT))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for string column") { - val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map(Row(_)) - val schema = StructType(StructField(name = "c1", dataType = StringType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = StringType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(2.25), ndv = Some(3)))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for binary column") { - val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map { i => - if (i != null) Row(i.getBytes) else Row(i) - } - val schema = StructType(StructField(name = "c1", dataType = BinaryType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = BinaryType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(2.25)))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for boolean column") { - val rdd = sparkContext.parallelize(Seq(null, true, false, true)).map(Row(_)) - val schema = - StructType(StructField(name = "c1", dataType = BooleanType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = BooleanType, numNulls = 1, - numTrues = Some(2), numFalses = Some(1)))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for date column") { - val rdd = sparkContext.parallelize(Seq(null, "1970-01-01", "1970-02-02")).map { i => - if (i != null) Row(Date.valueOf(i)) else Row(i) - } - val schema = - StructType(StructField(name = "c1", dataType = DateType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = DateType, numNulls = 1, - max = Some(32), min = Some(0), ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for timestamp column") { - val rdd = sparkContext.parallelize(Seq(null, "1970-01-01 00:00:00", "1970-01-01 00:00:05")) - .map(i => if (i != null) Row(Timestamp.valueOf(i)) else Row(i)) - val schema = - StructType(StructField(name = "c1", dataType = TimestampType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = TimestampType, numNulls = 1, - max = Some(5000000), min = Some(0), ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for null columns") { - val rdd = sparkContext.parallelize(Seq(Row(null, null))) - val schema = StructType( - StructField(name = "c1", dataType = LongType, nullable = true) :: - StructField(name = "c2", dataType = TimestampType, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats(dataType = LongType, numNulls = 1, - max = None, min = None, ndv = Some(0)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = TimestampType))) - checkColStats(rdd, schema, statsSeq) - } - - test("basic statistics for columns with different types") { - val rdd = sparkContext.parallelize(Seq( - Row(1, 1.01, "a", "a".getBytes, true, Date.valueOf("1970-01-01"), - Timestamp.valueOf("1970-01-01 00:00:00"), 5.toLong), - Row(2, 2.02, "bb", "bb".getBytes, false, Date.valueOf("1970-02-02"), - Timestamp.valueOf("1970-01-01 00:00:05"), 4.toLong))) - val schema = StructType(Seq( - StructField(name = "c1", dataType = IntegerType, nullable = false), - StructField(name = "c2", dataType = DoubleType, nullable = false), - StructField(name = "c3", dataType = StringType, nullable = false), - StructField(name = "c4", dataType = BinaryType, nullable = false), - StructField(name = "c5", dataType = BooleanType, nullable = false), - StructField(name = "c6", dataType = DateType, nullable = false), - StructField(name = "c7", dataType = TimestampType, nullable = false), - StructField(name = "c8", dataType = LongType, nullable = false))) - val statsSeq = Seq( - ("c1", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), - ndv = Some(2))), - ("c2", BasicColStats(dataType = DoubleType, numNulls = 0, max = Some(2.02), min = Some(1.01), - ndv = Some(2))), - ("c3", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(2), - avgColLen = Some(1.5), ndv = Some(2))), - ("c4", BasicColStats(dataType = BinaryType, numNulls = 0, maxColLen = Some(2), - avgColLen = Some(1.5))), - ("c5", BasicColStats(dataType = BooleanType, numNulls = 0, numTrues = Some(1), - numFalses = Some(1), ndv = Some(2))), - ("c6", BasicColStats(dataType = DateType, numNulls = 0, max = Some(32), min = Some(0), - ndv = Some(2))), - ("c7", BasicColStats(dataType = TimestampType, numNulls = 0, max = Some(5000000), - min = Some(0), ndv = Some(2))), - ("c8", BasicColStats(dataType = LongType, numNulls = 0, max = Some(5), min = Some(4), - ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) - } - - test("update table-level stats while collecting column-level stats") { - val table = "tbl" - val tmpTable = "tmp" - withTable(table, tmpTable) { - val rdd = sparkContext.parallelize(Seq(Row(1))) - val df = spark.createDataFrame(rdd, StructType(Seq( - StructField(name = "c1", dataType = IntegerType, nullable = false)))) - df.write.format("json").saveAsTable(tmpTable) - - sql(s"CREATE TABLE $table (c1 int)") - sql(s"INSERT INTO $table SELECT * FROM $tmpTable") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(1)) - - // update table between analyze table and analyze column commands - sql(s"INSERT INTO $table SELECT * FROM $tmpTable") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(2)) - - assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) - val basicColStats = fetchedStats2.get.basicColStats("c1") - checkColStats(colStats = basicColStats, expectedColStats = BasicColStats( - dataType = IntegerType, numNulls = 0, max = Some(1), min = Some(1), ndv = Some(1))) - } - } - test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala new file mode 100644 index 000000000000..2f7efa850de6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils { + + def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeTableCommand => a + case b: AnalyzeColumnCommand => b + case o => o + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators(0)} + |parsed command: + |$parsed + """.stripMargin) + } + } + + def checkTableStats( + stats: Option[Statistics], + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Unit = { + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes > 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + } + + def checkTableStats( + tableName: String, + isDataSourceTable: Boolean, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[Statistics] = { + val df = sql(s"SELECT * FROM $tableName") + val stats = df.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") + rel.catalogTable.stats + case rel: LogicalRelation => + checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } + + def checkColStats( + rowRDD: RDD[Row], + schema: StructType, + expectedColStatsSeq: Seq[(String, BasicColStats)]): Unit = { + val table = "tbl" + withTable(table) { + var df = spark.createDataFrame(rowRDD, schema) + df.write.format("json").saveAsTable(table) + val columns = expectedColStatsSeq.map(_._1).mkString(", ") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS $columns") + df = sql(s"SELECT * FROM $table") + val stats = df.queryExecution.analyzed.collect { + case rel: LogicalRelation => + expectedColStatsSeq.foreach { expected => + assert(rel.catalogTable.get.stats.get.basicColStats.contains(expected._1)) + checkColStats(colStats = rel.catalogTable.get.stats.get.basicColStats(expected._1), + expectedColStats = expected._2) + } + } + assert(stats.size == 1) + } + } + + def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { + assert(colStats.dataType == expectedColStats.dataType) + assert(colStats.numNulls == expectedColStats.numNulls) + colStats.dataType match { + case ByteType | ShortType | IntegerType | LongType => + assert(colStats.max.map(_.toString.toLong) == expectedColStats.max.map(_.toString.toLong)) + assert(colStats.min.map(_.toString.toLong) == expectedColStats.min.map(_.toString.toLong)) + case FloatType | DoubleType => + assert(colStats.max.map(_.toString.toDouble) == expectedColStats.max + .map(_.toString.toDouble)) + assert(colStats.min.map(_.toString.toDouble) == expectedColStats.min + .map(_.toString.toDouble)) + case DecimalType.SYSTEM_DEFAULT => + assert(colStats.max.map(i => Decimal(i.toString)) == expectedColStats.max + .map(i => Decimal(i.toString))) + assert(colStats.min.map(i => Decimal(i.toString)) == expectedColStats.min + .map(i => Decimal(i.toString))) + case DateType | TimestampType => + if (expectedColStats.max.isDefined) { + // just check the difference to exclude the influence of timezones + assert(colStats.max.get.toString.toLong - colStats.min.get.toString.toLong == + expectedColStats.max.get.toString.toLong - expectedColStats.min.get.toString.toLong) + } else { + assert(colStats.max.isEmpty && colStats.min.isEmpty) + } + case _ => // only numeric types, date type and timestamp type have max and min stats + } + colStats.dataType match { + case BinaryType => assert(colStats.ndv.isEmpty) + case BooleanType => assert(colStats.ndv.contains(2)) + case _ => + // ndv is an approximate value, so we just make sure we have the value + assert(colStats.ndv.get >= 0) + } + assert(colStats.avgColLen == expectedColStats.avgColLen) + assert(colStats.maxColLen == expectedColStats.maxColLen) + assert(colStats.numTrues == expectedColStats.numTrues) + assert(colStats.numFalses == expectedColStats.numFalses) + } + +} From 924a41d755092890b18e1cfcbc14cfe1e3c11e19 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 15 Sep 2016 21:48:51 -0700 Subject: [PATCH 04/22] refactor column stats collection --- .../spark/sql/execution/SparkSqlParser.scala | 6 +- .../command/AnalyzeColumnCommand.scala | 212 +++++++----------- .../command/AnalyzeTableCommand.scala | 11 +- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../spark/sql/internal/SessionState.scala | 20 +- .../sql/hive/StatisticsColumnSuite.scala | 10 +- 6 files changed, 115 insertions(+), 153 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 9aba2742187f..22e507e9e3d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -93,16 +93,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.partitionSpec == null && ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) } else if (ctx.identifierSeq() == null) { if (ctx.FOR() != null || ctx.COLUMNS() != null) { throw new ParseException("Need to specify the columns to analyze. Usage: " + "ANALYZE TABLE tbl COMPUTE STATISTICS FOR COLUMNS key, value", ctx) } - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString, noscan = false) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { AnalyzeColumnCommand( - visitTableIdentifier(ctx.tableIdentifier).toString, + visitTableIdentifier(ctx.tableIdentifier), visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 72f7d8947eda..7ef05d7fa534 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} -import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, BasicColStats, Statistics} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -35,17 +35,16 @@ import org.apache.spark.sql.types._ * which will be used in query optimizations. */ case class AnalyzeColumnCommand( - tableName: String, + tableIdent: TableIdentifier, columnNames: Seq[String]) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) // check correctness of column names - val validColumns = mutable.HashSet[NamedExpression]() - val resolver = sparkSession.sessionState.conf.resolver + val validColumns = mutable.MutableList[NamedExpression]() + val resolver = sessionState.conf.resolver columnNames.foreach { col => val exprOption = relation.resolve(col.split("\\."), resolver) if (exprOption.isEmpty) { @@ -71,143 +70,90 @@ case class AnalyzeColumnCommand( } def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - // collect column statistics - val aggColumns = mutable.ArrayBuffer[Column](count(Column("*"))) - validColumns.foreach(entry => aggColumns ++= statsAgg(entry.name, entry.dataType)) - val statsRow: InternalRow = Dataset.ofRows(sparkSession, relation).select(aggColumns: _*) + // Collect statistics per column. + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the BasicColStats. + val ndvMaxErr = sessionState.conf.ndvMaxError + val expressions = Count(Literal(1)).toAggregateExpression() +: + validColumns.map(ColumnStatsStruct(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head - // We also update table-level stats to prevent inconsistency in case of table modification - // between the two ANALYZE commands for collecting table-level stats and column-level stats. + // unwrap the result val rowCount = statsRow.getLong(0) - var newStats: Statistics = if (catalogTable.stats.isDefined) { - catalogTable.stats.get.copy(sizeInBytes = newTotalSize, rowCount = Some(rowCount)) - } else { - Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount)) - } - - var pos = 1 - val colStats = mutable.HashMap[String, BasicColStats]() - validColumns.foreach { attr => - attr.dataType match { - case n: NumericType => - colStats += attr.name -> BasicColStats( - dataType = attr.dataType, - numNulls = rowCount - statsRow.getLong(pos + NumericStatsAgg.numNotNullsIndex), - max = Option(statsRow.get(pos + NumericStatsAgg.maxIndex, attr.dataType)), - min = Option(statsRow.get(pos + NumericStatsAgg.minIndex, attr.dataType)), - ndv = Some(statsRow.getLong(pos + NumericStatsAgg.ndvIndex))) - pos += NumericStatsAgg.statsSeq.length - case TimestampType | DateType => - colStats += attr.name -> BasicColStats( - dataType = attr.dataType, - numNulls = rowCount - statsRow.getLong(pos + NumericStatsAgg.numNotNullsIndex), - max = Option(statsRow.get(pos + NumericStatsAgg.maxIndex, attr.dataType)), - min = Option(statsRow.get(pos + NumericStatsAgg.minIndex, attr.dataType)), - ndv = Some(statsRow.getLong(pos + NumericStatsAgg.ndvIndex))) - pos += NumericStatsAgg.statsSeq.length - case StringType => - colStats += attr.name -> BasicColStats( - dataType = attr.dataType, - numNulls = rowCount - statsRow.getLong(pos + StringStatsAgg.numNotNullsIndex), - maxColLen = Some(statsRow.getLong(pos + StringStatsAgg.maxLenIndex)), - avgColLen = - Some(statsRow.getLong(pos + StringStatsAgg.sumLenIndex) / (1.0 * rowCount)), - ndv = Some(statsRow.getLong(pos + StringStatsAgg.ndvIndex))) - pos += StringStatsAgg.statsSeq.length - case BinaryType => - colStats += attr.name -> BasicColStats( - dataType = attr.dataType, - numNulls = rowCount - statsRow.getLong(pos + BinaryStatsAgg.numNotNullsIndex), - maxColLen = Some(statsRow.getLong(pos + BinaryStatsAgg.maxLenIndex)), - avgColLen = - Some(statsRow.getLong(pos + BinaryStatsAgg.sumLenIndex) / (1.0 * rowCount))) - pos += BinaryStatsAgg.statsSeq.length - case BooleanType => - val numOfNotNulls = statsRow.getLong(pos + BooleanStatsAgg.numNotNullsIndex) - val numOfTrues = Some(statsRow.getLong(pos + BooleanStatsAgg.numTruesIndex)) - colStats += attr.name -> BasicColStats( - dataType = attr.dataType, - numNulls = rowCount - numOfNotNulls, - numTrues = numOfTrues, - numFalses = numOfTrues.map(i => numOfNotNulls - i), - ndv = Some(2)) - pos += BooleanStatsAgg.statsSeq.length - } - } - newStats = newStats.copy(basicColStats = colStats.toMap) - sessionState.catalog.alterTable(catalogTable.copy(stats = Some(newStats))) + val colStats = validColumns.zipWithIndex.map { case (expr, i) => + val colInfo = statsRow.getStruct(i + 1, ColumnStatsStruct.statsNumber) + val colStats = ColumnStatsStruct.unwrapRow(expr, colInfo) + (expr.name, colStats) + }.toMap + + val statistics = + Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount), basicColStats = colStats) + sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdent) } Seq.empty[Row] } - - private def statsAgg(name: String, dataType: DataType): Seq[Column] = dataType match { - // Currently we only support stats generation for atomic types - case n: NumericType => NumericStatsAgg(name) - case TimestampType | DateType => NumericStatsAgg(name) - case StringType => StringStatsAgg(name) - case BinaryType => BinaryStatsAgg(name) - case BooleanType => BooleanStatsAgg(name) - case otherType => - throw new AnalysisException(s"Analyzing column $name of $otherType is not supported.") - } } -object ColumnStats extends Enumeration { - val MAX, MIN, NDV, NUM_NOT_NULLS, MAX_LENGTH, SUM_LENGTH, NUM_TRUES = Value -} - -trait StatsAggFunc { - // This sequence is used to track the order of stats results when collecting. - val statsSeq: Seq[ColumnStats.Value] - - def apply(name: String): Seq[Column] = { - val col = Column(name) - statsSeq.map { - case ColumnStats.MAX => max(col) - case ColumnStats.MIN => min(col) - // count(distinct col) will have a shuffle, so we use an approximate ndv for efficiency - case ColumnStats.NDV => approxCountDistinct(col) - case ColumnStats.NUM_NOT_NULLS => count(col) - case ColumnStats.MAX_LENGTH => max(length(col)) - case ColumnStats.SUM_LENGTH => sum(length(col)) - case ColumnStats.NUM_TRUES => sum(col.cast(IntegerType)) +object ColumnStatsStruct { + val zero = Literal(0, LongType) + val one = Literal(1, LongType) + val two = Literal(2, LongType) + val nullLong = Literal(null, LongType) + val nullDouble = Literal(null, DoubleType) + val nullString = Literal(null, StringType) + val nullBinary = Literal(null, BinaryType) + val nullBoolean = Literal(null, BooleanType) + val statsNumber = 8 + + def apply(e: Expression, relativeSD: Double): CreateStruct = { + var statistics = e.dataType match { + case n: NumericType => + Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, + nullLong) + case TimestampType | DateType => + Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, + nullLong) + case StringType => + Seq(nullString, nullString, HyperLogLogPlusPlus(e, relativeSD), Average(Length(e)), + Max(Length(e)), nullLong, nullLong) + case BinaryType => + Seq(nullBinary, nullBinary, nullLong, Average(Length(e)), Max(Length(e)), nullLong, + nullLong) + case BooleanType => + Seq(nullBoolean, nullBoolean, two, nullDouble, nullLong, Sum(If(e, one, zero)), + Sum(If(e, zero, one))) + case otherType => + throw new AnalysisException("ANALYZE command is not supported for data type: " + + s"${e.dataType}") } + statistics = if (e.nullable) { + Sum(If(IsNull(e), one, zero)) +: statistics + } else { + zero +: statistics + } + assert(statistics.length == statsNumber) + CreateStruct(statistics.map { + case af: AggregateFunction => af.toAggregateExpression() + case e: Expression => e + }) } - // This is used to locate the needed stat in the sequence. - def offset: Map[ColumnStats.Value, Int] = statsSeq.zipWithIndex.toMap - - def numNotNullsIndex: Int = offset(ColumnStats.NUM_NOT_NULLS) -} - -object NumericStatsAgg extends StatsAggFunc { - override val statsSeq = Seq(ColumnStats.MAX, ColumnStats.MIN, ColumnStats.NDV, - ColumnStats.NUM_NOT_NULLS) - def maxIndex: Int = offset(ColumnStats.MAX) - def minIndex: Int = offset(ColumnStats.MIN) - def ndvIndex: Int = offset(ColumnStats.NDV) -} - -object StringStatsAgg extends StatsAggFunc { - override val statsSeq = Seq(ColumnStats.MAX_LENGTH, ColumnStats.SUM_LENGTH, ColumnStats.NDV, - ColumnStats.NUM_NOT_NULLS) - def maxLenIndex: Int = offset(ColumnStats.MAX_LENGTH) - def sumLenIndex: Int = offset(ColumnStats.SUM_LENGTH) - def ndvIndex: Int = offset(ColumnStats.NDV) -} - -object BinaryStatsAgg extends StatsAggFunc { - override val statsSeq = Seq(ColumnStats.MAX_LENGTH, ColumnStats.SUM_LENGTH, - ColumnStats.NUM_NOT_NULLS) - def maxLenIndex: Int = offset(ColumnStats.MAX_LENGTH) - def sumLenIndex: Int = offset(ColumnStats.SUM_LENGTH) -} - -object BooleanStatsAgg extends StatsAggFunc { - override val statsSeq = Seq(ColumnStats.NUM_TRUES, ColumnStats.NUM_NOT_NULLS) - def numTruesIndex: Int = offset(ColumnStats.NUM_TRUES) + def unwrapRow(e: Expression, row: InternalRow): BasicColStats = { + BasicColStats( + dataType = e.dataType, + numNulls = row.getLong(0), + max = if (row.isNullAt(1)) None else Some(row.get(1, e.dataType)), + min = if (row.isNullAt(2)) None else Some(row.get(2, e.dataType)), + ndv = if (row.isNullAt(3)) None else Some(row.getLong(3)), + avgColLen = if (row.isNullAt(4)) None else Some(row.getDouble(4)), + maxColLen = if (row.isNullAt(5)) None else Some(row.getLong(5)), + numTrues = if (row.isNullAt(6)) None else Some(row.getLong(6)), + numFalses = if (row.isNullAt(7)) None else Some(row.getLong(7) - row.getLong(0))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 8c82d10d66c3..1db6af2ac5a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -34,14 +34,15 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. */ -case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extends RunnableCommand { +case class AnalyzeTableCommand( + tableIdent: TableIdentifier, + noscan: Boolean = true) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case relation: CatalogRelation => @@ -83,7 +84,7 @@ case class AnalyzeTableCommand(tableName: String, noscan: Boolean = true) extend if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e67140fefef9..46a727911a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -581,6 +581,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val NDV_MAX_ERROR = + SQLConfigBuilder("spark.sql.ndv.maxError") + .internal() + .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") + .doubleConf + .createWithDefault(0.05) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -757,6 +764,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c899773b6b36..ded8e16d52f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,13 +23,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.AnalyzeTableCommand +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -186,13 +187,18 @@ private[sql] class SessionState(sparkSession: SparkSession) { } /** - * Analyzes the given table in the current database to generate statistics, which will be + * Analyzes the given table in the current database to generate table-level statistics, which + * will be used in query optimizations. + */ + def analyzeTable(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { + AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) + } + + /** + * Analyzes the given columns in the table to generate column-level statistics, which will be * used in query optimizations. - * - * Right now, it only supports catalog tables and it only updates the size of a catalog table - * in the external catalog. */ - def analyze(tableName: String, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableName, noscan).run(sparkSession) + def analyzeTableColumns(tableIdent: TableIdentifier, columnNames: Seq[String]): Unit = { + AnalyzeColumnCommand(tableIdent, columnNames).run(sparkSession) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala index ff9ccc2ece4f..c2764d1df9d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -106,7 +106,7 @@ class StatisticsColumnSuite extends StatisticsTest { val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map(Row(_)) val schema = StructType(StructField(name = "c1", dataType = StringType, nullable = true) :: Nil) val statsSeq = Seq(("c1", BasicColStats(dataType = StringType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(2.25), ndv = Some(3)))) + maxColLen = Some(4), avgColLen = Some(3), ndv = Some(3)))) checkColStats(rdd, schema, statsSeq) } @@ -116,7 +116,7 @@ class StatisticsColumnSuite extends StatisticsTest { } val schema = StructType(StructField(name = "c1", dataType = BinaryType, nullable = true) :: Nil) val statsSeq = Seq(("c1", BasicColStats(dataType = BinaryType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(2.25)))) + maxColLen = Some(4), avgColLen = Some(3)))) checkColStats(rdd, schema, statsSeq) } @@ -207,16 +207,16 @@ class StatisticsColumnSuite extends StatisticsTest { StructField(name = "c1", dataType = IntegerType, nullable = false)))) df.write.format("json").saveAsTable(tmpTable) - sql(s"CREATE TABLE $table (c1 int)") + sql(s"CREATE TABLE $table (c1 int) STORED AS PARQUET") sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, + val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(1)) // update table between analyze table and analyze column commands sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, + val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = true, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) From 22fc9fcc6826f1ccb6fa032339efa955f843ec19 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 15 Sep 2016 23:36:52 -0700 Subject: [PATCH 05/22] fix test case --- .../org/apache/spark/sql/hive/StatisticsColumnSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala index c2764d1df9d6..873c1b985de3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -207,16 +207,16 @@ class StatisticsColumnSuite extends StatisticsTest { StructField(name = "c1", dataType = IntegerType, nullable = false)))) df.write.format("json").saveAsTable(tmpTable) - sql(s"CREATE TABLE $table (c1 int) STORED AS PARQUET") + sql(s"CREATE TABLE $table (c1 int) STORED AS TEXTFILE") sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = true, + val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(1)) // update table between analyze table and analyze column commands sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = true, + val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) From 10b51a802e21879780dfc23f271435727355944d Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 19 Sep 2016 17:58:20 -0700 Subject: [PATCH 06/22] changes based on comments --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../catalyst/plans/logical/Statistics.scala | 20 +- .../spark/sql/execution/SparkSqlParser.scala | 4 - .../command/AnalyzeColumnCommand.scala | 79 ++-- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../apache/spark/sql/StatisticsSuite.scala | 14 +- .../spark/sql/hive/HiveExternalCatalog.scala | 12 +- .../sql/hive/StatisticsColumnSuite.scala | 369 +++++++++++------- .../spark/sql/hive/StatisticsTest.scala | 64 ++- 9 files changed, 330 insertions(+), 236 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index de2f9ee6bc7a..1284681fe80b 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -86,7 +86,7 @@ statement | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq?)? #analyze + (identifier | FOR COLUMNS identifierSeq)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index e2c1be1f8bc1..0080007db3b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.DataType * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. - * @param basicColStats Basic column-level statistics. + * @param colStats Column-level statistics. * @param isBroadcastable If true, output is small enough to be used in a broadcast join. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - basicColStats: Map[String, BasicColStats] = Map.empty, + colStats: Map[String, ColumnStats] = Map.empty, isBroadcastable: Boolean = false) { override def toString: String = "Statistics(" + simpleString + ")" @@ -49,13 +49,17 @@ case class Statistics( def simpleString: String = { Seq(s"sizeInBytes=$sizeInBytes", if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", - if (basicColStats.nonEmpty) s"basicColStats=$basicColStats" else "", + if (colStats.nonEmpty) s"colStats=$colStats" else "", s"isBroadcastable=$isBroadcastable" ).filter(_.nonEmpty).mkString(", ") } } -case class BasicColStats( +/** + * Statistics for a column. + * @param ndv Number of distinct values of the column. + */ +case class ColumnStats( dataType: DataType, numNulls: Long, max: Option[Any] = None, @@ -66,7 +70,7 @@ case class BasicColStats( numTrues: Option[Long] = None, numFalses: Option[Long] = None) { - override def toString: String = "BasicColStats(" + simpleString + ")" + override def toString: String = "ColumnStats(" + simpleString + ")" def simpleString: String = { Seq(s"numNulls=$numNulls", @@ -81,10 +85,10 @@ case class BasicColStats( } } -object BasicColStats { - def fromString(str: String, dataType: DataType): BasicColStats = { +object ColumnStats { + def fromString(str: String, dataType: DataType): ColumnStats = { val suffix = ",\\s|\\)" - BasicColStats( + ColumnStats( dataType = dataType, numNulls = findItem(source = str, prefix = "numNulls=", suffix = suffix).map(_.toLong).get, max = findItem(source = str, prefix = "max=", suffix = suffix), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 22e507e9e3d7..50133693aa48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -95,10 +95,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.identifier.getText.toLowerCase == "noscan") { AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) } else if (ctx.identifierSeq() == null) { - if (ctx.FOR() != null || ctx.COLUMNS() != null) { - throw new ParseException("Need to specify the columns to analyze. Usage: " + - "ANALYZE TABLE tbl COMPUTE STATISTICS FOR COLUMNS key, value", ctx) - } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { AnalyzeColumnCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7ef05d7fa534..a4d2234f4dbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, BasicColStats, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStats, Statistics} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types._ @@ -43,17 +43,17 @@ case class AnalyzeColumnCommand( val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) // check correctness of column names - val validColumns = mutable.MutableList[NamedExpression]() - val resolver = sessionState.conf.resolver + val attributesToAnalyze = mutable.MutableList[Attribute]() + val caseSensitive = sessionState.conf.caseSensitiveAnalysis columnNames.foreach { col => - val exprOption = relation.resolve(col.split("\\."), resolver) - if (exprOption.isEmpty) { - throw new AnalysisException(s"Invalid column name: $col") + val exprOption = relation.output.find { attr => + if (caseSensitive) attr.name == col else attr.name.equalsIgnoreCase(col) } - if (validColumns.map(_.exprId).contains(exprOption.get.exprId)) { - throw new AnalysisException(s"Duplicate column name: $col") + val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + // do deduplication + if (!attributesToAnalyze.contains(expr)) { + attributesToAnalyze += expr } - validColumns += exprOption.get } relation match { @@ -73,24 +73,22 @@ case class AnalyzeColumnCommand( // Collect statistics per column. // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. - // The layout of each struct follows the layout of the BasicColStats. + // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - validColumns.map(ColumnStatsStruct(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStatsStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head // unwrap the result val rowCount = statsRow.getLong(0) - val colStats = validColumns.zipWithIndex.map { case (expr, i) => - val colInfo = statsRow.getStruct(i + 1, ColumnStatsStruct.statsNumber) - val colStats = ColumnStatsStruct.unwrapRow(expr, colInfo) - (expr.name, colStats) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => + (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr)) }.toMap val statistics = - Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount), basicColStats = colStats) + Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount), colStats = columnStats) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdent) @@ -103,7 +101,6 @@ case class AnalyzeColumnCommand( object ColumnStatsStruct { val zero = Literal(0, LongType) val one = Literal(1, LongType) - val two = Literal(2, LongType) val nullLong = Literal(null, LongType) val nullDouble = Literal(null, DoubleType) val nullString = Literal(null, StringType) @@ -111,12 +108,9 @@ object ColumnStatsStruct { val nullBoolean = Literal(null, BooleanType) val statsNumber = 8 - def apply(e: Expression, relativeSD: Double): CreateStruct = { + def apply(e: NamedExpression, relativeSD: Double): CreateStruct = { var statistics = e.dataType match { - case n: NumericType => - Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, - nullLong) - case TimestampType | DateType => + case _: NumericType | TimestampType | DateType => Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, nullLong) case StringType => @@ -126,11 +120,11 @@ object ColumnStatsStruct { Seq(nullBinary, nullBinary, nullLong, Average(Length(e)), Max(Length(e)), nullLong, nullLong) case BooleanType => - Seq(nullBoolean, nullBoolean, two, nullDouble, nullLong, Sum(If(e, one, zero)), - Sum(If(e, zero, one))) + Seq(nullBoolean, nullBoolean, nullLong, nullDouble, nullLong, Sum(If(e, one, zero)), + Sum(If(Not(e), one, zero))) case otherType => - throw new AnalysisException("ANALYZE command is not supported for data type: " + - s"${e.dataType}") + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${e.name} of data type: ${e.dataType}.") } statistics = if (e.nullable) { Sum(If(IsNull(e), one, zero)) +: statistics @@ -144,16 +138,29 @@ object ColumnStatsStruct { }) } - def unwrapRow(e: Expression, row: InternalRow): BasicColStats = { - BasicColStats( + def unwrapStruct(row: InternalRow, offset: Int, e: Expression): ColumnStats = { + val struct = row.getStruct(offset, statsNumber) + ColumnStats( dataType = e.dataType, - numNulls = row.getLong(0), - max = if (row.isNullAt(1)) None else Some(row.get(1, e.dataType)), - min = if (row.isNullAt(2)) None else Some(row.get(2, e.dataType)), - ndv = if (row.isNullAt(3)) None else Some(row.getLong(3)), - avgColLen = if (row.isNullAt(4)) None else Some(row.getDouble(4)), - maxColLen = if (row.isNullAt(5)) None else Some(row.getLong(5)), - numTrues = if (row.isNullAt(6)) None else Some(row.getLong(6)), - numFalses = if (row.isNullAt(7)) None else Some(row.getLong(7) - row.getLong(0))) + numNulls = struct.getLong(0), + max = getField(struct, 1, e.dataType), + min = getField(struct, 2, e.dataType), + ndv = getLongField(struct, 3), + avgColLen = getDoubleField(struct, 4), + maxColLen = getLongField(struct, 5), + numTrues = getLongField(struct, 6), + numFalses = getLongField(struct, 7)) + } + + private def getField(struct: InternalRow, index: Int, dataType: DataType): Option[Any] = { + if (struct.isNullAt(index)) None else Some(struct.get(index, dataType)) + } + + private def getLongField(struct: InternalRow, index: Int): Option[Long] = { + if (struct.isNullAt(index)) None else Some(struct.getLong(index)) + } + + private def getDoubleField(struct: InternalRow, index: Int): Option[Double] = { + if (struct.isNullAt(index)) None else Some(struct.getDouble(index)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 46a727911a95..fecdf792fd14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -582,7 +582,7 @@ object SQLConf { .createWithDefault(10L) val NDV_MAX_ERROR = - SQLConfigBuilder("spark.sql.ndv.maxError") + SQLConfigBuilder("spark.sql.statistics.ndv.maxError") .internal() .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.") .doubleConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index ac7cdecc6dbb..26eabee1d31e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, GlobalLimit, Join, LocalLimit} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, GlobalLimit, Join, LocalLimit} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -103,7 +103,7 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { } test("test column-level statistics for data source table created in InMemoryCatalog") { - def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { + def checkColStats(colStats: ColumnStats, expectedColStats: ColumnStats): Unit = { assert(colStats.dataType == expectedColStats.dataType) assert(colStats.numNulls == expectedColStats.numNulls) assert(colStats.max == expectedColStats.max) @@ -126,17 +126,17 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS i, j") val df = sql(s"SELECT * FROM $tableName") val expectedRowCount = Some(2) - val expectedColStatsSeq: Seq[(String, BasicColStats)] = Seq( - ("i", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), + val expectedColStatsSeq: Seq[(String, ColumnStats)] = Seq( + ("i", ColumnStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), ndv = Some(2))), - ("j", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(1), + ("j", ColumnStats(dataType = StringType, numNulls = 0, maxColLen = Some(1), avgColLen = Some(1), ndv = Some(2)))) val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => val stats = rel.catalogTable.get.stats.get assert(stats.rowCount == expectedRowCount) expectedColStatsSeq.foreach { case (column, expectedColStats) => - assert(stats.basicColStats.contains(column)) - checkColStats(colStats = stats.basicColStats(column), expectedColStats = expectedColStats) + assert(stats.colStats.contains(column)) + checkColStats(colStats = stats.colStats(column), expectedColStats = expectedColStats) } rel } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 242a7b0de781..b6c2a291b7ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient @@ -403,7 +403,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - stats.basicColStats.foreach { case (colName, colStats) => + stats.colStats.foreach { case (colName, colStats) => statsProperties += (STATISTICS_BASIC_COL_STATS_PREFIX + colName) -> colStats.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) @@ -480,16 +480,16 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val colStatsProps = catalogTable.properties .filterKeys(_.startsWith(STATISTICS_BASIC_COL_STATS_PREFIX)) .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v)} - val colStats: Map[String, BasicColStats] = catalogTable.schema.collect { + val colStats: Map[String, ColumnStats] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => - (field.name, BasicColStats.fromString(colStatsProps(field.name), field.dataType)) + (field.name, ColumnStats.fromString(colStatsProps(field.name), field.dataType)) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), stats = Some(Statistics( sizeInBytes = BigInt(catalogTable.properties(STATISTICS_TOTAL_SIZE)), rowCount = catalogTable.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - basicColStats = colStats))) + colStats = colStats))) } else { catalogTable } @@ -701,7 +701,7 @@ object HiveExternalCatalog { val STATISTICS_PREFIX = "spark.sql.statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" - val STATISTICS_BASIC_COL_STATS_PREFIX = STATISTICS_PREFIX + "basicColStats." + val STATISTICS_BASIC_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala index 873c1b985de3..61920ff2076d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.hive import java.sql.{Date, Timestamp} -import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.plans.logical.BasicColStats +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStats +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.types._ class StatisticsColumnSuite extends StatisticsTest { + import testImplicits._ test("parse analyze column commands") { val table = "table" @@ -32,197 +34,288 @@ class StatisticsColumnSuite extends StatisticsTest { s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value", classOf[AnalyzeColumnCommand]) - val noColumnError = intercept[AnalysisException] { + intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS") } - assert(noColumnError.message == "Need to specify the columns to analyze. Usage: " + - "ANALYZE TABLE tbl COMPUTE STATISTICS FOR COLUMNS key, value") + } + test("check correctness of columns") { + val table = "tbl" + val quotedColumn = "x.yz" + val quotedName = s"`$quotedColumn`" withTable(table) { - sql(s"CREATE TABLE $table (key INT, value STRING)") - val invalidColError = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS k") - } - assert(invalidColError.message == s"Invalid column name: k") + sql(s"CREATE TABLE $table (abc int, $quotedName string)") - val duplicateColError = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value, key") + val invalidColError = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") } - assert(duplicateColError.message == s"Duplicate column name: key") + assert(invalidColError.message == s"Invalid column name: key.") withSQLConf("spark.sql.caseSensitive" -> "true") { val invalidErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS keY") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ABC") } - assert(invalidErr.message == s"Invalid column name: keY") + assert(invalidErr.message == s"Invalid column name: ABC.") } withSQLConf("spark.sql.caseSensitive" -> "false") { - val duplicateErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value, vaLue") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${quotedName.toUpperCase}, " + + s"ABC, $quotedName") + val df = sql(s"SELECT * FROM $table") + val stats = df.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + val colStats = rel.catalogTable.stats.get.colStats + // check deduplication + assert(colStats.size == 2) + assert(colStats.contains(quotedColumn)) + assert(colStats.contains("abc")) } - assert(duplicateErr.message == s"Duplicate column name: vaLue") + assert(stats.size == 1) } } } - test("basic statistics for integral type columns") { - val rdd = sparkContext.parallelize(Seq("1", null, "2", "3", null)).map { i => - if (i != null) Row(i.toByte, i.toShort, i.toInt, i.toLong) else Row(i, i, i, i) - } - val schema = StructType( - StructField(name = "c1", dataType = ByteType, nullable = true) :: - StructField(name = "c2", dataType = ShortType, nullable = true) :: - StructField(name = "c3", dataType = IntegerType, nullable = true) :: - StructField(name = "c4", dataType = LongType, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats( - dataType = ByteType, numNulls = 2, max = Some(3), min = Some(1), ndv = Some(3)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = ShortType)), - ("c3", expectedBasicStats.copy(dataType = IntegerType)), - ("c4", expectedBasicStats.copy(dataType = LongType))) - checkColStats(rdd, schema, statsSeq) + private def getNonNullValues[T](values: Seq[Option[T]]): Seq[T] = { + values.filter(_.isDefined).map(_.get) } - test("basic statistics for fractional type columns") { - val rdd = sparkContext.parallelize(Seq(null, "1.01", "2.02", "3.03")).map { i => - if (i != null) Row(i.toFloat, i.toDouble, Decimal(i)) else Row(i, i, i) - } - val schema = StructType( - StructField(name = "c1", dataType = FloatType, nullable = true) :: - StructField(name = "c2", dataType = DoubleType, nullable = true) :: - StructField(name = "c3", dataType = DecimalType.SYSTEM_DEFAULT, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats( - dataType = FloatType, numNulls = 1, max = Some(3.03), min = Some(1.01), ndv = Some(3)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = DoubleType)), - ("c3", expectedBasicStats.copy(dataType = DecimalType.SYSTEM_DEFAULT))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for integral type columns") { + val values = (0 to 5).map { i => + if (i % 2 == 0) None else Some(i) + } + val data = values.map { i => + (i.map(_.toByte), i.map(_.toShort), i.map(_.toInt), i.map(_.toLong)) + } + + val df = data.toDF("c1", "c2", "c3", "c4") + val nonNullValues = getNonNullValues[Int](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + max = Some(nonNullValues.max), + min = Some(nonNullValues.min), + ndv = Some(nonNullValues.distinct.length.toLong)) + (f.name, colStats) + } + checkColStats(df, statsSeq) + } + + test("column-level statistics for fractional type columns") { + val values = (0 to 5).map { i => + if (i == 0) None else Some(i + i * 0.01d) + } + val data = values.map { i => + (i.map(_.toFloat), i.map(_.toDouble), i.map(Decimal(_))) + } + + val df = data.toDF("c1", "c2", "c3") + val nonNullValues = getNonNullValues[Double](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + max = Some(nonNullValues.max), + min = Some(nonNullValues.min), + ndv = Some(nonNullValues.distinct.length.toLong)) + (f.name, colStats) + } + checkColStats(df, statsSeq) } - test("basic statistics for string column") { - val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map(Row(_)) - val schema = StructType(StructField(name = "c1", dataType = StringType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = StringType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(3), ndv = Some(3)))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for string column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc")) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[String](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + ndv = Some(nonNullValues.distinct.length.toLong), + maxColLen = Some(nonNullValues.map(_.length).max.toLong), + avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) + (f.name, colStats) + } + checkColStats(df, statsSeq) } - test("basic statistics for binary column") { - val rdd = sparkContext.parallelize(Seq(null, "a", "bbbb", "cccc")).map { i => - if (i != null) Row(i.getBytes) else Row(i) + test("column-level statistics for binary column") { + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc")).map(_.map(_.getBytes)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Array[Byte]](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + maxColLen = Some(nonNullValues.map(_.length).max.toLong), + avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) + (f.name, colStats) } - val schema = StructType(StructField(name = "c1", dataType = BinaryType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = BinaryType, numNulls = 1, - maxColLen = Some(4), avgColLen = Some(3)))) - checkColStats(rdd, schema, statsSeq) + checkColStats(df, statsSeq) } - test("basic statistics for boolean column") { - val rdd = sparkContext.parallelize(Seq(null, true, false, true)).map(Row(_)) - val schema = - StructType(StructField(name = "c1", dataType = BooleanType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = BooleanType, numNulls = 1, - numTrues = Some(2), numFalses = Some(1)))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for boolean column") { + val values = Seq(None, Some(true), Some(false), Some(true)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Boolean](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + numTrues = Some(nonNullValues.count(_.equals(true)).toLong), + numFalses = Some(nonNullValues.count(_.equals(false)).toLong)) + (f.name, colStats) + } + checkColStats(df, statsSeq) } - test("basic statistics for date column") { - val rdd = sparkContext.parallelize(Seq(null, "1970-01-01", "1970-02-02")).map { i => - if (i != null) Row(Date.valueOf(i)) else Row(i) + test("column-level statistics for date column") { + val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Date](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + // Internally, DateType is represented as the number of days from 1970-01-01. + max = Some(nonNullValues.map(DateTimeUtils.fromJavaDate).max), + min = Some(nonNullValues.map(DateTimeUtils.fromJavaDate).min), + ndv = Some(nonNullValues.distinct.length.toLong)) + (f.name, colStats) } - val schema = - StructType(StructField(name = "c1", dataType = DateType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = DateType, numNulls = 1, - max = Some(32), min = Some(0), ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) + checkColStats(df, statsSeq) } - test("basic statistics for timestamp column") { - val rdd = sparkContext.parallelize(Seq(null, "1970-01-01 00:00:00", "1970-01-01 00:00:05")) - .map(i => if (i != null) Row(Timestamp.valueOf(i)) else Row(i)) - val schema = - StructType(StructField(name = "c1", dataType = TimestampType, nullable = true) :: Nil) - val statsSeq = Seq(("c1", BasicColStats(dataType = TimestampType, numNulls = 1, - max = Some(5000000), min = Some(0), ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for timestamp column") { + val values = Seq(None, Some("1970-01-01 00:00:00"), Some("1970-01-01 00:00:05")).map { i => + i.map(Timestamp.valueOf) + } + val df = values.toDF("c1") + val nonNullValues = getNonNullValues[Timestamp](values) + val statsSeq = df.schema.map { f => + val colStats = ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + // Internally, TimestampType is represented as the number of days from 1970-01-01 + max = Some(nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max), + min = Some(nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min), + ndv = Some(nonNullValues.distinct.length.toLong)) + (f.name, colStats) + } + checkColStats(df, statsSeq) } - test("basic statistics for null columns") { - val rdd = sparkContext.parallelize(Seq(Row(null, null))) - val schema = StructType( - StructField(name = "c1", dataType = LongType, nullable = true) :: - StructField(name = "c2", dataType = TimestampType, nullable = true) :: Nil) - val expectedBasicStats = BasicColStats(dataType = LongType, numNulls = 1, - max = None, min = None, ndv = Some(0)) - val statsSeq = Seq( - ("c1", expectedBasicStats), - ("c2", expectedBasicStats.copy(dataType = TimestampType))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for null columns") { + val values = Seq(None, None) + val data = values.map { i => + (i.map(_.toString), i.map(_.toString.toInt)) + } + val df = data.toDF("c1", "c2") + val statsSeq = df.schema.map { f => + val colStats = f.dataType match { + case StringType => + ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + ndv = Some(0), + maxColLen = None, + avgColLen = None) + case IntegerType => + ColumnStats( + dataType = f.dataType, + numNulls = values.count(_.isEmpty), + max = None, + min = None, + ndv = Some(0)) + } + (f.name, colStats) + } + checkColStats(df, statsSeq) } - test("basic statistics for columns with different types") { - val rdd = sparkContext.parallelize(Seq( - Row(1, 1.01, "a", "a".getBytes, true, Date.valueOf("1970-01-01"), - Timestamp.valueOf("1970-01-01 00:00:00"), 5.toLong), - Row(2, 2.02, "bb", "bb".getBytes, false, Date.valueOf("1970-02-02"), - Timestamp.valueOf("1970-01-01 00:00:05"), 4.toLong))) - val schema = StructType(Seq( - StructField(name = "c1", dataType = IntegerType, nullable = false), - StructField(name = "c2", dataType = DoubleType, nullable = false), - StructField(name = "c3", dataType = StringType, nullable = false), - StructField(name = "c4", dataType = BinaryType, nullable = false), - StructField(name = "c5", dataType = BooleanType, nullable = false), - StructField(name = "c6", dataType = DateType, nullable = false), - StructField(name = "c7", dataType = TimestampType, nullable = false), - StructField(name = "c8", dataType = LongType, nullable = false))) - val statsSeq = Seq( - ("c1", BasicColStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), - ndv = Some(2))), - ("c2", BasicColStats(dataType = DoubleType, numNulls = 0, max = Some(2.02), min = Some(1.01), - ndv = Some(2))), - ("c3", BasicColStats(dataType = StringType, numNulls = 0, maxColLen = Some(2), - avgColLen = Some(1.5), ndv = Some(2))), - ("c4", BasicColStats(dataType = BinaryType, numNulls = 0, maxColLen = Some(2), - avgColLen = Some(1.5))), - ("c5", BasicColStats(dataType = BooleanType, numNulls = 0, numTrues = Some(1), - numFalses = Some(1), ndv = Some(2))), - ("c6", BasicColStats(dataType = DateType, numNulls = 0, max = Some(32), min = Some(0), - ndv = Some(2))), - ("c7", BasicColStats(dataType = TimestampType, numNulls = 0, max = Some(5000000), - min = Some(0), ndv = Some(2))), - ("c8", BasicColStats(dataType = LongType, numNulls = 0, max = Some(5), min = Some(4), - ndv = Some(2)))) - checkColStats(rdd, schema, statsSeq) + test("column-level statistics for columns with different types") { + val intSeq = Seq(1, 2) + val doubleSeq = Seq(1.01d, 2.02d) + val stringSeq = Seq("a", "bb") + val binarySeq = Seq("a", "bb").map(_.getBytes) + val booleanSeq = Seq(true, false) + val dateSeq = Seq("1970-01-01", "1970-02-02").map(Date.valueOf) + val timestampSeq = Seq("1970-01-01 00:00:00", "1970-01-01 00:00:05").map(Timestamp.valueOf) + val longSeq = Seq(5L, 4L) + + val data = intSeq.indices.map { i => + (intSeq(i), doubleSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i), dateSeq(i), + timestampSeq(i), longSeq(i)) + } + val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") + val statsSeq = df.schema.map { f => + val colStats = f.dataType match { + case IntegerType => + ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(intSeq.max), + min = Some(intSeq.min), ndv = Some(intSeq.distinct.length.toLong)) + case DoubleType => + ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(doubleSeq.max), + min = Some(doubleSeq.min), ndv = Some(doubleSeq.distinct.length.toLong)) + case StringType => + ColumnStats(dataType = f.dataType, numNulls = 0, + maxColLen = Some(stringSeq.map(_.length).max.toLong), + avgColLen = Some(stringSeq.map(_.length).sum / stringSeq.length.toDouble), + ndv = Some(stringSeq.distinct.length.toLong)) + case BinaryType => + ColumnStats(dataType = f.dataType, numNulls = 0, + maxColLen = Some(binarySeq.map(_.length).max.toLong), + avgColLen = Some(binarySeq.map(_.length).sum / binarySeq.length.toDouble)) + case BooleanType => + ColumnStats(dataType = f.dataType, numNulls = 0, + numTrues = Some(booleanSeq.count(_.equals(true)).toLong), + numFalses = Some(booleanSeq.count(_.equals(false)).toLong)) + case DateType => + ColumnStats(dataType = f.dataType, numNulls = 0, + max = Some(dateSeq.map(DateTimeUtils.fromJavaDate).max), + min = Some(dateSeq.map(DateTimeUtils.fromJavaDate).min), + ndv = Some(dateSeq.distinct.length.toLong)) + case TimestampType => + ColumnStats(dataType = f.dataType, numNulls = 0, + max = Some(timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max), + min = Some(timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min), + ndv = Some(timestampSeq.distinct.length.toLong)) + case LongType => + ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(longSeq.max), + min = Some(longSeq.min), ndv = Some(longSeq.distinct.length.toLong)) + } + (f.name, colStats) + } + checkColStats(df, statsSeq) } test("update table-level stats while collecting column-level stats") { val table = "tbl" val tmpTable = "tmp" withTable(table, tmpTable) { - val rdd = sparkContext.parallelize(Seq(Row(1))) - val df = spark.createDataFrame(rdd, StructType(Seq( - StructField(name = "c1", dataType = IntegerType, nullable = false)))) + val values = Seq(1) + val df = values.toDF("c1") df.write.format("json").saveAsTable(tmpTable) sql(s"CREATE TABLE $table (c1 int) STORED AS TEXTFILE") sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(1)) + hasSizeInBytes = true, expectedRowCounts = Some(values.length)) // update table between analyze table and analyze column commands sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(2)) + hasSizeInBytes = true, expectedRowCounts = Some(values.length * 2)) assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) - val basicColStats = fetchedStats2.get.basicColStats("c1") - checkColStats(colStats = basicColStats, expectedColStats = BasicColStats( - dataType = IntegerType, numNulls = 0, max = Some(1), min = Some(1), ndv = Some(1))) + val colStats = fetchedStats2.get.colStats("c1") + checkColStats(colStats = colStats, expectedColStats = ColumnStats( + dataType = IntegerType, + numNulls = 0, + max = Some(values.max), + min = Some(values.min), + ndv = Some(values.distinct.length.toLong))) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala index 2f7efa850de6..b6a95ef308f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.plans.logical.{BasicColStats, Statistics} +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -80,21 +79,19 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils } def checkColStats( - rowRDD: RDD[Row], - schema: StructType, - expectedColStatsSeq: Seq[(String, BasicColStats)]): Unit = { + df: DataFrame, + expectedColStatsSeq: Seq[(String, ColumnStats)]): Unit = { val table = "tbl" withTable(table) { - var df = spark.createDataFrame(rowRDD, schema) df.write.format("json").saveAsTable(table) val columns = expectedColStatsSeq.map(_._1).mkString(", ") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS $columns") - df = sql(s"SELECT * FROM $table") - val stats = df.queryExecution.analyzed.collect { + val readback = sql(s"SELECT * FROM $table") + val stats = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => expectedColStatsSeq.foreach { expected => - assert(rel.catalogTable.get.stats.get.basicColStats.contains(expected._1)) - checkColStats(colStats = rel.catalogTable.get.stats.get.basicColStats(expected._1), + assert(rel.catalogTable.get.stats.get.colStats.contains(expected._1)) + checkColStats(colStats = rel.catalogTable.get.stats.get.colStats(expected._1), expectedColStats = expected._2) } } @@ -102,39 +99,36 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils } } - def checkColStats(colStats: BasicColStats, expectedColStats: BasicColStats): Unit = { + def checkColStats(colStats: ColumnStats, expectedColStats: ColumnStats): Unit = { assert(colStats.dataType == expectedColStats.dataType) assert(colStats.numNulls == expectedColStats.numNulls) colStats.dataType match { - case ByteType | ShortType | IntegerType | LongType => + case _: IntegralType | DateType | TimestampType => assert(colStats.max.map(_.toString.toLong) == expectedColStats.max.map(_.toString.toLong)) assert(colStats.min.map(_.toString.toLong) == expectedColStats.min.map(_.toString.toLong)) - case FloatType | DoubleType => - assert(colStats.max.map(_.toString.toDouble) == expectedColStats.max - .map(_.toString.toDouble)) - assert(colStats.min.map(_.toString.toDouble) == expectedColStats.min - .map(_.toString.toDouble)) - case DecimalType.SYSTEM_DEFAULT => - assert(colStats.max.map(i => Decimal(i.toString)) == expectedColStats.max - .map(i => Decimal(i.toString))) - assert(colStats.min.map(i => Decimal(i.toString)) == expectedColStats.min - .map(i => Decimal(i.toString))) - case DateType | TimestampType => - if (expectedColStats.max.isDefined) { - // just check the difference to exclude the influence of timezones - assert(colStats.max.get.toString.toLong - colStats.min.get.toString.toLong == - expectedColStats.max.get.toString.toLong - expectedColStats.min.get.toString.toLong) - } else { - assert(colStats.max.isEmpty && colStats.min.isEmpty) - } - case _ => // only numeric types, date type and timestamp type have max and min stats + case _: FractionalType => + assert(colStats.max.map(_.toString.toDouble) == expectedColStats + .max.map(_.toString.toDouble)) + assert(colStats.min.map(_.toString.toDouble) == expectedColStats + .min.map(_.toString.toDouble)) + case _ => + // other types don't have max and min stats + assert(colStats.max.isEmpty) + assert(colStats.min.isEmpty) } colStats.dataType match { - case BinaryType => assert(colStats.ndv.isEmpty) - case BooleanType => assert(colStats.ndv.contains(2)) + case BinaryType | BooleanType => assert(colStats.ndv.isEmpty) case _ => - // ndv is an approximate value, so we just make sure we have the value + // ndv is an approximate value, so we make sure we have the value, and it should be + // within 3*SD's of the given rsd. assert(colStats.ndv.get >= 0) + if (expectedColStats.ndv.get == 0) { + assert(colStats.ndv.get == 0) + } else if (expectedColStats.ndv.get > 0) { + val rsd = spark.sessionState.conf.ndvMaxError + val error = math.abs((colStats.ndv.get / expectedColStats.ndv.get.toDouble) - 1.0d) + assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") + } } assert(colStats.avgColLen == expectedColStats.avgColLen) assert(colStats.maxColLen == expectedColStats.maxColLen) From b5cd2ff59d98b748f222b564f68aeb9d2c292c7c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 19 Sep 2016 21:25:25 -0700 Subject: [PATCH 07/22] support analyze column stats independently --- .../command/AnalyzeColumnCommand.scala | 6 +++-- .../sql/hive/StatisticsColumnSuite.scala | 22 +++++++++++++++++++ .../spark/sql/hive/StatisticsTest.scala | 2 +- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index a4d2234f4dbd..04ab1dc53825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -87,8 +87,10 @@ case class AnalyzeColumnCommand( (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr)) }.toMap - val statistics = - Statistics(sizeInBytes = newTotalSize, rowCount = Some(rowCount), colStats = columnStats) + val statistics = Statistics( + sizeInBytes = newTotalSize, + rowCount = Some(rowCount), + colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala index 61920ff2076d..6ef44752e34a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -318,4 +318,26 @@ class StatisticsColumnSuite extends StatisticsTest { ndv = Some(values.distinct.length.toLong))) } } + + test("analyze column stats independently") { + val table = "tbl" + withTable(table) { + sql(s"CREATE TABLE $table (c1 int, c2 long) STORED AS TEXTFILE") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") + val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = false, expectedRowCounts = Some(0)) + assert(fetchedStats1.get.colStats.size == 1) + val expected1 = ColumnStats(dataType = IntegerType, numNulls = 0, ndv = Some(0L)) + checkColStats(colStats = fetchedStats1.get.colStats("c1"), expectedColStats = expected1) + + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") + val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, + hasSizeInBytes = false, expectedRowCounts = Some(0)) + // column c1 is kept in the stats + assert(fetchedStats2.get.colStats.size == 2) + checkColStats(colStats = fetchedStats2.get.colStats("c1"), expectedColStats = expected1) + val expected2 = ColumnStats(dataType = LongType, numNulls = 0, ndv = Some(0L)) + checkColStats(colStats = fetchedStats2.get.colStats("c2"), expectedColStats = expected2) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala index b6a95ef308f0..666d6d207384 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala @@ -51,7 +51,7 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils expectedRowCounts: Option[Int]): Unit = { if (hasSizeInBytes || expectedRowCounts.nonEmpty) { assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) + assert(stats.get.sizeInBytes >= 0) assert(stats.get.rowCount === expectedRowCounts) } else { assert(stats.isEmpty) From f279370e62bacdc17a17025a1ce47fc1ebbe775e Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 20 Sep 2016 16:55:26 -0700 Subject: [PATCH 08/22] fix comments --- .../spark/sql/execution/SparkSqlParser.scala | 9 ++++ .../command/AnalyzeColumnCommand.scala | 20 ++++++-- .../command/AnalyzeTableCommand.scala | 10 ++-- .../sql/hive/StatisticsColumnSuite.scala | 47 ++++++++++--------- .../spark/sql/hive/StatisticsTest.scala | 3 +- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 50133693aa48..b8adc969a25e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -88,6 +88,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { /** * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. + * Example SQL for analyzing table : + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * ANALYZE TABLE table COMPUTE STATISTICS; + * }}} + * Example SQL for analyzing columns : + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2; + * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec == null && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 04ab1dc53825..b1183b9cb03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -40,7 +40,9 @@ case class AnalyzeColumnCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) // check correctness of column names val attributesToAnalyze = mutable.MutableList[Attribute]() @@ -59,7 +61,7 @@ case class AnalyzeColumnCommand( relation match { case catalogRel: CatalogRelation => updateStats(catalogRel.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sparkSession, catalogRel.catalogTable)) + AnalyzeTableCommand.calculateTotalSize(sessionState, catalogRel.catalogTable)) case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => updateStats(logicalRel.catalogTable.get, logicalRel.relation.sizeInBytes) @@ -93,7 +95,7 @@ case class AnalyzeColumnCommand( colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) } Seq.empty[Row] @@ -111,6 +113,18 @@ object ColumnStatsStruct { val statsNumber = 8 def apply(e: NamedExpression, relativeSD: Double): CreateStruct = { + // Use aggregate functions to compute statistics we need: + // - number of nulls: Sum(If(IsNull(e), one, zero)); + // - maximum value: Max(e); + // - minimum value: Min(e); + // - ndv (number of distinct values): HyperLogLogPlusPlus(e, relativeSD); + // - average length of values: Average(Length(e)); + // - maximum length of values: Max(Length(e)); + // - number of true values: Sum(If(e, one, zero)); + // - number of false values: Sum(If(Not(e), one, zero)); + // - If we don't need some statistic for the data type, use null literal. + // Note that: the order of each sequence must be as follows: + // numNulls, max, min, ndv, avgColLen, maxColLen, numTrues, numFalses var statistics = e.dataType match { case _: NumericType | TimestampType | DateType => Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 1db6af2ac5a5..96c0d9d14770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SessionState /** @@ -47,7 +48,7 @@ case class AnalyzeTableCommand( relation match { case relation: CatalogRelation => updateTableStats(relation.catalogTable, - AnalyzeTableCommand.calculateTotalSize(sparkSession, relation.catalogTable)) + AnalyzeTableCommand.calculateTotalSize(sessionState, relation.catalogTable)) // data source tables have been converted into LogicalRelations case logicalRel: LogicalRelation if logicalRel.catalogTable.isDefined => @@ -94,7 +95,7 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { - def calculateTotalSize(sparkSession: SparkSession, catalogTable: CatalogTable): Long = { + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -103,8 +104,7 @@ object AnalyzeTableCommand extends Logging { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. - val stagingDir = - sparkSession.sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) @@ -127,7 +127,7 @@ object AnalyzeTableCommand extends Logging { catalogTable.storage.locationUri.map { p => val path = new Path(p) try { - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = path.getFileSystem(sessionState.newHadoopConf()) calculateTableSize(fs, path) } catch { case NonFatal(e) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala index 6ef44752e34a..944de02f090f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.sql.{Date, Timestamp} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ColumnStats import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.AnalyzeColumnCommand @@ -34,7 +35,7 @@ class StatisticsColumnSuite extends StatisticsTest { s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value", classOf[AnalyzeColumnCommand]) - intercept[AnalysisException] { + intercept[ParseException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS") } } @@ -49,13 +50,13 @@ class StatisticsColumnSuite extends StatisticsTest { val invalidColError = intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") } - assert(invalidColError.message == s"Invalid column name: key.") + assert(invalidColError.message == "Invalid column name: key.") withSQLConf("spark.sql.caseSensitive" -> "true") { val invalidErr = intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ABC") } - assert(invalidErr.message == s"Invalid column name: ABC.") + assert(invalidErr.message == "Invalid column name: ABC.") } withSQLConf("spark.sql.caseSensitive" -> "false") { @@ -89,7 +90,7 @@ class StatisticsColumnSuite extends StatisticsTest { val df = data.toDF("c1", "c2", "c3", "c4") val nonNullValues = getNonNullValues[Int](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -98,7 +99,7 @@ class StatisticsColumnSuite extends StatisticsTest { ndv = Some(nonNullValues.distinct.length.toLong)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for fractional type columns") { @@ -111,7 +112,7 @@ class StatisticsColumnSuite extends StatisticsTest { val df = data.toDF("c1", "c2", "c3") val nonNullValues = getNonNullValues[Double](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -120,14 +121,14 @@ class StatisticsColumnSuite extends StatisticsTest { ndv = Some(nonNullValues.distinct.length.toLong)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for string column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc")) + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")) val df = values.toDF("c1") val nonNullValues = getNonNullValues[String](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -136,14 +137,14 @@ class StatisticsColumnSuite extends StatisticsTest { avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for binary column") { - val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc")).map(_.map(_.getBytes)) + val values = Seq(None, Some("a"), Some("bbbb"), Some("cccc"), Some("")).map(_.map(_.getBytes)) val df = values.toDF("c1") val nonNullValues = getNonNullValues[Array[Byte]](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -151,14 +152,14 @@ class StatisticsColumnSuite extends StatisticsTest { avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for boolean column") { val values = Seq(None, Some(true), Some(false), Some(true)) val df = values.toDF("c1") val nonNullValues = getNonNullValues[Boolean](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -166,14 +167,14 @@ class StatisticsColumnSuite extends StatisticsTest { numFalses = Some(nonNullValues.count(_.equals(false)).toLong)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for date column") { val values = Seq(None, Some("1970-01-01"), Some("1970-02-02")).map(_.map(Date.valueOf)) val df = values.toDF("c1") val nonNullValues = getNonNullValues[Date](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -183,7 +184,7 @@ class StatisticsColumnSuite extends StatisticsTest { ndv = Some(nonNullValues.distinct.length.toLong)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for timestamp column") { @@ -192,7 +193,7 @@ class StatisticsColumnSuite extends StatisticsTest { } val df = values.toDF("c1") val nonNullValues = getNonNullValues[Timestamp](values) - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = ColumnStats( dataType = f.dataType, numNulls = values.count(_.isEmpty), @@ -202,7 +203,7 @@ class StatisticsColumnSuite extends StatisticsTest { ndv = Some(nonNullValues.distinct.length.toLong)) (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for null columns") { @@ -211,7 +212,7 @@ class StatisticsColumnSuite extends StatisticsTest { (i.map(_.toString), i.map(_.toString.toInt)) } val df = data.toDF("c1", "c2") - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = f.dataType match { case StringType => ColumnStats( @@ -230,7 +231,7 @@ class StatisticsColumnSuite extends StatisticsTest { } (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("column-level statistics for columns with different types") { @@ -248,7 +249,7 @@ class StatisticsColumnSuite extends StatisticsTest { timestampSeq(i), longSeq(i)) } val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") - val statsSeq = df.schema.map { f => + val expectedColStatsSeq = df.schema.map { f => val colStats = f.dataType match { case IntegerType => ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(intSeq.max), @@ -285,7 +286,7 @@ class StatisticsColumnSuite extends StatisticsTest { } (f.name, colStats) } - checkColStats(df, statsSeq) + checkColStats(df, expectedColStatsSeq) } test("update table-level stats while collecting column-level stats") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala index 666d6d207384..e667d8db3ee9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala @@ -32,7 +32,6 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils val operators = parsed.collect { case a: AnalyzeTableCommand => a case b: AnalyzeColumnCommand => b - case o => o } assert(operators.size === 1) @@ -41,7 +40,7 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils s"""$analyzeCommand expected command: $c, but got ${operators(0)} |parsed command: |$parsed - """.stripMargin) + """.stripMargin) } } From 12d94cd483b3058b61f93b2337663d2024aa0c6b Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 21 Sep 2016 16:39:35 -0700 Subject: [PATCH 09/22] test column stats without depending on the catalog module --- .../command/AnalyzeColumnCommand.scala | 71 +++++++++------- .../spark/sql/internal/SessionState.scala | 11 ++- .../spark/sql}/StatisticsColumnSuite.scala | 72 +++++++++------- .../apache/spark/sql/StatisticsSuite.scala | 61 +------------- .../apache/spark/sql}/StatisticsTest.scala | 84 ++++--------------- .../spark/sql/hive/StatisticsSuite.scala | 58 ++++++++++++- 6 files changed, 165 insertions(+), 192 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/StatisticsColumnSuite.scala (86%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/StatisticsTest.scala (54%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b1183b9cb03a..8169a4efa86c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStats, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStats, LogicalPlan, Statistics} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types._ @@ -44,20 +44,6 @@ case class AnalyzeColumnCommand( val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) - // check correctness of column names - val attributesToAnalyze = mutable.MutableList[Attribute]() - val caseSensitive = sessionState.conf.caseSensitiveAnalysis - columnNames.foreach { col => - val exprOption = relation.output.find { attr => - if (caseSensitive) attr.name == col else attr.name.equalsIgnoreCase(col) - } - val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) - // do deduplication - if (!attributesToAnalyze.contains(expr)) { - attributesToAnalyze += expr - } - } - relation match { case catalogRel: CatalogRelation => updateStats(catalogRel.catalogTable, @@ -72,23 +58,7 @@ case class AnalyzeColumnCommand( } def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = { - // Collect statistics per column. - // The first element in the result will be the overall row count, the following elements - // will be structs containing all column stats. - // The layout of each struct follows the layout of the ColumnStats. - val ndvMaxErr = sessionState.conf.ndvMaxError - val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatsStruct(_, ndvMaxErr)) - val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) - .queryExecution.toRdd.collect().head - - // unwrap the result - val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr)) - }.toMap - + val (rowCount, columnStats) = computeColStats(sparkSession, relation) val statistics = Statistics( sizeInBytes = newTotalSize, rowCount = Some(rowCount), @@ -100,6 +70,43 @@ case class AnalyzeColumnCommand( Seq.empty[Row] } + + def computeColStats( + sparkSession: SparkSession, + relation: LogicalPlan): (Long, Map[String, ColumnStats]) = { + + // check correctness of column names + val attributesToAnalyze = mutable.MutableList[Attribute]() + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + columnNames.foreach { col => + val exprOption = relation.output.find { attr => + if (caseSensitive) attr.name == col else attr.name.equalsIgnoreCase(col) + } + val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) + // do deduplication + if (!attributesToAnalyze.contains(expr)) { + attributesToAnalyze += expr + } + } + + // Collect statistics per column. + // The first element in the result will be the overall row count, the following elements + // will be structs containing all column stats. + // The layout of each struct follows the layout of the ColumnStats. + val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError + val expressions = Count(Literal(1)).toAggregateExpression() +: + attributesToAnalyze.map(ColumnStatsStruct(_, ndvMaxErr)) + val namedExpressions = expressions.map(e => Alias(e, e.toString)()) + val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .queryExecution.toRdd.collect().head + + // unwrap the result + val rowCount = statsRow.getLong(0) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => + (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr)) + }.toMap + (rowCount, columnStats) + } } object ColumnStatsStruct { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ded8e16d52f0..f4b3a0be2211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, LogicalPlan} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} import org.apache.spark.sql.execution.datasources._ @@ -201,4 +201,13 @@ private[sql] class SessionState(sparkSession: SparkSession) { def analyzeTableColumns(tableIdent: TableIdentifier, columnNames: Seq[String]): Unit = { AnalyzeColumnCommand(tableIdent, columnNames).run(sparkSession) } + + // This api is used for testing. + def computeColumnStats(tableName: String, columnNames: Seq[String]): Map[String, ColumnStats] = { + val tableIdent = sqlParser.parseTableIdentifier(tableName) + val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = sparkSession.sessionState.catalog.lookupRelation(tableIdentWithDB) + AnalyzeColumnCommand(tableIdent, columnNames).computeColStats(sparkSession, relation)._2 + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala similarity index 86% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 944de02f090f..1ceda0824844 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ColumnStats import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -30,8 +29,24 @@ class StatisticsColumnSuite extends StatisticsTest { import testImplicits._ test("parse analyze column commands") { + def assertAnalyzeColumnCommand(analyzeCommand: String, c: Class[_]) { + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeColumnCommand => a + case o => o + } + assert(operators.size == 1) + if (operators.head.getClass != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators.head} + |parsed command: + |$parsed + """.stripMargin) + } + } + val table = "table" - assertAnalyzeCommand( + assertAnalyzeColumnCommand( s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value", classOf[AnalyzeColumnCommand]) @@ -42,10 +57,11 @@ class StatisticsColumnSuite extends StatisticsTest { test("check correctness of columns") { val table = "tbl" - val quotedColumn = "x.yz" - val quotedName = s"`$quotedColumn`" + val colName1 = "abc" + val colName2 = "x.yz" + val quotedColName2 = s"`$colName2`" withTable(table) { - sql(s"CREATE TABLE $table (abc int, $quotedName string)") + sql(s"CREATE TABLE $table ($colName1 int, $quotedColName2 string) USING PARQUET") val invalidColError = intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") @@ -54,24 +70,19 @@ class StatisticsColumnSuite extends StatisticsTest { withSQLConf("spark.sql.caseSensitive" -> "true") { val invalidErr = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ABC") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${colName1.toUpperCase}") } - assert(invalidErr.message == "Invalid column name: ABC.") + assert(invalidErr.message == s"Invalid column name: ${colName1.toUpperCase}.") } withSQLConf("spark.sql.caseSensitive" -> "false") { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS ${quotedName.toUpperCase}, " + - s"ABC, $quotedName") - val df = sql(s"SELECT * FROM $table") - val stats = df.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - val colStats = rel.catalogTable.stats.get.colStats - // check deduplication - assert(colStats.size == 2) - assert(colStats.contains(quotedColumn)) - assert(colStats.contains("abc")) - } - assert(stats.size == 1) + val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) + val columnStats = spark.sessionState.computeColumnStats(table, columnsToAnalyze) + assert(columnStats.contains(colName1)) + assert(columnStats.contains(colName2)) + // check deduplication + assert(columnStats.size == 2) + assert(!columnStats.contains(colName2.toUpperCase)) } } } @@ -297,18 +308,17 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") df.write.format("json").saveAsTable(tmpTable) - sql(s"CREATE TABLE $table (c1 int) STORED AS TEXTFILE") + sql(s"CREATE TABLE $table (c1 int) USING PARQUET") sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(values.length)) + val fetchedStats1 = + checkTableStats(tableName = table, expectedRowCount = Some(values.length)) - // update table between analyze table and analyze column commands + // update table-level stats between analyze table and analyze column commands sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = true, expectedRowCounts = Some(values.length * 2)) - assert(fetchedStats2.get.sizeInBytes > fetchedStats1.get.sizeInBytes) + val fetchedStats2 = + checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) val colStats = fetchedStats2.get.colStats("c1") checkColStats(colStats = colStats, expectedColStats = ColumnStats( @@ -323,17 +333,15 @@ class StatisticsColumnSuite extends StatisticsTest { test("analyze column stats independently") { val table = "tbl" withTable(table) { - sql(s"CREATE TABLE $table (c1 int, c2 long) STORED AS TEXTFILE") + sql(s"CREATE TABLE $table (c1 int, c2 long) USING PARQUET") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats1 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = false, expectedRowCounts = Some(0)) + val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) assert(fetchedStats1.get.colStats.size == 1) val expected1 = ColumnStats(dataType = IntegerType, numNulls = 0, ndv = Some(0L)) checkColStats(colStats = fetchedStats1.get.colStats("c1"), expectedColStats = expected1) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") - val fetchedStats2 = checkTableStats(tableName = table, isDataSourceTable = false, - hasSizeInBytes = false, expectedRowCounts = Some(0)) + val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) // column c1 is kept in the stats assert(fetchedStats2.get.colStats.size == 2) checkColStats(colStats = fetchedStats2.get.colStats("c1"), expectedColStats = expected1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 26eabee1d31e..8cf42e9248c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, GlobalLimit, Join, LocalLimit} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} import org.apache.spark.sql.types._ -class StatisticsSuite extends QueryTest with SharedSQLContext { +class StatisticsSuite extends StatisticsTest { import testImplicits._ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { @@ -77,20 +75,10 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { } test("test table-level statistics for data source table created in InMemoryCatalog") { - def checkTableStats(tableName: String, expectedRowCount: Option[BigInt]): Unit = { - val df = sql(s"SELECT * FROM $tableName") - val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.isDefined) - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel - } - assert(relations.size === 1) - } - val tableName = "tbl" withTable(tableName) { sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto(tableName) // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") @@ -101,47 +89,4 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { checkTableStats(tableName, expectedRowCount = Some(2)) } } - - test("test column-level statistics for data source table created in InMemoryCatalog") { - def checkColStats(colStats: ColumnStats, expectedColStats: ColumnStats): Unit = { - assert(colStats.dataType == expectedColStats.dataType) - assert(colStats.numNulls == expectedColStats.numNulls) - assert(colStats.max == expectedColStats.max) - assert(colStats.min == expectedColStats.min) - if (expectedColStats.ndv.isDefined) { - // ndv is an approximate value, so we just make sure we have the value - assert(colStats.ndv.get >= 0) - } - assert(colStats.avgColLen == expectedColStats.avgColLen) - assert(colStats.maxColLen == expectedColStats.maxColLen) - assert(colStats.numTrues == expectedColStats.numTrues) - assert(colStats.numFalses == expectedColStats.numFalses) - } - - val tableName = "tbl" - withTable(tableName) { - sql(s"CREATE TABLE $tableName(i INT, j STRING) USING parquet") - Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") - - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS i, j") - val df = sql(s"SELECT * FROM $tableName") - val expectedRowCount = Some(2) - val expectedColStatsSeq: Seq[(String, ColumnStats)] = Seq( - ("i", ColumnStats(dataType = IntegerType, numNulls = 0, max = Some(2), min = Some(1), - ndv = Some(2))), - ("j", ColumnStats(dataType = StringType, numNulls = 0, maxColLen = Some(1), - avgColLen = Some(1), ndv = Some(2)))) - val relations = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - val stats = rel.catalogTable.get.stats.get - assert(stats.rowCount == expectedRowCount) - expectedColStatsSeq.foreach { case (column, expectedColStats) => - assert(stats.colStats.contains(column)) - checkColStats(colStats = stats.colStats(column), expectedColStats = expectedColStats) - } - rel - } - assert(relations.size == 1) - } - } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala similarity index 54% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index e667d8db3ee9..a28f094a73a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -15,67 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils { - - def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeTableCommand => a - case b: AnalyzeColumnCommand => b - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators(0)} - |parsed command: - |$parsed - """.stripMargin) - } - } - - def checkTableStats( - stats: Option[Statistics], - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Unit = { - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes >= 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - } - - def checkTableStats( - tableName: String, - isDataSourceTable: Boolean, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[Statistics] = { - val df = sql(s"SELECT * FROM $tableName") - val stats = df.queryExecution.analyzed.collect { - case rel: MetastoreRelation => - checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) - assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") - rel.catalogTable.stats - case rel: LogicalRelation => - checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) - assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } +trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( df: DataFrame, @@ -83,18 +30,12 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils val table = "tbl" withTable(table) { df.write.format("json").saveAsTable(table) - val columns = expectedColStatsSeq.map(_._1).mkString(", ") - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS $columns") - val readback = sql(s"SELECT * FROM $table") - val stats = readback.queryExecution.analyzed.collect { - case rel: LogicalRelation => - expectedColStatsSeq.foreach { expected => - assert(rel.catalogTable.get.stats.get.colStats.contains(expected._1)) - checkColStats(colStats = rel.catalogTable.get.stats.get.colStats(expected._1), - expectedColStats = expected._2) - } + val columns = expectedColStatsSeq.map(_._1) + val columnStats = spark.sessionState.computeColumnStats(table, columns) + expectedColStatsSeq.foreach { expected => + assert(columnStats.contains(expected._1)) + checkColStats(colStats = columnStats(expected._1), expectedColStats = expected._2) } - assert(stats.size == 1) } } @@ -135,4 +76,13 @@ trait StatisticsTest extends QueryTest with TestHiveSingleton with SQLTestUtils assert(colStats.numFalses == expectedColStats.numFalses) } + def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { + val df = sql(s"SELECT * FROM $tableName") + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f83d797b3ccb..e275aa5add99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,16 +21,37 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -class StatisticsSuite extends StatisticsTest { +class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("parse analyze commands") { + def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeTableCommand => a + case o => o + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators(0)} + |parsed command: + |$parsed + """.stripMargin) + } + } + assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[AnalyzeTableCommand]) @@ -150,6 +171,39 @@ class StatisticsSuite extends StatisticsTest { TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + private def checkTableStats( + stats: Option[Statistics], + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Unit = { + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes > 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + } + + private def checkTableStats( + tableName: String, + isDataSourceTable: Boolean, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[Statistics] = { + val df = sql(s"SELECT * FROM $tableName") + val stats = df.queryExecution.analyzed.collect { + case rel: MetastoreRelation => + checkTableStats(rel.catalogTable.stats, hasSizeInBytes, expectedRowCounts) + assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table") + rel.catalogTable.stats + case rel: LogicalRelation => + checkTableStats(rel.catalogTable.get.stats, hasSizeInBytes, expectedRowCounts) + assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table") + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } + test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { From d1def0368725b33c537e7b7edcf93ef3c27ee0e2 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 21 Sep 2016 16:51:27 -0700 Subject: [PATCH 10/22] comments --- .../spark/sql/catalyst/plans/logical/Statistics.scala | 2 +- .../spark/sql/execution/command/AnalyzeColumnCommand.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveExternalCatalog.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 0080007db3b7..9bd042d21a5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -86,7 +86,7 @@ case class ColumnStats( } object ColumnStats { - def fromString(str: String, dataType: DataType): ColumnStats = { + def apply(str: String, dataType: DataType): ColumnStats = { val suffix = ",\\s|\\)" ColumnStats( dataType = dataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 8169a4efa86c..20a399310b38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -103,7 +103,7 @@ case class AnalyzeColumnCommand( // unwrap the result val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr)) + (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr, rowCount)) }.toMap (rowCount, columnStats) } @@ -161,14 +161,14 @@ object ColumnStatsStruct { }) } - def unwrapStruct(row: InternalRow, offset: Int, e: Expression): ColumnStats = { + def unwrapStruct(row: InternalRow, offset: Int, e: Expression, rowCount: Long): ColumnStats = { val struct = row.getStruct(offset, statsNumber) ColumnStats( dataType = e.dataType, numNulls = struct.getLong(0), max = getField(struct, 1, e.dataType), min = getField(struct, 2, e.dataType), - ndv = getLongField(struct, 3), + ndv = getLongField(struct, 3).map(math.min(_, rowCount)), avgColLen = getDoubleField(struct, 4), maxColLen = getLongField(struct, 5), numTrues = getLongField(struct, 6), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b6c2a291b7ed..d86f7b880ebd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -482,7 +482,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v)} val colStats: Map[String, ColumnStats] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => - (field.name, ColumnStats.fromString(colStatsProps(field.name), field.dataType)) + (field.name, ColumnStats(colStatsProps(field.name), field.dataType)) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), From 410125a0458be2ca9544cf9865b3b7555f3883cf Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 21 Sep 2016 20:51:22 -0700 Subject: [PATCH 11/22] delete unnecessary api --- .../org/apache/spark/sql/internal/SessionState.scala | 11 +---------- .../org/apache/spark/sql/StatisticsColumnSuite.scala | 6 +++++- .../scala/org/apache/spark/sql/StatisticsTest.scala | 7 ++++++- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f4b3a0be2211..ded8e16d52f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} import org.apache.spark.sql.execution.datasources._ @@ -201,13 +201,4 @@ private[sql] class SessionState(sparkSession: SparkSession) { def analyzeTableColumns(tableIdent: TableIdentifier, columnNames: Seq[String]): Unit = { AnalyzeColumnCommand(tableIdent, columnNames).run(sparkSession) } - - // This api is used for testing. - def computeColumnStats(tableName: String, columnNames: Seq[String]): Map[String, ColumnStats] = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) - val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = sparkSession.sessionState.catalog.lookupRelation(tableIdentWithDB) - AnalyzeColumnCommand(tableIdent, columnNames).computeColStats(sparkSession, relation)._2 - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 1ceda0824844..c214a5fe7f62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ColumnStats import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -77,7 +78,10 @@ class StatisticsColumnSuite extends StatisticsTest { withSQLConf("spark.sql.caseSensitive" -> "false") { val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) - val columnStats = spark.sessionState.computeColumnStats(table, columnsToAnalyze) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val columnStats = + AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation)._2 assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index a28f094a73a3..44da6026c1f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} +import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -31,7 +33,10 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { withTable(table) { df.write.format("json").saveAsTable(table) val columns = expectedColStatsSeq.map(_._1) - val columnStats = spark.sessionState.computeColumnStats(table, columns) + val tableIdent = TableIdentifier(table, Some("default")) + val relation = spark.sessionState.catalog.lookupRelation(tableIdent) + val columnStats = + AnalyzeColumnCommand(tableIdent, columns).computeColStats(spark, relation)._2 expectedColStatsSeq.foreach { expected => assert(columnStats.contains(expected._1)) checkColStats(colStats = columnStats(expected._1), expectedColStats = expected._2) From b43115cba3c818b08d8e0b576bf4d088dcfebd5e Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 22 Sep 2016 15:44:22 -0700 Subject: [PATCH 12/22] fix and add test cases --- .../command/AnalyzeColumnCommand.scala | 9 ++++-- .../command/AnalyzeTableCommand.scala | 8 +++-- .../spark/sql/StatisticsColumnSuite.scala | 30 ++++++++++++++++--- .../org/apache/spark/sql/StatisticsTest.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 3 +- 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 20a399310b38..143c97c97552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -40,9 +40,7 @@ case class AnalyzeColumnCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) relation match { case catalogRel: CatalogRelation => @@ -64,6 +62,10 @@ case class AnalyzeColumnCommand( rowCount = Some(rowCount), colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) + // We need to add database info to the table identifier so that we will not refresh the temp + // table with the same table name. + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } @@ -117,6 +119,7 @@ object ColumnStatsStruct { val nullString = Literal(null, StringType) val nullBinary = Literal(null, BinaryType) val nullBoolean = Literal(null, BooleanType) + // The number of different kinds of column-level statistics. val statsNumber = 8 def apply(e: NamedExpression, relativeSD: Double): CreateStruct = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 96c0d9d14770..6b47d81d023c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -41,9 +41,7 @@ case class AnalyzeTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) relation match { case relation: CatalogRelation => @@ -84,6 +82,10 @@ case class AnalyzeTableCommand( // recorded in the metastore. if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) + // We need to add database info to the table identifier so that we will not refresh the + // temp table with the same table name. + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index c214a5fe7f62..de4c16d358cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ColumnStats import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ class StatisticsColumnSuite extends StatisticsTest { @@ -56,6 +57,28 @@ class StatisticsColumnSuite extends StatisticsTest { } } + test("analyzing columns in temporary tables is not supported") { + val viewName = "tbl" + withTempView(viewName) { + spark.range(10).createOrReplaceTempView(viewName) + val err = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + } + + test("analyzing columns of non-atomic types is not supported") { + val tableName = "tbl" + withTable(tableName) { + Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) + val err = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + } + assert(err.message.contains("Analyzing columns is not supported")) + } + } + test("check correctness of columns") { val table = "tbl" val colName1 = "abc" @@ -315,16 +338,15 @@ class StatisticsColumnSuite extends StatisticsTest { sql(s"CREATE TABLE $table (c1 int) USING PARQUET") sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - val fetchedStats1 = - checkTableStats(tableName = table, expectedRowCount = Some(values.length)) + checkTableStats(tableName = table, expectedRowCount = Some(values.length)) // update table-level stats between analyze table and analyze column commands sql(s"INSERT INTO $table SELECT * FROM $tmpTable") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats2 = + val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) - val colStats = fetchedStats2.get.colStats("c1") + val colStats = fetchedStats.get.colStats("c1") checkColStats(colStats = colStats, expectedColStats = ColumnStats( dataType = IntegerType, numNulls = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 44da6026c1f4..2add22f20b77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -82,7 +82,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { } def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { - val df = sql(s"SELECT * FROM $tableName") + val df = spark.table(tableName) val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) rel.catalogTable.get.stats diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index d86f7b880ebd..b52ab6a4bdf3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -479,7 +479,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)).nonEmpty) { val colStatsProps = catalogTable.properties .filterKeys(_.startsWith(STATISTICS_BASIC_COL_STATS_PREFIX)) - .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v)} + .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v) } val colStats: Map[String, ColumnStats] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => (field.name, ColumnStats(colStatsProps(field.name), field.dataType)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e275aa5add99..476d588c98f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -164,9 +164,10 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // Try to analyze a temp table sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") - intercept[AnalysisException] { + val err = intercept[AnalysisException] { sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") } + assert(err.message.contains("ANALYZE TABLE is not supported")) spark.sessionState.catalog.dropTable( TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } From bb19f72789abc960efb937712512c0716fecd800 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 22 Sep 2016 19:38:53 -0700 Subject: [PATCH 13/22] fix tests --- .../command/AnalyzeColumnCommand.scala | 8 +++---- .../command/AnalyzeTableCommand.scala | 8 +++---- .../spark/sql/StatisticsColumnSuite.scala | 11 --------- .../spark/sql/hive/StatisticsSuite.scala | 23 +++++++++++++++++-- .../sql/hive/execution/SQLViewSuite.scala | 1 + 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 143c97c97552..fc68b9420880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -40,7 +40,9 @@ case class AnalyzeColumnCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case catalogRel: CatalogRelation => @@ -62,10 +64,6 @@ case class AnalyzeColumnCommand( rowCount = Some(rowCount), colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map())) sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics))) - // We need to add database info to the table identifier so that we will not refresh the temp - // table with the same table name. - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 6b47d81d023c..96c0d9d14770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -41,7 +41,9 @@ case class AnalyzeTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) relation match { case relation: CatalogRelation => @@ -82,10 +84,6 @@ case class AnalyzeTableCommand( // recorded in the metastore. if (newStats.isDefined) { sessionState.catalog.alterTable(catalogTable.copy(stats = newStats)) - // We need to add database info to the table identifier so that we will not refresh the - // temp table with the same table name. - val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) - val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index de4c16d358cb..97c195588768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -57,17 +57,6 @@ class StatisticsColumnSuite extends StatisticsTest { } } - test("analyzing columns in temporary tables is not supported") { - val viewName = "tbl" - withTempView(viewName) { - spark.range(10).createOrReplaceTempView(viewName) - val err = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") - } - assert(err.message.contains("ANALYZE TABLE is not supported")) - } - } - test("analyzing columns of non-atomic types is not supported") { val tableName = "tbl" withTable(tableName) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 476d588c98f8..fbe2a4336807 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -164,14 +164,33 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils // Try to analyze a temp table sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") - val err = intercept[AnalysisException] { + intercept[AnalysisException] { sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") } - assert(err.message.contains("ANALYZE TABLE is not supported")) spark.sessionState.catalog.dropTable( TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("analyzing views is not supported") { + def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { + val err = intercept[AnalysisException] { + sql(analyzeCommand) + } + assert(err.message.contains("ANALYZE TABLE is not supported")) + } + + val tableName = "tbl" + withTable(tableName) { + spark.range(10).write.saveAsTable(tableName) + val viewName = "view" + withView(viewName) { + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tableName") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertAnalyzeUnsupported(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + } + } + } + private def checkTableStats( stats: Option[Statistics], hasSizeInBytes: Boolean, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index a215c70da0c5..f5c605fe5e2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -123,6 +123,7 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assertNoSuchTable(s"SHOW CREATE TABLE $viewName") assertNoSuchTable(s"SHOW PARTITIONS $viewName") assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") } } From 2b645496e8965730690b5e72ea0f589371f9e9d7 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 26 Sep 2016 15:52:41 -0700 Subject: [PATCH 14/22] use InternalRow in ColumnStat --- .../catalyst/plans/logical/Statistics.scala | 91 ++++---- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../command/AnalyzeColumnCommand.scala | 124 +++++----- .../spark/sql/StatisticsColumnSuite.scala | 212 +++++++++--------- .../org/apache/spark/sql/StatisticsTest.scala | 88 +++++--- .../spark/sql/hive/HiveExternalCatalog.scala | 10 +- 6 files changed, 272 insertions(+), 256 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 9bd042d21a5f..86d0236d85ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.types.DataType +import org.apache.commons.codec.binary.Base64 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the @@ -40,7 +44,7 @@ import org.apache.spark.sql.types.DataType case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStats] = Map.empty, + colStats: Map[String, ColumnStat] = Map.empty, isBroadcastable: Boolean = false) { override def toString: String = "Statistics(" + simpleString + ")" @@ -49,7 +53,6 @@ case class Statistics( def simpleString: String = { Seq(s"sizeInBytes=$sizeInBytes", if (rowCount.isDefined) s"rowCount=${rowCount.get}" else "", - if (colStats.nonEmpty) s"colStats=$colStats" else "", s"isBroadcastable=$isBroadcastable" ).filter(_.nonEmpty).mkString(", ") } @@ -57,51 +60,55 @@ case class Statistics( /** * Statistics for a column. - * @param ndv Number of distinct values of the column. */ -case class ColumnStats( - dataType: DataType, - numNulls: Long, - max: Option[Any] = None, - min: Option[Any] = None, - ndv: Option[Long] = None, - avgColLen: Option[Double] = None, - maxColLen: Option[Long] = None, - numTrues: Option[Long] = None, - numFalses: Option[Long] = None) { +case class ColumnStat(dataType: DataType, statRow: InternalRow) { - override def toString: String = "ColumnStats(" + simpleString + ")" + def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { + NumericColumnStat(statRow, dataType) + } + def forString: StringColumnStat = StringColumnStat(statRow) + def forBinary: BinaryColumnStat = BinaryColumnStat(statRow) + def forBoolean: BooleanColumnStat = BooleanColumnStat(statRow) - def simpleString: String = { - Seq(s"numNulls=$numNulls", - if (max.isDefined) s"max=${max.get}" else "", - if (min.isDefined) s"min=${min.get}" else "", - if (ndv.isDefined) s"ndv=${ndv.get}" else "", - if (avgColLen.isDefined) s"avgColLen=${avgColLen.get}" else "", - if (maxColLen.isDefined) s"maxColLen=${maxColLen.get}" else "", - if (numTrues.isDefined) s"numTrues=${numTrues.get}" else "", - if (numFalses.isDefined) s"numFalses=${numFalses.get}" else "" - ).filter(_.nonEmpty).mkString(", ") + override def toString: String = { + // use Base64 for encoding + Base64.encodeBase64String(statRow.asInstanceOf[UnsafeRow].getBytes) } } -object ColumnStats { - def apply(str: String, dataType: DataType): ColumnStats = { - val suffix = ",\\s|\\)" - ColumnStats( - dataType = dataType, - numNulls = findItem(source = str, prefix = "numNulls=", suffix = suffix).map(_.toLong).get, - max = findItem(source = str, prefix = "max=", suffix = suffix), - min = findItem(source = str, prefix = "min=", suffix = suffix), - ndv = findItem(source = str, prefix = "ndv=", suffix = suffix).map(_.toLong), - avgColLen = findItem(source = str, prefix = "avgColLen=", suffix = suffix).map(_.toDouble), - maxColLen = findItem(source = str, prefix = "maxColLen=", suffix = suffix).map(_.toLong), - numTrues = findItem(source = str, prefix = "numTrues=", suffix = suffix).map(_.toLong), - numFalses = findItem(source = str, prefix = "numFalses=", suffix = suffix).map(_.toLong)) +object ColumnStat { + def apply(dataType: DataType, str: String): ColumnStat = { + // use Base64 for decoding + ColumnStat(dataType, InternalRow(Base64.decodeBase64(str))) } +} - private def findItem(source: String, prefix: String, suffix: String): Option[String] = { - val pattern = s"(?<=$prefix)(.+?)(?=$suffix)".r - pattern.findFirstIn(source) - } +case class NumericColumnStat[T <: AtomicType](statRow: InternalRow, dataType: T) { + // The indices here must be consistent with `ColumnStatStruct.numericColumnStat`. + val numNulls: Long = statRow.getLong(0) + val max: T#InternalType = statRow.get(1, dataType).asInstanceOf[T#InternalType] + val min: T#InternalType = statRow.get(2, dataType).asInstanceOf[T#InternalType] + val ndv: Long = statRow.getLong(3) +} + +case class StringColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.stringColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) + val ndv: Long = statRow.getLong(3) +} + +case class BinaryColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.binaryColumnStat`. + val numNulls: Long = statRow.getLong(0) + val avgColLen: Double = statRow.getDouble(1) + val maxColLen: Long = statRow.getLong(2) +} + +case class BooleanColumnStat(statRow: InternalRow) { + // The indices here must be consistent with `ColumnStatStruct.booleanColumnStat`. + val numNulls: Long = statRow.getLong(0) + val numTrues: Long = statRow.getLong(1) + val numFalses: Long = statRow.getLong(2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b8adc969a25e..ad1302fc38b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -90,8 +90,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. * Example SQL for analyzing table : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; - * ANALYZE TABLE table COMPUTE STATISTICS; + * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN]; * }}} * Example SQL for analyzing columns : * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index fc68b9420880..d68a28cdf587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStats, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types._ @@ -73,7 +73,7 @@ case class AnalyzeColumnCommand( def computeColStats( sparkSession: SparkSession, - relation: LogicalPlan): (Long, Map[String, ColumnStats]) = { + relation: LogicalPlan): (Long, Map[String, ColumnStat]) = { // check correctness of column names val attributesToAnalyze = mutable.MutableList[Attribute]() @@ -95,7 +95,7 @@ case class AnalyzeColumnCommand( // The layout of each struct follows the layout of the ColumnStats. val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStatsStruct(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStatStruct(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)) .queryExecution.toRdd.collect().head @@ -103,88 +103,76 @@ case class AnalyzeColumnCommand( // unwrap the result val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStatsStruct.unwrapStruct(statsRow, i + 1, expr, rowCount)) + (expr.name, ColumnStatStruct.unwrapStruct(statsRow, i + 1, expr, ndvMaxErr, rowCount)) }.toMap (rowCount, columnStats) } } -object ColumnStatsStruct { +object ColumnStatStruct { val zero = Literal(0, LongType) val one = Literal(1, LongType) - val nullLong = Literal(null, LongType) - val nullDouble = Literal(null, DoubleType) - val nullString = Literal(null, StringType) - val nullBinary = Literal(null, BinaryType) - val nullBoolean = Literal(null, BooleanType) - // The number of different kinds of column-level statistics. - val statsNumber = 8 - - def apply(e: NamedExpression, relativeSD: Double): CreateStruct = { - // Use aggregate functions to compute statistics we need: - // - number of nulls: Sum(If(IsNull(e), one, zero)); - // - maximum value: Max(e); - // - minimum value: Min(e); - // - ndv (number of distinct values): HyperLogLogPlusPlus(e, relativeSD); - // - average length of values: Average(Length(e)); - // - maximum length of values: Max(Length(e)); - // - number of true values: Sum(If(e, one, zero)); - // - number of false values: Sum(If(Not(e), one, zero)); - // - If we don't need some statistic for the data type, use null literal. - // Note that: the order of each sequence must be as follows: - // numNulls, max, min, ndv, avgColLen, maxColLen, numTrues, numFalses - var statistics = e.dataType match { - case _: NumericType | TimestampType | DateType => - Seq(Max(e), Min(e), HyperLogLogPlusPlus(e, relativeSD), nullDouble, nullLong, nullLong, - nullLong) - case StringType => - Seq(nullString, nullString, HyperLogLogPlusPlus(e, relativeSD), Average(Length(e)), - Max(Length(e)), nullLong, nullLong) - case BinaryType => - Seq(nullBinary, nullBinary, nullLong, Average(Length(e)), Max(Length(e)), nullLong, - nullLong) - case BooleanType => - Seq(nullBoolean, nullBoolean, nullLong, nullDouble, nullLong, Sum(If(e, one, zero)), - Sum(If(Not(e), one, zero))) - case otherType => - throw new AnalysisException("Analyzing columns is not supported for column " + - s"${e.name} of data type: ${e.dataType}.") - } - statistics = if (e.nullable) { - Sum(If(IsNull(e), one, zero)) +: statistics - } else { - zero +: statistics - } - assert(statistics.length == statsNumber) - CreateStruct(statistics.map { + + def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero + def max(e: Expression): Expression = Max(e) + def min(e: Expression): Expression = Min(e) + def ndv(e: Expression, relativeSD: Double): Expression = HyperLogLogPlusPlus(e, relativeSD) + def avgLength(e: Expression): Expression = Average(Length(e)) + def maxLength(e: Expression): Expression = Max(Length(e)) + def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) + def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) + + def getStruct(exprs: Seq[Expression]): CreateStruct = { + CreateStruct(exprs.map { case af: AggregateFunction => af.toAggregateExpression() case e: Expression => e }) } - def unwrapStruct(row: InternalRow, offset: Int, e: Expression, rowCount: Long): ColumnStats = { - val struct = row.getStruct(offset, statsNumber) - ColumnStats( - dataType = e.dataType, - numNulls = struct.getLong(0), - max = getField(struct, 1, e.dataType), - min = getField(struct, 2, e.dataType), - ndv = getLongField(struct, 3).map(math.min(_, rowCount)), - avgColLen = getDoubleField(struct, 4), - maxColLen = getLongField(struct, 5), - numTrues = getLongField(struct, 6), - numFalses = getLongField(struct, 7)) + def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD)) + } + + def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD)) } - private def getField(struct: InternalRow, index: Int, dataType: DataType): Option[Any] = { - if (struct.isNullAt(index)) None else Some(struct.get(index, dataType)) + def binaryColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), avgLength(e), maxLength(e)) } - private def getLongField(struct: InternalRow, index: Int): Option[Long] = { - if (struct.isNullAt(index)) None else Some(struct.getLong(index)) + def booleanColumnStat(e: Expression): Seq[Expression] = { + Seq(numNulls(e), numTrues(e), numFalses(e)) } - private def getDoubleField(struct: InternalRow, index: Int): Option[Double] = { - if (struct.isNullAt(index)) None else Some(struct.getDouble(index)) + def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { + // Use aggregate functions to compute statistics we need. + case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) + case StringType => getStruct(stringColumnStat(e, relativeSD)) + case BinaryType => getStruct(binaryColumnStat(e)) + case BooleanType => getStruct(booleanColumnStat(e)) + case otherType => + throw new AnalysisException("Analyzing columns is not supported for column " + + s"${e.name} of data type: ${e.dataType}.") + } + + def unwrapStruct( + row: InternalRow, + offset: Int, + e: Expression, + relativeSD: Double, + rowCount: Long): ColumnStat = { + val numFields = e.dataType match { + case _: NumericType | TimestampType | DateType => numericColumnStat(e, relativeSD).length + case StringType => stringColumnStat(e, relativeSD).length + case BinaryType => binaryColumnStat(e).length + case BooleanType => booleanColumnStat(e).length + } + val struct = row.getStruct(offset, numFields) + if (numFields >= 3 && !struct.isNullAt(3)) { + // ndv should not be larger than number of rows + if (struct.getLong(3) > rowCount) struct.asInstanceOf[UnsafeRow].setLong(3, rowCount) + } + ColumnStat(e.dataType, struct) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 97c195588768..5476a9f7f171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.ColumnStats +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.test.SQLTestData.ArrayData @@ -118,35 +118,42 @@ class StatisticsColumnSuite extends StatisticsTest { val df = data.toDF("c1", "c2", "c3", "c4") val nonNullValues = getNonNullValues[Int](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - max = Some(nonNullValues.max), - min = Some(nonNullValues.min), - ndv = Some(nonNullValues.distinct.length.toLong)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.max, + nonNullValues.min, + nonNullValues.distinct.length.toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } test("column-level statistics for fractional type columns") { - val values = (0 to 5).map { i => - if (i == 0) None else Some(i + i * 0.01d) + val values: Seq[Option[Decimal]] = (0 to 5).map { i => + if (i == 0) None else Some(Decimal(i + i * 0.01)) } val data = values.map { i => - (i.map(_.toFloat), i.map(_.toDouble), i.map(Decimal(_))) + (i.map(_.toFloat), i.map(_.toDouble), i) } val df = data.toDF("c1", "c2", "c3") - val nonNullValues = getNonNullValues[Double](values) + val nonNullValues = getNonNullValues[Decimal](values) + val numNulls = values.count(_.isEmpty).toLong + val ndv = nonNullValues.distinct.length.toLong val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( - dataType = f.dataType, - numNulls = values.count(_.isEmpty), - max = Some(nonNullValues.max), - min = Some(nonNullValues.min), - ndv = Some(nonNullValues.distinct.length.toLong)) - (f.name, colStats) + val colStat = f.dataType match { + case floatType: FloatType => + ColumnStat(floatType, InternalRow(numNulls, nonNullValues.max.toFloat, + nonNullValues.min.toFloat, ndv)) + case doubleType: DoubleType => + ColumnStat(doubleType, InternalRow(numNulls, nonNullValues.max.toDouble, + nonNullValues.min.toDouble, ndv)) + case decimalType: DecimalType => + ColumnStat(decimalType, InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) + } + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -156,13 +163,14 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[String](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - ndv = Some(nonNullValues.distinct.length.toLong), - maxColLen = Some(nonNullValues.map(_.length).max.toLong), - avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong, + nonNullValues.distinct.length.toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -172,12 +180,13 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Array[Byte]](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - maxColLen = Some(nonNullValues.map(_.length).max.toLong), - avgColLen = Some(nonNullValues.map(_.length).sum / nonNullValues.length.toDouble)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -187,12 +196,13 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Boolean](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - numTrues = Some(nonNullValues.count(_.equals(true)).toLong), - numFalses = Some(nonNullValues.count(_.equals(false)).toLong)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.count(_.equals(true)).toLong, + nonNullValues.count(_.equals(false)).toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -202,14 +212,15 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Date](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - // Internally, DateType is represented as the number of days from 1970-01-01. - max = Some(nonNullValues.map(DateTimeUtils.fromJavaDate).max), - min = Some(nonNullValues.map(DateTimeUtils.fromJavaDate).min), - ndv = Some(nonNullValues.distinct.length.toLong)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + // Internally, DateType is represented as the number of days from 1970-01-01. + nonNullValues.map(DateTimeUtils.fromJavaDate).max, + nonNullValues.map(DateTimeUtils.fromJavaDate).min, + nonNullValues.distinct.length.toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -221,14 +232,15 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Timestamp](values) val expectedColStatsSeq = df.schema.map { f => - val colStats = ColumnStats( + val colStat = ColumnStat( dataType = f.dataType, - numNulls = values.count(_.isEmpty), - // Internally, TimestampType is represented as the number of days from 1970-01-01 - max = Some(nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max), - min = Some(nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min), - ndv = Some(nonNullValues.distinct.length.toLong)) - (f.name, colStats) + statRow = InternalRow( + values.count(_.isEmpty).toLong, + // Internally, TimestampType is represented as the number of days from 1970-01-01 + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, + nonNullValues.distinct.length.toLong)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -240,23 +252,10 @@ class StatisticsColumnSuite extends StatisticsTest { } val df = data.toDF("c1", "c2") val expectedColStatsSeq = df.schema.map { f => - val colStats = f.dataType match { - case StringType => - ColumnStats( - dataType = f.dataType, - numNulls = values.count(_.isEmpty), - ndv = Some(0), - maxColLen = None, - avgColLen = None) - case IntegerType => - ColumnStats( - dataType = f.dataType, - numNulls = values.count(_.isEmpty), - max = None, - min = None, - ndv = Some(0)) - } - (f.name, colStats) + val colStat = ColumnStat( + dataType = f.dataType, + statRow = InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -277,41 +276,48 @@ class StatisticsColumnSuite extends StatisticsTest { } val df = data.toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8") val expectedColStatsSeq = df.schema.map { f => - val colStats = f.dataType match { + val colStat = f.dataType match { case IntegerType => - ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(intSeq.max), - min = Some(intSeq.min), ndv = Some(intSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) case DoubleType => - ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(doubleSeq.max), - min = Some(doubleSeq.min), ndv = Some(doubleSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, doubleSeq.max, doubleSeq.min, + doubleSeq.distinct.length.toLong)) case StringType => - ColumnStats(dataType = f.dataType, numNulls = 0, - maxColLen = Some(stringSeq.map(_.length).max.toLong), - avgColLen = Some(stringSeq.map(_.length).sum / stringSeq.length.toDouble), - ndv = Some(stringSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) case BinaryType => - ColumnStats(dataType = f.dataType, numNulls = 0, - maxColLen = Some(binarySeq.map(_.length).max.toLong), - avgColLen = Some(binarySeq.map(_.length).sum / binarySeq.length.toDouble)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + binarySeq.map(_.length).max.toLong)) case BooleanType => - ColumnStats(dataType = f.dataType, numNulls = 0, - numTrues = Some(booleanSeq.count(_.equals(true)).toLong), - numFalses = Some(booleanSeq.count(_.equals(false)).toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) case DateType => - ColumnStats(dataType = f.dataType, numNulls = 0, - max = Some(dateSeq.map(DateTimeUtils.fromJavaDate).max), - min = Some(dateSeq.map(DateTimeUtils.fromJavaDate).min), - ndv = Some(dateSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, + dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) case TimestampType => - ColumnStats(dataType = f.dataType, numNulls = 0, - max = Some(timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max), - min = Some(timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min), - ndv = Some(timestampSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, + timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, + timestampSeq.distinct.length.toLong)) case LongType => - ColumnStats(dataType = f.dataType, numNulls = 0, max = Some(longSeq.max), - min = Some(longSeq.min), ndv = Some(longSeq.distinct.length.toLong)) + ColumnStat( + dataType = f.dataType, + statRow = InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) } - (f.name, colStats) + (f.name, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -335,13 +341,11 @@ class StatisticsColumnSuite extends StatisticsTest { val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) - val colStats = fetchedStats.get.colStats("c1") - checkColStats(colStats = colStats, expectedColStats = ColumnStats( + val colStat = fetchedStats.get.colStats("c1") + checkColStat(colStat = colStat, expectedColStat = ColumnStat( dataType = IntegerType, - numNulls = 0, - max = Some(values.max), - min = Some(values.min), - ndv = Some(values.distinct.length.toLong))) + statRow = InternalRow.fromSeq( + Seq(0L, values.max, values.min, values.distinct.length.toLong)))) } } @@ -352,16 +356,18 @@ class StatisticsColumnSuite extends StatisticsTest { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) assert(fetchedStats1.get.colStats.size == 1) - val expected1 = ColumnStats(dataType = IntegerType, numNulls = 0, ndv = Some(0L)) - checkColStats(colStats = fetchedStats1.get.colStats("c1"), expectedColStats = expected1) + val expected1 = ColumnStat( + dataType = IntegerType, statRow = InternalRow(0L, null, null, 0L)) + checkColStat(colStat = fetchedStats1.get.colStats("c1"), expectedColStat = expected1) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) // column c1 is kept in the stats assert(fetchedStats2.get.colStats.size == 2) - checkColStats(colStats = fetchedStats2.get.colStats("c1"), expectedColStats = expected1) - val expected2 = ColumnStats(dataType = LongType, numNulls = 0, ndv = Some(0L)) - checkColStats(colStats = fetchedStats2.get.colStats("c2"), expectedColStats = expected2) + checkColStat(colStat = fetchedStats2.get.colStats("c1"), expectedColStat = expected1) + val expected2 = ColumnStat( + dataType = LongType, statRow = InternalRow(0L, null, null, 0L)) + checkColStat(colStat = fetchedStats2.get.colStats("c2"), expectedColStat = expected2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 2add22f20b77..c167d3decd49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.AnalyzeColumnCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext @@ -28,7 +28,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( df: DataFrame, - expectedColStatsSeq: Seq[(String, ColumnStats)]): Unit = { + expectedColStatsSeq: Seq[(String, ColumnStat)]): Unit = { val table = "tbl" withTable(table) { df.write.format("json").saveAsTable(table) @@ -39,46 +39,62 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { AnalyzeColumnCommand(tableIdent, columns).computeColStats(spark, relation)._2 expectedColStatsSeq.foreach { expected => assert(columnStats.contains(expected._1)) - checkColStats(colStats = columnStats(expected._1), expectedColStats = expected._2) + checkColStat(colStat = columnStats(expected._1), expectedColStat = expected._2) } } } - def checkColStats(colStats: ColumnStats, expectedColStats: ColumnStats): Unit = { - assert(colStats.dataType == expectedColStats.dataType) - assert(colStats.numNulls == expectedColStats.numNulls) - colStats.dataType match { - case _: IntegralType | DateType | TimestampType => - assert(colStats.max.map(_.toString.toLong) == expectedColStats.max.map(_.toString.toLong)) - assert(colStats.min.map(_.toString.toLong) == expectedColStats.min.map(_.toString.toLong)) - case _: FractionalType => - assert(colStats.max.map(_.toString.toDouble) == expectedColStats - .max.map(_.toString.toDouble)) - assert(colStats.min.map(_.toString.toDouble) == expectedColStats - .min.map(_.toString.toDouble)) - case _ => - // other types don't have max and min stats - assert(colStats.max.isEmpty) - assert(colStats.min.isEmpty) + def checkColStat(colStat: ColumnStat, expectedColStat: ColumnStat): Unit = { + assert(colStat.dataType == expectedColStat.dataType) + colStat.dataType match { + case StringType => + val cs = colStat.forString + val expectedCS = expectedColStat.forString + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv) + case BinaryType => + val cs = colStat.forBinary + val expectedCS = expectedColStat.forBinary + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.avgColLen == expectedCS.avgColLen) + assert(cs.maxColLen == expectedCS.maxColLen) + case BooleanType => + val cs = colStat.forBoolean + val expectedCS = expectedColStat.forBoolean + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.numTrues == expectedCS.numTrues) + assert(cs.numFalses == expectedCS.numFalses) + case atomicType: AtomicType => + checkNumericColStats( + dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat) } - colStats.dataType match { - case BinaryType | BooleanType => assert(colStats.ndv.isEmpty) - case _ => - // ndv is an approximate value, so we make sure we have the value, and it should be - // within 3*SD's of the given rsd. - assert(colStats.ndv.get >= 0) - if (expectedColStats.ndv.get == 0) { - assert(colStats.ndv.get == 0) - } else if (expectedColStats.ndv.get > 0) { - val rsd = spark.sessionState.conf.ndvMaxError - val error = math.abs((colStats.ndv.get / expectedColStats.ndv.get.toDouble) - 1.0d) - assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") - } + } + + private def checkNumericColStats( + dataType: AtomicType, + colStat: ColumnStat, + expectedColStat: ColumnStat): Unit = { + val cs = colStat.forNumeric(dataType) + val expectedCS = expectedColStat.forNumeric(dataType) + assert(cs.numNulls == expectedCS.numNulls) + assert(cs.max == expectedCS.max) + assert(cs.min == expectedCS.min) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv) + } + + private def checkNdv(ndv: Long, expectedNdv: Long): Unit = { + // ndv is an approximate value, so we make sure we have the value, and it should be + // within 3*SD's of the given rsd. + if (expectedNdv == 0) { + assert(ndv == 0) + } else if (expectedNdv > 0) { + assert(ndv > 0) + val rsd = spark.sessionState.conf.ndvMaxError + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") } - assert(colStats.avgColLen == expectedColStats.avgColLen) - assert(colStats.maxColLen == expectedColStats.maxColLen) - assert(colStats.numTrues == expectedColStats.numTrues) - assert(colStats.numFalses == expectedColStats.numFalses) } def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b52ab6a4bdf3..ea6355c868ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStats, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient @@ -403,8 +403,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - stats.colStats.foreach { case (colName, colStats) => - statsProperties += (STATISTICS_BASIC_COL_STATS_PREFIX + colName) -> colStats.toString + stats.colStats.foreach { case (colName, colStat) => + statsProperties += (STATISTICS_BASIC_COL_STATS_PREFIX + colName) -> colStat.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -480,9 +480,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val colStatsProps = catalogTable.properties .filterKeys(_.startsWith(STATISTICS_BASIC_COL_STATS_PREFIX)) .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v) } - val colStats: Map[String, ColumnStats] = catalogTable.schema.collect { + val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => - (field.name, ColumnStats(colStatsProps(field.name), field.dataType)) + (field.name, ColumnStat(field.dataType, colStatsProps(field.name))) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), From 08df66937d8834c8e0b7300beea1973705f852b7 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 26 Sep 2016 19:47:34 -0700 Subject: [PATCH 15/22] fix tests --- .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 1 + .../spark/sql/execution/command/AnalyzeColumnCommand.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index dd2f39eb816f..48a802d6e9ab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.plans.logical.Except; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index d68a28cdf587..bbe6ad1e3578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -169,7 +169,8 @@ object ColumnStatStruct { case BooleanType => booleanColumnStat(e).length } val struct = row.getStruct(offset, numFields) - if (numFields >= 3 && !struct.isNullAt(3)) { + // NumericType, TimestampType, DateType and StringType have ndv and its index is 3. + if (numFields >= 4 && !struct.isNullAt(3)) { // ndv should not be larger than number of rows if (struct.getLong(3) > rowCount) struct.asInstanceOf[UnsafeRow].setLong(3, rowCount) } From 7cd8f144c3d64fc407f42b069bd7a53d70604974 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 26 Sep 2016 23:30:18 -0700 Subject: [PATCH 16/22] change ndv expression --- .../command/AnalyzeColumnCommand.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index bbe6ad1e3578..7d5dec4cfa19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -116,16 +116,22 @@ object ColumnStatStruct { def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero def max(e: Expression): Expression = Max(e) def min(e: Expression): Expression = Min(e) - def ndv(e: Expression, relativeSD: Double): Expression = HyperLogLogPlusPlus(e, relativeSD) + def ndv(e: Expression, relativeSD: Double): Expression = { + val approxNdv = HyperLogLogPlusPlus(e, relativeSD) + // the approximate ndv should not be larger than the number of rows + If(LessThanOrEqual(approxNdv, Count(one)), approxNdv, Count(one)) + } def avgLength(e: Expression): Expression = Average(Length(e)) def maxLength(e: Expression): Expression = Max(Length(e)) def numTrues(e: Expression): Expression = Sum(If(e, one, zero)) def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero)) def getStruct(exprs: Seq[Expression]): CreateStruct = { - CreateStruct(exprs.map { - case af: AggregateFunction => af.toAggregateExpression() - case e: Expression => e + CreateStruct(exprs.map { expr: Expression => + expr.transformUp { + case af: AggregateFunction => af.toAggregateExpression() + case e: Expression => e + } }) } @@ -169,11 +175,6 @@ object ColumnStatStruct { case BooleanType => booleanColumnStat(e).length } val struct = row.getStruct(offset, numFields) - // NumericType, TimestampType, DateType and StringType have ndv and its index is 3. - if (numFields >= 4 && !struct.isNullAt(3)) { - // ndv should not be larger than number of rows - if (struct.getLong(3) > rowCount) struct.asInstanceOf[UnsafeRow].setLong(3, rowCount) - } ColumnStat(e.dataType, struct) } } From 377a4919f43f6da86f1dd7c0d121d0436d377431 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 27 Sep 2016 12:01:26 -0700 Subject: [PATCH 17/22] remove unnecessary dataType in ColumnStat and fix commets --- .../sql/catalyst/expressions/UnsafeRow.java | 1 - .../catalyst/plans/logical/Statistics.scala | 6 +- .../command/AnalyzeColumnCommand.scala | 19 +-- .../command/AnalyzeTableCommand.scala | 3 +- .../spark/sql/internal/SessionState.scala | 16 +- .../spark/sql/StatisticsColumnSuite.scala | 155 +++++++----------- .../org/apache/spark/sql/StatisticsTest.scala | 16 +- .../spark/sql/hive/HiveExternalCatalog.scala | 10 +- 8 files changed, 91 insertions(+), 135 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 48a802d6e9ab..dd2f39eb816f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,7 +31,6 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import org.apache.spark.sql.catalyst.plans.logical.Except; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 86d0236d85ae..ba62778119ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -61,7 +61,7 @@ case class Statistics( /** * Statistics for a column. */ -case class ColumnStat(dataType: DataType, statRow: InternalRow) { +case class ColumnStat(statRow: InternalRow) { def forNumeric[T <: AtomicType](dataType: T): NumericColumnStat[T] = { NumericColumnStat(statRow, dataType) @@ -77,9 +77,9 @@ case class ColumnStat(dataType: DataType, statRow: InternalRow) { } object ColumnStat { - def apply(dataType: DataType, str: String): ColumnStat = { + def apply(str: String): ColumnStat = { // use Base64 for decoding - ColumnStat(dataType, InternalRow(Base64.decodeBase64(str))) + ColumnStat(InternalRow(Base64.decodeBase64(str))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7d5dec4cfa19..5967ba4832a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.types._ /** - * Analyzes the given columns of the given table in the current database to generate statistics, - * which will be used in query optimizations. + * Analyzes the given columns of the given table to generate statistics, which will be used in + * query optimizations. */ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, @@ -77,11 +77,9 @@ case class AnalyzeColumnCommand( // check correctness of column names val attributesToAnalyze = mutable.MutableList[Attribute]() - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val resolver = sparkSession.sessionState.conf.resolver columnNames.foreach { col => - val exprOption = relation.output.find { attr => - if (caseSensitive) attr.name == col else attr.name.equalsIgnoreCase(col) - } + val exprOption = relation.output.find(attr => resolver(attr.name, col)) val expr = exprOption.getOrElse(throw new AnalysisException(s"Invalid column name: $col.")) // do deduplication if (!attributesToAnalyze.contains(expr)) { @@ -117,9 +115,8 @@ object ColumnStatStruct { def max(e: Expression): Expression = Max(e) def min(e: Expression): Expression = Min(e) def ndv(e: Expression, relativeSD: Double): Expression = { - val approxNdv = HyperLogLogPlusPlus(e, relativeSD) - // the approximate ndv should not be larger than the number of rows - If(LessThanOrEqual(approxNdv, Count(one)), approxNdv, Count(one)) + // the approximate ndv should never be larger than the number of rows + Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one))) } def avgLength(e: Expression): Expression = Average(Length(e)) def maxLength(e: Expression): Expression = Max(Length(e)) @@ -130,7 +127,6 @@ object ColumnStatStruct { CreateStruct(exprs.map { expr: Expression => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() - case e: Expression => e } }) } @@ -174,7 +170,6 @@ object ColumnStatStruct { case BinaryType => binaryColumnStat(e).length case BooleanType => booleanColumnStat(e).length } - val struct = row.getStruct(offset, numFields) - ColumnStat(e.dataType, struct) + ColumnStat(row.getStruct(offset, numFields)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 96c0d9d14770..7b0e49b665f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.internal.SessionState /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Analyzes the given table to generate statistics, which will be used in query optimizations. */ case class AnalyzeTableCommand( tableIdent: TableIdentifier, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ded8e16d52f0..9f7d0019c6b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, AnalyzeTableCommand} +import org.apache.spark.sql.execution.command.AnalyzeTableCommand import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -187,18 +187,10 @@ private[sql] class SessionState(sparkSession: SparkSession) { } /** - * Analyzes the given table in the current database to generate table-level statistics, which - * will be used in query optimizations. - */ - def analyzeTable(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) - } - - /** - * Analyzes the given columns in the table to generate column-level statistics, which will be + * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. */ - def analyzeTableColumns(tableIdent: TableIdentifier, columnNames: Seq[String]): Unit = { - AnalyzeColumnCommand(tableIdent, columnNames).run(sparkSession) + def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { + AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 5476a9f7f171..241a9fd87db7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -118,14 +118,12 @@ class StatisticsColumnSuite extends StatisticsTest { val df = data.toDF("c1", "c2", "c3", "c4") val nonNullValues = getNonNullValues[Int](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.max, - nonNullValues.min, - nonNullValues.distinct.length.toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.max, + nonNullValues.min, + nonNullValues.distinct.length.toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -145,15 +143,15 @@ class StatisticsColumnSuite extends StatisticsTest { val expectedColStatsSeq = df.schema.map { f => val colStat = f.dataType match { case floatType: FloatType => - ColumnStat(floatType, InternalRow(numNulls, nonNullValues.max.toFloat, - nonNullValues.min.toFloat, ndv)) + ColumnStat(InternalRow(numNulls, nonNullValues.max.toFloat, nonNullValues.min.toFloat, + ndv)) case doubleType: DoubleType => - ColumnStat(doubleType, InternalRow(numNulls, nonNullValues.max.toDouble, - nonNullValues.min.toDouble, ndv)) + ColumnStat(InternalRow(numNulls, nonNullValues.max.toDouble, nonNullValues.min.toDouble, + ndv)) case decimalType: DecimalType => - ColumnStat(decimalType, InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) + ColumnStat(InternalRow(numNulls, nonNullValues.max, nonNullValues.min, ndv)) } - (f.name, colStat) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -163,14 +161,12 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[String](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong, - nonNullValues.distinct.length.toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong, + nonNullValues.distinct.length.toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -180,13 +176,11 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Array[Byte]](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, - nonNullValues.map(_.length).max.toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.map(_.length).sum / nonNullValues.length.toDouble, + nonNullValues.map(_.length).max.toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -196,13 +190,11 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Boolean](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - nonNullValues.count(_.equals(true)).toLong, - nonNullValues.count(_.equals(false)).toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + nonNullValues.count(_.equals(true)).toLong, + nonNullValues.count(_.equals(false)).toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -212,15 +204,13 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Date](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - // Internally, DateType is represented as the number of days from 1970-01-01. - nonNullValues.map(DateTimeUtils.fromJavaDate).max, - nonNullValues.map(DateTimeUtils.fromJavaDate).min, - nonNullValues.distinct.length.toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, DateType is represented as the number of days from 1970-01-01. + nonNullValues.map(DateTimeUtils.fromJavaDate).max, + nonNullValues.map(DateTimeUtils.fromJavaDate).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -232,15 +222,13 @@ class StatisticsColumnSuite extends StatisticsTest { val df = values.toDF("c1") val nonNullValues = getNonNullValues[Timestamp](values) val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow( - values.count(_.isEmpty).toLong, - // Internally, TimestampType is represented as the number of days from 1970-01-01 - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, - nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, - nonNullValues.distinct.length.toLong)) - (f.name, colStat) + val colStat = ColumnStat(InternalRow( + values.count(_.isEmpty).toLong, + // Internally, TimestampType is represented as the number of days from 1970-01-01 + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).max, + nonNullValues.map(DateTimeUtils.fromJavaTimestamp).min, + nonNullValues.distinct.length.toLong)) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -252,10 +240,7 @@ class StatisticsColumnSuite extends StatisticsTest { } val df = data.toDF("c1", "c2") val expectedColStatsSeq = df.schema.map { f => - val colStat = ColumnStat( - dataType = f.dataType, - statRow = InternalRow(values.count(_.isEmpty).toLong, null, null, 0L)) - (f.name, colStat) + (f, ColumnStat(InternalRow(values.count(_.isEmpty).toLong, null, null, 0L))) } checkColStats(df, expectedColStatsSeq) } @@ -278,46 +263,30 @@ class StatisticsColumnSuite extends StatisticsTest { val expectedColStatsSeq = df.schema.map { f => val colStat = f.dataType match { case IntegerType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) case DoubleType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, doubleSeq.max, doubleSeq.min, + ColumnStat(InternalRow(0L, doubleSeq.max, doubleSeq.min, doubleSeq.distinct.length.toLong)) case StringType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) case BinaryType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, + ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble, binarySeq.map(_.length).max.toLong)) case BooleanType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, booleanSeq.count(_.equals(false)).toLong)) case DateType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, + ColumnStat(InternalRow(0L, dateSeq.map(DateTimeUtils.fromJavaDate).max, dateSeq.map(DateTimeUtils.fromJavaDate).min, dateSeq.distinct.length.toLong)) case TimestampType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, + ColumnStat(InternalRow(0L, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).max, timestampSeq.map(DateTimeUtils.fromJavaTimestamp).min, timestampSeq.distinct.length.toLong)) case LongType => - ColumnStat( - dataType = f.dataType, - statRow = InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) + ColumnStat(InternalRow(0L, longSeq.max, longSeq.min, longSeq.distinct.length.toLong)) } - (f.name, colStat) + (f, colStat) } checkColStats(df, expectedColStatsSeq) } @@ -342,9 +311,8 @@ class StatisticsColumnSuite extends StatisticsTest { checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) val colStat = fetchedStats.get.colStats("c1") - checkColStat(colStat = colStat, expectedColStat = ColumnStat( - dataType = IntegerType, - statRow = InternalRow.fromSeq( + checkColStat(dataType = IntegerType, colStat = colStat, expectedColStat = + ColumnStat(InternalRow.fromSeq( Seq(0L, values.max, values.min, values.distinct.length.toLong)))) } } @@ -356,18 +324,19 @@ class StatisticsColumnSuite extends StatisticsTest { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) assert(fetchedStats1.get.colStats.size == 1) - val expected1 = ColumnStat( - dataType = IntegerType, statRow = InternalRow(0L, null, null, 0L)) - checkColStat(colStat = fetchedStats1.get.colStats("c1"), expectedColStat = expected1) + val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) + checkColStat(dataType = IntegerType, colStat = fetchedStats1.get.colStats("c1"), + expectedColStat = expected1) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) // column c1 is kept in the stats assert(fetchedStats2.get.colStats.size == 2) - checkColStat(colStat = fetchedStats2.get.colStats("c1"), expectedColStat = expected1) - val expected2 = ColumnStat( - dataType = LongType, statRow = InternalRow(0L, null, null, 0L)) - checkColStat(colStat = fetchedStats2.get.colStats("c2"), expectedColStat = expected2) + checkColStat(dataType = IntegerType, colStat = fetchedStats2.get.colStats("c1"), + expectedColStat = expected1) + val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) + checkColStat(dataType = LongType, colStat = fetchedStats2.get.colStats("c2"), + expectedColStat = expected2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index c167d3decd49..787aff4a2fd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -28,7 +28,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { def checkColStats( df: DataFrame, - expectedColStatsSeq: Seq[(String, ColumnStat)]): Unit = { + expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = { val table = "tbl" withTable(table) { df.write.format("json").saveAsTable(table) @@ -36,17 +36,19 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) val columnStats = - AnalyzeColumnCommand(tableIdent, columns).computeColStats(spark, relation)._2 + AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation)._2 expectedColStatsSeq.foreach { expected => - assert(columnStats.contains(expected._1)) - checkColStat(colStat = columnStats(expected._1), expectedColStat = expected._2) + assert(columnStats.contains(expected._1.name)) + checkColStat( + dataType = expected._1.dataType, + colStat = columnStats(expected._1.name), + expectedColStat = expected._2) } } } - def checkColStat(colStat: ColumnStat, expectedColStat: ColumnStat): Unit = { - assert(colStat.dataType == expectedColStat.dataType) - colStat.dataType match { + def checkColStat(dataType: DataType, colStat: ColumnStat, expectedColStat: ColumnStat): Unit = { + dataType match { case StringType => val cs = colStat.forString val expectedCS = expectedColStat.forString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ea6355c868ff..ec550f09f135 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -404,7 +404,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } stats.colStats.foreach { case (colName, colStat) => - statsProperties += (STATISTICS_BASIC_COL_STATS_PREFIX + colName) -> colStat.toString + statsProperties += (STATISTICS_COL_STATS_PREFIX + colName) -> colStat.toString } tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) } else { @@ -478,11 +478,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // construct Spark's statistics from information in Hive metastore if (catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)).nonEmpty) { val colStatsProps = catalogTable.properties - .filterKeys(_.startsWith(STATISTICS_BASIC_COL_STATS_PREFIX)) - .map { case (k, v) => (k.replace(STATISTICS_BASIC_COL_STATS_PREFIX, ""), v) } + .filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) + .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => - (field.name, ColumnStat(field.dataType, colStatsProps(field.name))) + (field.name, ColumnStat(colStatsProps(field.name))) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), @@ -701,7 +701,7 @@ object HiveExternalCatalog { val STATISTICS_PREFIX = "spark.sql.statistics." val STATISTICS_TOTAL_SIZE = STATISTICS_PREFIX + "totalSize" val STATISTICS_NUM_ROWS = STATISTICS_PREFIX + "numRows" - val STATISTICS_BASIC_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." + val STATISTICS_COL_STATS_PREFIX = STATISTICS_PREFIX + "colStats." def removeStatsProperties(metadata: CatalogTable): Map[String, String] = { metadata.properties.filterNot { case (key, _) => key.startsWith(STATISTICS_PREFIX) } From 1d3dd621531185e34fc5ad91e7e85337009ab286 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 28 Sep 2016 12:48:44 -0700 Subject: [PATCH 18/22] encoding and decoding for ColumnStat --- .../catalyst/plans/logical/Statistics.scala | 11 +++- .../spark/sql/StatisticsColumnSuite.scala | 31 +++++++--- .../org/apache/spark/sql/StatisticsTest.scala | 57 ++++++++++++------- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 55 ++++++++++++++++-- 5 files changed, 120 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index ba62778119ed..4445ea2e1ed9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -77,9 +77,16 @@ case class ColumnStat(statRow: InternalRow) { } object ColumnStat { - def apply(str: String): ColumnStat = { + def apply(dataType: DataType, str: String): ColumnStat = { // use Base64 for decoding - ColumnStat(InternalRow(Base64.decodeBase64(str))) + val bytes = Base64.decodeBase64(str) + val numFields = dataType match { + case BinaryType | BooleanType => 3 + case _ => 4 + } + val unsafeRow = new UnsafeRow(numFields) + unsafeRow.pointTo(bytes, bytes.length) + ColumnStat(unsafeRow) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 241a9fd87db7..73fd01b10761 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -311,9 +311,12 @@ class StatisticsColumnSuite extends StatisticsTest { checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) val colStat = fetchedStats.get.colStats("c1") - checkColStat(dataType = IntegerType, colStat = colStat, expectedColStat = - ColumnStat(InternalRow.fromSeq( - Seq(0L, values.max, values.min, values.distinct.length.toLong)))) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = colStat, + expectedColStat = ColumnStat(InternalRow.fromSeq( + Seq(0L, values.max, values.min, values.distinct.length.toLong))), + rsd = spark.sessionState.conf.ndvMaxError) } } @@ -325,18 +328,28 @@ class StatisticsColumnSuite extends StatisticsTest { val fetchedStats1 = checkTableStats(tableName = table, expectedRowCount = Some(0)) assert(fetchedStats1.get.colStats.size == 1) val expected1 = ColumnStat(InternalRow(0L, null, null, 0L)) - checkColStat(dataType = IntegerType, colStat = fetchedStats1.get.colStats("c1"), - expectedColStat = expected1) + val rsd = spark.sessionState.conf.ndvMaxError + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats1.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(tableName = table, expectedRowCount = Some(0)) // column c1 is kept in the stats assert(fetchedStats2.get.colStats.size == 2) - checkColStat(dataType = IntegerType, colStat = fetchedStats2.get.colStats("c1"), - expectedColStat = expected1) + StatisticsTest.checkColStat( + dataType = IntegerType, + colStat = fetchedStats2.get.colStats("c1"), + expectedColStat = expected1, + rsd = rsd) val expected2 = ColumnStat(InternalRow(0L, null, null, 0L)) - checkColStat(dataType = LongType, colStat = fetchedStats2.get.colStats("c2"), - expectedColStat = expected2) + StatisticsTest.checkColStat( + dataType = LongType, + colStat = fetchedStats2.get.colStats("c2"), + expectedColStat = expected2, + rsd = rsd) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 787aff4a2fd3..79a458183a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -39,15 +39,42 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation)._2 expectedColStatsSeq.foreach { expected => assert(columnStats.contains(expected._1.name)) - checkColStat( + val colStat = columnStats(expected._1.name) + StatisticsTest.checkColStat( dataType = expected._1.dataType, - colStat = columnStats(expected._1.name), - expectedColStat = expected._2) + colStat = colStat, + expectedColStat = expected._2, + rsd = spark.sessionState.conf.ndvMaxError) + + // check if we get the same colStat after encoding and decoding + val encodedCS = colStat.toString + val decodedCS = ColumnStat(expected._1.dataType, encodedCS) + StatisticsTest.checkColStat( + dataType = expected._1.dataType, + colStat = decodedCS, + expectedColStat = expected._2, + rsd = spark.sessionState.conf.ndvMaxError) } } } - def checkColStat(dataType: DataType, colStat: ColumnStat, expectedColStat: ColumnStat): Unit = { + def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { + val df = spark.table(tableName) + val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => + assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) + rel.catalogTable.get.stats + } + assert(stats.size == 1) + stats.head + } +} + +object StatisticsTest { + def checkColStat( + dataType: DataType, + colStat: ColumnStat, + expectedColStat: ColumnStat, + rsd: Double): Unit = { dataType match { case StringType => val cs = colStat.forString @@ -55,7 +82,7 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { assert(cs.numNulls == expectedCS.numNulls) assert(cs.avgColLen == expectedCS.avgColLen) assert(cs.maxColLen == expectedCS.maxColLen) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) case BinaryType => val cs = colStat.forBinary val expectedCS = expectedColStat.forBinary @@ -70,42 +97,32 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { assert(cs.numFalses == expectedCS.numFalses) case atomicType: AtomicType => checkNumericColStats( - dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat) + dataType = atomicType, colStat = colStat, expectedColStat = expectedColStat, rsd = rsd) } } private def checkNumericColStats( dataType: AtomicType, colStat: ColumnStat, - expectedColStat: ColumnStat): Unit = { + expectedColStat: ColumnStat, + rsd: Double): Unit = { val cs = colStat.forNumeric(dataType) val expectedCS = expectedColStat.forNumeric(dataType) assert(cs.numNulls == expectedCS.numNulls) assert(cs.max == expectedCS.max) assert(cs.min == expectedCS.min) - checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv) + checkNdv(ndv = cs.ndv, expectedNdv = expectedCS.ndv, rsd = rsd) } - private def checkNdv(ndv: Long, expectedNdv: Long): Unit = { + private def checkNdv(ndv: Long, expectedNdv: Long, rsd: Double): Unit = { // ndv is an approximate value, so we make sure we have the value, and it should be // within 3*SD's of the given rsd. if (expectedNdv == 0) { assert(ndv == 0) } else if (expectedNdv > 0) { assert(ndv > 0) - val rsd = spark.sessionState.conf.ndvMaxError val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) assert(error <= rsd * 3.0d, "Error should be within 3 std. errors.") } } - - def checkTableStats(tableName: String, expectedRowCount: Option[Int]): Option[Statistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ec550f09f135..b483c15267ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -482,7 +482,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { case field if colStatsProps.contains(field.name) => - (field.name, ColumnStat(colStatsProps(field.name))) + (field.name, ColumnStat(field.dataType, colStatsProps(field.name))) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index fbe2a4336807..4fd7343f2821 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,16 +21,16 @@ import java.io.{File, PrintWriter} import scala.reflect.ClassTag -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -358,6 +358,53 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils } } + test("generate column-level statistics and load them from hive metastore") { + import testImplicits._ + + val intSeq = Seq(1, 2) + val stringSeq = Seq("a", "bb") + val booleanSeq = Seq(true, false) + + val data = intSeq.indices.map { i => + (intSeq(i), stringSeq(i), booleanSeq(i)) + } + val tableName = "table" + withTable(tableName) { + val df = data.toDF("c1", "c2", "c3") + df.write.format("parquet").saveAsTable(tableName) + val expectedColStatsSeq = df.schema.map { f => + val colStat = f.dataType match { + case IntegerType => + ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong)) + case StringType => + ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble, + stringSeq.map(_.length).max.toLong, stringSeq.distinct.length.toLong)) + case BooleanType => + ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong, + booleanSeq.count(_.equals(false)).toLong)) + } + (f, colStat) + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3") + val readback = spark.table(tableName) + val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => + val columnStats = rel.catalogTable.get.stats.get.colStats + expectedColStatsSeq.foreach { expected => + assert(columnStats.contains(expected._1.name)) + val colStat = columnStats(expected._1.name) + StatisticsTest.checkColStat( + dataType = expected._1.dataType, + colStat = colStat, + expectedColStat = expected._2, + rsd = spark.sessionState.conf.ndvMaxError) + } + rel + } + assert(relations.size == 1) + } + } + test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => From 3335af65c2b53dcfa96bf1f5602043cf3d11e3b1 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 29 Sep 2016 15:06:43 -0700 Subject: [PATCH 19/22] fix comments --- .../catalyst/plans/logical/Statistics.scala | 6 +---- .../command/AnalyzeColumnCommand.scala | 27 +++++++------------ .../spark/sql/StatisticsColumnSuite.scala | 3 +-- .../org/apache/spark/sql/StatisticsTest.scala | 5 ++-- .../spark/sql/hive/HiveExternalCatalog.scala | 7 ++--- 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 4445ea2e1ed9..43455c989c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -77,13 +77,9 @@ case class ColumnStat(statRow: InternalRow) { } object ColumnStat { - def apply(dataType: DataType, str: String): ColumnStat = { + def apply(numFields: Int, str: String): ColumnStat = { // use Base64 for decoding val bytes = Base64.decodeBase64(str) - val numFields = dataType match { - case BinaryType | BooleanType => 3 - case _ => 4 - } val unsafeRow = new UnsafeRow(numFields) unsafeRow.pointTo(bytes, bytes.length) ColumnStat(unsafeRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 5967ba4832a1..7a85e2b63cd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ @@ -101,7 +101,8 @@ case class AnalyzeColumnCommand( // unwrap the result val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStatStruct.unwrapStruct(statsRow, i + 1, expr, ndvMaxErr, rowCount)) + val numFields = ColumnStatStruct.numStatFields(expr.dataType) + (expr.name, ColumnStat(statsRow.getStruct(i + 1, numFields))) }.toMap (rowCount, columnStats) } @@ -147,6 +148,13 @@ object ColumnStatStruct { Seq(numNulls(e), numTrues(e), numFalses(e)) } + def numStatFields(dataType: DataType): Int = { + dataType match { + case BinaryType | BooleanType => 3 + case _ => 4 + } + } + def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match { // Use aggregate functions to compute statistics we need. case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD)) @@ -157,19 +165,4 @@ object ColumnStatStruct { throw new AnalysisException("Analyzing columns is not supported for column " + s"${e.name} of data type: ${e.dataType}.") } - - def unwrapStruct( - row: InternalRow, - offset: Int, - e: Expression, - relativeSD: Double, - rowCount: Long): ColumnStat = { - val numFields = e.dataType match { - case _: NumericType | TimestampType | DateType => numericColumnStat(e, relativeSD).length - case StringType => stringColumnStat(e, relativeSD).length - case BinaryType => binaryColumnStat(e).length - case BooleanType => booleanColumnStat(e).length - } - ColumnStat(row.getStruct(offset, numFields)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 73fd01b10761..69792790ec58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -72,9 +72,8 @@ class StatisticsColumnSuite extends StatisticsTest { val table = "tbl" val colName1 = "abc" val colName2 = "x.yz" - val quotedColName2 = s"`$colName2`" withTable(table) { - sql(s"CREATE TABLE $table ($colName1 int, $quotedColName2 string) USING PARQUET") + sql(s"CREATE TABLE $table ($colName1 int, `$colName2` string) USING PARQUET") val invalidColError = intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index 79a458183a62..b60de8b54791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.AnalyzeColumnCommand +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, ColumnStatStruct} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,7 +48,8 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val decodedCS = ColumnStat(expected._1.dataType, encodedCS) + val numFields = ColumnStatStruct.numStatFields(expected._1.dataType) + val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( dataType = expected._1.dataType, colStat = decodedCS, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b483c15267ce..84bd66e9ad65 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics} -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{ColumnStatStruct, DDLUtils} import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe @@ -481,8 +481,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat .filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { - case field if colStatsProps.contains(field.name) => - (field.name, ColumnStat(field.dataType, colStatsProps(field.name))) + case f if colStatsProps.contains(f.name) => + val numFields = ColumnStatStruct.numStatFields(f.dataType) + (f.name, ColumnStat(numFields, colStatsProps(f.name))) }.toMap catalogTable.copy( properties = removeStatsProperties(catalogTable), From 06819ddbfae6fad8056ca6e8293692c22856fb95 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 29 Sep 2016 20:57:11 -0700 Subject: [PATCH 20/22] fix test style --- .../spark/sql/StatisticsColumnSuite.scala | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 69792790ec58..0673b646cf77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -31,30 +31,17 @@ class StatisticsColumnSuite extends StatisticsTest { import testImplicits._ test("parse analyze column commands") { - def assertAnalyzeColumnCommand(analyzeCommand: String, c: Class[_]) { - val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) - val operators = parsed.collect { - case a: AnalyzeColumnCommand => a - case o => o - } - assert(operators.size == 1) - if (operators.head.getClass != c) { - fail( - s"""$analyzeCommand expected command: $c, but got ${operators.head} - |parsed command: - |$parsed - """.stripMargin) - } - } - - val table = "table" - assertAnalyzeColumnCommand( - s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS key, value", - classOf[AnalyzeColumnCommand]) + val tableName = "tbl" + // we need to specify column names intercept[ParseException] { - sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS") + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS") } + + val analyzeSql = s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key, value" + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeSql) + val expected = AnalyzeColumnCommand(TableIdentifier(tableName), Seq("key", "value")) + comparePlans(parsed, expected) } test("analyzing columns of non-atomic types is not supported") { From 95c2d2fab6f24d433aedf6d4a3d28de69d6c2889 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Fri, 30 Sep 2016 10:16:24 -0700 Subject: [PATCH 21/22] add logWarning for duplicated columns --- .../spark/sql/execution/command/AnalyzeColumnCommand.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 7a85e2b63cd6..76daaaeda622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -84,6 +84,8 @@ case class AnalyzeColumnCommand( // do deduplication if (!attributesToAnalyze.contains(expr)) { attributesToAnalyze += expr + } else { + logWarning(s"Duplicated column: $col") } } From 734abad045a5378d14489a4e956b7a8e1c95a811 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sat, 1 Oct 2016 10:59:40 -0700 Subject: [PATCH 22/22] more comments --- .../command/AnalyzeColumnCommand.scala | 7 +++++- .../spark/sql/StatisticsColumnSuite.scala | 23 +++++++------------ .../org/apache/spark/sql/StatisticsTest.scala | 20 ++++++++-------- .../spark/sql/hive/HiveExternalCatalog.scala | 6 ++--- .../spark/sql/hive/StatisticsSuite.scala | 10 ++++---- 5 files changed, 32 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 76daaaeda622..706637827997 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -77,6 +77,7 @@ case class AnalyzeColumnCommand( // check correctness of column names val attributesToAnalyze = mutable.MutableList[Attribute]() + val duplicatedColumns = mutable.MutableList[String]() val resolver = sparkSession.sessionState.conf.resolver columnNames.foreach { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) @@ -85,9 +86,13 @@ case class AnalyzeColumnCommand( if (!attributesToAnalyze.contains(expr)) { attributesToAnalyze += expr } else { - logWarning(s"Duplicated column: $col") + duplicatedColumns += col } } + if (duplicatedColumns.nonEmpty) { + logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " + + s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.") + } // Collect statistics per column. // The first element in the result will be the overall row count, the following elements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala index 0673b646cf77..0ee0547c4559 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsColumnSuite.scala @@ -78,8 +78,8 @@ class StatisticsColumnSuite extends StatisticsTest { val columnsToAnalyze = Seq(colName2.toUpperCase, colName1, colName2) val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val columnStats = - AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation)._2 + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columnsToAnalyze).computeColStats(spark, relation) assert(columnStats.contains(colName1)) assert(columnStats.contains(colName2)) // check deduplication @@ -279,29 +279,22 @@ class StatisticsColumnSuite extends StatisticsTest { test("update table-level stats while collecting column-level stats") { val table = "tbl" - val tmpTable = "tmp" - withTable(table, tmpTable) { - val values = Seq(1) - val df = values.toDF("c1") - df.write.format("json").saveAsTable(tmpTable) - + withTable(table) { sql(s"CREATE TABLE $table (c1 int) USING PARQUET") - sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"INSERT INTO $table SELECT 1") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") - checkTableStats(tableName = table, expectedRowCount = Some(values.length)) + checkTableStats(tableName = table, expectedRowCount = Some(1)) // update table-level stats between analyze table and analyze column commands - sql(s"INSERT INTO $table SELECT * FROM $tmpTable") + sql(s"INSERT INTO $table SELECT 1") sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") - val fetchedStats = - checkTableStats(tableName = table, expectedRowCount = Some(values.length * 2)) + val fetchedStats = checkTableStats(tableName = table, expectedRowCount = Some(2)) val colStat = fetchedStats.get.colStats("c1") StatisticsTest.checkColStat( dataType = IntegerType, colStat = colStat, - expectedColStat = ColumnStat(InternalRow.fromSeq( - Seq(0L, values.max, values.min, values.distinct.length.toLong))), + expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)), rsd = spark.sessionState.conf.ndvMaxError) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala index b60de8b54791..5134ac0e7e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsTest.scala @@ -35,25 +35,25 @@ trait StatisticsTest extends QueryTest with SharedSQLContext { val columns = expectedColStatsSeq.map(_._1) val tableIdent = TableIdentifier(table, Some("default")) val relation = spark.sessionState.catalog.lookupRelation(tableIdent) - val columnStats = - AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation)._2 - expectedColStatsSeq.foreach { expected => - assert(columnStats.contains(expected._1.name)) - val colStat = columnStats(expected._1.name) + val (_, columnStats) = + AnalyzeColumnCommand(tableIdent, columns.map(_.name)).computeColStats(spark, relation) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) StatisticsTest.checkColStat( - dataType = expected._1.dataType, + dataType = field.dataType, colStat = colStat, - expectedColStat = expected._2, + expectedColStat = expectedColStat, rsd = spark.sessionState.conf.ndvMaxError) // check if we get the same colStat after encoding and decoding val encodedCS = colStat.toString - val numFields = ColumnStatStruct.numStatFields(expected._1.dataType) + val numFields = ColumnStatStruct.numStatFields(field.dataType) val decodedCS = ColumnStat(numFields, encodedCS) StatisticsTest.checkColStat( - dataType = expected._1.dataType, + dataType = field.dataType, colStat = decodedCS, - expectedColStat = expected._2, + expectedColStat = expectedColStat, rsd = spark.sessionState.conf.ndvMaxError) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 84bd66e9ad65..261cc6feff09 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -476,9 +476,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } // construct Spark's statistics from information in Hive metastore - if (catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)).nonEmpty) { - val colStatsProps = catalogTable.properties - .filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) + val statsProps = catalogTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + if (statsProps.nonEmpty) { + val colStatsProps = statsProps.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)) .map { case (k, v) => (k.drop(STATISTICS_COL_STATS_PREFIX.length), v) } val colStats: Map[String, ColumnStat] = catalogTable.schema.collect { case f if colStatsProps.contains(f.name) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4fd7343f2821..99dd080683d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -390,13 +390,13 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils val readback = spark.table(tableName) val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation => val columnStats = rel.catalogTable.get.stats.get.colStats - expectedColStatsSeq.foreach { expected => - assert(columnStats.contains(expected._1.name)) - val colStat = columnStats(expected._1.name) + expectedColStatsSeq.foreach { case (field, expectedColStat) => + assert(columnStats.contains(field.name)) + val colStat = columnStats(field.name) StatisticsTest.checkColStat( - dataType = expected._1.dataType, + dataType = field.dataType, colStat = colStat, - expectedColStat = expected._2, + expectedColStat = expectedColStat, rsd = spark.sessionState.conf.ndvMaxError) } rel