Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ case class AnalyzeColumnCommand(

def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
val (rowCount, columnStats) = computeColStats(sparkSession, relation)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
sizeInBytes = newTotalSize,
rowCount = Some(rowCount),
colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map()))
// Newly computed column stats should override the existing ones.
colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats)
sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
// Refresh the cached data source table in the catalog.
sessionState.catalog.refreshTable(tableIdentWithDB)
Expand Down Expand Up @@ -90,8 +92,9 @@ case class AnalyzeColumnCommand(
}
}
if (duplicatedColumns.nonEmpty) {
logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " +
s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.")
logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " +
s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " +
s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.")
}

// Collect statistics per column.
Expand All @@ -116,42 +119,44 @@ case class AnalyzeColumnCommand(
}

object ColumnStatStruct {
val zero = Literal(0, LongType)
val one = Literal(1, LongType)
private val zero = Literal(0, LongType)
private val one = Literal(1, LongType)

def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
def max(e: Expression): Expression = Max(e)
def min(e: Expression): Expression = Min(e)
def ndv(e: Expression, relativeSD: Double): Expression = {
private def numNulls(e: Expression): Expression = {
if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
}
private def max(e: Expression): Expression = Max(e)
private def min(e: Expression): Expression = Min(e)
private def ndv(e: Expression, relativeSD: Double): Expression = {
// the approximate ndv should never be larger than the number of rows
Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
}
def avgLength(e: Expression): Expression = Average(Length(e))
def maxLength(e: Expression): Expression = Max(Length(e))
def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
private def avgLength(e: Expression): Expression = Average(Length(e))
private def maxLength(e: Expression): Expression = Max(Length(e))
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))

def getStruct(exprs: Seq[Expression]): CreateStruct = {
private def getStruct(exprs: Seq[Expression]): CreateStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
}
})
}

def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
}

def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
}

def binaryColumnStat(e: Expression): Seq[Expression] = {
private def binaryColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e))
}

def booleanColumnStat(e: Expression): Seq[Expression] = {
private def booleanColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), numTrues(e), numFalses(e))
}

Expand All @@ -162,14 +167,14 @@ object ColumnStatStruct {
}
}

def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match {
def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match {
// Use aggregate functions to compute statistics we need.
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD))
case StringType => getStruct(stringColumnStat(e, relativeSD))
case BinaryType => getStruct(binaryColumnStat(e))
case BooleanType => getStruct(booleanColumnStat(e))
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
case StringType => getStruct(stringColumnStat(attr, relativeSD))
case BinaryType => getStruct(binaryColumnStat(attr))
case BooleanType => getStruct(booleanColumnStat(attr))
case otherType =>
throw new AnalysisException("Analyzing columns is not supported for column " +
s"${e.name} of data type: ${e.dataType}.")
s"${attr.name} of data type: ${attr.dataType}.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ object SQLConf {
val NDV_MAX_ERROR =
SQLConfigBuilder("spark.sql.statistics.ndv.maxError")
.internal()
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm.")
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " +
"column level statistics.")
.doubleConf
.createWithDefault(0.05)

Expand Down
198 changes: 166 additions & 32 deletions sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{File, PrintWriter}

import scala.reflect.ClassTag

import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
Expand Down Expand Up @@ -358,53 +358,187 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}

test("generate column-level statistics and load them from hive metastore") {
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
val tableName = "tbl"
var statsBeforeUpdate: Statistics = null
var statsAfterUpdate: Statistics = null
withTable(tableName) {
val tableIndent = TableIdentifier(tableName, Some("default"))
val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
sql(s"CREATE TABLE $tableName (key int) USING PARQUET")
sql(s"INSERT INTO $tableName SELECT 1")
if (isAnalyzeColumns) {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
} else {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
}
// Table lookup will make the table cached.
catalog.lookupRelation(tableIndent)
statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent)
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get

sql(s"INSERT INTO $tableName SELECT 2")
if (isAnalyzeColumns) {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
} else {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
}
catalog.lookupRelation(tableIndent)
statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent)
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get
}
(statsBeforeUpdate, statsAfterUpdate)
}

test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") {
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false)

assert(statsBeforeUpdate.sizeInBytes > 0)
assert(statsBeforeUpdate.rowCount == Some(1))

assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
assert(statsAfterUpdate.rowCount == Some(2))
}

test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)

assert(statsBeforeUpdate.sizeInBytes > 0)
assert(statsBeforeUpdate.rowCount == Some(1))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsBeforeUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)

assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
assert(statsAfterUpdate.rowCount == Some(2))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsAfterUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
rsd = spark.sessionState.conf.ndvMaxError)
}

private lazy val (testDataFrame, expectedColStatsSeq) = {
import testImplicits._

val intSeq = Seq(1, 2)
val stringSeq = Seq("a", "bb")
val binarySeq = Seq("a", "bb").map(_.getBytes)
val booleanSeq = Seq(true, false)

val data = intSeq.indices.map { i =>
(intSeq(i), stringSeq(i), booleanSeq(i))
(intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
}
val tableName = "table"
withTable(tableName) {
val df = data.toDF("c1", "c2", "c3")
df.write.format("parquet").saveAsTable(tableName)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
}
(f, colStat)
val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BinaryType =>
ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
binarySeq.map(_.length).max.toInt))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
}
(f, colStat)
}
(df, expectedColStatsSeq)
}

private def checkColStats(
tableName: String,
isDataSourceTable: Boolean,
expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
val readback = spark.table(tableName)
val stats = readback.queryExecution.analyzed.collect {
case rel: MetastoreRelation =>
assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
rel.catalogTable.stats.get
case rel: LogicalRelation =>
assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
rel.catalogTable.get.stats.get
}
assert(stats.length == 1)
val columnStats = stats.head.colStats
assert(columnStats.size == expectedColStatsSeq.length)
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = columnStats(field.name),
expectedColStat = expectedColStat,
rsd = spark.sessionState.conf.ndvMaxError)
}
}

sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3")
val readback = spark.table(tableName)
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
val columnStats = rel.catalogTable.get.stats.get.colStats
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
assert(columnStats.contains(field.name))
val colStat = columnStats(field.name)
test("generate and load column-level stats for data source table") {
val dsTable = "dsTable"
withTable(dsTable) {
testDataFrame.write.format("parquet").saveAsTable(dsTable)
sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
}
}

test("generate and load column-level stats for hive serde table") {
val hTable = "hTable"
val tmp = "tmp"
withTable(hTable, tmp) {
testDataFrame.write.format("parquet").saveAsTable(tmp)
sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
}
}

// When caseSensitive is on, for columns with only case difference, they are different columns
// and we should generate column stats for all of them.
private def checkCaseSensitiveColStats(columnName: String): Unit = {
val tableName = "tbl"
withTable(tableName) {
val column1 = columnName.toLowerCase
val column2 = columnName.toUpperCase
withSQLConf("spark.sql.caseSensitive" -> "true") {
sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hit a bug here... Not by your PRs, but this test case just exposes it. No need to worry about it. I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What bug? Please let me know when that bug fix pr is sent. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not attempt to create a Hive-compatible table in this case. It always fails because of column names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it cause any problems? The logic to create Hive-compatible table is quite conservative, we will try to save into hive metastore first, if fails, fallback to spark specific format.

Copy link
Contributor Author

@wzhfy wzhfy Oct 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It outputs a warning including an exception, and the test can complete successfully.

WARN org.apache.spark.sql.hive.HiveExternalCatalog: Could not persist `default`.`tbl` in a Hive compatible way. Persisting it into Hive metastore in Spark SQL specific format.
org.apache.hadoop.hive.ql.metadata.HiveException: org.apache.hadoop.hive.ql.metadata.HiveException: Duplicate column name c1 in the table definition.
    at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:720)
...

sql(s"INSERT INTO $tableName SELECT 1, 3.0")
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
val readback = spark.table(tableName)
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
val columnStats = rel.catalogTable.get.stats.get.colStats
assert(columnStats.size == 2)
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = columnStats(column1),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = colStat,
expectedColStat = expectedColStat,
dataType = DoubleType,
colStat = columnStats(column2),
expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
rel
}
rel
assert(relations.size == 1)
}
assert(relations.size == 1)
}
}

test("check column statistics for case sensitive column names") {
checkCaseSensitiveColStats(columnName = "c1")
}

test("check column statistics for case sensitive non-ascii column names") {
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkCaseSensitiveColStats(columnName = "列c")
// scalastyle:on
}

test("estimates the size of a test MetastoreRelation") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>
Expand Down