Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,49 +192,17 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
var bit = g.groupByExprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
bit -= 1
}

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == g.gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Cube if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case a: Rollup if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case x: GroupingSets if x.resolved =>
case a: Cube =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case a: Rollup =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
Aggregate(
x.groupByExprs :+ x.gid,
x.groupByExprs :+ VirtualColumn.groupingIdAttribute,
x.aggregations,
Expand(expand(x), x.child.output :+ x.gid, x.child))
Expand(x.bitmasks, x.groupByExprs, gid, x.child))
}
}

Expand Down Expand Up @@ -368,12 +336,7 @@ class Analyzer(

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
q.asInstanceOf[GroupingAnalytics].gid
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E

object VirtualColumn {
val groupingIdName: String = "grouping__id"
def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))

// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
Expand Down Expand Up @@ -228,24 +229,76 @@ case class Window(
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param projections The group of expressions, all of the group expressions should
* output the same schema specified by the parameter `output`
* @param output The output Schema
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param child Child operator
*/
case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
gid: Attribute,
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

val projections: Seq[Seq[Expression]] = expand()

/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}

set
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)

val substitution = (child.output :+ gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
case x if x == gid =>
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
}

override def output: Seq[Attribute] = {
child.output :+ gid
}
}

trait GroupingAnalytics extends UnaryNode {
self: Product =>
def gid: AttributeReference
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]

Expand All @@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
* The associated output will be one of the value in `bitmasks`
*/
case class GroupingSets(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand All @@ -290,15 +338,11 @@ case class GroupingSets(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand All @@ -313,15 +357,11 @@ case class Cube(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
* the bitmask indicates the selected GroupBy Expressions for each
* aggregating output row.
*/
case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {

def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Expand(projections, output, child) =>
execution.Expand(projections, output, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Window(projectList, windowExpressions, spec, child) =>
Expand Down