From 79ba7429abfcfab1263740cc89a92eddc2851fc7 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 23 Dec 2016 14:44:57 -0800 Subject: [PATCH 01/36] implemented first version of filter estimation --- .../plans/logical/basicLogicalOperators.scala | 6 +- .../logical/estimation/EstimationUtils.scala | 55 +++ .../logical/estimation/FilterEstimation.scala | 341 ++++++++++++++++++ .../plans/logical/estimation/Range.scala | 75 ++++ .../estimation/FilterEstimationSuite.scala | 57 +++ 5 files changed, 533 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ce1c55dc089e..ded8d94c3c2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, JoinEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -129,6 +129,10 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } + + override lazy val statistics: Statistics = + FilterEstimation.estimate(this).getOrElse(super.statistics) + } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala new file mode 100644 index 000000000000..25644554e4a7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import scala.math.BigDecimal.RoundingMode + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, Expression} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.types.StringType + +object EstimationUtils extends Logging { + + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() + + def getRowSize(attributes: Seq[Attribute], colStats: Map[String, ColumnStat]): Long = { + attributes.map { attr => + if (colStats.contains(attr.name)) { + attr.dataType match { + case StringType => + // base + offset + numBytes + colStats(attr.name).avgLen + 8 + 4 + case _ => + colStats(attr.name).avgLen + } + } else { + attr.dataType.defaultSize + } + }.sum + } +} + +/** Attribute Reference extractor */ +object ExtractAttr { + def unapply(exp: Expression): Option[AttributeReference] = exp match { + case ar: AttributeReference => Some(ar) + case Cast(ar: AttributeReference, dataType) => Some(ar) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala new file mode 100644 index 000000000000..7b6bef60af1c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import scala.collection.immutable.{HashSet, Map} +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +object FilterEstimation extends Logging { + + /** + * We use a mutable colStats because we need to update the corresponding ColumnStat + * for a column after we apply a predicate condition. + */ + private var mutableColStats: mutable.Map[String, ColumnStat] = mutable.Map.empty + + def estimate(plan: Filter): Option[Statistics] = { + val stats: Statistics = plan.child.statistics + if (stats.rowCount.isEmpty) return None + + /** save a mutable copy of colStats so that we can later change it recursively */ + mutableColStats = mutable.HashMap(stats.colStats.toSeq: _*) + + /** estimate selectivity for this filter */ + val percent: Double = calculateConditions(plan, plan.condition) + + /** copy mutableColStats contents to an immutable map */ + val newColStats = mutableColStats.toMap + + val filteredRowCountValue: BigInt = + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * percent) + val avgRowSize = BigDecimal(EstimationUtils.getRowSize(plan.output, newColStats)) + val filteredSizeInBytes: BigInt = + EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) + + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), + colStats = newColStats)) + } + + def calculateConditions( + plan: Filter, + condition: Expression, + update: Boolean = true) + : Double = { + /** + * For conditions linked by And, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. + * For conditions linked by OR, we do not update stats after a condition estimation. + */ + condition match { + case And(cond1, cond2) => + calculateConditions(plan, cond1, update) * calculateConditions(plan, cond2, update) + + case Or(cond1, cond2) => + val p1 = calculateConditions(plan, cond1, update = false) + val p2 = calculateConditions(plan, cond2, update = false) + math.min(1.0, p1 + p2 - (p1 * p2)) + + case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) + case _ => calculateSingleCondition(plan, condition, isNot = false, update) + } + } + + def calculateSingleCondition( + plan: Filter, + condition: Expression, + isNot: Boolean, + update: Boolean) + : Double = { + var notSupported: Boolean = false + val planStat = plan.child.statistics + val percent: Double = condition match { + /** + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * Note that: all binary predicate computing methods assume the literal is at the right side, + * so we will change the predicate order if not. + */ + case op@LessThan(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, planStat, ar, l, update) + case op@LessThan(l: Literal, ExtractAttr(ar)) => + evaluateBinary(GreaterThan(ar, l), planStat, ar, l, update) + + case op@LessThanOrEqual(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, planStat, ar, l, update) + case op@LessThanOrEqual(l: Literal, ExtractAttr(ar)) => + evaluateBinary(GreaterThanOrEqual(ar, l), planStat, ar, l, update) + + case op@GreaterThan(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, planStat, ar, l, update) + case op@GreaterThan(l: Literal, ExtractAttr(ar)) => + evaluateBinary(LessThan(ar, l), planStat, ar, l, update) + + case op@GreaterThanOrEqual(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, planStat, ar, l, update) + case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => + evaluateBinary(LessThanOrEqual(ar, l), planStat, ar, l, update) + + /** EqualTo does not care about the order */ + case op@EqualTo(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, planStat, ar, l, update) + case op@EqualTo(l: Literal, ExtractAttr(ar)) => + evaluateBinary(op, planStat, ar, l, update) + + case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => + /** + * Expression [In (value, seq[Literal])] will be replaced with optimized version + * [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + * Here we convert In into InSet anyway, because they share the same processing logic. + */ + val hSet = expList.map(e => e.eval()) + evaluateInSet(planStat, ar, HashSet() ++ hSet, update) + + case InSet(ExtractAttr(ar), set) => + evaluateInSet(planStat, ar, set, update) + + /** + * It's difficult to estimate IsNull after outer joins. Hence, + * we support IsNull and IsNotNull only when the child is a leaf node (table). + */ + case IsNull(ExtractAttr(ar)) => + if (plan.child.isInstanceOf[LeafNode ]) { + evaluateIsNull(planStat, ar, true, update) + } + else 1.0 + + case IsNotNull(ExtractAttr(ar)) => + if (plan.child.isInstanceOf[LeafNode ]) { + evaluateIsNull(planStat, ar, false, update) + } + else 1.0 + + case _ => + /** + * TODO: it's difficult to support string operators without advanced statistics. + * Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + * | EndsWith(_, _) are not supported yet + */ + logDebug("[CBO] Unsupported filter condition: " + condition) + notSupported = true + 1.0 + } + if (notSupported) { + 1.0 + } else if (isNot) { + 1.0 - percent + } else { + percent + } + } + + def evaluateIsNull( + planStat: Statistics, + attrRef: AttributeReference, + isNull: Boolean, + update: Boolean) + : Double = { + if (!planStat.colStats.contains(attrRef.name)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + val aColStat = planStat.colStats(attrRef.name) + val rowCountValue = planStat.rowCount.get + val nullPercent: BigDecimal = + if (rowCountValue == 0) 0.0 + else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) + + if (update) { + val newStats = + if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) + else aColStat.copy(nullCount = 0) + + mutableColStats += (attrRef.name -> newStats) + } + + val percent = + if (isNull) nullPercent.toDouble + else { + /** ISNOTNULL(column) */ + 1.0 - nullPercent.toDouble + } + + percent + } + + /** This method evaluates binary comparison operators such as =, <, <=, >, >= */ + def evaluateBinary( + op: BinaryComparison, + planStat: Statistics, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + if (!planStat.colStats.contains(attrRef.name)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + op match { + case EqualTo(l, r) => evaluateEqualTo(op, planStat, attrRef, literal, update) + case _ => + attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, planStat, attrRef, literal, update) + case StringType | BinaryType => + /** + * TODO: It is difficult to support other binary comparisons for String/Binary + * type without min/max and advanced statistics like histogram. + */ + logDebug("[CBO] No statistics for String/Binary type " + attrRef) + return 1.0 + } + } + } + + /** This method evaluates the equality predicate for all data types. */ + def evaluateEqualTo( + op: BinaryComparison, + planStat: Statistics, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + + val aColStat = planStat.colStats(attrRef.name) + val ndv = aColStat.distinctCount + + /** + * decide if the value is in [min, max] of the column. + * We currently don't store min/max for binary/string type. + * Hence, we assume it is in boundary for binary/string type. + */ + val inBoundary: Boolean = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + attrRef.dataType match { + case _: IntegralType => + (BigDecimal(literal.value.asInstanceOf[Long]) >= statsRange.min) && + (BigDecimal(literal.value.asInstanceOf[Long]) <= statsRange.max) + + case _: FractionalType => + (BigDecimal(literal.value.asInstanceOf[Double]) >= statsRange.min) && + (BigDecimal(literal.value.asInstanceOf[Double]) <= statsRange.max) + + case DateType => + val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) + if (dateLiteral.isEmpty) { + logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) + return 1.0 + } + val dateBigDecimal = BigDecimal(dateLiteral.asInstanceOf[BigInt]) + (dateBigDecimal >= statsRange.min) && (dateBigDecimal <= statsRange.max) + + case TimestampType => + val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) + if (tsLiteral.isEmpty) { + logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) + return 1.0 + } + val tsBigDecimal = BigDecimal(tsLiteral.asInstanceOf[BigInt]) + (tsBigDecimal >= statsRange.min) && (tsBigDecimal <= statsRange.max) + } + + case _ => true /** for String/Binary type */ + } + + val percent: Double = + if (inBoundary) { + + if (update) { + /** + * We update ColumnStat structure after apply this equality predicate. + * Set distinctCount to 1. Set nullCount to 0. + */ + val newStats = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + val newValue = Some(literal.value) + aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + case _ => aColStat.copy(distinctCount = 1, nullCount = 0) + } + mutableColStats += (attrRef.name -> newStats) + } + + 1.0 / ndv.toDouble + } else { + 0.0 + } + + percent + } + + def evaluateInSet( + planStat: Statistics, + attrRef: AttributeReference, + hSet: Set[Any], + update: Boolean) + : Double = { + if (!planStat.colStats.contains(attrRef.name)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + // TODO: will fill in this method later. + 1.0 + } + + def evaluateBinaryForNumeric( + op: BinaryComparison, + planStat: Statistics, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + // TODO: will fill in this method later. + 1.0 + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala new file mode 100644 index 000000000000..24a4f9b1ca66 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import java.math.{BigDecimal => JDecimal} +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} + + +/** Value range of a column. */ +trait Range + +/** For simplicity we use decimal to unify operations of numeric ranges. */ +case class NumericRange(min: JDecimal, max: JDecimal) extends Range + +/** + * This version of Spark does not have min/max for binary/string types, we define their default + * behaviors by this class. + */ +class DefaultRange extends Range + +/** This is for columns with only null values. */ +class NullRange extends Range + +object Range { + def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { + case StringType | BinaryType => new DefaultRange() + case _ if min.isEmpty || max.isEmpty => new NullRange() + case _ => toNumericRange(min.get, max.get, dataType) + } + + /** + * For simplicity we use decimal to unify operations of numeric types, the two methods below + * are the contract of conversion. + */ + private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { + dataType match { + case _: NumericType => + NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) + case BooleanType => + val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 + val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case DateType => + val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) + val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case TimestampType => + val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) + val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) + NumericRange(new JDecimal(min1), new JDecimal(max1)) + case _ => + throw new AnalysisException(s"Type $dataType is not castable to numeric in estimation.") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala new file mode 100644 index 000000000000..9e1897b10b61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.estimation + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + + +class FilterEstimationSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val data1 = Seq[Long](1, 2, 3, 4) + private val colStatAfterFilter = ColumnStat(1, Some(2L), Some(2L), 0, 8, 8) + private val expectedFilterStats = Statistics( + sizeInBytes = 1 * 8, + rowCount = Some(1), + colStats = Map("key1" -> colStatAfterFilter), + isBroadcastable = false) + + test("filter estimation with equality comparison using basic formula") { + val table1 = "filter_estimation_test1" + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"select * from $table1 where key1=2").queryExecution.optimizedPlan + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } +} From c46ccbf58583d31a3800e57c9aebc6ec66ca9ce5 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Wed, 28 Dec 2016 19:48:08 -0800 Subject: [PATCH 02/36] added evaluateBinaryForNumeric --- .../logical/estimation/FilterEstimation.scala | 182 +++++++++++++++--- .../estimation/FilterEstimationSuite.scala | 2 +- 2 files changed, 151 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index 7b6bef60af1c..2f38b25ad231 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -217,6 +217,24 @@ object FilterEstimation extends Logging { logDebug("[CBO] No statistics for " + attrRef) return 1.0 } + + /** Make sure that the Date/Timestamp literal is a valid one */ + attrRef.dataType match { + case DateType => + val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) + if (dateLiteral.isEmpty) { + logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) + return 1.0 + } + case TimestampType => + val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) + if (tsLiteral.isEmpty) { + logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) + return 1.0 + } + case _ => + } + op match { case EqualTo(l, r) => evaluateEqualTo(op, planStat, attrRef, literal, update) case _ => @@ -234,6 +252,24 @@ object FilterEstimation extends Logging { } } + def numericLiteralToBigDecimal( + literal: Literal, + dataType: DataType) + : BigDecimal = { + dataType match { + case _: IntegralType => + BigDecimal(literal.value.asInstanceOf[Long]) + case _: FractionalType => + BigDecimal(literal.value.asInstanceOf[Double]) + case DateType => + val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) + BigDecimal(dateLiteral.asInstanceOf[BigInt]) + case TimestampType => + val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) + BigDecimal(tsLiteral.asInstanceOf[BigInt]) + } + } + /** This method evaluates the equality predicate for all data types. */ def evaluateEqualTo( op: BinaryComparison, @@ -255,34 +291,8 @@ object FilterEstimation extends Logging { case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - - attrRef.dataType match { - case _: IntegralType => - (BigDecimal(literal.value.asInstanceOf[Long]) >= statsRange.min) && - (BigDecimal(literal.value.asInstanceOf[Long]) <= statsRange.max) - - case _: FractionalType => - (BigDecimal(literal.value.asInstanceOf[Double]) >= statsRange.min) && - (BigDecimal(literal.value.asInstanceOf[Double]) <= statsRange.max) - - case DateType => - val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) - if (dateLiteral.isEmpty) { - logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) - return 1.0 - } - val dateBigDecimal = BigDecimal(dateLiteral.asInstanceOf[BigInt]) - (dateBigDecimal >= statsRange.min) && (dateBigDecimal <= statsRange.max) - - case TimestampType => - val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) - if (tsLiteral.isEmpty) { - logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) - return 1.0 - } - val tsBigDecimal = BigDecimal(tsLiteral.asInstanceOf[BigInt]) - (tsBigDecimal >= statsRange.min) && (tsBigDecimal <= statsRange.max) - } + val lit = numericLiteralToBigDecimal(literal, attrRef.dataType) + (lit >= statsRange.min) && (lit <= statsRange.max) case _ => true /** for String/Binary type */ } @@ -323,8 +333,51 @@ object FilterEstimation extends Logging { logDebug("[CBO] No statistics for " + attrRef) return 1.0 } - // TODO: will fill in this method later. - 1.0 + + val aColStat = planStat.colStats(attrRef.name) + val ndv = aColStat.distinctCount + val aType = attrRef.dataType + + // use [min, max] to filter the original hSet + val validQuerySet = aType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] + hSet.map(e => numericLiteralToBigDecimal(e.asInstanceOf[Literal], aType)). + filter(e => e >= statsRange.min && e <= statsRange.max) + + /** We assume the whole set since there is no min/max information for String/Binary type */ + case StringType | BinaryType => hSet + } + if (validQuerySet.isEmpty) { + return 0.0 + } + + val newNdv = validQuerySet.size + val(newMax, newMin) = aType match { + case _: NumericType | DateType | TimestampType => + val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) + (Some(tmpSet.max), Some(tmpSet.min)) + case _ => + (None, None) + } + + if (update) { + val newStats = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + case StringType | BinaryType => + aColStat.copy(distinctCount = newNdv, nullCount = 0) + } + mutableColStats += (attrRef.name -> newStats) + } + + /** + * return the filter selectivity. Without advanced statistics such as histograms, + * we have to assume uniform distribution. + */ + math.min(1.0, validQuerySet.size / ndv.toDouble) } def evaluateBinaryForNumeric( @@ -334,8 +387,73 @@ object FilterEstimation extends Logging { literal: Literal, update: Boolean) : Double = { - // TODO: will fill in this method later. - 1.0 + + var percent = 1.0 + val aColStat = planStat.colStats(attrRef.name) + val ndv = aColStat.distinctCount + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + val literalValueBD = numericLiteralToBigDecimal(literal, attrRef.dataType) + + /** determine the overlapping degree between predicate range and column's range */ + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case LessThan(l, r) => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case LessThanOrEqual(l, r) => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case GreaterThan(l, r) => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case GreaterThanOrEqual(l, r) => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + /** this is partial overlap case */ + var newMax = aColStat.max + var newMin = aColStat.min + var newNdv = ndv + val literalToDouble = literalValueBD.toDouble + val maxToDouble = BigDecimal(statsRange.max).toDouble + val minToDouble = BigDecimal(statsRange.min).toDouble + + /** + * Without advanced statistics like histogram, we assume uniform data distribution. + * We just prorate the adjusted range over the initial range to compute filter selectivity. + */ + percent = op match { + case LessThan(l, r) => + (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case LessThanOrEqual(l, r) => + if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble + else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case GreaterThan(l, r) => + (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + case GreaterThanOrEqual(l, r) => + if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble + else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + } + + if (update) { + op match { + case GreaterThan(l, r) => newMin = Some(literal.value) + case GreaterThanOrEqual(l, r) => newMin = Some(literal.value) + case LessThan(l, r) => newMax = Some(literal.value) + case LessThanOrEqual(l, r) => newMax = Some(literal.value) + } + newNdv = math.max(math.round(ndv.toDouble * percent), 1) + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + mutableColStats += (attrRef.name -> newStats) + } + } + + percent } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala index 9e1897b10b61..b165c85151c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.test.SharedSQLContext class FilterEstimationSuite extends QueryTest with SharedSQLContext { import testImplicits._ - private val data1 = Seq[Long](1, 2, 3, 4) + private val data1 = Seq[Long](1, 2, 3, 4, 5) private val colStatAfterFilter = ColumnStat(1, Some(2L), Some(2L), 0, 8, 8) private val expectedFilterStats = Statistics( sizeInBytes = 1 * 8, From cf42e5ecb553156abf9fe95abc1673f83d031714 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 30 Dec 2016 18:19:46 -0800 Subject: [PATCH 03/36] support all binary expressions and add test cases --- .../logical/estimation/FilterEstimation.scala | 54 ++-- .../estimation/FilterEstimationSuite.scala | 284 +++++++++++++++++- 2 files changed, 311 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index 2f38b25ad231..f4afc30183c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -44,7 +44,7 @@ object FilterEstimation extends Logging { mutableColStats = mutable.HashMap(stats.colStats.toSeq: _*) /** estimate selectivity for this filter */ - val percent: Double = calculateConditions(plan, plan.condition) + val percent: Double = calculateConditions(plan, stats, plan.condition) /** copy mutableColStats contents to an immutable map */ val newColStats = mutableColStats.toMap @@ -61,6 +61,7 @@ object FilterEstimation extends Logging { def calculateConditions( plan: Filter, + planStat: Statistics, condition: Expression, update: Boolean = true) : Double = { @@ -71,26 +72,30 @@ object FilterEstimation extends Logging { */ condition match { case And(cond1, cond2) => - calculateConditions(plan, cond1, update) * calculateConditions(plan, cond2, update) + val newStats1 = planStat.copy(colStats = mutableColStats.toMap) + val p1 = calculateConditions(plan, newStats1, cond1, update) + val newStats2 = planStat.copy(colStats = mutableColStats.toMap) + val p2 = calculateConditions(plan, newStats2, cond2, update) + p1 * p2 case Or(cond1, cond2) => - val p1 = calculateConditions(plan, cond1, update = false) - val p2 = calculateConditions(plan, cond2, update = false) + val p1 = calculateConditions(plan, planStat, cond1, update = false) + val p2 = calculateConditions(plan, planStat, cond2, update = false) math.min(1.0, p1 + p2 - (p1 * p2)) - case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) - case _ => calculateSingleCondition(plan, condition, isNot = false, update) + case Not(cond) => calculateSingleCondition(plan, planStat, cond, isNot = true, update = false) + case _ => calculateSingleCondition(plan, planStat, condition, isNot = false, update) } } def calculateSingleCondition( plan: Filter, + planStat: Statistics, condition: Expression, isNot: Boolean, update: Boolean) : Double = { var notSupported: Boolean = false - val planStat = plan.child.statistics val percent: Double = condition match { /** * Currently we only support binary predicates where one side is a column, @@ -252,21 +257,36 @@ object FilterEstimation extends Logging { } } + /** + * This method converts a numeric or Literal value of numeric type to a BigDecimal value. + * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. + */ def numericLiteralToBigDecimal( - literal: Literal, - dataType: DataType) + literal: Any, + dataType: DataType, + isNumeric: Boolean = false) : BigDecimal = { dataType match { case _: IntegralType => - BigDecimal(literal.value.asInstanceOf[Long]) + if (isNumeric) BigDecimal(literal.asInstanceOf[Long]) + else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Long]) case _: FractionalType => - BigDecimal(literal.value.asInstanceOf[Double]) + if (isNumeric) BigDecimal(literal.asInstanceOf[Double]) + else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Double]) case DateType => - val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) - BigDecimal(dateLiteral.asInstanceOf[BigInt]) + if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) + else { + val dateLiteral = DateTimeUtils.stringToDate( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) + BigDecimal(dateLiteral.asInstanceOf[BigInt]) + } case TimestampType => - val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) - BigDecimal(tsLiteral.asInstanceOf[BigInt]) + if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) + else { + val tsLiteral = DateTimeUtils.stringToTimestamp( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) + BigDecimal(tsLiteral.asInstanceOf[BigInt]) + } } } @@ -343,7 +363,7 @@ object FilterEstimation extends Logging { case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - hSet.map(e => numericLiteralToBigDecimal(e.asInstanceOf[Literal], aType)). + hSet.map(e => numericLiteralToBigDecimal(e, aType, true)). filter(e => e >= statsRange.min && e <= statsRange.max) /** We assume the whole set since there is no min/max information for String/Binary type */ @@ -456,4 +476,4 @@ object FilterEstimation extends Logging { percent } -} +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala index b165c85151c9..7e0ea20b2e13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -21,20 +21,96 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.test.SharedSQLContext +/** + * In this test suite, we test the proedicates containing the following operators: + * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN + */ class FilterEstimationSuite extends QueryTest with SharedSQLContext { import testImplicits._ - private val data1 = Seq[Long](1, 2, 3, 4, 5) - private val colStatAfterFilter = ColumnStat(1, Some(2L), Some(2L), 0, 8, 8) - private val expectedFilterStats = Statistics( - sizeInBytes = 1 * 8, - rowCount = Some(1), - colStats = Map("key1" -> colStatAfterFilter), - isBroadcastable = false) + private val data1 = Seq[Long](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + private val table1 = "filter_estimation_test1" + + test("filter estimation with equality comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 = 2").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 1 * 8, rowCount = Some(1), + colStats = Map("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with less than comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 < 3").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 3 * 8, rowCount = Some(3), + colStats = Map("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with less than or equal to comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 <= 3").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 3 * 8, rowCount = Some(3), + colStats = Map("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } - test("filter estimation with equality comparison using basic formula") { - val table1 = "filter_estimation_test1" + test("filter estimation with greater than comparison") { val df1 = data1.toDF("key1") withTable(table1) { df1.write.saveAsTable(table1) @@ -44,7 +120,64 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { /** Validate statistics */ val logicalPlan = - sql(s"select * from $table1 where key1=2").queryExecution.optimizedPlan + sql(s"SELECT * FROM $table1 WHERE key1 > 6").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 5 * 8, rowCount = Some(5), + colStats = Map("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with greater than or equal to comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 >= 6").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 5 * 8, rowCount = Some(5), + colStats = Map("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with IS NULL comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 IS NULL").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 0, rowCount = Some(0), + colStats = Map("key1" -> ColumnStat(0, None, None, 0, 8, 8)), + isBroadcastable = false) + val filterNodes = logicalPlan.collect { case filter: Filter => val filterStats = filter.statistics @@ -54,4 +187,135 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { assert(filterNodes.size == 1) } } + + test("filter estimation with IS NOT NULL comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 IS NOT NULL").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 10 * 8, rowCount = Some(10), + colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with logical AND operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 4 * 8, rowCount = Some(4), + colStats = Map("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with logical OR operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 2 * 8, rowCount = Some(2), + colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with IN operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 3 * 8, rowCount = Some(3), + colStats = Map("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + + test("filter estimation with logical NOT operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val logicalPlan = + sql(s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)").queryExecution.optimizedPlan + val expectedFilterStats = Statistics( + sizeInBytes = 7 * 8, rowCount = Some(7), + colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), + isBroadcastable = false) + + val filterNodes = logicalPlan.collect { + case filter: Filter => + val filterStats = filter.statistics + assert(filterStats == expectedFilterStats) + filter + } + assert(filterNodes.size == 1) + } + } + } From 22f9637d4cc2d73f49628e2ce5952c232db7c7be Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 2 Jan 2017 16:06:24 -0800 Subject: [PATCH 04/36] Use AttributeMap for column statistics --- .../logical/estimation/EstimationUtils.scala | 20 +- .../logical/estimation/FilterEstimation.scala | 117 ++++++---- .../estimation/FilterEstimationSuite.scala | 217 +++++------------- 3 files changed, 140 insertions(+), 214 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala index 25644554e4a7..06187336caf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala @@ -20,23 +20,26 @@ package org.apache.spark.sql.catalyst.plans.logical.estimation import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, Expression} -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.types.StringType + object EstimationUtils extends Logging { def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() - def getRowSize(attributes: Seq[Attribute], colStats: Map[String, ColumnStat]): Long = { - attributes.map { attr => - if (colStats.contains(attr.name)) { + def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { + // We assign a generic overhead for a Row object, the actual overhead is different for different + // Row format. + 8 + attributes.map { attr => + if (attrStats.contains(attr)) { attr.dataType match { case StringType => - // base + offset + numBytes - colStats(attr.name).avgLen + 8 + 4 + // UTF8String: base + offset + numBytes + attrStats(attr).avgLen + 8 + 4 case _ => - colStats(attr.name).avgLen + attrStats(attr).avgLen } } else { attr.dataType.defaultSize @@ -49,7 +52,6 @@ object EstimationUtils extends Logging { object ExtractAttr { def unapply(exp: Expression): Option[AttributeReference] = exp match { case ar: AttributeReference => Some(ar) - case Cast(ar: AttributeReference, dataType) => Some(ar) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index f4afc30183c7..cd8f554e572f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -34,20 +34,30 @@ object FilterEstimation extends Logging { * We use a mutable colStats because we need to update the corresponding ColumnStat * for a column after we apply a predicate condition. */ - private var mutableColStats: mutable.Map[String, ColumnStat] = mutable.Map.empty + // private var mutableColStats: mutable.Map[AttributeReference, ColumnStat] = mutable.Map.empty + private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty def estimate(plan: Filter): Option[Statistics] = { val stats: Statistics = plan.child.statistics if (stats.rowCount.isEmpty) return None /** save a mutable copy of colStats so that we can later change it recursively */ - mutableColStats = mutable.HashMap(stats.colStats.toSeq: _*) + val statsSeq: Seq[(ExprId, ColumnStat)] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq + mutableColStats = mutable.HashMap[ExprId, ColumnStat](statsSeq: _*) + + // mutableColStats = mutable.HashMap(stats.attributeStats.toSeq: _*) /** estimate selectivity for this filter */ - val percent: Double = calculateConditions(plan, stats, plan.condition) + val percent: Double = calculateConditions(plan, plan.condition) - /** copy mutableColStats contents to an immutable map */ - val newColStats = mutableColStats.toMap + /** copy mutableColStats contents to an immutable AttributeMap */ + var mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = mutable.Map.empty + for ( (k, v) <- mutableColStats) { + val attr = mapExprIdToAttribute(stats.attributeStats, k).asInstanceOf[Attribute] + mutableAttributeStats += (attr -> v) + } + val newColStats = AttributeMap(mutableAttributeStats.toSeq) val filteredRowCountValue: BigInt = EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * percent) @@ -56,12 +66,18 @@ object FilterEstimation extends Logging { EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), - colStats = newColStats)) + attributeStats = newColStats)) + } + + def mapExprIdToAttribute( + attrStats: AttributeMap[ColumnStat], + exId: ExprId) + : Unit = { + attrStats.foreach( arg => if (arg._1.exprId == exId) return arg._1 ) } def calculateConditions( plan: Filter, - planStat: Statistics, condition: Expression, update: Boolean = true) : Double = { @@ -72,25 +88,22 @@ object FilterEstimation extends Logging { */ condition match { case And(cond1, cond2) => - val newStats1 = planStat.copy(colStats = mutableColStats.toMap) - val p1 = calculateConditions(plan, newStats1, cond1, update) - val newStats2 = planStat.copy(colStats = mutableColStats.toMap) - val p2 = calculateConditions(plan, newStats2, cond2, update) + val p1 = calculateConditions(plan, cond1, update) + val p2 = calculateConditions(plan, cond2, update) p1 * p2 case Or(cond1, cond2) => - val p1 = calculateConditions(plan, planStat, cond1, update = false) - val p2 = calculateConditions(plan, planStat, cond2, update = false) + val p1 = calculateConditions(plan, cond1, update = false) + val p2 = calculateConditions(plan, cond2, update = false) math.min(1.0, p1 + p2 - (p1 * p2)) - case Not(cond) => calculateSingleCondition(plan, planStat, cond, isNot = true, update = false) - case _ => calculateSingleCondition(plan, planStat, condition, isNot = false, update) + case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) + case _ => calculateSingleCondition(plan, condition, isNot = false, update) } } def calculateSingleCondition( plan: Filter, - planStat: Statistics, condition: Expression, isNot: Boolean, update: Boolean) @@ -104,30 +117,30 @@ object FilterEstimation extends Logging { * so we will change the predicate order if not. */ case op@LessThan(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case op@LessThan(l: Literal, ExtractAttr(ar)) => - evaluateBinary(GreaterThan(ar, l), planStat, ar, l, update) + evaluateBinary(GreaterThan(ar, l), ar, l, update) case op@LessThanOrEqual(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case op@LessThanOrEqual(l: Literal, ExtractAttr(ar)) => - evaluateBinary(GreaterThanOrEqual(ar, l), planStat, ar, l, update) + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) case op@GreaterThan(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case op@GreaterThan(l: Literal, ExtractAttr(ar)) => - evaluateBinary(LessThan(ar, l), planStat, ar, l, update) + evaluateBinary(LessThan(ar, l), ar, l, update) case op@GreaterThanOrEqual(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => - evaluateBinary(LessThanOrEqual(ar, l), planStat, ar, l, update) + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) /** EqualTo does not care about the order */ case op@EqualTo(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case op@EqualTo(l: Literal, ExtractAttr(ar)) => - evaluateBinary(op, planStat, ar, l, update) + evaluateBinary(op, ar, l, update) case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => /** @@ -136,10 +149,10 @@ object FilterEstimation extends Logging { * Here we convert In into InSet anyway, because they share the same processing logic. */ val hSet = expList.map(e => e.eval()) - evaluateInSet(planStat, ar, HashSet() ++ hSet, update) + evaluateInSet(ar, HashSet() ++ hSet, update) case InSet(ExtractAttr(ar), set) => - evaluateInSet(planStat, ar, set, update) + evaluateInSet(ar, set, update) /** * It's difficult to estimate IsNull after outer joins. Hence, @@ -147,13 +160,13 @@ object FilterEstimation extends Logging { */ case IsNull(ExtractAttr(ar)) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(planStat, ar, true, update) + evaluateIsNull(plan, ar, true, update) } else 1.0 case IsNotNull(ExtractAttr(ar)) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(planStat, ar, false, update) + evaluateIsNull(plan, ar, false, update) } else 1.0 @@ -177,27 +190,35 @@ object FilterEstimation extends Logging { } def evaluateIsNull( - planStat: Statistics, + plan: Filter, attrRef: AttributeReference, isNull: Boolean, update: Boolean) : Double = { - if (!planStat.colStats.contains(attrRef.name)) { + if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) return 1.0 } - val aColStat = planStat.colStats(attrRef.name) - val rowCountValue = planStat.rowCount.get + val aColStat = mutableColStats(attrRef.exprId) + val rowCountValue = plan.child.statistics.rowCount.get val nullPercent: BigDecimal = if (rowCountValue == 0) 0.0 else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) + /** + * For predicate "WHERE key = 2", parser generates additional conditions to make it + * "WHERE key is not null AND key = 2". We avoid updating mutableColStats for this case. + * val redundant = + * if ((!isNull) && (aColStat.nullCount == 0)) true + * else false + */ + if (update) { val newStats = if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) else aColStat.copy(nullCount = 0) - mutableColStats += (attrRef.name -> newStats) + mutableColStats += (attrRef.exprId -> newStats) } val percent = @@ -213,12 +234,11 @@ object FilterEstimation extends Logging { /** This method evaluates binary comparison operators such as =, <, <=, >, >= */ def evaluateBinary( op: BinaryComparison, - planStat: Statistics, attrRef: AttributeReference, literal: Literal, update: Boolean) : Double = { - if (!planStat.colStats.contains(attrRef.name)) { + if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) return 1.0 } @@ -241,11 +261,11 @@ object FilterEstimation extends Logging { } op match { - case EqualTo(l, r) => evaluateEqualTo(op, planStat, attrRef, literal, update) + case EqualTo(l, r) => evaluateEqualTo(op, attrRef, literal, update) case _ => attrRef.dataType match { case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, planStat, attrRef, literal, update) + evaluateBinaryForNumeric(op, attrRef, literal, update) case StringType | BinaryType => /** * TODO: It is difficult to support other binary comparisons for String/Binary @@ -293,13 +313,12 @@ object FilterEstimation extends Logging { /** This method evaluates the equality predicate for all data types. */ def evaluateEqualTo( op: BinaryComparison, - planStat: Statistics, attrRef: AttributeReference, literal: Literal, update: Boolean) : Double = { - val aColStat = planStat.colStats(attrRef.name) + val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount /** @@ -332,7 +351,7 @@ object FilterEstimation extends Logging { max = newValue, nullCount = 0) case _ => aColStat.copy(distinctCount = 1, nullCount = 0) } - mutableColStats += (attrRef.name -> newStats) + mutableColStats += (attrRef.exprId -> newStats) } 1.0 / ndv.toDouble @@ -344,17 +363,16 @@ object FilterEstimation extends Logging { } def evaluateInSet( - planStat: Statistics, attrRef: AttributeReference, hSet: Set[Any], update: Boolean) : Double = { - if (!planStat.colStats.contains(attrRef.name)) { + if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) return 1.0 } - val aColStat = planStat.colStats(attrRef.name) + val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount val aType = attrRef.dataType @@ -390,7 +408,7 @@ object FilterEstimation extends Logging { case StringType | BinaryType => aColStat.copy(distinctCount = newNdv, nullCount = 0) } - mutableColStats += (attrRef.name -> newStats) + mutableColStats += (attrRef.exprId -> newStats) } /** @@ -402,14 +420,13 @@ object FilterEstimation extends Logging { def evaluateBinaryForNumeric( op: BinaryComparison, - planStat: Statistics, attrRef: AttributeReference, literal: Literal, update: Boolean) : Double = { var percent = 1.0 - val aColStat = planStat.colStats(attrRef.name) + val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] @@ -469,11 +486,11 @@ object FilterEstimation extends Logging { val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - mutableColStats += (attrRef.name -> newStats) + mutableColStats += (attrRef.exprId -> newStats) } } percent } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala index 7e0ea20b2e13..b57475f8740c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.estimation +import org.apache.spark.sql.catalyst.expressions.AttributeMap import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ import org.apache.spark.sql.test.SharedSQLContext /** @@ -41,20 +43,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 = 2").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 1 * 8, rowCount = Some(1), - colStats = Map("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 2" + val colStats = Seq("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(1L)) } } @@ -67,20 +58,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 < 3").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 3 * 8, rowCount = Some(3), - colStats = Map("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 < 3" + val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) } } @@ -93,20 +73,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 <= 3").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 3 * 8, rowCount = Some(3), - colStats = Map("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 <= 3" + val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) } } @@ -119,20 +88,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 > 6").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 5 * 8, rowCount = Some(5), - colStats = Map("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 6" + val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(5L)) } } @@ -145,20 +103,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 >= 6").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 5 * 8, rowCount = Some(5), - colStats = Map("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 >= 6" + val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(5L)) } } @@ -171,20 +118,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 IS NULL").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 0, rowCount = Some(0), - colStats = Map("key1" -> ColumnStat(0, None, None, 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NULL" + val colStats = Seq("key1" -> ColumnStat(0, None, None, 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(0L)) } } @@ -197,20 +133,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 IS NOT NULL").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 10 * 8, rowCount = Some(10), - colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NOT NULL" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(10L)) } } @@ -223,20 +148,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 4 * 8, rowCount = Some(4), - colStats = Map("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6" + val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(4L)) } } @@ -249,20 +163,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 2 * 8, rowCount = Some(2), - colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(2L)) } } @@ -275,20 +178,9 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 3 * 8, rowCount = Some(3), - colStats = Map("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)" + val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) } } @@ -301,21 +193,36 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") /** Validate statistics */ - val logicalPlan = - sql(s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)").queryExecution.optimizedPlan - val expectedFilterStats = Statistics( - sizeInBytes = 7 * 8, rowCount = Some(7), - colStats = Map("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)), - isBroadcastable = false) - - val filterNodes = logicalPlan.collect { - case filter: Filter => - val filterStats = filter.statistics - assert(filterStats == expectedFilterStats) - filter - } - assert(filterNodes.size == 1) + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(7L)) } } -} + private def validateEstimatedStats( + sqlStmt: String, + expectedColStats: Seq[(String, ColumnStat)], + rowCount: Option[Long] = None) + : Unit = { + val logicalPlan = sql(sqlStmt).queryExecution.optimizedPlan + val operNode = logicalPlan.collect { + case oper: Filter => + oper + }.head + val expectedRowCount = rowCount.getOrElse(sql(sqlStmt).collect().head.getLong(0)) + val nameToAttr = operNode.output.map(a => (a.name, a)).toMap + val expectedAttrStats = + AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) + val expectedStats = Statistics( + sizeInBytes = expectedRowCount * getRowSize(operNode.output, expectedAttrStats), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttrStats, + isBroadcastable = false) + + val filterStats = operNode.statistics + assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) + assert(filterStats.rowCount == expectedStats.rowCount) + assert(filterStats.isBroadcastable == expectedStats.isBroadcastable) + } + +} \ No newline at end of file From dfe8eb2aa1b0463eade2649573854c9a70fc2145 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 2 Jan 2017 17:18:18 -0800 Subject: [PATCH 05/36] maintain an ExprId-to-Attribute map --- .../logical/estimation/FilterEstimation.scala | 15 ++++----------- .../sql/estimation/FilterEstimationSuite.scala | 4 +--- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index cd8f554e572f..a2131fb759b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -45,8 +45,9 @@ object FilterEstimation extends Logging { val statsSeq: Seq[(ExprId, ColumnStat)] = stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq mutableColStats = mutable.HashMap[ExprId, ColumnStat](statsSeq: _*) - - // mutableColStats = mutable.HashMap(stats.attributeStats.toSeq: _*) + /** save a copy of ExprId-to-Attribute map for later conversion */ + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) /** estimate selectivity for this filter */ val percent: Double = calculateConditions(plan, plan.condition) @@ -54,7 +55,7 @@ object FilterEstimation extends Logging { /** copy mutableColStats contents to an immutable AttributeMap */ var mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = mutable.Map.empty for ( (k, v) <- mutableColStats) { - val attr = mapExprIdToAttribute(stats.attributeStats, k).asInstanceOf[Attribute] + val attr = expridToAttrMap(k) mutableAttributeStats += (attr -> v) } val newColStats = AttributeMap(mutableAttributeStats.toSeq) @@ -205,14 +206,6 @@ object FilterEstimation extends Logging { if (rowCountValue == 0) 0.0 else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) - /** - * For predicate "WHERE key = 2", parser generates additional conditions to make it - * "WHERE key is not null AND key = 2". We avoid updating mutableColStats for this case. - * val redundant = - * if ((!isNull) && (aColStat.nullCount == 0)) true - * else false - */ - if (update) { val newStats = if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala index b57475f8740c..269aff272d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -220,9 +220,7 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { isBroadcastable = false) val filterStats = operNode.statistics - assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) - assert(filterStats.rowCount == expectedStats.rowCount) - assert(filterStats.isBroadcastable == expectedStats.isBroadcastable) + assert(filterStats == expectedStats) } } \ No newline at end of file From 84d60338c692058fef6b286b73ac095d13696ee0 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 3 Jan 2017 12:59:26 -0800 Subject: [PATCH 06/36] fixed sql/test:scalastyle errors --- .../logical/estimation/FilterEstimation.scala | 18 ++++-------------- .../sql/estimation/FilterEstimationSuite.scala | 2 +- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index a2131fb759b6..62e3d104b4c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -34,7 +34,6 @@ object FilterEstimation extends Logging { * We use a mutable colStats because we need to update the corresponding ColumnStat * for a column after we apply a predicate condition. */ - // private var mutableColStats: mutable.Map[AttributeReference, ColumnStat] = mutable.Map.empty private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty def estimate(plan: Filter): Option[Statistics] = { @@ -45,7 +44,8 @@ object FilterEstimation extends Logging { val statsSeq: Seq[(ExprId, ColumnStat)] = stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq mutableColStats = mutable.HashMap[ExprId, ColumnStat](statsSeq: _*) - /** save a copy of ExprId-to-Attribute map for later conversion */ + + /** save a copy of ExprId-to-Attribute map for later conversion use */ val expridToAttrMap: Map[ExprId, Attribute] = stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) @@ -53,11 +53,8 @@ object FilterEstimation extends Logging { val percent: Double = calculateConditions(plan, plan.condition) /** copy mutableColStats contents to an immutable AttributeMap */ - var mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = mutable.Map.empty - for ( (k, v) <- mutableColStats) { - val attr = expridToAttrMap(k) - mutableAttributeStats += (attr -> v) - } + val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = + mutableColStats.map(kv => (expridToAttrMap(kv._1) -> kv._2)) val newColStats = AttributeMap(mutableAttributeStats.toSeq) val filteredRowCountValue: BigInt = @@ -70,13 +67,6 @@ object FilterEstimation extends Logging { attributeStats = newColStats)) } - def mapExprIdToAttribute( - attrStats: AttributeMap[ColumnStat], - exId: ExprId) - : Unit = { - attrStats.foreach( arg => if (arg._1.exprId == exId) return arg._1 ) - } - def calculateConditions( plan: Filter, condition: Expression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala index 269aff272d7b..a82f61a6ae9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala @@ -223,4 +223,4 @@ class FilterEstimationSuite extends QueryTest with SharedSQLContext { assert(filterStats == expectedStats) } -} \ No newline at end of file +} From c5e3a6b5a2def5d6cb329f1d40efd39f669c34a6 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 3 Jan 2017 18:57:39 -0800 Subject: [PATCH 07/36] make mutableColStats start from empty every time estimate is called --- .../plans/logical/estimation/FilterEstimation.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index 62e3d104b4c1..bb43d1918090 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -41,9 +41,9 @@ object FilterEstimation extends Logging { if (stats.rowCount.isEmpty) return None /** save a mutable copy of colStats so that we can later change it recursively */ - val statsSeq: Seq[(ExprId, ColumnStat)] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq - mutableColStats = mutable.HashMap[ExprId, ColumnStat](statsSeq: _*) + val statsExprIdMap: Map[ExprId, ColumnStat] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) + mutableColStats = mutable.Map.empty ++= statsExprIdMap /** save a copy of ExprId-to-Attribute map for later conversion use */ val expridToAttrMap: Map[ExprId, Attribute] = From c6dcf907f328617e59d19ace412bbb678d1cd20a Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 9 Jan 2017 11:36:52 -0800 Subject: [PATCH 08/36] make FilterEstimation a class --- .../plans/logical/basicLogicalOperators.scala | 6 +- .../logical/estimation/FilterEstimation.scala | 191 ++++++++++++------ 2 files changed, 134 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ded8d94c3c2d..b1c9acf4db43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -130,8 +130,10 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override lazy val statistics: Statistics = - FilterEstimation.estimate(this).getOrElse(super.statistics) + override lazy val statistics: Statistics = { + val filterEstimation = new FilterEstimation + filterEstimation.estimate(this).getOrElse(super.statistics) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala index bb43d1918090..c58cd84dc5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala @@ -28,31 +28,45 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -object FilterEstimation extends Logging { +class FilterEstimation extends Logging { /** * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. + * for a column after we apply a predicate condition. For example, A column c has + * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), + * we need to set the column's [min, max] value to [40, 100] after we evaluate the + * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. */ private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @param plan a LogicalPlan node that must be an instance of Filter. + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ def estimate(plan: Filter): Option[Statistics] = { val stats: Statistics = plan.child.statistics if (stats.rowCount.isEmpty) return None - /** save a mutable copy of colStats so that we can later change it recursively */ + // save a mutable copy of colStats so that we can later change it recursively val statsExprIdMap: Map[ExprId, ColumnStat] = stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) mutableColStats = mutable.Map.empty ++= statsExprIdMap - /** save a copy of ExprId-to-Attribute map for later conversion use */ + // save a copy of ExprId-to-Attribute map for later conversion use val expridToAttrMap: Map[ExprId, Attribute] = stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - /** estimate selectivity for this filter */ + // estimate selectivity for this filter predicate val percent: Double = calculateConditions(plan, plan.condition) - /** copy mutableColStats contents to an immutable AttributeMap */ + // copy mutableColStats contents to an immutable AttributeMap val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = mutableColStats.map(kv => (expridToAttrMap(kv._1) -> kv._2)) val newColStats = AttributeMap(mutableAttributeStats.toSeq) @@ -67,16 +81,26 @@ object FilterEstimation extends Logging { attributeStats = newColStats)) } + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is depomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param plan the Filter LogicalPlan node + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def calculateConditions( plan: Filter, condition: Expression, update: Boolean = true) : Double = { - /** - * For conditions linked by And, we need to update stats after a condition estimation - * so that the stats will be more accurate for subsequent estimation. - * For conditions linked by OR, we do not update stats after a condition estimation. - */ + condition match { case And(cond1, cond2) => val p1 = calculateConditions(plan, cond1, update) @@ -93,6 +117,18 @@ object FilterEstimation extends Logging { } } + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param plan the Filter LogicalPlan node + * @param condition a single logical expression + * @param isNot set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def calculateSingleCondition( plan: Filter, condition: Expression, @@ -101,12 +137,8 @@ object FilterEstimation extends Logging { : Double = { var notSupported: Boolean = false val percent: Double = condition match { - /** - * Currently we only support binary predicates where one side is a column, - * and the other is a literal. - * Note that: all binary predicate computing methods assume the literal is at the right side, - * so we will change the predicate order if not. - */ + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. case op@LessThan(ExtractAttr(ar), l: Literal) => evaluateBinary(op, ar, l, update) case op@LessThan(l: Literal, ExtractAttr(ar)) => @@ -127,28 +159,24 @@ object FilterEstimation extends Logging { case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - /** EqualTo does not care about the order */ + // EqualTo does not care about the order case op@EqualTo(ExtractAttr(ar), l: Literal) => evaluateBinary(op, ar, l, update) case op@EqualTo(l: Literal, ExtractAttr(ar)) => evaluateBinary(op, ar, l, update) case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => - /** - * Expression [In (value, seq[Literal])] will be replaced with optimized version - * [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. - * Here we convert In into InSet anyway, because they share the same processing logic. - */ + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) case InSet(ExtractAttr(ar), set) => evaluateInSet(ar, set, update) - /** - * It's difficult to estimate IsNull after outer joins. Hence, - * we support IsNull and IsNotNull only when the child is a leaf node (table). - */ + // It's difficult to estimate IsNull after outer joins. Hence, + // we support IsNull and IsNotNull only when the child is a leaf node (table). case IsNull(ExtractAttr(ar)) => if (plan.child.isInstanceOf[LeafNode ]) { evaluateIsNull(plan, ar, true, update) @@ -162,11 +190,9 @@ object FilterEstimation extends Logging { else 1.0 case _ => - /** - * TODO: it's difficult to support string operators without advanced statistics. - * Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) - * | EndsWith(_, _) are not supported yet - */ + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet logDebug("[CBO] Unsupported filter condition: " + condition) notSupported = true 1.0 @@ -180,6 +206,16 @@ object FilterEstimation extends Logging { } } + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param plan the Filter LogicalPlan node + * @param attrRef an AttributeReference (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def evaluateIsNull( plan: Filter, attrRef: AttributeReference, @@ -214,7 +250,16 @@ object FilterEstimation extends Logging { percent } - /** This method evaluates binary comparison operators such as =, <, <=, >, >= */ + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def evaluateBinary( op: BinaryComparison, attrRef: AttributeReference, @@ -226,7 +271,7 @@ object FilterEstimation extends Logging { return 1.0 } - /** Make sure that the Date/Timestamp literal is a valid one */ + // Make sure that the Date/Timestamp literal is a valid one attrRef.dataType match { case DateType => val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) @@ -244,16 +289,16 @@ object FilterEstimation extends Logging { } op match { - case EqualTo(l, r) => evaluateEqualTo(op, attrRef, literal, update) + case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) case _ => attrRef.dataType match { case _: NumericType | DateType | TimestampType => evaluateBinaryForNumeric(op, attrRef, literal, update) case StringType | BinaryType => - /** - * TODO: It is difficult to support other binary comparisons for String/Binary - * type without min/max and advanced statistics like histogram. - */ + + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No statistics for String/Binary type " + attrRef) return 1.0 } @@ -293,9 +338,17 @@ object FilterEstimation extends Logging { } } - /** This method evaluates the equality predicate for all data types. */ + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def evaluateEqualTo( - op: BinaryComparison, attrRef: AttributeReference, literal: Literal, update: Boolean) @@ -304,11 +357,11 @@ object FilterEstimation extends Logging { val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount - /** - * decide if the value is in [min, max] of the column. - * We currently don't store min/max for binary/string type. - * Hence, we assume it is in boundary for binary/string type. - */ + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + val inBoundary: Boolean = attrRef.dataType match { case _: NumericType | DateType | TimestampType => val statsRange = @@ -323,10 +376,8 @@ object FilterEstimation extends Logging { if (inBoundary) { if (update) { - /** - * We update ColumnStat structure after apply this equality predicate. - * Set distinctCount to 1. Set nullCount to 0. - */ + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. val newStats = attrRef.dataType match { case _: NumericType | DateType | TimestampType => val newValue = Some(literal.value) @@ -345,6 +396,17 @@ object FilterEstimation extends Logging { percent } + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def evaluateInSet( attrRef: AttributeReference, hSet: Set[Any], @@ -367,7 +429,7 @@ object FilterEstimation extends Logging { hSet.map(e => numericLiteralToBigDecimal(e, aType, true)). filter(e => e >= statsRange.min && e <= statsRange.max) - /** We assume the whole set since there is no min/max information for String/Binary type */ + // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => hSet } if (validQuerySet.isEmpty) { @@ -394,13 +456,22 @@ object FilterEstimation extends Logging { mutableColStats += (attrRef.exprId -> newStats) } - /** - * return the filter selectivity. Without advanced statistics such as histograms, - * we have to assume uniform distribution. - */ + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. math.min(1.0, validQuerySet.size / ndv.toDouble) } + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ def evaluateBinaryForNumeric( op: BinaryComparison, attrRef: AttributeReference, @@ -416,7 +487,7 @@ object FilterEstimation extends Logging { val literalValueBD = numericLiteralToBigDecimal(literal, attrRef.dataType) - /** determine the overlapping degree between predicate range and column's range */ + // determine the overlapping degree between predicate range and column's range val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case LessThan(l, r) => (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) @@ -433,7 +504,7 @@ object FilterEstimation extends Logging { } else if (completeOverlap) { percent = 1.0 } else { - /** this is partial overlap case */ + // this is partial overlap case var newMax = aColStat.max var newMin = aColStat.min var newNdv = ndv @@ -441,10 +512,8 @@ object FilterEstimation extends Logging { val maxToDouble = BigDecimal(statsRange.max).toDouble val minToDouble = BigDecimal(statsRange.min).toDouble - /** - * Without advanced statistics like histogram, we assume uniform data distribution. - * We just prorate the adjusted range over the initial range to compute filter selectivity. - */ + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. percent = op match { case LessThan(l, r) => (literalToDouble - minToDouble) / (maxToDouble - minToDouble) From 3826bd008807fce4b17d80f79a43dc1965164ade Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 9 Jan 2017 19:33:35 -0800 Subject: [PATCH 09/36] move files to new directory statsEstimation --- .../logical/estimation/EstimationUtils.scala | 57 -- .../logical/estimation/FilterEstimation.scala | 548 ------------------ .../plans/logical/estimation/Range.scala | 75 --- .../estimation/FilterEstimationSuite.scala | 226 -------- 4 files changed, 906 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala deleted file mode 100644 index 06187336caf8..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/EstimationUtils.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.estimation - -import scala.math.BigDecimal.RoundingMode - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StringType - - -object EstimationUtils extends Logging { - - def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() - - def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { - // We assign a generic overhead for a Row object, the actual overhead is different for different - // Row format. - 8 + attributes.map { attr => - if (attrStats.contains(attr)) { - attr.dataType match { - case StringType => - // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 - case _ => - attrStats(attr).avgLen - } - } else { - attr.dataType.defaultSize - } - }.sum - } -} - -/** Attribute Reference extractor */ -object ExtractAttr { - def unapply(exp: Expression): Option[AttributeReference] = exp match { - case ar: AttributeReference => Some(ar) - case _ => None - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala deleted file mode 100644 index c58cd84dc5b5..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/FilterEstimation.scala +++ /dev/null @@ -1,548 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.estimation - -import scala.collection.immutable.{HashSet, Map} -import scala.collection.mutable - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -class FilterEstimation extends Logging { - - /** - * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, A column c has - * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), - * we need to set the column's [min, max] value to [40, 100] after we evaluate the - * first condition c > 40. We need to set the column's [min, max] value to [40, 50] - * after we evaluate the second condition c <= 50. - */ - private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty - - /** - * Returns an option of Statistics for a Filter logical plan node. - * For a given compound expression condition, this method computes filter selectivity - * (or the percentage of rows meeting the filter condition), which - * is used to compute row count, size in bytes, and the updated statistics after a given - * predicated is applied. - * - * @param plan a LogicalPlan node that must be an instance of Filter. - * @return Option[Statistics] When there is no statistics collected, it returns None. - */ - def estimate(plan: Filter): Option[Statistics] = { - val stats: Statistics = plan.child.statistics - if (stats.rowCount.isEmpty) return None - - // save a mutable copy of colStats so that we can later change it recursively - val statsExprIdMap: Map[ExprId, ColumnStat] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) - mutableColStats = mutable.Map.empty ++= statsExprIdMap - - // save a copy of ExprId-to-Attribute map for later conversion use - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - - // estimate selectivity for this filter predicate - val percent: Double = calculateConditions(plan, plan.condition) - - // copy mutableColStats contents to an immutable AttributeMap - val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => (expridToAttrMap(kv._1) -> kv._2)) - val newColStats = AttributeMap(mutableAttributeStats.toSeq) - - val filteredRowCountValue: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * percent) - val avgRowSize = BigDecimal(EstimationUtils.getRowSize(plan.output, newColStats)) - val filteredSizeInBytes: BigInt = - EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) - - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), - attributeStats = newColStats)) - } - - /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is depomposed into multiple single conditions linked with AND, OR, NOT. - * For logical AND conditions, we need to update stats after a condition estimation - * so that the stats will be more accurate for subsequent estimation. This is needed for - * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. - * - * @param plan the Filter LogicalPlan node - * @param condition the compound logical expression - * @param update a boolean flag to specify if we need to update ColumnStat of a column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def calculateConditions( - plan: Filter, - condition: Expression, - update: Boolean = true) - : Double = { - - condition match { - case And(cond1, cond2) => - val p1 = calculateConditions(plan, cond1, update) - val p2 = calculateConditions(plan, cond2, update) - p1 * p2 - - case Or(cond1, cond2) => - val p1 = calculateConditions(plan, cond1, update = false) - val p2 = calculateConditions(plan, cond2, update = false) - math.min(1.0, p1 + p2 - (p1 * p2)) - - case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) - case _ => calculateSingleCondition(plan, condition, isNot = false, update) - } - } - - /** - * Returns a percentage of rows meeting a single condition in Filter node. - * Currently we only support binary predicates where one side is a column, - * and the other is a literal. - * - * @param plan the Filter LogicalPlan node - * @param condition a single logical expression - * @param isNot set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition - * @param update a boolean flag to specify if we need to update ColumnStat of a column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def calculateSingleCondition( - plan: Filter, - condition: Expression, - isNot: Boolean, - update: Boolean) - : Double = { - var notSupported: Boolean = false - val percent: Double = condition match { - // For evaluateBinary method, we assume the literal on the right side of an operator. - // So we will change the order if not. - case op@LessThan(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@LessThan(l: Literal, ExtractAttr(ar)) => - evaluateBinary(GreaterThan(ar, l), ar, l, update) - - case op@LessThanOrEqual(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@LessThanOrEqual(l: Literal, ExtractAttr(ar)) => - evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - - case op@GreaterThan(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@GreaterThan(l: Literal, ExtractAttr(ar)) => - evaluateBinary(LessThan(ar, l), ar, l, update) - - case op@GreaterThanOrEqual(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => - evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - - // EqualTo does not care about the order - case op@EqualTo(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@EqualTo(l: Literal, ExtractAttr(ar)) => - evaluateBinary(op, ar, l, update) - - case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => - // Expression [In (value, seq[Literal])] will be replaced with optimized version - // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. - // Here we convert In into InSet anyway, because they share the same processing logic. - val hSet = expList.map(e => e.eval()) - evaluateInSet(ar, HashSet() ++ hSet, update) - - case InSet(ExtractAttr(ar), set) => - evaluateInSet(ar, set, update) - - // It's difficult to estimate IsNull after outer joins. Hence, - // we support IsNull and IsNotNull only when the child is a leaf node (table). - case IsNull(ExtractAttr(ar)) => - if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(plan, ar, true, update) - } - else 1.0 - - case IsNotNull(ExtractAttr(ar)) => - if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(plan, ar, false, update) - } - else 1.0 - - case _ => - // TODO: it's difficult to support string operators without advanced statistics. - // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) - // | EndsWith(_, _) are not supported yet - logDebug("[CBO] Unsupported filter condition: " + condition) - notSupported = true - 1.0 - } - if (notSupported) { - 1.0 - } else if (isNot) { - 1.0 - percent - } else { - percent - } - } - - /** - * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. - * - * @param plan the Filter LogicalPlan node - * @param attrRef an AttributeReference (or a column) - * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def evaluateIsNull( - plan: Filter, - attrRef: AttributeReference, - isNull: Boolean, - update: Boolean) - : Double = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return 1.0 - } - val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.statistics.rowCount.get - val nullPercent: BigDecimal = - if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) - - if (update) { - val newStats = - if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) - else aColStat.copy(nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) - } - - val percent = - if (isNull) nullPercent.toDouble - else { - /** ISNOTNULL(column) */ - 1.0 - nullPercent.toDouble - } - - percent - } - - /** - * Returns a percentage of rows meeting a binary comparison expression. - * - * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def evaluateBinary( - op: BinaryComparison, - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Double = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return 1.0 - } - - // Make sure that the Date/Timestamp literal is a valid one - attrRef.dataType match { - case DateType => - val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) - if (dateLiteral.isEmpty) { - logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) - return 1.0 - } - case TimestampType => - val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) - if (tsLiteral.isEmpty) { - logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) - return 1.0 - } - case _ => - } - - op match { - case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) - case _ => - attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - evaluateBinaryForNumeric(op, attrRef, literal, update) - case StringType | BinaryType => - - // TODO: It is difficult to support other binary comparisons for String/Binary - // type without min/max and advanced statistics like histogram. - - logDebug("[CBO] No statistics for String/Binary type " + attrRef) - return 1.0 - } - } - } - - /** - * This method converts a numeric or Literal value of numeric type to a BigDecimal value. - * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. - */ - def numericLiteralToBigDecimal( - literal: Any, - dataType: DataType, - isNumeric: Boolean = false) - : BigDecimal = { - dataType match { - case _: IntegralType => - if (isNumeric) BigDecimal(literal.asInstanceOf[Long]) - else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Long]) - case _: FractionalType => - if (isNumeric) BigDecimal(literal.asInstanceOf[Double]) - else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Double]) - case DateType => - if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) - else { - val dateLiteral = DateTimeUtils.stringToDate( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) - BigDecimal(dateLiteral.asInstanceOf[BigInt]) - } - case TimestampType => - if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) - else { - val tsLiteral = DateTimeUtils.stringToTimestamp( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) - BigDecimal(tsLiteral.asInstanceOf[BigInt]) - } - } - } - - /** - * Returns a percentage of rows meeting an equality (=) expression. - * This method evaluates the equality predicate for all data types. - * - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def evaluateEqualTo( - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Double = { - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - - - // decide if the value is in [min, max] of the column. - // We currently don't store min/max for binary/string type. - // Hence, we assume it is in boundary for binary/string type. - - val inBoundary: Boolean = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - val lit = numericLiteralToBigDecimal(literal, attrRef.dataType) - (lit >= statsRange.min) && (lit <= statsRange.max) - - case _ => true /** for String/Binary type */ - } - - val percent: Double = - if (inBoundary) { - - if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - val newStats = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - val newValue = Some(literal.value) - aColStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) - case _ => aColStat.copy(distinctCount = 1, nullCount = 0) - } - mutableColStats += (attrRef.exprId -> newStats) - } - - 1.0 / ndv.toDouble - } else { - 0.0 - } - - percent - } - - /** - * Returns a percentage of rows meeting "IN" operator expression. - * This method evaluates the equality predicate for all data types. - * - * @param attrRef an AttributeReference (or a column) - * @param hSet a set of literal values - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - - def evaluateInSet( - attrRef: AttributeReference, - hSet: Set[Any], - update: Boolean) - : Double = { - if (!mutableColStats.contains(attrRef.exprId)) { - logDebug("[CBO] No statistics for " + attrRef) - return 1.0 - } - - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val aType = attrRef.dataType - - // use [min, max] to filter the original hSet - val validQuerySet = aType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - hSet.map(e => numericLiteralToBigDecimal(e, aType, true)). - filter(e => e >= statsRange.min && e <= statsRange.max) - - // We assume the whole set since there is no min/max information for String/Binary type - case StringType | BinaryType => hSet - } - if (validQuerySet.isEmpty) { - return 0.0 - } - - val newNdv = validQuerySet.size - val(newMax, newMin) = aType match { - case _: NumericType | DateType | TimestampType => - val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) - (Some(tmpSet.max), Some(tmpSet.min)) - case _ => - (None, None) - } - - if (update) { - val newStats = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - aColStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) - case StringType | BinaryType => - aColStat.copy(distinctCount = newNdv, nullCount = 0) - } - mutableColStats += (attrRef.exprId -> newStats) - } - - // return the filter selectivity. Without advanced statistics such as histograms, - // we have to assume uniform distribution. - math.min(1.0, validQuerySet.size / ndv.toDouble) - } - - /** - * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. - * - * @param op a binary comparison operator uch as =, <, <=, >, >= - * @param attrRef an AttributeReference (or a column) - * @param literal a literal value (or constant) - * @param update a boolean flag to specify if we need to update ColumnStat of a given column - * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition - */ - def evaluateBinaryForNumeric( - op: BinaryComparison, - attrRef: AttributeReference, - literal: Literal, - update: Boolean) - : Double = { - - var percent = 1.0 - val aColStat = mutableColStats(attrRef.exprId) - val ndv = aColStat.distinctCount - val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - - val literalValueBD = numericLiteralToBigDecimal(literal, attrRef.dataType) - - // determine the overlapping degree between predicate range and column's range - val (noOverlap: Boolean, completeOverlap: Boolean) = op match { - case LessThan(l, r) => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) - case LessThanOrEqual(l, r) => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) - case GreaterThan(l, r) => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) - case GreaterThanOrEqual(l, r) => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) - } - - if (noOverlap) { - percent = 0.0 - } else if (completeOverlap) { - percent = 1.0 - } else { - // this is partial overlap case - var newMax = aColStat.max - var newMin = aColStat.min - var newNdv = ndv - val literalToDouble = literalValueBD.toDouble - val maxToDouble = BigDecimal(statsRange.max).toDouble - val minToDouble = BigDecimal(statsRange.min).toDouble - - // Without advanced statistics like histogram, we assume uniform data distribution. - // We just prorate the adjusted range over the initial range to compute filter selectivity. - percent = op match { - case LessThan(l, r) => - (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case LessThanOrEqual(l, r) => - if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble - else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case GreaterThan(l, r) => - (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) - case GreaterThanOrEqual(l, r) => - if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble - else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) - } - - if (update) { - op match { - case GreaterThan(l, r) => newMin = Some(literal.value) - case GreaterThanOrEqual(l, r) => newMin = Some(literal.value) - case LessThan(l, r) => newMax = Some(literal.value) - case LessThanOrEqual(l, r) => newMax = Some(literal.value) - } - newNdv = math.max(math.round(ndv.toDouble * percent), 1) - val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) - - mutableColStats += (attrRef.exprId -> newStats) - } - } - - percent - } - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala deleted file mode 100644 index 24a4f9b1ca66..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/estimation/Range.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.estimation - -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} - - -/** Value range of a column. */ -trait Range - -/** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range - -/** - * This version of Spark does not have min/max for binary/string types, we define their default - * behaviors by this class. - */ -class DefaultRange extends Range - -/** This is for columns with only null values. */ -class NullRange extends Range - -object Range { - def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { - case StringType | BinaryType => new DefaultRange() - case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) - } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case _ => - throw new AnalysisException(s"Type $dataType is not castable to numeric in estimation.") - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala deleted file mode 100644 index a82f61a6ae9b..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/estimation/FilterEstimationSuite.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.estimation - -import org.apache.spark.sql.catalyst.expressions.AttributeMap -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ -import org.apache.spark.sql.test.SharedSQLContext - -/** - * In this test suite, we test the proedicates containing the following operators: - * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN - */ - -class FilterEstimationSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val data1 = Seq[Long](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - private val table1 = "filter_estimation_test1" - - test("filter estimation with equality comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 2" - val colStats = Seq("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(1L)) - } - } - - test("filter estimation with less than comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 < 3" - val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with less than or equal to comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 <= 3" - val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with greater than comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 6" - val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(5L)) - } - } - - test("filter estimation with greater than or equal to comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 >= 6" - val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(5L)) - } - } - - test("filter estimation with IS NULL comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NULL" - val colStats = Seq("key1" -> ColumnStat(0, None, None, 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(0L)) - } - } - - test("filter estimation with IS NOT NULL comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NOT NULL" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(10L)) - } - } - - test("filter estimation with logical AND operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6" - val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(4L)) - } - } - - test("filter estimation with logical OR operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(2L)) - } - } - - test("filter estimation with IN operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)" - val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with logical NOT operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(7L)) - } - } - - private def validateEstimatedStats( - sqlStmt: String, - expectedColStats: Seq[(String, ColumnStat)], - rowCount: Option[Long] = None) - : Unit = { - val logicalPlan = sql(sqlStmt).queryExecution.optimizedPlan - val operNode = logicalPlan.collect { - case oper: Filter => - oper - }.head - val expectedRowCount = rowCount.getOrElse(sql(sqlStmt).collect().head.getLong(0)) - val nameToAttr = operNode.output.map(a => (a.name, a)).toMap - val expectedAttrStats = - AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) - val expectedStats = Statistics( - sizeInBytes = expectedRowCount * getRowSize(operNode.output, expectedAttrStats), - rowCount = Some(expectedRowCount), - attributeStats = expectedAttrStats, - isBroadcastable = false) - - val filterStats = operNode.statistics - assert(filterStats == expectedStats) - } - -} From f007a4d34d453b9c53c00e2e07a1a651ddb8a056 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 9 Jan 2017 19:35:08 -0800 Subject: [PATCH 10/36] added files to statsEstimation --- .../statsEstimation/FilterEstimation.scala | 548 ++++++++++++++++++ .../plans/logical/statsEstimation/Range.scala | 1 + .../FilterEstimationSuite.scala | 226 ++++++++ 3 files changed, 775 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala new file mode 100644 index 000000000000..c58cd84dc5b5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.estimation + +import scala.collection.immutable.{HashSet, Map} +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +class FilterEstimation extends Logging { + + /** + * We use a mutable colStats because we need to update the corresponding ColumnStat + * for a column after we apply a predicate condition. For example, A column c has + * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), + * we need to set the column's [min, max] value to [40, 100] after we evaluate the + * first condition c > 40. We need to set the column's [min, max] value to [40, 50] + * after we evaluate the second condition c <= 50. + */ + private var mutableColStats: mutable.Map[ExprId, ColumnStat] = mutable.Map.empty + + /** + * Returns an option of Statistics for a Filter logical plan node. + * For a given compound expression condition, this method computes filter selectivity + * (or the percentage of rows meeting the filter condition), which + * is used to compute row count, size in bytes, and the updated statistics after a given + * predicated is applied. + * + * @param plan a LogicalPlan node that must be an instance of Filter. + * @return Option[Statistics] When there is no statistics collected, it returns None. + */ + def estimate(plan: Filter): Option[Statistics] = { + val stats: Statistics = plan.child.statistics + if (stats.rowCount.isEmpty) return None + + // save a mutable copy of colStats so that we can later change it recursively + val statsExprIdMap: Map[ExprId, ColumnStat] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) + mutableColStats = mutable.Map.empty ++= statsExprIdMap + + // save a copy of ExprId-to-Attribute map for later conversion use + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) + + // estimate selectivity for this filter predicate + val percent: Double = calculateConditions(plan, plan.condition) + + // copy mutableColStats contents to an immutable AttributeMap + val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = + mutableColStats.map(kv => (expridToAttrMap(kv._1) -> kv._2)) + val newColStats = AttributeMap(mutableAttributeStats.toSeq) + + val filteredRowCountValue: BigInt = + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * percent) + val avgRowSize = BigDecimal(EstimationUtils.getRowSize(plan.output, newColStats)) + val filteredSizeInBytes: BigInt = + EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) + + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), + attributeStats = newColStats)) + } + + /** + * Returns a percentage of rows meeting a compound condition in Filter node. + * A compound condition is depomposed into multiple single conditions linked with AND, OR, NOT. + * For logical AND conditions, we need to update stats after a condition estimation + * so that the stats will be more accurate for subsequent estimation. This is needed for + * range condition such as (c > 40 AND c <= 50) + * For logical OR conditions, we do not update stats after a condition estimation. + * + * @param plan the Filter LogicalPlan node + * @param condition the compound logical expression + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def calculateConditions( + plan: Filter, + condition: Expression, + update: Boolean = true) + : Double = { + + condition match { + case And(cond1, cond2) => + val p1 = calculateConditions(plan, cond1, update) + val p2 = calculateConditions(plan, cond2, update) + p1 * p2 + + case Or(cond1, cond2) => + val p1 = calculateConditions(plan, cond1, update = false) + val p2 = calculateConditions(plan, cond2, update = false) + math.min(1.0, p1 + p2 - (p1 * p2)) + + case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) + case _ => calculateSingleCondition(plan, condition, isNot = false, update) + } + } + + /** + * Returns a percentage of rows meeting a single condition in Filter node. + * Currently we only support binary predicates where one side is a column, + * and the other is a literal. + * + * @param plan the Filter LogicalPlan node + * @param condition a single logical expression + * @param isNot set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def calculateSingleCondition( + plan: Filter, + condition: Expression, + isNot: Boolean, + update: Boolean) + : Double = { + var notSupported: Boolean = false + val percent: Double = condition match { + // For evaluateBinary method, we assume the literal on the right side of an operator. + // So we will change the order if not. + case op@LessThan(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@LessThan(l: Literal, ExtractAttr(ar)) => + evaluateBinary(GreaterThan(ar, l), ar, l, update) + + case op@LessThanOrEqual(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@LessThanOrEqual(l: Literal, ExtractAttr(ar)) => + evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) + + case op@GreaterThan(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@GreaterThan(l: Literal, ExtractAttr(ar)) => + evaluateBinary(LessThan(ar, l), ar, l, update) + + case op@GreaterThanOrEqual(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => + evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) + + // EqualTo does not care about the order + case op@EqualTo(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@EqualTo(l: Literal, ExtractAttr(ar)) => + evaluateBinary(op, ar, l, update) + + case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => + // Expression [In (value, seq[Literal])] will be replaced with optimized version + // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. + // Here we convert In into InSet anyway, because they share the same processing logic. + val hSet = expList.map(e => e.eval()) + evaluateInSet(ar, HashSet() ++ hSet, update) + + case InSet(ExtractAttr(ar), set) => + evaluateInSet(ar, set, update) + + // It's difficult to estimate IsNull after outer joins. Hence, + // we support IsNull and IsNotNull only when the child is a leaf node (table). + case IsNull(ExtractAttr(ar)) => + if (plan.child.isInstanceOf[LeafNode ]) { + evaluateIsNull(plan, ar, true, update) + } + else 1.0 + + case IsNotNull(ExtractAttr(ar)) => + if (plan.child.isInstanceOf[LeafNode ]) { + evaluateIsNull(plan, ar, false, update) + } + else 1.0 + + case _ => + // TODO: it's difficult to support string operators without advanced statistics. + // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) + // | EndsWith(_, _) are not supported yet + logDebug("[CBO] Unsupported filter condition: " + condition) + notSupported = true + 1.0 + } + if (notSupported) { + 1.0 + } else if (isNot) { + 1.0 - percent + } else { + percent + } + } + + /** + * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. + * + * @param plan the Filter LogicalPlan node + * @param attrRef an AttributeReference (or a column) + * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def evaluateIsNull( + plan: Filter, + attrRef: AttributeReference, + isNull: Boolean, + update: Boolean) + : Double = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + val aColStat = mutableColStats(attrRef.exprId) + val rowCountValue = plan.child.statistics.rowCount.get + val nullPercent: BigDecimal = + if (rowCountValue == 0) 0.0 + else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) + + if (update) { + val newStats = + if (isNull) aColStat.copy(distinctCount = 0, min = None, max = None) + else aColStat.copy(nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + + val percent = + if (isNull) nullPercent.toDouble + else { + /** ISNOTNULL(column) */ + 1.0 - nullPercent.toDouble + } + + percent + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def evaluateBinary( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + + // Make sure that the Date/Timestamp literal is a valid one + attrRef.dataType match { + case DateType => + val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) + if (dateLiteral.isEmpty) { + logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) + return 1.0 + } + case TimestampType => + val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) + if (tsLiteral.isEmpty) { + logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) + return 1.0 + } + case _ => + } + + op match { + case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) + case _ => + attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + evaluateBinaryForNumeric(op, attrRef, literal, update) + case StringType | BinaryType => + + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + + logDebug("[CBO] No statistics for String/Binary type " + attrRef) + return 1.0 + } + } + } + + /** + * This method converts a numeric or Literal value of numeric type to a BigDecimal value. + * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. + */ + def numericLiteralToBigDecimal( + literal: Any, + dataType: DataType, + isNumeric: Boolean = false) + : BigDecimal = { + dataType match { + case _: IntegralType => + if (isNumeric) BigDecimal(literal.asInstanceOf[Long]) + else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Long]) + case _: FractionalType => + if (isNumeric) BigDecimal(literal.asInstanceOf[Double]) + else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Double]) + case DateType => + if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) + else { + val dateLiteral = DateTimeUtils.stringToDate( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) + BigDecimal(dateLiteral.asInstanceOf[BigInt]) + } + case TimestampType => + if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) + else { + val tsLiteral = DateTimeUtils.stringToTimestamp( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) + BigDecimal(tsLiteral.asInstanceOf[BigInt]) + } + } + } + + /** + * Returns a percentage of rows meeting an equality (=) expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def evaluateEqualTo( + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + + + // decide if the value is in [min, max] of the column. + // We currently don't store min/max for binary/string type. + // Hence, we assume it is in boundary for binary/string type. + + val inBoundary: Boolean = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + val lit = numericLiteralToBigDecimal(literal, attrRef.dataType) + (lit >= statsRange.min) && (lit <= statsRange.max) + + case _ => true /** for String/Binary type */ + } + + val percent: Double = + if (inBoundary) { + + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + val newStats = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + val newValue = Some(literal.value) + aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + case _ => aColStat.copy(distinctCount = 1, nullCount = 0) + } + mutableColStats += (attrRef.exprId -> newStats) + } + + 1.0 / ndv.toDouble + } else { + 0.0 + } + + percent + } + + /** + * Returns a percentage of rows meeting "IN" operator expression. + * This method evaluates the equality predicate for all data types. + * + * @param attrRef an AttributeReference (or a column) + * @param hSet a set of literal values + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + + def evaluateInSet( + attrRef: AttributeReference, + hSet: Set[Any], + update: Boolean) + : Double = { + if (!mutableColStats.contains(attrRef.exprId)) { + logDebug("[CBO] No statistics for " + attrRef) + return 1.0 + } + + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val aType = attrRef.dataType + + // use [min, max] to filter the original hSet + val validQuerySet = aType match { + case _: NumericType | DateType | TimestampType => + val statsRange = + Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] + hSet.map(e => numericLiteralToBigDecimal(e, aType, true)). + filter(e => e >= statsRange.min && e <= statsRange.max) + + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => hSet + } + if (validQuerySet.isEmpty) { + return 0.0 + } + + val newNdv = validQuerySet.size + val(newMax, newMin) = aType match { + case _: NumericType | DateType | TimestampType => + val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) + (Some(tmpSet.max), Some(tmpSet.min)) + case _ => + (None, None) + } + + if (update) { + val newStats = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + case StringType | BinaryType => + aColStat.copy(distinctCount = newNdv, nullCount = 0) + } + mutableColStats += (attrRef.exprId -> newStats) + } + + // return the filter selectivity. Without advanced statistics such as histograms, + // we have to assume uniform distribution. + math.min(1.0, validQuerySet.size / ndv.toDouble) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression. + * This method evaluate expression for Numeric columns only. + * + * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param attrRef an AttributeReference (or a column) + * @param literal a literal value (or constant) + * @param update a boolean flag to specify if we need to update ColumnStat of a given column + * for subsequent conditions + * @return a doube value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForNumeric( + op: BinaryComparison, + attrRef: AttributeReference, + literal: Literal, + update: Boolean) + : Double = { + + var percent = 1.0 + val aColStat = mutableColStats(attrRef.exprId) + val ndv = aColStat.distinctCount + val statsRange = + Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] + + val literalValueBD = numericLiteralToBigDecimal(literal, attrRef.dataType) + + // determine the overlapping degree between predicate range and column's range + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + case LessThan(l, r) => + (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + case LessThanOrEqual(l, r) => + (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + case GreaterThan(l, r) => + (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + case GreaterThanOrEqual(l, r) => + (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + } + + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // this is partial overlap case + var newMax = aColStat.max + var newMin = aColStat.min + var newNdv = ndv + val literalToDouble = literalValueBD.toDouble + val maxToDouble = BigDecimal(statsRange.max).toDouble + val minToDouble = BigDecimal(statsRange.min).toDouble + + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + percent = op match { + case LessThan(l, r) => + (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case LessThanOrEqual(l, r) => + if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble + else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) + case GreaterThan(l, r) => + (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + case GreaterThanOrEqual(l, r) => + if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble + else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) + } + + if (update) { + op match { + case GreaterThan(l, r) => newMin = Some(literal.value) + case GreaterThanOrEqual(l, r) => newMin = Some(literal.value) + case LessThan(l, r) => newMax = Some(literal.value) + case LessThanOrEqual(l, r) => newMax = Some(literal.value) + } + newNdv = math.max(math.round(ndv.toDouble * percent), 1) + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + + mutableColStats += (attrRef.exprId -> newStats) + } + } + + percent + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 5aa6b9353bc4..4a346c924db6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -113,4 +113,5 @@ object Range { DateTimeUtils.toJavaTimestamp(n.max.longValue())) } } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala new file mode 100644 index 000000000000..a82f61a6ae9b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.estimation + +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * In this test suite, we test the proedicates containing the following operators: + * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN + */ + +class FilterEstimationSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val data1 = Seq[Long](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + private val table1 = "filter_estimation_test1" + + test("filter estimation with equality comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 2" + val colStats = Seq("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(1L)) + } + } + + test("filter estimation with less than comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 < 3" + val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) + } + } + + test("filter estimation with less than or equal to comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 <= 3" + val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) + } + } + + test("filter estimation with greater than comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 6" + val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(5L)) + } + } + + test("filter estimation with greater than or equal to comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 >= 6" + val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(5L)) + } + } + + test("filter estimation with IS NULL comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NULL" + val colStats = Seq("key1" -> ColumnStat(0, None, None, 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(0L)) + } + } + + test("filter estimation with IS NOT NULL comparison") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NOT NULL" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(10L)) + } + } + + test("filter estimation with logical AND operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6" + val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(4L)) + } + } + + test("filter estimation with logical OR operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(2L)) + } + } + + test("filter estimation with IN operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)" + val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(3L)) + } + } + + test("filter estimation with logical NOT operator") { + val df1 = data1.toDF("key1") + withTable(table1) { + df1.write.saveAsTable(table1) + + /** Collect statistics */ + sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") + + /** Validate statistics */ + val sqlStmt = s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)" + val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) + validateEstimatedStats(sqlStmt, colStats, Some(7L)) + } + } + + private def validateEstimatedStats( + sqlStmt: String, + expectedColStats: Seq[(String, ColumnStat)], + rowCount: Option[Long] = None) + : Unit = { + val logicalPlan = sql(sqlStmt).queryExecution.optimizedPlan + val operNode = logicalPlan.collect { + case oper: Filter => + oper + }.head + val expectedRowCount = rowCount.getOrElse(sql(sqlStmt).collect().head.getLong(0)) + val nameToAttr = operNode.output.map(a => (a.name, a)).toMap + val expectedAttrStats = + AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) + val expectedStats = Statistics( + sizeInBytes = expectedRowCount * getRowSize(operNode.output, expectedAttrStats), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttrStats, + isBroadcastable = false) + + val filterStats = operNode.statistics + assert(filterStats == expectedStats) + } + +} From 7caf600fface838fc5b14b2c239a2c17f23b60fd Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 10 Jan 2017 13:52:31 -0800 Subject: [PATCH 11/36] reduce the dependency of unit test case FilterEstimationSuite --- .../statsEstimation/EstimationUtils.scala | 15 +- .../statsEstimation/FilterEstimation.scala | 17 +- .../FilterEstimationSuite.scala | 223 +++--------------- 3 files changed, 49 insertions(+), 206 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 4d18b28be866..0db199798098 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.types.{DataType, StringType} @@ -51,6 +51,8 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() + def getOutputSize( attributes: Seq[Attribute], outputRowCount: BigInt, @@ -76,3 +78,12 @@ object EstimationUtils { if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } } + +/** Attribute Reference extractor */ +object ExtractAttr { + def unapply(exp: Expression): Option[AttributeReference] = exp match { + case ar: AttributeReference => Some(ar) + case _ => None + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index c58cd84dc5b5..baae4bee4345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans.logical.estimation +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.{HashSet, Map} import scala.collection.mutable @@ -59,16 +59,17 @@ class FilterEstimation extends Logging { stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) mutableColStats = mutable.Map.empty ++= statsExprIdMap - // save a copy of ExprId-to-Attribute map for later conversion use - val expridToAttrMap: Map[ExprId, Attribute] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) - - // estimate selectivity for this filter predicate + // estimate selectivity of this filter predicate val percent: Double = calculateConditions(plan, plan.condition) - // copy mutableColStats contents to an immutable AttributeMap + // attributeStats has mapping Attribute-to-ColumnStat. + // mutableColStats has mapping ExprId-to-ColumnStat. + // We use an ExprId-to-Attribute map to facilitate the mapping Attribute-to-ColumnStat + val expridToAttrMap: Map[ExprId, Attribute] = + stats.attributeStats.map(kv => (kv._1.exprId, kv._1)) + // copy mutableColStats contents to an immutable AttributeMap. val mutableAttributeStats: mutable.Map[Attribute, ColumnStat] = - mutableColStats.map(kv => (expridToAttrMap(kv._1) -> kv._2)) + mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) val newColStats = AttributeMap(mutableAttributeStats.toSeq) val filteredRowCountValue: BigInt = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index a82f61a6ae9b..086e53711160 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -15,212 +16,42 @@ * limitations under the License. */ -package org.apache.spark.sql.estimation +package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.expressions.AttributeMap -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.estimation.EstimationUtils._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.types.IntegerType /** * In this test suite, we test the proedicates containing the following operators: * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN */ -class FilterEstimationSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val data1 = Seq[Long](1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - private val table1 = "filter_estimation_test1" +class FilterEstimationSuite extends StatsEstimationTestBase { test("filter estimation with equality comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 2" - val colStats = Seq("key1" -> ColumnStat(1, Some(2L), Some(2L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(1L)) - } - } - - test("filter estimation with less than comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 < 3" - val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with less than or equal to comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 <= 3" - val colStats = Seq("key1" -> ColumnStat(2, Some(1L), Some(3L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with greater than comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 6" - val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(5L)) - } - } - - test("filter estimation with greater than or equal to comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 >= 6" - val colStats = Seq("key1" -> ColumnStat(4, Some(6L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(5L)) - } - } - - test("filter estimation with IS NULL comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NULL" - val colStats = Seq("key1" -> ColumnStat(0, None, None, 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(0L)) - } - } - - test("filter estimation with IS NOT NULL comparison") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IS NOT NULL" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(10L)) - } - } - - test("filter estimation with logical AND operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 > 3 AND key1 <= 6" - val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(6L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(4L)) - } - } - - test("filter estimation with logical OR operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 = 3 OR key1 = 6" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(2L)) - } - } - - test("filter estimation with IN operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 IN (3, 4, 5)" - val colStats = Seq("key1" -> ColumnStat(3, Some(3L), Some(5L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(3L)) - } - } - - test("filter estimation with logical NOT operator") { - val df1 = data1.toDF("key1") - withTable(table1) { - df1.write.saveAsTable(table1) - - /** Collect statistics */ - sql(s"analyze table $table1 compute STATISTICS FOR COLUMNS key1") - - /** Validate statistics */ - val sqlStmt = s"SELECT * FROM $table1 WHERE key1 NOT IN (3, 4, 5)" - val colStats = Seq("key1" -> ColumnStat(10, Some(1L), Some(10L), 0, 8, 8)) - validateEstimatedStats(sqlStmt, colStats, Some(7L)) - } - } - - private def validateEstimatedStats( - sqlStmt: String, - expectedColStats: Seq[(String, ColumnStat)], - rowCount: Option[Long] = None) - : Unit = { - val logicalPlan = sql(sqlStmt).queryExecution.optimizedPlan - val operNode = logicalPlan.collect { - case oper: Filter => - oper - }.head - val expectedRowCount = rowCount.getOrElse(sql(sqlStmt).collect().head.getLong(0)) - val nameToAttr = operNode.output.map(a => (a.name, a)).toMap - val expectedAttrStats = - AttributeMap(expectedColStats.map(kv => nameToAttr(kv._1) -> kv._2)) + val ar = AttributeReference("key1", IntegerType)() + val colStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + + val child = StatsTestPlan( + outputList = Seq(ar), + stats = Statistics( + sizeInBytes = 10 * 4, + rowCount = Some(10), + attributeStats = AttributeMap(Seq(ar -> colStat)) + ) + ) + + val filterNode = Filter(condition: Expression, child) + val expectedColStats = Seq("key1" -> colStat) + val expectedAttrStats = toAttributeMap(expectedColStats, filterNode) + // The number of rows won't change for project. val expectedStats = Statistics( - sizeInBytes = expectedRowCount * getRowSize(operNode.output, expectedAttrStats), - rowCount = Some(expectedRowCount), - attributeStats = expectedAttrStats, - isBroadcastable = false) - - val filterStats = operNode.statistics - assert(filterStats == expectedStats) + sizeInBytes = 2 * getRowSize(filterNode.output, expectedAttrStats), + rowCount = Some(2), + attributeStats = expectedAttrStats) + assert(filterNode.statistics == expectedStats) } - } + From 6fe19945c3f40534a16ee24b3818df019970ac25 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 10 Jan 2017 20:35:52 -0800 Subject: [PATCH 12/36] make FilterEstimationSuite modular --- .../statsEstimation/FilterEstimation.scala | 22 ++- .../FilterEstimationSuite.scala | 156 +++++++++++++++--- 2 files changed, 150 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index baae4bee4345..7d187e19d16a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -140,6 +140,13 @@ class FilterEstimation extends Logging { val percent: Double = condition match { // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. + + // EqualTo does not care about the order + case op@EqualTo(ExtractAttr(ar), l: Literal) => + evaluateBinary(op, ar, l, update) + case op@EqualTo(l: Literal, ExtractAttr(ar)) => + evaluateBinary(op, ar, l, update) + case op@LessThan(ExtractAttr(ar), l: Literal) => evaluateBinary(op, ar, l, update) case op@LessThan(l: Literal, ExtractAttr(ar)) => @@ -160,12 +167,6 @@ class FilterEstimation extends Logging { case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - // EqualTo does not care about the order - case op@EqualTo(ExtractAttr(ar), l: Literal) => - evaluateBinary(op, ar, l, update) - case op@EqualTo(l: Literal, ExtractAttr(ar)) => - evaluateBinary(op, ar, l, update) - case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. @@ -317,11 +318,15 @@ class FilterEstimation extends Logging { : BigDecimal = { dataType match { case _: IntegralType => - if (isNumeric) BigDecimal(literal.asInstanceOf[Long]) - else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Long]) + val stringValue: String = + if (isNumeric) literal.toString + else literal.asInstanceOf[Literal].value.toString + BigDecimal(java.lang.Long.valueOf(stringValue)) + case _: FractionalType => if (isNumeric) BigDecimal(literal.asInstanceOf[Double]) else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Double]) + case DateType => if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) else { @@ -329,6 +334,7 @@ class FilterEstimation extends Logging { literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) BigDecimal(dateLiteral.asInstanceOf[BigInt]) } + case TimestampType => if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 086e53711160..e0d1f258286d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -18,7 +17,7 @@ package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types.IntegerType @@ -30,28 +29,145 @@ import org.apache.spark.sql.types.IntegerType class FilterEstimationSuite extends StatsEstimationTestBase { - test("filter estimation with equality comparison") { - val ar = AttributeReference("key1", IntegerType)() - val colStat = ColumnStat(2, Some(1), Some(2), 0, 4, 4) - - val child = StatsTestPlan( - outputList = Seq(ar), - stats = Statistics( - sizeInBytes = 10 * 4, - rowCount = Some(10), - attributeStats = AttributeMap(Seq(ar -> colStat)) - ) + // Suppose our test table has one column called "key1". + // It has 10 rows with values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val ar = AttributeReference("key1", IntegerType)() + val childColStat = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + val child = StatsTestPlan( + outputList = Seq(ar), + stats = Statistics( + sizeInBytes = 10 * 4, + rowCount = Some(10), + attributeStats = AttributeMap(Seq(ar -> childColStat)) ) + ) + + test("filter estimation with equality comparison") { + // the predicate is "WHERE key1 = 2" + val intValue = Literal(2, IntegerType) + val condition = EqualTo(ar, intValue) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(1, Some(2), Some(2), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(1L)) + } + + test("filter estimation with less than comparison") { + // the predicate is "WHERE key1 < 3" + val intValue = Literal(3, IntegerType) + val condition = LessThan(ar, intValue) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(2, Some(1), Some(3), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(3L)) + } + + test("filter estimation with less than or equal to comparison") { + // the predicate is "WHERE key1 <= 3" + val intValue = Literal(3, IntegerType) + val condition = LessThanOrEqual(ar, intValue) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(2, Some(1), Some(3), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(3L)) + + } + + test("filter estimation with greater than comparison") { + // the predicate is "WHERE key1 > 6" + val intValue = Literal(6, IntegerType) + val condition = GreaterThan(ar, intValue) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(4, Some(6), Some(10), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(5L)) + } + + test("filter estimation with greater than or equal to comparison") { + // the predicate is "WHERE key1 >= 6" + val intValue = Literal(6, IntegerType) + val condition = GreaterThanOrEqual(ar, intValue) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(4, Some(6), Some(10), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(5L)) + + } + + test("filter estimation with IS NULL comparison") { + // the predicate is "WHERE key1 IS NULL" + val condition = IsNull(ar) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(0, None, None, 0, 4, 4) - val filterNode = Filter(condition: Expression, child) - val expectedColStats = Seq("key1" -> colStat) - val expectedAttrStats = toAttributeMap(expectedColStats, filterNode) - // The number of rows won't change for project. + validateEstimatedStats(filterNode, filteredColStats, Some(0L)) + } + + test("filter estimation with IS NOT NULL comparison") { + // the predicate is "WHERE key1 IS NOT NULL" + val condition = IsNotNull(ar) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(10L)) + } + + test("filter estimation with logical AND operator") { + // the predicate is "WHERE key1 > 3 AND key1 <= 6" + val condition1 = GreaterThan(ar, Literal(3, IntegerType)) + val condition2 = LessThanOrEqual(ar, Literal(6, IntegerType)) + val condition = And(condition1, condition2) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(3, Some(3), Some(6), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(4L)) + } + + test("filter estimation with logical OR operator") { + // the predicate is "WHERE key1 = 3 OR key1 = 6" + val condition1 = EqualTo(ar, Literal(3, IntegerType)) + val condition2 = EqualTo(ar, Literal(6, IntegerType)) + val condition = Or(condition1, condition2) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(2L)) + } + + test("filter estimation with logical IN operator") { + // the predicate is "WHERE key1 IN (3, 4, 5)" + val condition = InSet(ar, Set(3, 4, 5)) + val filterNode = Filter(condition, child) + val filteredColStats = ColumnStat(3, Some(3), Some(5), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(3L)) + } + + test("filter estimation with logical NOT operator") { + // the predicate is "WHERE key1 NOT IN (3, 4, 5)" + val condition = InSet(ar, Set(3, 4, 5)) + val notCondition = Not(condition) + val filterNode = Filter(notCondition, child) + val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + + validateEstimatedStats(filterNode, filteredColStats, Some(7L)) + } + + private def validateEstimatedStats( + filterNode: Filter, + filteredColStats: ColumnStat, + rowCount: Option[Long] = None) + : Unit = { + + val expectedRowCount = rowCount.getOrElse(0L) + val expectedAttrStats = toAttributeMap(Seq("key1" -> filteredColStats), filterNode) val expectedStats = Statistics( - sizeInBytes = 2 * getRowSize(filterNode.output, expectedAttrStats), - rowCount = Some(2), + sizeInBytes = expectedRowCount * getRowSize(filterNode.output, expectedAttrStats), + rowCount = Some(expectedRowCount), attributeStats = expectedAttrStats) + assert(filterNode.statistics == expectedStats) } -} +} From 490f41a7c813cedb3f9f010ff50e77ec9334e60a Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 12 Jan 2017 11:12:53 -0800 Subject: [PATCH 13/36] use variable binding pattern matching --- .../statsEstimation/EstimationUtils.scala | 9 - .../statsEstimation/FilterEstimation.scala | 75 +++--- .../FilterEstimationSuite.scala | 218 ++++++++++-------- 3 files changed, 164 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0db199798098..f226944520d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -78,12 +78,3 @@ object EstimationUtils { if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } } - -/** Attribute Reference extractor */ -object ExtractAttr { - def unapply(exp: Expression): Option[AttributeReference] = exp match { - case ar: AttributeReference => Some(ar) - case _ => None - } -} - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7d187e19d16a..04e470094117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -84,7 +84,7 @@ class FilterEstimation extends Logging { /** * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is depomposed into multiple single conditions linked with AND, OR, NOT. + * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) @@ -125,7 +125,7 @@ class FilterEstimation extends Logging { * * @param plan the Filter LogicalPlan node * @param condition a single logical expression - * @param isNot set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition + * @param isNot set to true for Not logical operator. Otherwise it is set to false. * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions * @return a doube value to show the percentage of rows meeting a given condition @@ -142,54 +142,56 @@ class FilterEstimation extends Logging { // So we will change the order if not. // EqualTo does not care about the order - case op@EqualTo(ExtractAttr(ar), l: Literal) => + case op @ EqualTo(ar: AttributeReference, l: Literal) => evaluateBinary(op, ar, l, update) - case op@EqualTo(l: Literal, ExtractAttr(ar)) => + case op @ EqualTo(l: Literal, ar: AttributeReference) => evaluateBinary(op, ar, l, update) - case op@LessThan(ExtractAttr(ar), l: Literal) => + case op @ LessThan(ar: AttributeReference, l: Literal) => evaluateBinary(op, ar, l, update) - case op@LessThan(l: Literal, ExtractAttr(ar)) => + case op @ LessThan(l: Literal, ar: AttributeReference) => evaluateBinary(GreaterThan(ar, l), ar, l, update) - case op@LessThanOrEqual(ExtractAttr(ar), l: Literal) => + case op @ LessThanOrEqual(ar: AttributeReference, l: Literal) => evaluateBinary(op, ar, l, update) - case op@LessThanOrEqual(l: Literal, ExtractAttr(ar)) => + case op @ LessThanOrEqual(l: Literal, ar: AttributeReference) => evaluateBinary(GreaterThanOrEqual(ar, l), ar, l, update) - case op@GreaterThan(ExtractAttr(ar), l: Literal) => + case op @ GreaterThan(ar: AttributeReference, l: Literal) => evaluateBinary(op, ar, l, update) - case op@GreaterThan(l: Literal, ExtractAttr(ar)) => + case op @ GreaterThan(l: Literal, ar: AttributeReference) => evaluateBinary(LessThan(ar, l), ar, l, update) - case op@GreaterThanOrEqual(ExtractAttr(ar), l: Literal) => + case op @ GreaterThanOrEqual(ar: AttributeReference, l: Literal) => evaluateBinary(op, ar, l, update) - case op@GreaterThanOrEqual(l: Literal, ExtractAttr(ar)) => + case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ExtractAttr(ar), expList) if !expList.exists(!_.isInstanceOf[Literal]) => + case In(ar: AttributeReference, expList) if !expList.exists(!_.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. // Here we convert In into InSet anyway, because they share the same processing logic. val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ExtractAttr(ar), set) => + case InSet(ar: AttributeReference, set) => evaluateInSet(ar, set, update) // It's difficult to estimate IsNull after outer joins. Hence, // we support IsNull and IsNotNull only when the child is a leaf node (table). - case IsNull(ExtractAttr(ar)) => + case IsNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { evaluateIsNull(plan, ar, true, update) + } else { + 1.0 } - else 1.0 - case IsNotNull(ExtractAttr(ar)) => + case IsNotNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { evaluateIsNull(plan, ar, false, update) + } else { + 1.0 } - else 1.0 case _ => // TODO: it's difficult to support string operators without advanced statistics. @@ -364,7 +366,6 @@ class FilterEstimation extends Logging { val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount - // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. @@ -379,28 +380,26 @@ class FilterEstimation extends Logging { case _ => true /** for String/Binary type */ } - val percent: Double = - if (inBoundary) { - - if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - val newStats = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - val newValue = Some(literal.value) - aColStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) - case _ => aColStat.copy(distinctCount = 1, nullCount = 0) - } - mutableColStats += (attrRef.exprId -> newStats) - } + if (inBoundary) { - 1.0 / ndv.toDouble - } else { - 0.0 + if (update) { + // We update ColumnStat structure after apply this equality predicate. + // Set distinctCount to 1. Set nullCount to 0. + val newStats = attrRef.dataType match { + case _: NumericType | DateType | TimestampType => + val newValue = Some(literal.value) + aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) + case _ => aColStat.copy(distinctCount = 1, nullCount = 0) + } + mutableColStats += (attrRef.exprId -> newStats) } - percent + 1.0 / ndv.toDouble + } else { + 0.0 + } + } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index e0d1f258286d..80cb28022e6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -23,17 +23,18 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUti import org.apache.spark.sql.types.IntegerType /** - * In this test suite, we test the proedicates containing the following operators: + * In this test suite, we test predicates containing the following operators: * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN */ class FilterEstimationSuite extends StatsEstimationTestBase { - // Suppose our test table has one column called "key1". + // Suppose our test table has one column called "key". // It has 10 rows with values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val ar = AttributeReference("key1", IntegerType)() - val childColStat = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + val ar = AttributeReference("key", IntegerType)() + val childColStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) val child = StatsTestPlan( outputList = Seq(ar), stats = Statistics( @@ -43,125 +44,160 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) ) - test("filter estimation with equality comparison") { - // the predicate is "WHERE key1 = 2" - val intValue = Literal(2, IntegerType) - val condition = EqualTo(ar, intValue) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(1, Some(2), Some(2), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(1L)) + test("key = 2") { + // the predicate is "WHERE key = 2" + validateEstimatedStats( + Filter(EqualTo(ar, Literal(2)), child), + ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) } - test("filter estimation with less than comparison") { - // the predicate is "WHERE key1 < 3" - val intValue = Literal(3, IntegerType) - val condition = LessThan(ar, intValue) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(2, Some(1), Some(3), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(3L)) + test("key = 0") { + // the predicate is "WHERE key = 0" + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + Filter(EqualTo(ar, Literal(0)), child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) } - test("filter estimation with less than or equal to comparison") { - // the predicate is "WHERE key1 <= 3" - val intValue = Literal(3, IntegerType) - val condition = LessThanOrEqual(ar, intValue) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(2, Some(1), Some(3), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(3L)) - + test("key < 3") { + // the predicate is "WHERE key < 3" + validateEstimatedStats( + Filter(LessThan(ar, Literal(3)), child), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) } - test("filter estimation with greater than comparison") { - // the predicate is "WHERE key1 > 6" - val intValue = Literal(6, IntegerType) - val condition = GreaterThan(ar, intValue) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(4, Some(6), Some(10), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(5L)) + test("key < 0") { + // the predicate is "WHERE key < 0" + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + Filter(LessThan(ar, Literal(0)), child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) } - test("filter estimation with greater than or equal to comparison") { - // the predicate is "WHERE key1 >= 6" - val intValue = Literal(6, IntegerType) - val condition = GreaterThanOrEqual(ar, intValue) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(4, Some(6), Some(10), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(5L)) + test("key <= 3") { + // the predicate is "WHERE key <= 3" + validateEstimatedStats( + Filter(LessThanOrEqual(ar, Literal(3)), child), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) } - test("filter estimation with IS NULL comparison") { - // the predicate is "WHERE key1 IS NULL" - val condition = IsNull(ar) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(0, None, None, 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(0L)) + test("key > 6") { + // the predicate is "WHERE key > 6" + validateEstimatedStats( + Filter(GreaterThan(ar, Literal(6)), child), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) } - test("filter estimation with IS NOT NULL comparison") { - // the predicate is "WHERE key1 IS NOT NULL" - val condition = IsNotNull(ar) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(10L)) + test("key > 10") { + // the predicate is "WHERE key > 10" + // This is a corner case since max value is 10. + validateEstimatedStats( + Filter(GreaterThan(ar, Literal(10)), child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) } - test("filter estimation with logical AND operator") { - // the predicate is "WHERE key1 > 3 AND key1 <= 6" - val condition1 = GreaterThan(ar, Literal(3, IntegerType)) - val condition2 = LessThanOrEqual(ar, Literal(6, IntegerType)) - val condition = And(condition1, condition2) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(3, Some(3), Some(6), 0, 4, 4) - - validateEstimatedStats(filterNode, filteredColStats, Some(4L)) + test("key >= 6") { + // the predicate is "WHERE key >= 6" + validateEstimatedStats( + Filter(GreaterThanOrEqual(ar, Literal(6)), child), + ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(5L) + ) } - test("filter estimation with logical OR operator") { - // the predicate is "WHERE key1 = 3 OR key1 = 6" - val condition1 = EqualTo(ar, Literal(3, IntegerType)) - val condition2 = EqualTo(ar, Literal(6, IntegerType)) - val condition = Or(condition1, condition2) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + test("key IS NULL") { + // the predicate is "WHERE key IS NULL" + validateEstimatedStats( + Filter(IsNull(ar), child), + ColumnStat(distinctCount = 0, min = None, max = None, + nullCount = 0, avgLen = 4, maxLen = 4), + Some(0L) + ) + } - validateEstimatedStats(filterNode, filteredColStats, Some(2L)) + test("key IS NOT NULL") { + // the predicate is "WHERE key IS NOT NULL" + validateEstimatedStats( + Filter(IsNotNull(ar), child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(10L) + ) } - test("filter estimation with logical IN operator") { - // the predicate is "WHERE key1 IN (3, 4, 5)" - val condition = InSet(ar, Set(3, 4, 5)) - val filterNode = Filter(condition, child) - val filteredColStats = ColumnStat(3, Some(3), Some(5), 0, 4, 4) + test("key > 3 AND key <= 6") { + // the predicate is "WHERE key > 3 AND key <= 6" + val condition = And(GreaterThan(ar, Literal(3)), LessThanOrEqual(ar, Literal(6))) + validateEstimatedStats( + Filter(condition, child), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(4L) + ) + } - validateEstimatedStats(filterNode, filteredColStats, Some(3L)) + test("key = 3 OR key = 6") { + // the predicate is "WHERE key = 3 OR key = 6" + val condition = Or(EqualTo(ar, Literal(3)), EqualTo(ar, Literal(6))) + validateEstimatedStats( + Filter(condition, child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) } - test("filter estimation with logical NOT operator") { - // the predicate is "WHERE key1 NOT IN (3, 4, 5)" - val condition = InSet(ar, Set(3, 4, 5)) - val notCondition = Not(condition) - val filterNode = Filter(notCondition, child) - val filteredColStats = ColumnStat(10, Some(1), Some(10), 0, 4, 4) + test("key IN (3, 4, 5)") { + // the predicate is "WHERE key IN (3, 4, 5)" + validateEstimatedStats( + Filter(InSet(ar, Set(3, 4, 5)), child), + ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } - validateEstimatedStats(filterNode, filteredColStats, Some(7L)) + test("key NOT IN (3, 4, 5)") { + // the predicate is "WHERE key NOT IN (3, 4, 5)" + validateEstimatedStats( + Filter(Not(InSet(ar, Set(3, 4, 5))), child), + ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(7L) + ) } private def validateEstimatedStats( filterNode: Filter, - filteredColStats: ColumnStat, + expectedColStats: ColumnStat, rowCount: Option[Long] = None) : Unit = { val expectedRowCount = rowCount.getOrElse(0L) - val expectedAttrStats = toAttributeMap(Seq("key1" -> filteredColStats), filterNode) + val expectedAttrStats = toAttributeMap(Seq("key" -> expectedColStats), filterNode) val expectedStats = Statistics( sizeInBytes = expectedRowCount * getRowSize(filterNode.output, expectedAttrStats), rowCount = Some(expectedRowCount), From 4bc300890535e3e4145fc91ab19cbbfa17e40aa0 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 12 Jan 2017 16:54:03 -0800 Subject: [PATCH 14/36] change return type to Option[Double] --- .../statsEstimation/FilterEstimation.scala | 73 +++++++++---------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 04e470094117..7cd99e60f17d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -60,7 +60,7 @@ class FilterEstimation extends Logging { mutableColStats = mutable.Map.empty ++= statsExprIdMap // estimate selectivity of this filter predicate - val percent: Double = calculateConditions(plan, plan.condition) + val filterSelectivity: Double = calculateConditions(plan, plan.condition) // attributeStats has mapping Attribute-to-ColumnStat. // mutableColStats has mapping ExprId-to-ColumnStat. @@ -73,7 +73,7 @@ class FilterEstimation extends Logging { val newColStats = AttributeMap(mutableAttributeStats.toSeq) val filteredRowCountValue: BigInt = - EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * percent) + EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) val avgRowSize = BigDecimal(EstimationUtils.getRowSize(plan.output, newColStats)) val filteredSizeInBytes: BigInt = EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) @@ -113,8 +113,15 @@ class FilterEstimation extends Logging { val p2 = calculateConditions(plan, cond2, update = false) math.min(1.0, p1 + p2 - (p1 * p2)) - case Not(cond) => calculateSingleCondition(plan, cond, isNot = true, update = false) - case _ => calculateSingleCondition(plan, condition, isNot = false, update) + case Not(cond) => calculateSingleCondition(plan, cond, update = false) match { + case Some(percent) => 1.0 - percent + case None => 1.0 + } + case _ => calculateSingleCondition(plan, condition, update) match { + case Some(percent) => percent + case None => 1.0 + // for not-supported condition, set filter selectivity to a conservative estimate 100% + } } } @@ -125,19 +132,17 @@ class FilterEstimation extends Logging { * * @param plan the Filter LogicalPlan node * @param condition a single logical expression - * @param isNot set to true for Not logical operator. Otherwise it is set to false. * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return Option[Double] value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. */ def calculateSingleCondition( plan: Filter, condition: Expression, - isNot: Boolean, update: Boolean) - : Double = { - var notSupported: Boolean = false - val percent: Double = condition match { + : Option[Double] = { + condition match { // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. @@ -183,14 +188,14 @@ class FilterEstimation extends Logging { if (plan.child.isInstanceOf[LeafNode ]) { evaluateIsNull(plan, ar, true, update) } else { - 1.0 + None } case IsNotNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { evaluateIsNull(plan, ar, false, update) } else { - 1.0 + None } case _ => @@ -198,15 +203,7 @@ class FilterEstimation extends Logging { // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) // | EndsWith(_, _) are not supported yet logDebug("[CBO] Unsupported filter condition: " + condition) - notSupported = true - 1.0 - } - if (notSupported) { - 1.0 - } else if (isNot) { - 1.0 - percent - } else { - percent + None } } @@ -225,10 +222,10 @@ class FilterEstimation extends Logging { attrRef: AttributeReference, isNull: Boolean, update: Boolean) - : Double = { + : Option[Double] = { if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) - return 1.0 + return None } val aColStat = mutableColStats(attrRef.exprId) val rowCountValue = plan.child.statistics.rowCount.get @@ -251,7 +248,7 @@ class FilterEstimation extends Logging { 1.0 - nullPercent.toDouble } - percent + Some(percent) } /** @@ -269,10 +266,10 @@ class FilterEstimation extends Logging { attrRef: AttributeReference, literal: Literal, update: Boolean) - : Double = { + : Option[Double] = { if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) - return 1.0 + return None } // Make sure that the Date/Timestamp literal is a valid one @@ -281,13 +278,13 @@ class FilterEstimation extends Logging { val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) if (dateLiteral.isEmpty) { logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) - return 1.0 + return None } case TimestampType => val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) if (tsLiteral.isEmpty) { logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) - return 1.0 + return None } case _ => } @@ -304,7 +301,7 @@ class FilterEstimation extends Logging { // type without min/max and advanced statistics like histogram. logDebug("[CBO] No statistics for String/Binary type " + attrRef) - return 1.0 + None } } } @@ -361,7 +358,7 @@ class FilterEstimation extends Logging { attrRef: AttributeReference, literal: Literal, update: Boolean) - : Double = { + : Option[Double] = { val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount @@ -395,9 +392,9 @@ class FilterEstimation extends Logging { mutableColStats += (attrRef.exprId -> newStats) } - 1.0 / ndv.toDouble + Some(1.0 / ndv.toDouble) } else { - 0.0 + Some(0.0) } } @@ -417,10 +414,10 @@ class FilterEstimation extends Logging { attrRef: AttributeReference, hSet: Set[Any], update: Boolean) - : Double = { + : Option[Double] = { if (!mutableColStats.contains(attrRef.exprId)) { logDebug("[CBO] No statistics for " + attrRef) - return 1.0 + return None } val aColStat = mutableColStats(attrRef.exprId) @@ -439,7 +436,7 @@ class FilterEstimation extends Logging { case StringType | BinaryType => hSet } if (validQuerySet.isEmpty) { - return 0.0 + return Some(0.0) } val newNdv = validQuerySet.size @@ -464,7 +461,7 @@ class FilterEstimation extends Logging { // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - math.min(1.0, validQuerySet.size / ndv.toDouble) + Some(math.min(1.0, validQuerySet.size / ndv.toDouble)) } /** @@ -483,7 +480,7 @@ class FilterEstimation extends Logging { attrRef: AttributeReference, literal: Literal, update: Boolean) - : Double = { + : Option[Double] = { var percent = 1.0 val aColStat = mutableColStats(attrRef.exprId) @@ -548,7 +545,7 @@ class FilterEstimation extends Logging { } } - percent + Some(percent) } } From 7af19a6aa24fcf30b7ab7740a9ac2cad22a6caeb Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 12 Jan 2017 20:30:13 -0800 Subject: [PATCH 15/36] use the unified computeStats method --- .../plans/logical/basicLogicalOperators.scala | 9 ++++++--- .../logical/statsEstimation/FilterEstimation.scala | 13 +++++++------ .../statsEstimation/FilterEstimationSuite.scala | 12 ++++-------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b1c9acf4db43..d27e9ded451f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -130,9 +130,12 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override lazy val statistics: Statistics = { - val filterEstimation = new FilterEstimation - filterEstimation.estimate(this).getOrElse(super.statistics) + override def computeStats(conf: CatalystConf): Statistics = { + if (conf.cboEnabled) { + FilterEstimation(conf).estimate(this).getOrElse(super.computeStats(conf)) + } else { + super.computeStats(conf) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7cd99e60f17d..fcdff3780e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.{HashSet, Map} import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -28,7 +29,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class FilterEstimation extends Logging { +case class FilterEstimation(catalystConf: CatalystConf) extends Logging { /** * We use a mutable colStats because we need to update the corresponding ColumnStat @@ -51,7 +52,7 @@ class FilterEstimation extends Logging { * @return Option[Statistics] When there is no statistics collected, it returns None. */ def estimate(plan: Filter): Option[Statistics] = { - val stats: Statistics = plan.child.statistics + val stats: Statistics = plan.child.stats(catalystConf) if (stats.rowCount.isEmpty) return None // save a mutable copy of colStats so that we can later change it recursively @@ -74,9 +75,9 @@ class FilterEstimation extends Logging { val filteredRowCountValue: BigInt = EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val avgRowSize = BigDecimal(EstimationUtils.getRowSize(plan.output, newColStats)) - val filteredSizeInBytes: BigInt = - EstimationUtils.ceil(BigDecimal(filteredRowCountValue) * avgRowSize) + val filteredSizeInBytes: BigInt = EstimationUtils.ceil(BigDecimal( + EstimationUtils.getOutputSize(plan.output, newColStats, filteredRowCountValue) + )) Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), attributeStats = newColStats)) @@ -228,7 +229,7 @@ class FilterEstimation extends Logging { return None } val aColStat = mutableColStats(attrRef.exprId) - val rowCountValue = plan.child.statistics.rowCount.get + val rowCountValue = plan.child.stats(catalystConf).rowCount.get val nullPercent: BigDecimal = if (rowCountValue == 0) 0.0 else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 80cb28022e6f..66dd993293f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,11 +37,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 4, maxLen = 4) val child = StatsTestPlan( outputList = Seq(ar), - stats = Statistics( - sizeInBytes = 10 * 4, - rowCount = Some(10), - attributeStats = AttributeMap(Seq(ar -> childColStat)) - ) + rowCount = 10L, + attributeStats = AttributeMap(Seq(ar -> childColStat)) ) test("key = 2") { @@ -94,7 +91,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) ) - } test("key > 6") { @@ -199,11 +195,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val expectedRowCount = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq("key" -> expectedColStats), filterNode) val expectedStats = Statistics( - sizeInBytes = expectedRowCount * getRowSize(filterNode.output, expectedAttrStats), + sizeInBytes = getOutputSize(filterNode.output, expectedAttrStats, expectedRowCount), rowCount = Some(expectedRowCount), attributeStats = expectedAttrStats) - assert(filterNode.statistics == expectedStats) + assert(filterNode.stats(conf) == expectedStats) } } From af911e127c57d4f2b6e573985e2c9c5402d23cbb Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 13 Jan 2017 19:40:27 -0800 Subject: [PATCH 16/36] make method calculateSingleCondition return Option[Double] --- .../plans/logical/basicLogicalOperators.scala | 3 +- .../statsEstimation/FilterEstimation.scala | 98 +++++++---- .../FilterEstimationSuite.scala | 162 ++++++++++++------ 3 files changed, 176 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d27e9ded451f..ccebae3cc270 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -132,12 +132,11 @@ case class Filter(condition: Expression, child: LogicalPlan) override def computeStats(conf: CatalystConf): Statistics = { if (conf.cboEnabled) { - FilterEstimation(conf).estimate(this).getOrElse(super.computeStats(conf)) + FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) } else { super.computeStats(conf) } } - } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index fcdff3780e9e..0fb4ff2d84e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import java.sql.Date + import scala.collection.immutable.{HashSet, Map} import scala.collection.mutable @@ -28,8 +30,11 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - -case class FilterEstimation(catalystConf: CatalystConf) extends Logging { +/** + * @param plan a LogicalPlan node that must be an instance of Filter + * @param catalystConf a configuration showing if CBO is enabled + */ +case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { /** * We use a mutable colStats because we need to update the corresponding ColumnStat @@ -48,10 +53,9 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { * is used to compute row count, size in bytes, and the updated statistics after a given * predicated is applied. * - * @param plan a LogicalPlan node that must be an instance of Filter. * @return Option[Statistics] When there is no statistics collected, it returns None. */ - def estimate(plan: Filter): Option[Statistics] = { + def estimate: Option[Statistics] = { val stats: Statistics = plan.child.stats(catalystConf) if (stats.rowCount.isEmpty) return None @@ -61,7 +65,7 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { mutableColStats = mutable.Map.empty ++= statsExprIdMap // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateConditions(plan, plan.condition) + val filterSelectivity: Double = calculateConditions(plan.condition) // attributeStats has mapping Attribute-to-ColumnStat. // mutableColStats has mapping ExprId-to-ColumnStat. @@ -91,34 +95,32 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { * range condition such as (c > 40 AND c <= 50) * For logical OR conditions, we do not update stats after a condition estimation. * - * @param plan the Filter LogicalPlan node * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions * @return a doube value to show the percentage of rows meeting a given condition */ def calculateConditions( - plan: Filter, condition: Expression, update: Boolean = true) : Double = { condition match { case And(cond1, cond2) => - val p1 = calculateConditions(plan, cond1, update) - val p2 = calculateConditions(plan, cond2, update) + val p1 = calculateConditions(cond1, update) + val p2 = calculateConditions(cond2, update) p1 * p2 case Or(cond1, cond2) => - val p1 = calculateConditions(plan, cond1, update = false) - val p2 = calculateConditions(plan, cond2, update = false) + val p1 = calculateConditions(cond1, update = false) + val p2 = calculateConditions(cond2, update = false) math.min(1.0, p1 + p2 - (p1 * p2)) - case Not(cond) => calculateSingleCondition(plan, cond, update = false) match { + case Not(cond) => calculateSingleCondition(cond, update = false) match { case Some(percent) => 1.0 - percent case None => 1.0 } - case _ => calculateSingleCondition(plan, condition, update) match { + case _ => calculateSingleCondition(condition, update) match { case Some(percent) => percent case None => 1.0 // for not-supported condition, set filter selectivity to a conservative estimate 100% @@ -131,15 +133,13 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { * Currently we only support binary predicates where one side is a column, * and the other is a literal. * - * @param plan the Filter LogicalPlan node * @param condition a single logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions * @return Option[Double] value to show the percentage of rows meeting a given condition. - * It returns None if the condition is not supported. + * It returns None if the condition is not supported. */ def calculateSingleCondition( - plan: Filter, condition: Expression, update: Boolean) : Option[Double] = { @@ -187,14 +187,14 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { // we support IsNull and IsNotNull only when the child is a leaf node (table). case IsNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(plan, ar, true, update) + evaluateIsNull(ar, true, update) } else { None } case IsNotNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(plan, ar, false, update) + evaluateIsNull(ar, false, update) } else { None } @@ -211,15 +211,14 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { /** * Returns a percentage of rows meeting "IS NULL" or "IS NOT NULL" condition. * - * @param plan the Filter LogicalPlan node * @param attrRef an AttributeReference (or a column) * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return a doube value to show the percentage of rows meeting a given condition + * It returns None if no statistics collected for a given column. */ def evaluateIsNull( - plan: Filter, attrRef: AttributeReference, isNull: Boolean, update: Boolean) @@ -261,6 +260,7 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return a doube value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column or wrong value. */ def evaluateBinary( op: BinaryComparison, @@ -275,13 +275,13 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { // Make sure that the Date/Timestamp literal is a valid one attrRef.dataType match { - case DateType => + case DateType if literal.dataType.isInstanceOf[StringType] => val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) if (dateLiteral.isEmpty) { logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) return None } - case TimestampType => + case TimestampType if literal.dataType.isInstanceOf[StringType] => val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) if (tsLiteral.isEmpty) { logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) @@ -309,7 +309,15 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { /** * This method converts a numeric or Literal value of numeric type to a BigDecimal value. + * In order to avoid type casting error such as Java int to Java long, we need to + * convert a numeric integer value to String, and then convert it to long, + * and then convert it to BigDecimal. * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. + * + * @param literal can be either a Literal or numeric value + * @param dataType the column data type + * @param isNumeric If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. + * @return a BigDecimal value */ def numericLiteralToBigDecimal( literal: Any, @@ -330,17 +338,27 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { case DateType => if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) else { - val dateLiteral = DateTimeUtils.stringToDate( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) - BigDecimal(dateLiteral.asInstanceOf[BigInt]) + val dateLiteral = literal.asInstanceOf[Literal].dataType match { + case StringType => + DateTimeUtils.stringToDate( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]). + getOrElse(0).toString + case _ => literal.asInstanceOf[Literal].value.toString + } + BigDecimal(java.lang.Long.valueOf(dateLiteral)) } case TimestampType => if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) else { - val tsLiteral = DateTimeUtils.stringToTimestamp( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]) - BigDecimal(tsLiteral.asInstanceOf[BigInt]) + val tsLiteral = literal.asInstanceOf[Literal].dataType match { + case StringType => + DateTimeUtils.stringToTimestamp( + literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]). + getOrElse(0).toString + case _ => literal.asInstanceOf[Literal].value.toString + } + BigDecimal(java.lang.Long.valueOf(tsLiteral)) } } } @@ -384,10 +402,31 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { // We update ColumnStat structure after apply this equality predicate. // Set distinctCount to 1. Set nullCount to 0. val newStats = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType => val newValue = Some(literal.value) aColStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) + + case DateType => + val dateValue = literal.dataType match { + case StringType => + Some(Date.valueOf(literal.value.asInstanceOf[String])) + case _ => Some(literal.value) + } + aColStat.copy(distinctCount = 1, min = dateValue, + max = dateValue, nullCount = 0) + + case TimestampType => + val tsValue = literal.dataType match { + case StringType => + Some(DateTimeUtils.stringToTimestamp( + literal.value.asInstanceOf[UTF8String]). + getOrElse(0)) + case _ => Some(literal.value) + } + aColStat.copy(distinctCount = 1, min = tsValue, + max = tsValue, nullCount = 0) + case _ => aColStat.copy(distinctCount = 1, nullCount = 0) } mutableColStats += (attrRef.exprId -> newStats) @@ -409,6 +448,7 @@ case class FilterEstimation(catalystConf: CatalystConf) extends Logging { * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return a doube value to show the percentage of rows meeting a given condition + * It returns None if no statistics exists for a given column. */ def evaluateInSet( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 66dd993293f0..fd629507d28c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.statsEstimation +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DateType, IntegerType, TimestampType} /** * In this test suite, we test predicates containing the following operators: @@ -29,125 +31,156 @@ import org.apache.spark.sql.types.IntegerType class FilterEstimationSuite extends StatsEstimationTestBase { - // Suppose our test table has one column called "key". - // It has 10 rows with values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // Suppose our test table has 10 rows and 3 columns. + // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val ar = AttributeReference("key", IntegerType)() - val childColStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val arInt = AttributeReference("cint", IntegerType)() + val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Second column cdate has values, from 2017-01-01 through 2017-01-10 for 10 values. + val dMin = Date.valueOf("2017-01-01") + val dMax = Date.valueOf("2017-01-10") + val arDate = AttributeReference("cdate", DateType)() + val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) + + // Third column ctimestamp has values from "2017-01-01 01:00:00" through + // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). + val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") + val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") + val arTimestamp = AttributeReference("ctimestamp", TimestampType)() + val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), + nullCount = 0, avgLen = 8, maxLen = 8) + val child = StatsTestPlan( - outputList = Seq(ar), + outputList = Seq(arInt), rowCount = 10L, - attributeStats = AttributeMap(Seq(ar -> childColStat)) + attributeStats = AttributeMap(Seq( + arInt -> childColStatInt, + arDate -> childColStatDate, + arTimestamp -> childColStatTimestamp + )) ) - test("key = 2") { - // the predicate is "WHERE key = 2" + test("cint = 2") { + // the predicate is "WHERE cint = 2" validateEstimatedStats( - Filter(EqualTo(ar, Literal(2)), child), + arInt, + Filter(EqualTo(arInt, Literal(2)), child), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) ) } - test("key = 0") { - // the predicate is "WHERE key = 0" + test("cint = 0") { + // the predicate is "WHERE cint = 0" // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - Filter(EqualTo(ar, Literal(0)), child), + arInt, + Filter(EqualTo(arInt, Literal(0)), child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) ) } - test("key < 3") { - // the predicate is "WHERE key < 3" + test("cint < 3") { + // the predicate is "WHERE cint < 3" validateEstimatedStats( - Filter(LessThan(ar, Literal(3)), child), + arInt, + Filter(LessThan(arInt, Literal(3)), child), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) ) } - test("key < 0") { - // the predicate is "WHERE key < 0" + test("cint < 0") { + // the predicate is "WHERE cint < 0" // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - Filter(LessThan(ar, Literal(0)), child), + arInt, + Filter(LessThan(arInt, Literal(0)), child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) ) } - test("key <= 3") { - // the predicate is "WHERE key <= 3" + test("cint <= 3") { + // the predicate is "WHERE cint <= 3" validateEstimatedStats( - Filter(LessThanOrEqual(ar, Literal(3)), child), + arInt, + Filter(LessThanOrEqual(arInt, Literal(3)), child), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) ) } - test("key > 6") { - // the predicate is "WHERE key > 6" + test("cint > 6") { + // the predicate is "WHERE cint > 6" validateEstimatedStats( - Filter(GreaterThan(ar, Literal(6)), child), + arInt, + Filter(GreaterThan(arInt, Literal(6)), child), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) ) } - test("key > 10") { - // the predicate is "WHERE key > 10" + test("cint > 10") { + // the predicate is "WHERE cint > 10" // This is a corner case since max value is 10. validateEstimatedStats( - Filter(GreaterThan(ar, Literal(10)), child), + arInt, + Filter(GreaterThan(arInt, Literal(10)), child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) ) } - test("key >= 6") { - // the predicate is "WHERE key >= 6" + test("cint >= 6") { + // the predicate is "WHERE cint >= 6" validateEstimatedStats( - Filter(GreaterThanOrEqual(ar, Literal(6)), child), + arInt, + Filter(GreaterThanOrEqual(arInt, Literal(6)), child), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) ) } - test("key IS NULL") { - // the predicate is "WHERE key IS NULL" + test("cint IS NULL") { + // the predicate is "WHERE cint IS NULL" validateEstimatedStats( - Filter(IsNull(ar), child), + arInt, + Filter(IsNull(arInt), child), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) ) } - test("key IS NOT NULL") { - // the predicate is "WHERE key IS NOT NULL" + test("cint IS NOT NULL") { + // the predicate is "WHERE cint IS NOT NULL" validateEstimatedStats( - Filter(IsNotNull(ar), child), + arInt, + Filter(IsNotNull(arInt), child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(10L) ) } - test("key > 3 AND key <= 6") { - // the predicate is "WHERE key > 3 AND key <= 6" - val condition = And(GreaterThan(ar, Literal(3)), LessThanOrEqual(ar, Literal(6))) + test("cint > 3 AND cint <= 6") { + // the predicate is "WHERE cint > 3 AND cint <= 6" + val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) validateEstimatedStats( + arInt, Filter(condition, child), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), @@ -155,10 +188,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("key = 3 OR key = 6") { - // the predicate is "WHERE key = 3 OR key = 6" - val condition = Or(EqualTo(ar, Literal(3)), EqualTo(ar, Literal(6))) + test("cint = 3 OR cint = 6") { + // the predicate is "WHERE cint = 3 OR cint = 6" + val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) validateEstimatedStats( + arInt, Filter(condition, child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), @@ -166,40 +200,56 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("key IN (3, 4, 5)") { - // the predicate is "WHERE key IN (3, 4, 5)" + test("cint IN (3, 4, 5)") { + // the predicate is "WHERE cint IN (3, 4, 5)" validateEstimatedStats( - Filter(InSet(ar, Set(3, 4, 5)), child), + arInt, + Filter(InSet(arInt, Set(3, 4, 5)), child), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) ) } - test("key NOT IN (3, 4, 5)") { - // the predicate is "WHERE key NOT IN (3, 4, 5)" + test("cint NOT IN (3, 4, 5)") { + // the predicate is "WHERE cint NOT IN (3, 4, 5)" validateEstimatedStats( - Filter(Not(InSet(ar, Set(3, 4, 5))), child), + arInt, + Filter(Not(InSet(arInt, Set(3, 4, 5))), child), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(7L) ) } + test("cdate = 2017-01-02") { + // the predicate is: WHERE cdate = "2017-01-02" + val d20170102 = Date.valueOf("2017-01-02") + validateEstimatedStats( + arDate, + Filter(EqualTo(arDate, Literal(d20170102, DateType)), + child.copy(outputList = Seq(arDate))), + ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + private def validateEstimatedStats( + ar: AttributeReference, filterNode: Filter, expectedColStats: ColumnStat, rowCount: Option[Long] = None) - : Unit = { + : Unit = { val expectedRowCount = rowCount.getOrElse(0L) - val expectedAttrStats = toAttributeMap(Seq("key" -> expectedColStats), filterNode) - val expectedStats = Statistics( - sizeInBytes = getOutputSize(filterNode.output, expectedAttrStats, expectedRowCount), - rowCount = Some(expectedRowCount), - attributeStats = expectedAttrStats) + val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) + val expectedSizeInBytes = getOutputSize(filterNode.output, expectedAttrStats, expectedRowCount) - assert(filterNode.stats(conf) == expectedStats) + val filteredStats = filterNode.stats(conf) + assert(filteredStats.sizeInBytes == expectedSizeInBytes) + assert(filteredStats.rowCount == rowCount) + assert(filteredStats.attributeStats(ar) == expectedColStats) } } From 2121ff2c65e0861d5f349bf59f4a490593e8f8ff Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sat, 14 Jan 2017 17:23:43 -0800 Subject: [PATCH 17/36] add date and timestamp tests --- .../statsEstimation/FilterEstimation.scala | 8 +-- .../FilterEstimationSuite.scala | 68 +++++++++++-------- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0fb4ff2d84e8..4efbefdb08af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.sql.Date +import java.sql.{Timestamp, Date} import scala.collection.immutable.{HashSet, Map} import scala.collection.mutable @@ -410,7 +410,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case DateType => val dateValue = literal.dataType match { case StringType => - Some(Date.valueOf(literal.value.asInstanceOf[String])) + Some(Date.valueOf(literal.value.toString)) case _ => Some(literal.value) } aColStat.copy(distinctCount = 1, min = dateValue, @@ -419,9 +419,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case TimestampType => val tsValue = literal.dataType match { case StringType => - Some(DateTimeUtils.stringToTimestamp( - literal.value.asInstanceOf[UTF8String]). - getOrElse(0)) + Some(Timestamp.valueOf(literal.value.toString)) case _ => Some(literal.value) } aColStat.copy(distinctCount = 1, min = tsValue, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index fd629507d28c..15171560b32c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -53,21 +53,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), nullCount = 0, avgLen = 8, maxLen = 8) - val child = StatsTestPlan( - outputList = Seq(arInt), - rowCount = 10L, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arDate -> childColStatDate, - arTimestamp -> childColStatTimestamp - )) - ) - test("cint = 2") { // the predicate is "WHERE cint = 2" validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(2)), child), + Filter(EqualTo(arInt, Literal(2)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -79,7 +69,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(0)), child), + Filter(EqualTo(arInt, Literal(0)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -90,7 +80,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint < 3" validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(3)), child), + Filter(LessThan(arInt, Literal(3)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -102,7 +92,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(0)), child), + Filter(LessThan(arInt, Literal(0)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -113,7 +103,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint <= 3" validateEstimatedStats( arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), child), + Filter(LessThanOrEqual(arInt, Literal(3)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -124,7 +114,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint > 6" validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(6)), child), + Filter(GreaterThan(arInt, Literal(6)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -136,7 +126,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is a corner case since max value is 10. validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(10)), child), + Filter(GreaterThan(arInt, Literal(10)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -147,7 +137,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint >= 6" validateEstimatedStats( arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), child), + Filter(GreaterThanOrEqual(arInt, Literal(6)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -158,7 +148,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint IS NULL" validateEstimatedStats( arInt, - Filter(IsNull(arInt), child), + Filter(IsNull(arInt), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -169,7 +159,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint IS NOT NULL" validateEstimatedStats( arInt, - Filter(IsNotNull(arInt), child), + Filter(IsNotNull(arInt), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(10L) @@ -181,7 +171,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, child), + Filter(condition, ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), Some(4L) @@ -193,7 +183,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, child), + Filter(condition, ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(2L) @@ -204,7 +194,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint IN (3, 4, 5)" validateEstimatedStats( arInt, - Filter(InSet(arInt, Set(3, 4, 5)), child), + Filter(InSet(arInt, Set(3, 4, 5)), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -215,26 +205,50 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // the predicate is "WHERE cint NOT IN (3, 4, 5)" validateEstimatedStats( arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), child), + Filter(Not(InSet(arInt, Set(3, 4, 5))), ChildStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(7L) ) } - test("cdate = 2017-01-02") { + test("cdate = '2017-01-02' ") { // the predicate is: WHERE cdate = "2017-01-02" val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( arDate, - Filter(EqualTo(arDate, Literal(d20170102, DateType)), - child.copy(outputList = Seq(arDate))), + Filter(EqualTo(arDate, Literal("2017-01-02")), ChildStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) ) } + test("ctimestamp = '2017-01-01 02:00:00' ") { + // the predicate is: WHERE ctimestamp = "2017-01-01 02:00:00" + val ts20170102 = Timestamp.valueOf("2017-01-01 02:00:00") + validateEstimatedStats( + arTimestamp, + Filter(EqualTo(arTimestamp, Literal("2017-01-01 02:00:00")), + ChildStatsTestPlan(Seq(arTimestamp))), + ColumnStat(distinctCount = 1, min = Some(ts20170102), max = Some(ts20170102), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + def ChildStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { + StatsTestPlan( + outputList = outList, + rowCount = 10L, + attributeStats = AttributeMap(Seq( + arInt -> childColStatInt, + arDate -> childColStatDate, + arTimestamp -> childColStatTimestamp + )) + ) + } + private def validateEstimatedStats( ar: AttributeReference, filterNode: Filter, From 487813c2199861cd2bb42a146bae28a8aec3a60b Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sat, 14 Jan 2017 17:47:21 -0800 Subject: [PATCH 18/36] fix scalastyle error --- .../plans/logical/statsEstimation/FilterEstimation.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4efbefdb08af..5e30ace0d844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import scala.collection.immutable.{HashSet, Map} import scala.collection.mutable @@ -316,7 +316,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * * @param literal can be either a Literal or numeric value * @param dataType the column data type - * @param isNumeric If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. + * @param isNumeric If isNumeric is true, then it is a numeric value. + * Otherwise, it is a Literal value. * @return a BigDecimal value */ def numericLiteralToBigDecimal( From 0f73034dff77481ae697dca90e471dbd2c9ee407 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sun, 15 Jan 2017 14:42:46 -0800 Subject: [PATCH 19/36] add additional date / timestamp tests --- .../statsEstimation/FilterEstimation.scala | 43 ++++++++++++++--- .../FilterEstimationSuite.scala | 46 +++++++++++++++++-- 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5e30ace0d844..ae7695baf36d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -412,7 +412,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val dateValue = literal.dataType match { case StringType => Some(Date.valueOf(literal.value.toString)) - case _ => Some(literal.value) + case _ => Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) } aColStat.copy(distinctCount = 1, min = dateValue, max = dateValue, nullCount = 0) @@ -421,7 +421,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val tsValue = literal.dataType match { case StringType => Some(Timestamp.valueOf(literal.value.toString)) - case _ => Some(literal.value) + case _ => Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) } aColStat.copy(distinctCount = 1, min = tsValue, max = tsValue, nullCount = 0) @@ -571,11 +571,40 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } if (update) { - op match { - case GreaterThan(l, r) => newMin = Some(literal.value) - case GreaterThanOrEqual(l, r) => newMin = Some(literal.value) - case LessThan(l, r) => newMax = Some(literal.value) - case LessThanOrEqual(l, r) => newMax = Some(literal.value) + attrRef.dataType match { + case DateType => + val dateValue = literal.dataType match { + case StringType => + Date.valueOf(literal.value.toString) + case _ => DateTimeUtils.toJavaDate(literal.value.toString.toInt) + } + op match { + case GreaterThan(l, r) => newMin = Some(dateValue) + case GreaterThanOrEqual(l, r) => newMin = Some(dateValue) + case LessThan(l, r) => newMax = Some(dateValue) + case LessThanOrEqual(l, r) => newMax = Some(dateValue) + } + + case TimestampType => + val tsValue = literal.dataType match { + case StringType => + Timestamp.valueOf(literal.value.toString) + case _ => DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong) + } + op match { + case GreaterThan(l, r) => newMin = Some(tsValue) + case GreaterThanOrEqual(l, r) => newMin = Some(tsValue) + case LessThan(l, r) => newMax = Some(tsValue) + case LessThanOrEqual(l, r) => newMax = Some(tsValue) + } + + case _ => + op match { + case GreaterThan (l, r) => newMin = Some (literal.value) + case GreaterThanOrEqual (l, r) => newMin = Some (literal.value) + case LessThan (l, r) => newMax = Some (literal.value) + case LessThanOrEqual (l, r) => newMax = Some (literal.value) + } } newNdv = math.max(math.round(ndv.toDouble * percent), 1) val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 15171560b32c..d28e2c521d6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -22,7 +22,8 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.types.{DateType, IntegerType, TimestampType} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.{DateType, IntegerType, LongType, TimestampType} /** * In this test suite, we test predicates containing the following operators: @@ -224,19 +225,58 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cdate = cast('2017-01-02' AS DATE)") { + // the predicate is: WHERE cdate = cast("2017-01-02" AS DATE) + val d20170102 = Date.valueOf("2017-01-02") + val d20170102_SQLDate = DateTimeUtils.fromJavaDate(d20170102) + validateEstimatedStats( + arDate, + Filter(EqualTo(arDate, Literal(d20170102_SQLDate, IntegerType)), + ChildStatsTestPlan(Seq(arDate))), + ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(1L) + ) + } + + test("cdate < '2017-01-03' ") { + // the predicate is: WHERE cdate < "2017-01-03" + val d20170103 = Date.valueOf("2017-01-03") + validateEstimatedStats( + arDate, + Filter(LessThan(arDate, Literal("2017-01-03")), ChildStatsTestPlan(Seq(arDate))), + ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + test("ctimestamp = '2017-01-01 02:00:00' ") { // the predicate is: WHERE ctimestamp = "2017-01-01 02:00:00" - val ts20170102 = Timestamp.valueOf("2017-01-01 02:00:00") + val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") validateEstimatedStats( arTimestamp, Filter(EqualTo(arTimestamp, Literal("2017-01-01 02:00:00")), ChildStatsTestPlan(Seq(arTimestamp))), - ColumnStat(distinctCount = 1, min = Some(ts20170102), max = Some(ts20170102), + ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), nullCount = 0, avgLen = 8, maxLen = 8), Some(1L) ) } + test("ctimestamp < '2017-01-01 03:00:00' ") { + // the predicate is: WHERE ctimestamp < "2017-01-01 03:00:00" + val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") + validateEstimatedStats( + arTimestamp, + Filter(LessThan(arTimestamp, Literal("2017-01-01 03:00:00")), + ChildStatsTestPlan(Seq(arTimestamp))), + ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + def ChildStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { StatsTestPlan( outputList = outList, From f24cf3edd53903667c5bd32d9c96a19cd3043742 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 16 Jan 2017 15:01:07 -0800 Subject: [PATCH 20/36] update code based on wzhfy's comments --- .../statsEstimation/FilterEstimation.scala | 66 +++++++++-------- .../FilterEstimationSuite.scala | 72 +++++++++---------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index ae7695baf36d..4b659af685e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -38,7 +38,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * We use a mutable colStats because we need to update the corresponding ColumnStat - * for a column after we apply a predicate condition. For example, A column c has + * for a column after we apply a predicate condition. For example, column c has * [min, max] value as [0, 100]. In a range condition such as (c > 40 AND c <= 50), * we need to set the column's [min, max] value to [40, 100] after we evaluate the * first condition c > 40. We need to set the column's [min, max] value to [40, 50] @@ -60,9 +60,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo if (stats.rowCount.isEmpty) return None // save a mutable copy of colStats so that we can later change it recursively - val statsExprIdMap: Map[ExprId, ColumnStat] = - stats.attributeStats.map(kv => (kv._1.exprId, kv._2)) - mutableColStats = mutable.Map.empty ++= statsExprIdMap + mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) // estimate selectivity of this filter predicate val filterSelectivity: Double = calculateConditions(plan.condition) @@ -77,13 +75,13 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo mutableColStats.map(kv => expridToAttrMap(kv._1) -> kv._2) val newColStats = AttributeMap(mutableAttributeStats.toSeq) - val filteredRowCountValue: BigInt = + val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) val filteredSizeInBytes: BigInt = EstimationUtils.ceil(BigDecimal( - EstimationUtils.getOutputSize(plan.output, newColStats, filteredRowCountValue) + EstimationUtils.getOutputSize(plan.output, newColStats, filteredRowCount) )) - Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCountValue), + Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) } @@ -118,12 +116,13 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case Not(cond) => calculateSingleCondition(cond, update = false) match { case Some(percent) => 1.0 - percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% case None => 1.0 } case _ => calculateSingleCondition(condition, update) match { case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% case None => 1.0 - // for not-supported condition, set filter selectivity to a conservative estimate 100% } } } @@ -187,14 +186,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // we support IsNull and IsNotNull only when the child is a leaf node (table). case IsNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(ar, true, update) + evaluateIsNull(ar, isNull = true, update) } else { None } case IsNotNull(ar: AttributeReference) => if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(ar, false, update) + evaluateIsNull(ar, isNull = false, update) } else { None } @@ -231,7 +230,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val rowCountValue = plan.child.stats(catalystConf).rowCount.get val nullPercent: BigDecimal = if (rowCountValue == 0) 0.0 - else BigDecimal(aColStat.nullCount)/BigDecimal(rowCountValue) + else BigDecimal(aColStat.nullCount) / BigDecimal(rowCountValue) if (update) { val newStats = @@ -242,7 +241,9 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = - if (isNull) nullPercent.toDouble + if (isNull) { + nullPercent.toDouble + } else { /** ISNOTNULL(column) */ 1.0 - nullPercent.toDouble @@ -297,11 +298,9 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: NumericType | DateType | TimestampType => evaluateBinaryForNumeric(op, attrRef, literal, update) case StringType | BinaryType => - // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. - - logDebug("[CBO] No statistics for String/Binary type " + attrRef) + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrRef) None } } @@ -316,8 +315,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * * @param literal can be either a Literal or numeric value * @param dataType the column data type - * @param isNumeric If isNumeric is true, then it is a numeric value. - * Otherwise, it is a Literal value. + * @param isNumeric If isNumeric is true, then it is a numeric value. For example, + * a condition "IN (3, 4, 5)" has numeric values since 3, 4, 5 have + * been converted to integer values (no longer Literal objects). + * For other conditions, isNumeric is set to false because Literal + * objects are passed. * @return a BigDecimal value */ def numericLiteralToBigDecimal( @@ -386,7 +388,6 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val inBoundary: Boolean = attrRef.dataType match { case _: NumericType | DateType | TimestampType => val statsRange = @@ -407,25 +408,22 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val newValue = Some(literal.value) aColStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) - case DateType => val dateValue = literal.dataType match { - case StringType => - Some(Date.valueOf(literal.value.toString)) - case _ => Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) + case StringType => + Some(Date.valueOf(literal.value.toString)) + case _ => Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) } aColStat.copy(distinctCount = 1, min = dateValue, max = dateValue, nullCount = 0) - case TimestampType => val tsValue = literal.dataType match { - case StringType => - Some(Timestamp.valueOf(literal.value.toString)) - case _ => Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) + case StringType => + Some(Timestamp.valueOf(literal.value.toString)) + case _ => Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) } aColStat.copy(distinctCount = 1, min = tsValue, max = tsValue, nullCount = 0) - case _ => aColStat.copy(distinctCount = 1, nullCount = 0) } mutableColStats += (attrRef.exprId -> newStats) @@ -469,8 +467,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - hSet.map(e => numericLiteralToBigDecimal(e, aType, true)). - filter(e => e >= statsRange.min && e <= statsRange.max) + hSet.map(e => numericLiteralToBigDecimal(e, aType, isNumeric = true)) + .filter(e => e >= statsRange.min && e <= statsRange.max) // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => hSet @@ -600,11 +598,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _ => op match { - case GreaterThan (l, r) => newMin = Some (literal.value) - case GreaterThanOrEqual (l, r) => newMin = Some (literal.value) - case LessThan (l, r) => newMax = Some (literal.value) - case LessThanOrEqual (l, r) => newMax = Some (literal.value) - } + case GreaterThan (l, r) => newMin = Some (literal.value) + case GreaterThanOrEqual (l, r) => newMin = Some (literal.value) + case LessThan (l, r) => newMax = Some (literal.value) + case LessThanOrEqual (l, r) => newMax = Some (literal.value) + } } newNdv = math.max(math.round(ndv.toDouble * percent), 1) val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index d28e2c521d6a..c1b87d28dd2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -55,10 +55,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 8, maxLen = 8) test("cint = 2") { - // the predicate is "WHERE cint = 2" validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(2)), ChildStatsTestPlan(Seq(arInt))), + Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -66,11 +65,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint = 0") { - // the predicate is "WHERE cint = 0" // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(0)), ChildStatsTestPlan(Seq(arInt))), + Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -78,10 +76,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint < 3") { - // the predicate is "WHERE cint < 3" validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(3)), ChildStatsTestPlan(Seq(arInt))), + Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -89,11 +86,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint < 0") { - // the predicate is "WHERE cint < 0" // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(0)), ChildStatsTestPlan(Seq(arInt))), + Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -101,10 +97,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint <= 3") { - // the predicate is "WHERE cint <= 3" validateEstimatedStats( arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), ChildStatsTestPlan(Seq(arInt))), + Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -112,10 +107,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint > 6") { - // the predicate is "WHERE cint > 6" validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(6)), ChildStatsTestPlan(Seq(arInt))), + Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -123,11 +117,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint > 10") { - // the predicate is "WHERE cint > 10" // This is a corner case since max value is 10. validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(10)), ChildStatsTestPlan(Seq(arInt))), + Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -135,10 +128,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint >= 6") { - // the predicate is "WHERE cint >= 6" validateEstimatedStats( arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), ChildStatsTestPlan(Seq(arInt))), + Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -146,10 +138,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint IS NULL") { - // the predicate is "WHERE cint IS NULL" validateEstimatedStats( arInt, - Filter(IsNull(arInt), ChildStatsTestPlan(Seq(arInt))), + Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -157,10 +148,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint IS NOT NULL") { - // the predicate is "WHERE cint IS NOT NULL" validateEstimatedStats( arInt, - Filter(IsNotNull(arInt), ChildStatsTestPlan(Seq(arInt))), + Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(10L) @@ -168,11 +158,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint > 3 AND cint <= 6") { - // the predicate is "WHERE cint > 3 AND cint <= 6" val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, ChildStatsTestPlan(Seq(arInt))), + Filter(condition, childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), Some(4L) @@ -180,11 +169,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint = 3 OR cint = 6") { - // the predicate is "WHERE cint = 3 OR cint = 6" val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, ChildStatsTestPlan(Seq(arInt))), + Filter(condition, childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(2L) @@ -192,10 +180,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint IN (3, 4, 5)") { - // the predicate is "WHERE cint IN (3, 4, 5)" validateEstimatedStats( arInt, - Filter(InSet(arInt, Set(3, 4, 5)), ChildStatsTestPlan(Seq(arInt))), + Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -203,10 +190,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cint NOT IN (3, 4, 5)") { - // the predicate is "WHERE cint NOT IN (3, 4, 5)" validateEstimatedStats( arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), ChildStatsTestPlan(Seq(arInt))), + Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt))), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(7L) @@ -214,11 +200,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = '2017-01-02' ") { - // the predicate is: WHERE cdate = "2017-01-02" val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( arDate, - Filter(EqualTo(arDate, Literal("2017-01-02")), ChildStatsTestPlan(Seq(arDate))), + Filter(EqualTo(arDate, Literal("2017-01-02")), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -226,13 +211,12 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = cast('2017-01-02' AS DATE)") { - // the predicate is: WHERE cdate = cast("2017-01-02" AS DATE) val d20170102 = Date.valueOf("2017-01-02") val d20170102_SQLDate = DateTimeUtils.fromJavaDate(d20170102) validateEstimatedStats( arDate, Filter(EqualTo(arDate, Literal(d20170102_SQLDate, IntegerType)), - ChildStatsTestPlan(Seq(arDate))), + childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -240,11 +224,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate < '2017-01-03' ") { - // the predicate is: WHERE cdate < "2017-01-03" val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( arDate, - Filter(LessThan(arDate, Literal("2017-01-03")), ChildStatsTestPlan(Seq(arDate))), + Filter(LessThan(arDate, Literal("2017-01-03")), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -252,12 +235,24 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("ctimestamp = '2017-01-01 02:00:00' ") { - // the predicate is: WHERE ctimestamp = "2017-01-01 02:00:00" val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") validateEstimatedStats( arTimestamp, Filter(EqualTo(arTimestamp, Literal("2017-01-01 02:00:00")), - ChildStatsTestPlan(Seq(arTimestamp))), + childStatsTestPlan(Seq(arTimestamp))), + ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { + val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") + val ts2017010102_SQLTS = DateTimeUtils.fromJavaTimestamp(ts2017010102) + validateEstimatedStats( + arTimestamp, + Filter(EqualTo(arTimestamp, Literal(ts2017010102_SQLTS, LongType)), + childStatsTestPlan(Seq(arTimestamp))), ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), nullCount = 0, avgLen = 8, maxLen = 8), Some(1L) @@ -265,19 +260,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("ctimestamp < '2017-01-01 03:00:00' ") { - // the predicate is: WHERE ctimestamp < "2017-01-01 03:00:00" val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") validateEstimatedStats( arTimestamp, Filter(LessThan(arTimestamp, Literal("2017-01-01 03:00:00")), - ChildStatsTestPlan(Seq(arTimestamp))), + childStatsTestPlan(Seq(arTimestamp))), ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), nullCount = 0, avgLen = 8, maxLen = 8), Some(3L) ) } - def ChildStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { + private def childStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = 10L, From 3ebf3a8855d0d36e851e731d8c0601898a788f9f Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 16 Jan 2017 19:50:59 -0800 Subject: [PATCH 21/36] filtered NDV should be no larger than initial NDV --- .../statsEstimation/FilterEstimation.scala | 17 +++++++-------- .../FilterEstimationSuite.scala | 21 +++++++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4b659af685e7..e9cd51088085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -96,12 +96,9 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return a double value to show the percentage of rows meeting a given condition */ - def calculateConditions( - condition: Expression, - update: Boolean = true) - : Double = { + def calculateConditions(condition: Expression, update: Boolean = true): Double = { condition match { case And(cond1, cond2) => @@ -138,10 +135,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return Option[Double] value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateSingleCondition( - condition: Expression, - update: Boolean) - : Option[Double] = { + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. @@ -477,7 +471,9 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo return Some(0.0) } - val newNdv = validQuerySet.size + // newNdv should be no greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + val newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) val(newMax, newMin) = aType match { case _: NumericType | DateType | TimestampType => val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) @@ -555,6 +551,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. + // For ease of computation, we convert all relevant numeric values to Double. percent = op match { case LessThan(l, r) => (literalToDouble - minToDouble) / (maxToDouble - minToDouble) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index c1b87d28dd2e..d942c07eae37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -271,6 +271,27 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). + test("cint IN (1, 2, 3, 4, 5)") { + val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildStatsTestplan = StatsTestPlan( + outputList = Seq(arInt), + rowCount = 2L, + attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + ) + validateEstimatedStats( + arInt, + Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(2L) + ) + } + private def childStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { StatsTestPlan( outputList = outList, From 976363599de94126abe2f50716bb2db1c16aa322 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 17 Jan 2017 19:07:11 -0800 Subject: [PATCH 22/36] add tests to handle decimal data type --- .../statsEstimation/FilterEstimation.scala | 93 ++++++++++--------- .../FilterEstimationSuite.scala | 58 ++++++++++-- 2 files changed, 102 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index e9cd51088085..91ec8a3b3c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -307,56 +307,40 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * and then convert it to BigDecimal. * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. * - * @param literal can be either a Literal or numeric value - * @param dataType the column data type - * @param isNumeric If isNumeric is true, then it is a numeric value. For example, - * a condition "IN (3, 4, 5)" has numeric values since 3, 4, 5 have - * been converted to integer values (no longer Literal objects). - * For other conditions, isNumeric is set to false because Literal - * objects are passed. + * @param attrDataType the column data type + * @param litValue can be either a Literal or numeric value + * @param litDataType * @return a BigDecimal value */ def numericLiteralToBigDecimal( - literal: Any, - dataType: DataType, - isNumeric: Boolean = false) + attrDataType: DataType, + litValue: Any, + litDataType: DataType) : BigDecimal = { - dataType match { + attrDataType match { case _: IntegralType => - val stringValue: String = - if (isNumeric) literal.toString - else literal.asInstanceOf[Literal].value.toString - BigDecimal(java.lang.Long.valueOf(stringValue)) + BigDecimal(java.lang.Long.valueOf(litValue.toString)) case _: FractionalType => - if (isNumeric) BigDecimal(literal.asInstanceOf[Double]) - else BigDecimal(literal.asInstanceOf[Literal].value.asInstanceOf[Double]) + BigDecimal(litValue.toString.toDouble) case DateType => - if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) - else { - val dateLiteral = literal.asInstanceOf[Literal].dataType match { - case StringType => - DateTimeUtils.stringToDate( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]). - getOrElse(0).toString - case _ => literal.asInstanceOf[Literal].value.toString - } - BigDecimal(java.lang.Long.valueOf(dateLiteral)) + val dateLiteral = litDataType match { + case StringType => + DateTimeUtils.stringToDate(litValue.asInstanceOf[UTF8String]) + .getOrElse(0).toString + case _ => litValue.toString } + BigDecimal(java.lang.Long.valueOf(dateLiteral)) case TimestampType => - if (isNumeric) BigDecimal(literal.asInstanceOf[BigInt]) - else { - val tsLiteral = literal.asInstanceOf[Literal].dataType match { - case StringType => - DateTimeUtils.stringToTimestamp( - literal.asInstanceOf[Literal].value.asInstanceOf[UTF8String]). - getOrElse(0).toString - case _ => literal.asInstanceOf[Literal].value.toString - } - BigDecimal(java.lang.Long.valueOf(tsLiteral)) + val tsLiteral = litDataType match { + case StringType => + DateTimeUtils.stringToTimestamp(litValue.asInstanceOf[UTF8String]) + .getOrElse(0).toString + case _ => litValue.toString } + BigDecimal(java.lang.Long.valueOf(tsLiteral)) } } @@ -386,7 +370,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - val lit = numericLiteralToBigDecimal(literal, attrRef.dataType) + val lit = numericLiteralToBigDecimal(attrRef.dataType, literal.value, literal.dataType) (lit >= statsRange.min) && (lit <= statsRange.max) case _ => true /** for String/Binary type */ @@ -461,8 +445,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - hSet.map(e => numericLiteralToBigDecimal(e, aType, isNumeric = true)) - .filter(e => e >= statsRange.min && e <= statsRange.max) + val hSetMap = hSet.map(e => + if (e.isInstanceOf[String]) { + val utf8String = UTF8String.fromString(e.asInstanceOf[String]) + numericLiteralToBigDecimal(aType, utf8String, StringType) + } else { + numericLiteralToBigDecimal(aType, e, aType) + }) + hSetMap.filter(e => e >= statsRange.min && e <= statsRange.max) // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => hSet @@ -475,9 +465,27 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. val newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) val(newMax, newMin) = aType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType => val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) (Some(tmpSet.max), Some(tmpSet.min)) + case DateType => + if (hSet.isInstanceOf[Set[String]]) { + val dateMax = Date.valueOf(hSet.asInstanceOf[Set[String]].max) + val dateMin = Date.valueOf(hSet.asInstanceOf[Set[String]].min) + (Some(dateMax), Some(dateMin)) + } else { + val tmpSet: Set[Long] = validQuerySet.map(e => e.toString.toLong) + (Some(tmpSet.max), Some(tmpSet.min)) + } + case TimestampType => + if (hSet.isInstanceOf[Set[String]]) { + val dateMax = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].max) + val dateMin = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].min) + (Some(dateMax), Some(dateMin)) + } else { + val tmpSet: Set[Long] = validQuerySet.map(e => e.toString.toLong) + (Some(tmpSet.max), Some(tmpSet.min)) + } case _ => (None, None) } @@ -522,7 +530,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - val literalValueBD = numericLiteralToBigDecimal(literal, attrRef.dataType) + val literalValueBD = + numericLiteralToBigDecimal(attrRef.dataType, literal.value, literal.dataType) // determine the overlapping degree between predicate range and column's range val (noOverlap: Boolean, completeOverlap: Boolean) = op match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index d942c07eae37..84f7d889c97e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -23,13 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{DateType, IntegerType, LongType, TimestampType} +import org.apache.spark.sql.types._ /** * In this test suite, we test predicates containing the following operators: * =, <, <=, >, >=, AND, OR, IS NULL, IS NOT NULL, IN, NOT IN */ - class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 3 columns. @@ -39,14 +38,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) - // Second column cdate has values, from 2017-01-01 through 2017-01-10 for 10 values. + // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") val arDate = AttributeReference("cdate", DateType)() val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Third column ctimestamp has values from "2017-01-01 01:00:00" through + // Third column ctimestamp has 10 values from "2017-01-01 01:00:00" through // "2017-01-01 10:00:00" for 10 distinct timestamps (or hours). val tsMin = Timestamp.valueOf("2017-01-01 01:00:00") val tsMax = Timestamp.valueOf("2017-01-01 10:00:00") @@ -54,6 +53,13 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), nullCount = 0, avgLen = 8, maxLen = 8) + // Fourth column cdate has 10 values from 0.20 through 2.00 at increment of 0.2. + val decMin = new java.math.BigDecimal("0.200000000000000000") + val decMax = new java.math.BigDecimal("2.000000000000000000") + val arDecimal = AttributeReference("cdecimal", DecimalType(12, 2))() + val childColStatDecimal = ColumnStat(distinctCount = 10, min = Some(decMin), max = Some(decMax), + nullCount = 0, avgLen = 8, maxLen = 8) + test("cint = 2") { validateEstimatedStats( arInt, @@ -234,6 +240,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cdate IN ('2017-01-03', '2017-01-04', '2017-01-05')") { + val d20170103 = Date.valueOf("2017-01-03") + val d20170105 = Date.valueOf("2017-01-05") + validateEstimatedStats( + arDate, + Filter(InSet(arDate, Set("2017-01-03", "2017-01-04", "2017-01-05")), + childStatsTestPlan(Seq(arDate))), + ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + test("ctimestamp = '2017-01-01 02:00:00' ") { val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") validateEstimatedStats( @@ -271,6 +290,30 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cdecimal = 0.40") { + val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + validateEstimatedStats( + arDecimal, + Filter(EqualTo(arDecimal, Literal(dec_0_40, DecimalType(12, 2))), + childStatsTestPlan(Seq(arDecimal))), + ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(1L) + ) + } + + test("cdecimal < 0.60 ") { + val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + validateEstimatedStats( + arDecimal, + Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), + childStatsTestPlan(Seq(arDecimal))), + ColumnStat(distinctCount = 2, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + // This is a corner test case. We want to test if we can handle the case when the number of // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. @@ -299,7 +342,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap(Seq( arInt -> childColStatInt, arDate -> childColStatDate, - arTimestamp -> childColStatTimestamp + arTimestamp -> childColStatTimestamp, + arDecimal -> childColStatDecimal )) ) } @@ -308,10 +352,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ar: AttributeReference, filterNode: Filter, expectedColStats: ColumnStat, - rowCount: Option[Long] = None) + rowCount: Option[BigInt] = None) : Unit = { - val expectedRowCount = rowCount.getOrElse(0L) + val expectedRowCount: BigInt = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) val expectedSizeInBytes = getOutputSize(filterNode.output, expectedAttrStats, expectedRowCount) From 35c213f70c4532f3ee7bdd372eb0f6be0ad37a60 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Wed, 18 Jan 2017 11:24:51 -0800 Subject: [PATCH 23/36] add test cases for float and double types --- .../statsEstimation/FilterEstimation.scala | 9 +++-- .../FilterEstimationSuite.scala | 36 +++++++++++++++++-- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 91ec8a3b3c90..ce4dfadd95d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -305,11 +305,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * In order to avoid type casting error such as Java int to Java long, we need to * convert a numeric integer value to String, and then convert it to long, * and then convert it to BigDecimal. - * If isNumeric is true, then it is a numeric value. Otherwise, it is a Literal value. * * @param attrDataType the column data type - * @param litValue can be either a Literal or numeric value - * @param litDataType + * @param litValue the literal value + * @param litDataType the data type of literal value * @return a BigDecimal value */ def numericLiteralToBigDecimal( @@ -469,7 +468,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) (Some(tmpSet.max), Some(tmpSet.min)) case DateType => - if (hSet.isInstanceOf[Set[String]]) { + if (hSet.forall(e => e.isInstanceOf[String])) { val dateMax = Date.valueOf(hSet.asInstanceOf[Set[String]].max) val dateMin = Date.valueOf(hSet.asInstanceOf[Set[String]].min) (Some(dateMax), Some(dateMin)) @@ -478,7 +477,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo (Some(tmpSet.max), Some(tmpSet.min)) } case TimestampType => - if (hSet.isInstanceOf[Set[String]]) { + if (hSet.forall(e => e.isInstanceOf[String])) { val dateMax = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].max) val dateMin = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].min) (Some(dateMax), Some(dateMin)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 84f7d889c97e..47741d8e77e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ */ class FilterEstimationSuite extends StatsEstimationTestBase { - // Suppose our test table has 10 rows and 3 columns. + // Suppose our test table has 10 rows and 6 columns. // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val arInt = AttributeReference("cint", IntegerType)() @@ -60,6 +60,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatDecimal = ColumnStat(distinctCount = 10, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) + // Fifth column cfloat has 10 float values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arFloat = AttributeReference("cfloat", FloatType)() + val childColStatFloat = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + nullCount = 0, avgLen = 4, maxLen = 4) + + // Sixth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arDouble = AttributeReference("cdouble", FloatType)() + val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + nullCount = 0, avgLen = 8, maxLen = 8) + test("cint = 2") { validateEstimatedStats( arInt, @@ -314,6 +324,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cfloat < 3.0") { + validateEstimatedStats( + arFloat, + Filter(LessThan(arFloat, Literal(3.0)), childStatsTestPlan(Seq(arFloat))), + ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + + test("cdouble < 3.0") { + validateEstimatedStats( + arDouble, + Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble))), + ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + // This is a corner test case. We want to test if we can handle the case when the number of // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. @@ -343,7 +373,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { arInt -> childColStatInt, arDate -> childColStatDate, arTimestamp -> childColStatTimestamp, - arDecimal -> childColStatDecimal + arDecimal -> childColStatDecimal, + arFloat -> childColStatFloat, + arDouble -> childColStatDouble )) ) } From 6b8aab3c5c8fdb67362783623f508112210edb0a Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 20 Jan 2017 10:33:08 -0800 Subject: [PATCH 24/36] add cast-as-date test cases --- .../FilterEstimationSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 47741d8e77e4..918028dac238 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -250,6 +250,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cdate < cast('2017-01-03' AS DATE)") { + val d20170103 = Date.valueOf("2017-01-03") + val d20170103_SQLDate = DateTimeUtils.fromJavaDate(d20170103) + validateEstimatedStats( + arDate, + Filter(LessThan(arDate, Literal(d20170103_SQLDate, IntegerType)), + childStatsTestPlan(Seq(arDate))), + ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4), + Some(3L) + ) + } + test("cdate IN ('2017-01-03', '2017-01-04', '2017-01-05')") { val d20170103 = Date.valueOf("2017-01-03") val d20170105 = Date.valueOf("2017-01-05") @@ -300,6 +313,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { + val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") + val ts2017010102_SQLTS = DateTimeUtils.fromJavaTimestamp(ts2017010103) + validateEstimatedStats( + arTimestamp, + Filter(LessThan(arTimestamp, Literal(ts2017010102_SQLTS, LongType)), + childStatsTestPlan(Seq(arTimestamp))), + ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), + nullCount = 0, avgLen = 8, maxLen = 8), + Some(3L) + ) + } + test("cdecimal = 0.40") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( From 894d85cdb5323ef3c318e9036bcab4e549061646 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 20 Jan 2017 11:30:07 -0800 Subject: [PATCH 25/36] update calls to getOutputSize --- .../plans/logical/statsEstimation/FilterEstimation.scala | 2 +- .../sql/catalyst/statsEstimation/FilterEstimationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index ce4dfadd95d5..092292166959 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -78,7 +78,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) val filteredSizeInBytes: BigInt = EstimationUtils.ceil(BigDecimal( - EstimationUtils.getOutputSize(plan.output, newColStats, filteredRowCount) + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) )) Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 918028dac238..d2f0392839d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -415,7 +415,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val expectedRowCount: BigInt = rowCount.getOrElse(0L) val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, expectedAttrStats, expectedRowCount) + val expectedSizeInBytes = getOutputSize(filterNode.output, expectedRowCount, expectedAttrStats) val filteredStats = filterNode.stats(conf) assert(filteredStats.sizeInBytes == expectedSizeInBytes) From 97aacdfbbac0a9eb2efec73a6812dda979cd60ca Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 13 Feb 2017 19:34:23 -0800 Subject: [PATCH 26/36] use Range.isIntersected to decide if a literal is in boundary --- .../statsEstimation/FilterEstimation.scala | 57 +++++++++---------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 092292166959..01be7a3adf5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * @param plan a LogicalPlan node that must be an instance of Filter * @param catalystConf a configuration showing if CBO is enabled */ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { @@ -56,6 +55,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return Option[Statistics] When there is no statistics collected, it returns None. */ def estimate: Option[Statistics] = { + // We first copy child node's statistics and then modify it based on filter selectivity. val stats: Statistics = plan.child.stats(catalystConf) if (stats.rowCount.isEmpty) return None @@ -96,7 +96,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions - * @return a double value to show the percentage of rows meeting a given condition + * @return a double value to show the percentage of rows meeting a given condition. + * Returns 1.0 (a conservative filter estimate) if a condition is not supported. */ def calculateConditions(condition: Expression, update: Boolean = true): Double = { @@ -365,15 +366,9 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val inBoundary: Boolean = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - val statsRange = - Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - val lit = numericLiteralToBigDecimal(attrRef.dataType, literal.value, literal.dataType) - (lit >= statsRange.min) && (lit <= statsRange.max) - - case _ => true /** for String/Binary type */ - } + val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) + val litRange = Range(Some(literal.value), Some(literal.value), literal.dataType) + val inBoundary: Boolean = Range.isIntersected(statsRange, litRange) if (inBoundary) { @@ -534,13 +529,13 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // determine the overlapping degree between predicate range and column's range val (noOverlap: Boolean, completeOverlap: Boolean) = op match { - case LessThan(l, r) => + case _: LessThan => (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) - case LessThanOrEqual(l, r) => + case _: LessThanOrEqual => (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) - case GreaterThan(l, r) => + case _: GreaterThan => (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) - case GreaterThanOrEqual(l, r) => + case _: GreaterThanOrEqual => (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) } @@ -561,14 +556,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // We just prorate the adjusted range over the initial range to compute filter selectivity. // For ease of computation, we convert all relevant numeric values to Double. percent = op match { - case LessThan(l, r) => + case _: LessThan => (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case LessThanOrEqual(l, r) => + case _: LessThanOrEqual => if (literalValueBD == BigDecimal(statsRange.min)) 1.0 / ndv.toDouble else (literalToDouble - minToDouble) / (maxToDouble - minToDouble) - case GreaterThan(l, r) => + case _: GreaterThan => (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) - case GreaterThanOrEqual(l, r) => + case _: GreaterThanOrEqual => if (literalValueBD == BigDecimal(statsRange.max)) 1.0 / ndv.toDouble else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) } @@ -582,10 +577,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _ => DateTimeUtils.toJavaDate(literal.value.toString.toInt) } op match { - case GreaterThan(l, r) => newMin = Some(dateValue) - case GreaterThanOrEqual(l, r) => newMin = Some(dateValue) - case LessThan(l, r) => newMax = Some(dateValue) - case LessThanOrEqual(l, r) => newMax = Some(dateValue) + case _: GreaterThan => newMin = Some(dateValue) + case _: GreaterThanOrEqual => newMin = Some(dateValue) + case _: LessThan => newMax = Some(dateValue) + case _: LessThanOrEqual => newMax = Some(dateValue) } case TimestampType => @@ -595,18 +590,18 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case _ => DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong) } op match { - case GreaterThan(l, r) => newMin = Some(tsValue) - case GreaterThanOrEqual(l, r) => newMin = Some(tsValue) - case LessThan(l, r) => newMax = Some(tsValue) - case LessThanOrEqual(l, r) => newMax = Some(tsValue) + case _: GreaterThan => newMin = Some(tsValue) + case _: GreaterThanOrEqual => newMin = Some(tsValue) + case _: LessThan => newMax = Some(tsValue) + case _: LessThanOrEqual => newMax = Some(tsValue) } case _ => op match { - case GreaterThan (l, r) => newMin = Some (literal.value) - case GreaterThanOrEqual (l, r) => newMin = Some (literal.value) - case LessThan (l, r) => newMax = Some (literal.value) - case LessThanOrEqual (l, r) => newMax = Some (literal.value) + case _: GreaterThan => newMin = Some(literal.value) + case _: GreaterThanOrEqual => newMin = Some(literal.value) + case _: LessThan => newMax = Some(literal.value) + case _: LessThanOrEqual => newMax = Some(literal.value) } } newNdv = math.max(math.round(ndv.toDouble * percent), 1) From 2b4a10aaa134fbfb22de7159b874cb73c2cd955f Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Wed, 15 Feb 2017 13:34:04 -0800 Subject: [PATCH 27/36] handle date/timestamp string literal --- .../statsEstimation/FilterEstimation.scala | 29 ++----------------- .../FilterEstimationSuite.scala | 23 --------------- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 01be7a3adf5d..363027bd2ee6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -269,23 +269,6 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo return None } - // Make sure that the Date/Timestamp literal is a valid one - attrRef.dataType match { - case DateType if literal.dataType.isInstanceOf[StringType] => - val dateLiteral = DateTimeUtils.stringToDate(literal.value.asInstanceOf[UTF8String]) - if (dateLiteral.isEmpty) { - logDebug("[CBO] Date literal is wrong, No statistics for " + attrRef) - return None - } - case TimestampType if literal.dataType.isInstanceOf[StringType] => - val tsLiteral = DateTimeUtils.stringToTimestamp(literal.value.asInstanceOf[UTF8String]) - if (tsLiteral.isEmpty) { - logDebug("[CBO] Timestamp literal is wrong, No statistics for " + attrRef) - return None - } - case _ => - } - op match { case EqualTo(l, r) => evaluateEqualTo(attrRef, literal, update) case _ => @@ -381,19 +364,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo aColStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) case DateType => - val dateValue = literal.dataType match { - case StringType => - Some(Date.valueOf(literal.value.toString)) - case _ => Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) - } + val dateValue = Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) aColStat.copy(distinctCount = 1, min = dateValue, max = dateValue, nullCount = 0) case TimestampType => - val tsValue = literal.dataType match { - case StringType => - Some(Timestamp.valueOf(literal.value.toString)) - case _ => Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) - } + val tsValue = Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) aColStat.copy(distinctCount = 1, min = tsValue, max = tsValue, nullCount = 0) case _ => aColStat.copy(distinctCount = 1, nullCount = 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index d2f0392839d1..f255733700b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -215,17 +215,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("cdate = '2017-01-02' ") { - val d20170102 = Date.valueOf("2017-01-02") - validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal("2017-01-02")), childStatsTestPlan(Seq(arDate))), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - Some(1L) - ) - } - test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") val d20170102_SQLDate = DateTimeUtils.fromJavaDate(d20170102) @@ -276,18 +265,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("ctimestamp = '2017-01-01 02:00:00' ") { - val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") - validateEstimatedStats( - arTimestamp, - Filter(EqualTo(arTimestamp, Literal("2017-01-01 02:00:00")), - childStatsTestPlan(Seq(arTimestamp))), - ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(1L) - ) - } - test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") val ts2017010102_SQLTS = DateTimeUtils.fromJavaTimestamp(ts2017010102) From f54a6cef0f76fca3ef2152d16d4822df9dcca55b Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Wed, 15 Feb 2017 19:21:13 -0800 Subject: [PATCH 28/36] solve merge conflict in EstimationUtils --- .../plans/logical/statsEstimation/EstimationUtils.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index f226944520d2..4d18b28be866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.types.{DataType, StringType} @@ -51,8 +51,6 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() - def getOutputSize( attributes: Seq[Attribute], outputRowCount: BigInt, From 07e6320a960b94df9b20bdd691a117dd75d37097 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 16 Feb 2017 11:45:58 -0800 Subject: [PATCH 29/36] remove useless type checking since typecocercion already did --- .../statsEstimation/FilterEstimation.scala | 48 +++++-------------- .../FilterEstimationSuite.scala | 35 ++------------ 2 files changed, 16 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 363027bd2ee6..034a9690680c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -309,6 +309,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case DateType => val dateLiteral = litDataType match { + case DateType => + DateTimeUtils.fromJavaDate(litValue.asInstanceOf[Date]).toString case StringType => DateTimeUtils.stringToDate(litValue.asInstanceOf[UTF8String]) .getOrElse(0).toString @@ -318,6 +320,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case TimestampType => val tsLiteral = litDataType match { + case TimestampType => + DateTimeUtils.fromJavaTimestamp(litValue.asInstanceOf[Timestamp]).toString case StringType => DateTimeUtils.stringToTimestamp(litValue.asInstanceOf[UTF8String]) .getOrElse(0).toString @@ -364,11 +368,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo aColStat.copy(distinctCount = 1, min = newValue, max = newValue, nullCount = 0) case DateType => - val dateValue = Some(DateTimeUtils.toJavaDate(literal.value.toString.toInt)) + val dateValue = Some(literal.value) aColStat.copy(distinctCount = 1, min = dateValue, max = dateValue, nullCount = 0) case TimestampType => - val tsValue = Some(DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong)) + val tsValue = Some(literal.value) aColStat.copy(distinctCount = 1, min = tsValue, max = tsValue, nullCount = 0) case _ => aColStat.copy(distinctCount = 1, nullCount = 0) @@ -544,41 +548,13 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } if (update) { - attrRef.dataType match { - case DateType => - val dateValue = literal.dataType match { - case StringType => - Date.valueOf(literal.value.toString) - case _ => DateTimeUtils.toJavaDate(literal.value.toString.toInt) - } - op match { - case _: GreaterThan => newMin = Some(dateValue) - case _: GreaterThanOrEqual => newMin = Some(dateValue) - case _: LessThan => newMax = Some(dateValue) - case _: LessThanOrEqual => newMax = Some(dateValue) - } - - case TimestampType => - val tsValue = literal.dataType match { - case StringType => - Timestamp.valueOf(literal.value.toString) - case _ => DateTimeUtils.toJavaTimestamp(literal.value.toString.toLong) - } - op match { - case _: GreaterThan => newMin = Some(tsValue) - case _: GreaterThanOrEqual => newMin = Some(tsValue) - case _: LessThan => newMax = Some(tsValue) - case _: LessThanOrEqual => newMax = Some(tsValue) - } - - case _ => - op match { - case _: GreaterThan => newMin = Some(literal.value) - case _: GreaterThanOrEqual => newMin = Some(literal.value) - case _: LessThan => newMax = Some(literal.value) - case _: LessThanOrEqual => newMax = Some(literal.value) - } + op match { + case _: GreaterThan => newMin = Some(literal.value) + case _: GreaterThanOrEqual => newMin = Some(literal.value) + case _: LessThan => newMax = Some(literal.value) + case _: LessThanOrEqual => newMax = Some(literal.value) } + newNdv = math.max(math.round(ndv.toDouble * percent), 1) val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index f255733700b1..abfcd82192bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -217,10 +217,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") - val d20170102_SQLDate = DateTimeUtils.fromJavaDate(d20170102) validateEstimatedStats( arDate, - Filter(EqualTo(arDate, Literal(d20170102_SQLDate, IntegerType)), + Filter(EqualTo(arDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), @@ -228,23 +227,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("cdate < '2017-01-03' ") { - val d20170103 = Date.valueOf("2017-01-03") - validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal("2017-01-03")), childStatsTestPlan(Seq(arDate))), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) - } - test("cdate < cast('2017-01-03' AS DATE)") { val d20170103 = Date.valueOf("2017-01-03") - val d20170103_SQLDate = DateTimeUtils.fromJavaDate(d20170103) validateEstimatedStats( arDate, - Filter(LessThan(arDate, Literal(d20170103_SQLDate, IntegerType)), + Filter(LessThan(arDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), @@ -267,10 +254,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("ctimestamp = cast('2017-01-01 02:00:00' AS TIMESTAMP)") { val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") - val ts2017010102_SQLTS = DateTimeUtils.fromJavaTimestamp(ts2017010102) validateEstimatedStats( arTimestamp, - Filter(EqualTo(arTimestamp, Literal(ts2017010102_SQLTS, LongType)), + Filter(EqualTo(arTimestamp, Literal(ts2017010102, TimestampType)), childStatsTestPlan(Seq(arTimestamp))), ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), nullCount = 0, avgLen = 8, maxLen = 8), @@ -278,24 +264,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("ctimestamp < '2017-01-01 03:00:00' ") { - val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") - validateEstimatedStats( - arTimestamp, - Filter(LessThan(arTimestamp, Literal("2017-01-01 03:00:00")), - childStatsTestPlan(Seq(arTimestamp))), - ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), - nullCount = 0, avgLen = 8, maxLen = 8), - Some(3L) - ) - } - test("ctimestamp < cast('2017-01-01 03:00:00' AS TIMESTAMP)") { val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") - val ts2017010102_SQLTS = DateTimeUtils.fromJavaTimestamp(ts2017010103) validateEstimatedStats( arTimestamp, - Filter(LessThan(arTimestamp, Literal(ts2017010102_SQLTS, LongType)), + Filter(LessThan(arTimestamp, Literal(ts2017010103, TimestampType)), childStatsTestPlan(Seq(arTimestamp))), ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), nullCount = 0, avgLen = 8, maxLen = 8), From 7ba660940f95dd6d6c0ca8fa6cd0c0402eb8c516 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Fri, 17 Feb 2017 19:35:53 -0800 Subject: [PATCH 30/36] add string column tests. remove float time tests. --- .../statsEstimation/FilterEstimation.scala | 196 ++++++++---------- .../FilterEstimationSuite.scala | 55 +++-- 2 files changed, 117 insertions(+), 134 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 034a9690680c..8d49abf7a50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -30,9 +30,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** - * @param catalystConf a configuration showing if CBO is enabled - */ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { /** @@ -63,7 +60,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo mutableColStats = mutable.Map(stats.attributeStats.map(kv => (kv._1.exprId, kv._2)).toSeq: _*) // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateConditions(plan.condition) + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { + case Some(percent) => percent + // for not-supported condition, set filter selectivity to a conservative estimate 100% + case None => 1.0 + } // attributeStats has mapping Attribute-to-ColumnStat. // mutableColStats has mapping ExprId-to-ColumnStat. @@ -77,9 +78,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(stats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes: BigInt = EstimationUtils.ceil(BigDecimal( - EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) - )) + val filteredSizeInBytes = + EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) Some(stats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) @@ -99,29 +99,44 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @return a double value to show the percentage of rows meeting a given condition. * Returns 1.0 (a conservative filter estimate) if a condition is not supported. */ - def calculateConditions(condition: Expression, update: Boolean = true): Double = { + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - val p1 = calculateConditions(cond1, update) - val p2 = calculateConditions(cond2, update) - p1 * p2 + val p1 = calculateFilterSelectivity(cond1, update) + val p2 = calculateFilterSelectivity(cond2, update) + p1 match { + case Some(percent1) => p2 match { + case Some(percent2) => Some(percent1 * percent2) + case None => Some(percent1) + } + case None => p2 match { + case Some(percent2) => Some(percent2) + case None => None + } + } case Or(cond1, cond2) => - val p1 = calculateConditions(cond1, update = false) - val p2 = calculateConditions(cond2, update = false) - math.min(1.0, p1 + p2 - (p1 * p2)) + val p1 = calculateFilterSelectivity(cond1, update = false) + val p2 = calculateFilterSelectivity(cond2, update = false) + p1 match { + case Some(percent1) => p2 match { + case Some(percent2) => Some(math.min(1.0, percent1 + percent2 - (percent1 * percent2))) + case None => Some(1.0) + } + case None => p2 match { + case Some(percent2) => Some(1.0) + case None => None + } + } - case Not(cond) => calculateSingleCondition(cond, update = false) match { - case Some(percent) => 1.0 - percent + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } - case _ => calculateSingleCondition(condition, update) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 + case None => None } + case _ => + calculateSingleCondition(condition, update) } } @@ -167,7 +182,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case op @ GreaterThanOrEqual(l: Literal, ar: AttributeReference) => evaluateBinary(LessThanOrEqual(ar, l), ar, l, update) - case In(ar: AttributeReference, expList) if !expList.exists(!_.isInstanceOf[Literal]) => + case In(ar: AttributeReference, expList) + if expList.forall(e => e.isInstanceOf[Literal]) => // Expression [In (value, seq[Literal])] will be replaced with optimized version // [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10. // Here we convert In into InSet anyway, because they share the same processing logic. @@ -177,21 +193,11 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case InSet(ar: AttributeReference, set) => evaluateInSet(ar, set, update) - // It's difficult to estimate IsNull after outer joins. Hence, - // we support IsNull and IsNotNull only when the child is a leaf node (table). case IsNull(ar: AttributeReference) => - if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(ar, isNull = true, update) - } else { - None - } + evaluateIsNull(ar, isNull = true, update) case IsNotNull(ar: AttributeReference) => - if (plan.child.isInstanceOf[LeafNode ]) { - evaluateIsNull(ar, isNull = false, update) - } else { - None - } + evaluateIsNull(ar, isNull = false, update) case _ => // TODO: it's difficult to support string operators without advanced statistics. @@ -209,7 +215,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param isNull set to true for "IS NULL" condition. set to false for "IS NOT NULL" condition * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics collected for a given column. */ def evaluateIsNull( @@ -255,7 +261,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics exists for a given column or wrong value. */ def evaluateBinary( @@ -285,48 +291,34 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * This method converts a numeric or Literal value of numeric type to a BigDecimal value. + * This method converts a Literal value of numeric type to a BigDecimal value. * In order to avoid type casting error such as Java int to Java long, we need to * convert a numeric integer value to String, and then convert it to long, * and then convert it to BigDecimal. * * @param attrDataType the column data type * @param litValue the literal value - * @param litDataType the data type of literal value * @return a BigDecimal value */ def numericLiteralToBigDecimal( attrDataType: DataType, - litValue: Any, - litDataType: DataType) + litValue: Any) : BigDecimal = { attrDataType match { case _: IntegralType => - BigDecimal(java.lang.Long.valueOf(litValue.toString)) + BigDecimal(litValue.toString) case _: FractionalType => BigDecimal(litValue.toString.toDouble) case DateType => - val dateLiteral = litDataType match { - case DateType => - DateTimeUtils.fromJavaDate(litValue.asInstanceOf[Date]).toString - case StringType => - DateTimeUtils.stringToDate(litValue.asInstanceOf[UTF8String]) - .getOrElse(0).toString - case _ => litValue.toString - } + val dateLiteral = + DateTimeUtils.fromJavaDate(litValue.asInstanceOf[Date]).toString BigDecimal(java.lang.Long.valueOf(dateLiteral)) case TimestampType => - val tsLiteral = litDataType match { - case TimestampType => - DateTimeUtils.fromJavaTimestamp(litValue.asInstanceOf[Timestamp]).toString - case StringType => - DateTimeUtils.stringToTimestamp(litValue.asInstanceOf[UTF8String]) - .getOrElse(0).toString - case _ => litValue.toString - } + val tsLiteral = + DateTimeUtils.fromJavaTimestamp(litValue.asInstanceOf[Timestamp]).toString BigDecimal(java.lang.Long.valueOf(tsLiteral)) } } @@ -339,7 +331,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return an optional double value to show the percentage of rows meeting a given condition */ def evaluateEqualTo( attrRef: AttributeReference, @@ -395,7 +387,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param hSet a set of literal values * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return an optional double value to show the percentage of rows meeting a given condition * It returns None if no statistics exists for a given column. */ @@ -412,71 +404,49 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val aColStat = mutableColStats(attrRef.exprId) val ndv = aColStat.distinctCount val aType = attrRef.dataType + var newNdv: Long = 0 // use [min, max] to filter the original hSet - val validQuerySet = aType match { + aType match { case _: NumericType | DateType | TimestampType => val statsRange = Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - val hSetMap = hSet.map(e => - if (e.isInstanceOf[String]) { - val utf8String = UTF8String.fromString(e.asInstanceOf[String]) - numericLiteralToBigDecimal(aType, utf8String, StringType) - } else { - numericLiteralToBigDecimal(aType, e, aType) - }) - hSetMap.filter(e => e >= statsRange.min && e <= statsRange.max) - // We assume the whole set since there is no min/max information for String/Binary type - case StringType | BinaryType => hSet - } - if (validQuerySet.isEmpty) { - return Some(0.0) - } + // To faciliate finding the min and max values in hSet, we map hSet values to BigDecimal. + // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. + // We use hSetBigdecToAnyMap to help us find the original hSet value. + val hSetBigdecToAnyMap: Map[BigDecimal, Any] = + hSet.map(e => numericLiteralToBigDecimal(aType, e) -> e).toMap + val hSetBigdec = hSet.map(e => numericLiteralToBigDecimal(aType, e)) + val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) - // newNdv should be no greater than the old ndv. For example, column has only 2 values - // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - val newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) - val(newMax, newMin) = aType match { - case _: NumericType => - val tmpSet: Set[Double] = validQuerySet.map(e => e.toString.toDouble) - (Some(tmpSet.max), Some(tmpSet.min)) - case DateType => - if (hSet.forall(e => e.isInstanceOf[String])) { - val dateMax = Date.valueOf(hSet.asInstanceOf[Set[String]].max) - val dateMin = Date.valueOf(hSet.asInstanceOf[Set[String]].min) - (Some(dateMax), Some(dateMin)) - } else { - val tmpSet: Set[Long] = validQuerySet.map(e => e.toString.toLong) - (Some(tmpSet.max), Some(tmpSet.min)) + if (validQuerySet.isEmpty) { + return Some(0.0) } - case TimestampType => - if (hSet.forall(e => e.isInstanceOf[String])) { - val dateMax = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].max) - val dateMin = Timestamp.valueOf(hSet.asInstanceOf[Set[String]].min) - (Some(dateMax), Some(dateMin)) - } else { - val tmpSet: Set[Long] = validQuerySet.map(e => e.toString.toLong) - (Some(tmpSet.max), Some(tmpSet.min)) + + val newMax = Some(hSetBigdecToAnyMap(validQuerySet.max)) + val newMin = Some(hSetBigdecToAnyMap(validQuerySet.min)) + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, min = newMin, + max = newMax, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) } - case _ => - (None, None) - } - if (update) { - val newStats = attrRef.dataType match { - case _: NumericType | DateType | TimestampType => - aColStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) - case StringType | BinaryType => - aColStat.copy(distinctCount = newNdv, nullCount = 0) - } - mutableColStats += (attrRef.exprId -> newStats) + // We assume the whole set since there is no min/max information for String/Binary type + case StringType | BinaryType => + newNdv = math.min(hSet.size.toLong, ndv.longValue()) + if (update) { + val newStats = aColStat.copy(distinctCount = newNdv, nullCount = 0) + mutableColStats += (attrRef.exprId -> newStats) + } } // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, validQuerySet.size / ndv.toDouble)) + Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) } /** @@ -488,7 +458,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions - * @return a doube value to show the percentage of rows meeting a given condition + * @return an optional double value to show the percentage of rows meeting a given condition */ def evaluateBinaryForNumeric( op: BinaryComparison, @@ -504,7 +474,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] val literalValueBD = - numericLiteralToBigDecimal(attrRef.dataType, literal.value, literal.dataType) + numericLiteralToBigDecimal(attrRef.dataType, literal.value) // determine the overlapping degree between predicate range and column's range val (noOverlap: Boolean, completeOverlap: Boolean) = op match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index abfcd82192bf..70e2d9cac4b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -53,23 +53,24 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fourth column cdate has 10 values from 0.20 through 2.00 at increment of 0.2. + // Fourth column cdecimal has 10 values from 0.20 through 2.00 at increment of 0.2. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("2.000000000000000000") val arDecimal = AttributeReference("cdecimal", DecimalType(12, 2))() val childColStatDecimal = ColumnStat(distinctCount = 10, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fifth column cfloat has 10 float values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arFloat = AttributeReference("cfloat", FloatType)() - val childColStatFloat = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 4, maxLen = 4) - - // Sixth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", FloatType)() + // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val arDouble = AttributeReference("cdouble", DoubleType)() val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) + // Sixth column cstring has 10 String values: + // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" + val arString = AttributeReference("cstring", StringType)() + val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2) + test("cint = 2") { validateEstimatedStats( arInt, @@ -241,10 +242,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cdate IN ('2017-01-03', '2017-01-04', '2017-01-05')") { val d20170103 = Date.valueOf("2017-01-03") + val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( arDate, - Filter(InSet(arDate, Set("2017-01-03", "2017-01-04", "2017-01-05")), + Filter(InSet(arDate, Set(d20170103, d20170104, d20170105)), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4), @@ -300,16 +302,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - test("cfloat < 3.0") { - validateEstimatedStats( - arFloat, - Filter(LessThan(arFloat, Literal(3.0)), childStatsTestPlan(Seq(arFloat))), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 4, maxLen = 4), - Some(3L) - ) - } - test("cdouble < 3.0") { validateEstimatedStats( arDouble, @@ -320,6 +312,27 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } + test("cstring = 'A2'") { + validateEstimatedStats( + arString, + Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString))), + ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(1L) + ) + } + + // There is no min/max statistics for String type. We estimate 10 rows returned. + test("cstring < 'A2'") { + validateEstimatedStats( + arString, + Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString))), + ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2), + Some(10L) + ) + } + // This is a corner test case. We want to test if we can handle the case when the number of // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. @@ -350,8 +363,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { arDate -> childColStatDate, arTimestamp -> childColStatTimestamp, arDecimal -> childColStatDecimal, - arFloat -> childColStatFloat, - arDouble -> childColStatDouble + arDouble -> childColStatDouble, + arString -> childColStatString )) ) } From 1f3619f963e9dd4e85d7680c04770d374a81710e Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sat, 18 Feb 2017 13:52:39 -0800 Subject: [PATCH 31/36] improve readability --- .../statsEstimation/FilterEstimation.scala | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 8d49abf7a50d..f3eddf77ad8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -97,37 +97,27 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * @param update a boolean flag to specify if we need to update ColumnStat of a column * for subsequent conditions * @return a double value to show the percentage of rows meeting a given condition. - * Returns 1.0 (a conservative filter estimate) if a condition is not supported. + * It returns None if the condition is not supported. */ def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - val p1 = calculateFilterSelectivity(cond1, update) - val p2 = calculateFilterSelectivity(cond2, update) - p1 match { - case Some(percent1) => p2 match { - case Some(percent2) => Some(percent1 * percent2) - case None => Some(percent1) - } - case None => p2 match { - case Some(percent2) => Some(percent2) - case None => None - } + (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + match { + case (Some(p1), Some(p2)) => Some(p1 * p2) + case (Some(p1), None) => Some(p1) + case (None, Some(p2)) => Some(p2) + case (None, None) => None } case Or(cond1, cond2) => - val p1 = calculateFilterSelectivity(cond1, update = false) - val p2 = calculateFilterSelectivity(cond2, update = false) - p1 match { - case Some(percent1) => p2 match { - case Some(percent2) => Some(math.min(1.0, percent1 + percent2 - (percent1 * percent2))) - case None => Some(1.0) - } - case None => p2 match { - case Some(percent2) => Some(1.0) - case None => None - } + (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + match { + case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) + case (Some(p1), None) => Some(1.0) + case (None, Some(p2)) => Some(1.0) + case (None, None) => None } case Not(cond) => calculateFilterSelectivity(cond, update = false) match { @@ -135,6 +125,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // for not-supported condition, set filter selectivity to a conservative estimate 100% case None => None } + case _ => calculateSingleCondition(condition, update) } From 6411d2122f06213bb28aee53e9bf2b278e9d168e Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sat, 18 Feb 2017 14:24:50 -0800 Subject: [PATCH 32/36] specify update = false for Or condition --- .../plans/logical/statsEstimation/FilterEstimation.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index f3eddf77ad8d..148663f44049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -112,7 +112,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } case Or(cond1, cond2) => - (calculateFilterSelectivity(cond1, update), calculateFilterSelectivity(cond2, update)) + (calculateFilterSelectivity(cond1, update = false), + calculateFilterSelectivity(cond2, update = false)) match { case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) case (Some(p1), None) => Some(1.0) From 11b6a0b2aab4cf14358c059381a8eee577b52079 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Sun, 19 Feb 2017 13:53:13 -0800 Subject: [PATCH 33/36] remove the unused import. save with Unix style line ending --- .../plans/logical/statsEstimation/FilterEstimation.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 148663f44049..06d3789139d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { @@ -112,9 +111,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } case Or(cond1, cond2) => - (calculateFilterSelectivity(cond1, update = false), - calculateFilterSelectivity(cond2, update = false)) - match { + // For ease of debugging, we compute percent1 and percent2 in 2 statements. + val percent1 = calculateFilterSelectivity(cond1, update = false) + val percent2 = calculateFilterSelectivity(cond2, update = false) + (percent1, percent2) match { case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) case (Some(p1), None) => Some(1.0) case (None, Some(p2)) => Some(1.0) From 298d2558ba84c7c1b22a6f59e005fe61e2b13cfe Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 21 Feb 2017 10:30:30 -0800 Subject: [PATCH 34/36] update date column test case --- .../plans/logical/statsEstimation/FilterEstimation.scala | 2 +- .../sql/catalyst/statsEstimation/FilterEstimationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 06d3789139d6..42b8ec8cd6d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -404,7 +404,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val statsRange = Range(aColStat.min, aColStat.max, aType).asInstanceOf[NumericRange] - // To faciliate finding the min and max values in hSet, we map hSet values to BigDecimal. + // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. // We use hSetBigdecToAnyMap to help us find the original hSet value. val hSetBigdecToAnyMap: Map[BigDecimal, Any] = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 70e2d9cac4b5..12674d21dcf6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -220,7 +220,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( arDate, - Filter(EqualTo(arDate, Literal(d20170102, DateType)), + Filter(EqualTo(arDate, Literal(d20170102)), childStatsTestPlan(Seq(arDate))), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), From eac69aff9c15b94cd98271ebe787fec530d3e207 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 23 Feb 2017 16:36:12 -0800 Subject: [PATCH 35/36] clean up internal/external type conversion --- .../statsEstimation/FilterEstimation.scala | 90 +++++++----------- .../plans/logical/statsEstimation/Range.scala | 15 +++ .../FilterEstimationSuite.scala | 92 +++++++++++-------- 3 files changed, 103 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 42b8ec8cd6d7..80005ebb4024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -283,35 +283,24 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * This method converts a Literal value of numeric type to a BigDecimal value. - * In order to avoid type casting error such as Java int to Java long, we need to - * convert a numeric integer value to String, and then convert it to long, - * and then convert it to BigDecimal. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def numericLiteralToBigDecimal( - attrDataType: DataType, - litValue: Any) - : BigDecimal = { + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ + def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { attrDataType match { - case _: IntegralType => - BigDecimal(litValue.toString) - - case _: FractionalType => - BigDecimal(litValue.toString.toDouble) - case DateType => - val dateLiteral = - DateTimeUtils.fromJavaDate(litValue.asInstanceOf[Date]).toString - BigDecimal(java.lang.Long.valueOf(dateLiteral)) - + Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) case TimestampType => - val tsLiteral = - DateTimeUtils.fromJavaTimestamp(litValue.asInstanceOf[Timestamp]).toString - BigDecimal(java.lang.Long.valueOf(tsLiteral)) + Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) + case StringType | BinaryType => + None + case _ => + Some(litValue) } } @@ -338,29 +327,17 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType) - val litRange = Range(Some(literal.value), Some(literal.value), literal.dataType) - val inBoundary: Boolean = Range.isIntersected(statsRange, litRange) + val inBoundary: Boolean = Range.rangeContainsLiteral(statsRange, literal) if (inBoundary) { if (update) { // We update ColumnStat structure after apply this equality predicate. // Set distinctCount to 1. Set nullCount to 0. - val newStats = attrRef.dataType match { - case _: NumericType => - val newValue = Some(literal.value) - aColStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) - case DateType => - val dateValue = Some(literal.value) - aColStat.copy(distinctCount = 1, min = dateValue, - max = dateValue, nullCount = 0) - case TimestampType => - val tsValue = Some(literal.value) - aColStat.copy(distinctCount = 1, min = tsValue, - max = tsValue, nullCount = 0) - case _ => aColStat.copy(distinctCount = 1, nullCount = 0) - } + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + val newStats = aColStat.copy(distinctCount = 1, min = newValue, + max = newValue, nullCount = 0) mutableColStats += (attrRef.exprId -> newStats) } @@ -406,18 +383,20 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // To facilitate finding the min and max values in hSet, we map hSet values to BigDecimal. // Using hSetBigdec, we can find the min and max values quickly in the ordered hSetBigdec. + val hSetBigdec = hSet.map(e => BigDecimal(e.toString)) + val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) // We use hSetBigdecToAnyMap to help us find the original hSet value. val hSetBigdecToAnyMap: Map[BigDecimal, Any] = - hSet.map(e => numericLiteralToBigDecimal(aType, e) -> e).toMap - val hSetBigdec = hSet.map(e => numericLiteralToBigDecimal(aType, e)) - val validQuerySet = hSetBigdec.filter(e => e >= statsRange.min && e <= statsRange.max) + hSet.map(e => BigDecimal(e.toString) -> e).toMap if (validQuerySet.isEmpty) { return Some(0.0) } - val newMax = Some(hSetBigdecToAnyMap(validQuerySet.max)) - val newMin = Some(hSetBigdecToAnyMap(validQuerySet.min)) + // Need to save new min/max using the external type value of the literal + val newMax = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.max)) + val newMin = convertBoundValue(attrRef.dataType, hSetBigdecToAnyMap(validQuerySet.min)) + // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = math.min(validQuerySet.size.toLong, ndv.longValue()) @@ -465,10 +444,8 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val statsRange = Range(aColStat.min, aColStat.max, attrRef.dataType).asInstanceOf[NumericRange] - val literalValueBD = - numericLiteralToBigDecimal(attrRef.dataType, literal.value) - // determine the overlapping degree between predicate range and column's range + val literalValueBD = BigDecimal(literal.value.toString) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) @@ -509,12 +486,15 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo else (maxToDouble - literalToDouble) / (maxToDouble - minToDouble) } + // Need to save new min/max using the external type value of the literal + val newValue = convertBoundValue(attrRef.dataType, literal.value) + if (update) { op match { - case _: GreaterThan => newMin = Some(literal.value) - case _: GreaterThanOrEqual => newMin = Some(literal.value) - case _: LessThan => newMax = Some(literal.value) - case _: LessThanOrEqual => newMax = Some(literal.value) + case _: GreaterThan => newMin = newValue + case _: GreaterThanOrEqual => newMin = newValue + case _: LessThan => newMax = newValue + case _: LessThanOrEqual => newMax = newValue } newNdv = math.max(math.round(ndv.toDouble * percent), 1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 4a346c924db6..455711453272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import java.math.{BigDecimal => JDecimal} import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} @@ -57,6 +58,20 @@ object Range { n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } + def rangeContainsLiteral(r: Range, lit: Literal): Boolean = r match { + case _: DefaultRange => true + case _: NullRange => false + case n: NumericRange => + val literalValue = if (lit.dataType.isInstanceOf[BooleanType]) { + if (lit.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) + } else { + assert(lit.dataType.isInstanceOf[NumericType] || lit.dataType.isInstanceOf[DateType] || + lit.dataType.isInstanceOf[TimestampType]) + new JDecimal(lit.value.toString) + } + n.min.compareTo(literalValue) <= 0 && n.max.compareTo(literalValue) >= 0 + } + /** * Intersected results of two ranges. This is only for two overlapped ranges. * The outputs are the intersected min/max values. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 12674d21dcf6..f139c9e28c6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -53,11 +53,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val childColStatTimestamp = ColumnStat(distinctCount = 10, min = Some(tsMin), max = Some(tsMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fourth column cdecimal has 10 values from 0.20 through 2.00 at increment of 0.2. + // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("2.000000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(12, 2))() - val childColStatDecimal = ColumnStat(distinctCount = 10, min = Some(decMin), max = Some(decMax), + val decMax = new java.math.BigDecimal("0.800000000000000000") + val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 @@ -74,7 +74,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint = 2") { validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt))), + Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -85,7 +85,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt))), + Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -95,7 +95,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt))), + Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -106,7 +106,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt))), + Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -116,7 +116,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt))), + Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -126,7 +126,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint > 6") { validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt))), + Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -137,7 +137,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // This is a corner case since max value is 10. validateEstimatedStats( arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt))), + Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -147,7 +147,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt))), + Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(5L) @@ -157,7 +157,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NULL") { validateEstimatedStats( arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt))), + Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, avgLen = 4, maxLen = 4), Some(0L) @@ -167,7 +167,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt))), + Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(10L) @@ -178,7 +178,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, childStatsTestPlan(Seq(arInt))), + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4), Some(4L) @@ -189,7 +189,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) validateEstimatedStats( arInt, - Filter(condition, childStatsTestPlan(Seq(arInt))), + Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(2L) @@ -199,7 +199,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IN (3, 4, 5)") { validateEstimatedStats( arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt))), + Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -209,7 +209,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt))), + Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), Some(7L) @@ -221,7 +221,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( arDate, Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate))), + childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4), Some(1L) @@ -232,22 +232,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( arDate, - Filter(LessThan(arDate, Literal(d20170103, DateType)), - childStatsTestPlan(Seq(arDate))), + Filter(LessThan(arDate, Literal(d20170103)), + childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) ) } - test("cdate IN ('2017-01-03', '2017-01-04', '2017-01-05')") { + test("""cdate IN ( cast('2017-01-03' AS DATE), + cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { val d20170103 = Date.valueOf("2017-01-03") val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( arDate, - Filter(InSet(arDate, Set(d20170103, d20170104, d20170105)), - childStatsTestPlan(Seq(arDate))), + Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(arDate), 10L)), ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4), Some(3L) @@ -258,8 +259,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val ts2017010102 = Timestamp.valueOf("2017-01-01 02:00:00") validateEstimatedStats( arTimestamp, - Filter(EqualTo(arTimestamp, Literal(ts2017010102, TimestampType)), - childStatsTestPlan(Seq(arTimestamp))), + Filter(EqualTo(arTimestamp, Literal(ts2017010102)), + childStatsTestPlan(Seq(arTimestamp), 10L)), ColumnStat(distinctCount = 1, min = Some(ts2017010102), max = Some(ts2017010102), nullCount = 0, avgLen = 8, maxLen = 8), Some(1L) @@ -270,20 +271,20 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val ts2017010103 = Timestamp.valueOf("2017-01-01 03:00:00") validateEstimatedStats( arTimestamp, - Filter(LessThan(arTimestamp, Literal(ts2017010103, TimestampType)), - childStatsTestPlan(Seq(arTimestamp))), + Filter(LessThan(arTimestamp, Literal(ts2017010103)), + childStatsTestPlan(Seq(arTimestamp), 10L)), ColumnStat(distinctCount = 2, min = Some(tsMin), max = Some(ts2017010103), nullCount = 0, avgLen = 8, maxLen = 8), Some(3L) ) } - test("cdecimal = 0.40") { + test("cdecimal = 0.400000000000000000") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40, DecimalType(12, 2))), - childStatsTestPlan(Seq(arDecimal))), + Filter(EqualTo(arDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(arDecimal), 4L)), ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), nullCount = 0, avgLen = 8, maxLen = 8), Some(1L) @@ -295,8 +296,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( arDecimal, Filter(LessThan(arDecimal, Literal(dec_0_60, DecimalType(12, 2))), - childStatsTestPlan(Seq(arDecimal))), - ColumnStat(distinctCount = 2, min = Some(decMin), max = Some(dec_0_60), + childStatsTestPlan(Seq(arDecimal), 4L)), + ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), nullCount = 0, avgLen = 8, maxLen = 8), Some(3L) ) @@ -305,7 +306,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cdouble < 3.0") { validateEstimatedStats( arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble))), + Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), nullCount = 0, avgLen = 8, maxLen = 8), Some(3L) @@ -315,7 +316,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cstring = 'A2'") { validateEstimatedStats( arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString))), + Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), Some(1L) @@ -326,7 +327,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cstring < 'A2'") { validateEstimatedStats( arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString))), + Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2), Some(10L) @@ -354,10 +355,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) } - private def childStatsTestPlan(outList: Seq[Attribute]): StatsTestPlan = { + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, - rowCount = 10L, + rowCount = tableRowCount, attributeStats = AttributeMap(Seq( arInt -> childColStatInt, arDate -> childColStatDate, @@ -383,7 +384,20 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val filteredStats = filterNode.stats(conf) assert(filteredStats.sizeInBytes == expectedSizeInBytes) assert(filteredStats.rowCount == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) + ar.dataType match { + case DecimalType() => + // Due to the internal transformation for DecimalType within engine, the new min/max + // in ColumnStat may have a different structure even it contains the right values. + // We convert them to Java BigDecimal values so that we can compare the entire object. + val generatedColumnStats = filteredStats.attributeStats(ar) + val newMax = new java.math.BigDecimal(generatedColumnStats.max.getOrElse(0).toString) + val newMin = new java.math.BigDecimal(generatedColumnStats.min.getOrElse(0).toString) + val outputColStats = generatedColumnStats.copy(min = Some(newMin), max = Some(newMax)) + assert(outputColStats == expectedColStats) + case _ => + // For all other SQL types, we compare the entire object directly. + assert(filteredStats.attributeStats(ar) == expectedColStats) + } } } From a48a4fd7dbc1e926cbe0836017dda86ca7486002 Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Thu, 23 Feb 2017 17:01:50 -0800 Subject: [PATCH 36/36] use Javadoc style indentation for multiline comments --- .../statsEstimation/FilterEstimation.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 80005ebb4024..fcc607a610fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -283,14 +283,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ + * For a SQL data type, its internal data type may be different from its external type. + * For DateType, its internal type is Int, and its external data type is Java Date type. + * The min/max values in ColumnStat are saved in their corresponding external type. + * + * @param attrDataType the column data type + * @param litValue the literal value + * @return a BigDecimal value + */ def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { attrDataType match { case DateType =>