Skip to content

Commit 3ccabdf

Browse files
wzhfyrxin
authored andcommitted
[SPARK-17077][SQL] Cardinality estimation for project operator
## What changes were proposed in this pull request? Support cardinality estimation for project operator. ## How was this patch tested? Add a test suite and a base class in the catalyst package. Author: Zhenhua Wang <[email protected]> Closes #16430 from wzhfy/projectEstimation.
1 parent 19d9d4c commit 3ccabdf

File tree

6 files changed

+196
-0
lines changed

6 files changed

+196
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
3333

3434
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
3535

36+
override def contains(k: Attribute): Boolean = get(k).isDefined
37+
3638
override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
3739

3840
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans._
25+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
2526
import org.apache.spark.sql.types._
2627
import org.apache.spark.util.Utils
2728

@@ -53,6 +54,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
5354

5455
override def validConstraints: Set[Expression] =
5556
child.constraints.union(getAliasedConstraints(projectList))
57+
58+
override lazy val statistics: Statistics =
59+
ProjectEstimation.estimate(this).getOrElse(super.statistics)
5660
}
5761

5862
/**
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
21+
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
22+
import org.apache.spark.sql.types.StringType
23+
24+
25+
object EstimationUtils {
26+
27+
/** Check if each plan has rowCount in its statistics. */
28+
def rowCountsExist(plans: LogicalPlan*): Boolean =
29+
plans.forall(_.statistics.rowCount.isDefined)
30+
31+
/** Get column stats for output attributes. */
32+
def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute])
33+
: AttributeMap[ColumnStat] = {
34+
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
35+
}
36+
37+
def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = {
38+
// We assign a generic overhead for a Row object, the actual overhead is different for different
39+
// Row format.
40+
8 + attributes.map { attr =>
41+
if (attrStats.contains(attr)) {
42+
attr.dataType match {
43+
case StringType =>
44+
// UTF8String: base + offset + numBytes
45+
attrStats(attr).avgLen + 8 + 4
46+
case _ =>
47+
attrStats(attr).avgLen
48+
}
49+
} else {
50+
attr.dataType.defaultSize
51+
}
52+
}.sum
53+
}
54+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap}
21+
import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics}
22+
23+
object ProjectEstimation {
24+
import EstimationUtils._
25+
26+
def estimate(project: Project): Option[Statistics] = {
27+
if (rowCountsExist(project.child)) {
28+
val childStats = project.child.statistics
29+
val inputAttrStats = childStats.attributeStats
30+
// Match alias with its child's column stat
31+
val aliasStats = project.expressions.collect {
32+
case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) =>
33+
alias.toAttribute -> inputAttrStats(attr)
34+
}
35+
val outputAttrStats =
36+
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
37+
Some(childStats.copy(
38+
sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats),
39+
attributeStats = outputAttrStats))
40+
} else {
41+
None
42+
}
43+
}
44+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.statsEstimation
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference}
21+
import org.apache.spark.sql.catalyst.plans.logical._
22+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
23+
import org.apache.spark.sql.types.IntegerType
24+
25+
26+
class ProjectEstimationSuite extends StatsEstimationTestBase {
27+
28+
test("estimate project with alias") {
29+
val ar1 = AttributeReference("key1", IntegerType)()
30+
val ar2 = AttributeReference("key2", IntegerType)()
31+
val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4)
32+
val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4)
33+
34+
val child = StatsTestPlan(
35+
outputList = Seq(ar1, ar2),
36+
stats = Statistics(
37+
sizeInBytes = 2 * (4 + 4),
38+
rowCount = Some(2),
39+
attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2))))
40+
41+
val project = Project(Seq(ar1, Alias(ar2, "abc")()), child)
42+
val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2)
43+
val expectedAttrStats = toAttributeMap(expectedColStats, project)
44+
// The number of rows won't change for project.
45+
val expectedStats = Statistics(
46+
sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats),
47+
rowCount = Some(2),
48+
attributeStats = expectedAttrStats)
49+
assert(project.statistics == expectedStats)
50+
}
51+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.statsEstimation
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
22+
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
23+
24+
25+
class StatsEstimationTestBase extends SparkFunSuite {
26+
27+
/** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
28+
def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
29+
: AttributeMap[ColumnStat] = {
30+
val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap
31+
AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2))
32+
}
33+
}
34+
35+
/**
36+
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
37+
*/
38+
protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode {
39+
override def output: Seq[Attribute] = outputList
40+
override lazy val statistics = stats
41+
}

0 commit comments

Comments
 (0)