From c9d0c1d51fb920028496e34191220d06b7bdb233 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 2 Nov 2015 08:15:55 +0100 Subject: [PATCH 1/8] rebase --- .../expressions/aggregate/Utils.scala | 155 +++++++++++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../plans/logical/basicOperators.scala | 80 +++++---- .../spark/sql/execution/SparkStrategies.scala | 2 +- 4 files changed, 199 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 644c6211d5f3..1a7b513a0e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -41,7 +42,7 @@ object Utils { private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Average(child), @@ -144,7 +145,8 @@ object Utils { aggregateFunction = aggregate.VarianceSamp(child), mode = aggregate.Complete, isDistinct = false) - } + }) + // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => @@ -156,6 +158,7 @@ object Utils { } // Check if there are multiple distinct columns. + // TODO remove this. val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg @@ -213,3 +216,147 @@ object Utils { case other => None } } + +/** + * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * TODO Expression cannocalization + * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate + * operator. Perhaps this is a good thing? It is much simpler to plan later on... + */ +object MultipleDistinctRewriter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression2 => + ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Only continue to rewrite if there is more than one distinct group. + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.map(expressionAttributePair) + val groupByAttrs = groupByMap.map(_._2) + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap + val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { e => + if (group.contains(e)) { + e + } else { + nullify(e) + } + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = af.withNewChildren(af.children.map { case afc => + // Make sure only the input originating from the projection above is used for + // aggregation. + If(EqualTo(gid, id), distinctAggChildAttrMap(afc), nullify(afc)) + }).asInstanceOf[AggregateFunction2] + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + + // Setup aggregates for 'regular' aggregate expressions. + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val a = Alias(e.transform(regularAggChildAttrMap), "ra")() + // Get the result of the first aggregate in the last aggregate. + val b = AggregateExpression2(aggregate.First( + If(EqualTo(gid, Literal(0)), a.toAttribute, nullify(e)), Literal(true)), + mode = Complete, + isDistinct = false) + (e, a, b) + } + + // Construct the regular aggregate input projection only when we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(Literal(0)) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ regularAggChildAttrMap.values.toSeq :+ gid, + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations = (groupByMap ++ + distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown(transformations).asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = (e, toAttribute(e)) + + private def toAttribute(e: Expression) = e match { + case ne: NamedExpression => ne.toAttribute.withNullability(true) + case e: Expression => new AttributeReference(e.prettyName, e.dataType, true)() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a..d222dfa33ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33e..fb963e2f8f7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -235,33 +235,17 @@ case class Window( projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { + private def buildNonSelectExprSet( + bitmask: Int, + exprs: Seq[Expression]): OpenHashSet[Expression] = { val set = new OpenHashSet[Expression](2) var bit = exprs.length - 1 @@ -274,18 +258,28 @@ case class Expand( } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { + (child.output :+ gid).map(expr => expr transformDown { + // TODO this causes a problem when a column is used both for grouping and aggregation. case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null @@ -294,15 +288,29 @@ case class Expand( // replace the groupingId with concrete value (the bit mask) Literal.create(bitmask, IntegerType) }) - - result += substitution } - - result.toSeq + Expand(projections, child.output :+ gid, child) } +} - override def output: Seq[Attribute] = { - child.output :+ gid +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def statistics: Statistics = { + // TODO shouldn't we factor in the size of the projection versus the size of the backing child + // row? + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f..dd3bb33c5728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled From 733fcede821bbe0f6781c938386c4ca14cc01a7a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 2 Nov 2015 10:30:52 +0100 Subject: [PATCH 2/8] Fix a few small bugs. --- .../expressions/aggregate/Utils.scala | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 1a7b513a0e21..f12a7cad691d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -223,6 +223,7 @@ object Utils { * in a separate group. The results are then combined in a second aggregate. * * TODO Expression cannocalization + * TODO Eliminate foldable expressions from distinct clauses. * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate * operator. Perhaps this is a good thing? It is much simpler to plan later on... */ @@ -238,8 +239,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Collect all aggregate expressions. val aggExpressions = a.aggregateExpressions.flatMap { e => e.collect { - case ae: AggregateExpression2 => - ae + case ae: AggregateExpression2 => ae } } @@ -255,6 +255,17 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val groupByMap = a.groupingExpressions.map(expressionAttributePair) val groupByAttrs = groupByMap.map(_._2) + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction2, + id: Literal, + attrs: Map[Expression, Expression]): AggregateFunction2 = { + af.withNewChildren(af.children.map { case afc => + evalWithinGroup(id, attrs(afc)) + }).asInstanceOf[AggregateFunction2] + } + // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap @@ -277,11 +288,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Final aggregate val operators = expressions.map { e => val af = e.aggregateFunction - val naf = af.withNewChildren(af.children.map { case afc => - // Make sure only the input originating from the projection above is used for - // aggregation. - If(EqualTo(gid, id), distinctAggChildAttrMap(afc), nullify(afc)) - }).asInstanceOf[AggregateFunction2] + val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -295,13 +302,17 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Setup aggregates for 'regular' aggregate expressions. val regularAggOperatorMap = regularAggExprs.map { e => + val id = Literal(0) + // Perform the actual aggregation in the initial aggregate. - val a = Alias(e.transform(regularAggChildAttrMap), "ra")() + val af = patchAggregateFunctionChildren(e.aggregateFunction, id, regularAggChildAttrMap) + val a = Alias(e.copy(aggregateFunction = af), "ra")() + // Get the result of the first aggregate in the last aggregate. val b = AggregateExpression2(aggregate.First( - If(EqualTo(gid, Literal(0)), a.toAttribute, nullify(e)), Literal(true)), - mode = Complete, - isDistinct = false) + evalWithinGroup(id, a.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) (e, a, b) } @@ -327,7 +338,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Construct the expand operator. val expand = Expand( regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ regularAggChildAttrMap.values.toSeq :+ gid, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, a.child) // Construct the first aggregate operator. This de-duplicates the all the children of From d85462dca9b894b2582d071dcf395bf1808228cc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 4 Nov 2015 13:51:46 +0100 Subject: [PATCH 3/8] Improve readability --- .../spark/sql/catalyst/expressions/aggregate/Utils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index f12a7cad691d..c46b1ab204d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -309,8 +309,8 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val a = Alias(e.copy(aggregateFunction = af), "ra")() // Get the result of the first aggregate in the last aggregate. - val b = AggregateExpression2(aggregate.First( - evalWithinGroup(id, a.toAttribute), Literal(true)), + val b = AggregateExpression2( + aggregate.First(evalWithinGroup(id, a.toAttribute), Literal(true)), mode = Complete, isDistinct = false) (e, a, b) From 7b5369c215f155dad8e61baaa7a005eaf2fd1ca5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 4 Nov 2015 16:57:42 +0100 Subject: [PATCH 4/8] Fix issue with variable reuse between regular and distinct aggregate operators. --- .../expressions/aggregate/Utils.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index c46b1ab204d6..e1663eac3bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -252,7 +252,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() - val groupByMap = a.groupingExpressions.map(expressionAttributePair) + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + } val groupByAttrs = groupByMap.map(_._2) // Functions used to modify aggregate functions and their inputs. @@ -306,7 +309,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Perform the actual aggregation in the initial aggregate. val af = patchAggregateFunctionChildren(e.aggregateFunction, id, regularAggChildAttrMap) - val a = Alias(e.copy(aggregateFunction = af), "ra")() + val a = Alias(e.copy(aggregateFunction = af), e.toString)() // Get the result of the first aggregate in the last aggregate. val b = AggregateExpression2( @@ -364,10 +367,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { private def nullify(e: Expression) = Literal.create(null, e.dataType) - private def expressionAttributePair(e: Expression) = (e, toAttribute(e)) - - private def toAttribute(e: Expression) = e match { - case ne: NamedExpression => ne.toAttribute.withNullability(true) - case e: Expression => new AttributeReference(e.prettyName, e.dataType, true)() - } + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing a reference in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyName, e.dataType, true)() } From d626c20499af12f6ee901eb851ab4d6a4915a80a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Nov 2015 22:01:57 +0100 Subject: [PATCH 5/8] Fix Group By Clause equality --- .../expressions/aggregate/Utils.scala | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index e1663eac3bd9..f0a3642d61a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -244,6 +244,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. + // TODO try to get the distinct group as small val distinctAggGroups = aggExpressions .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) @@ -280,12 +281,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val id = Literal(i + 1) // Expand projection - val projection = distinctAggChildren.map { e => - if (group.contains(e)) { - e - } else { - nullify(e) - } + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) } :+ id // Final aggregate @@ -304,26 +302,28 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) val regularAggOperatorMap = regularAggExprs.map { e => - val id = Literal(0) - // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction, id, regularAggChildAttrMap) + val af = patchAggregateFunctionChildren( + e.aggregateFunction, + regularGroupId, + regularAggChildAttrMap) val a = Alias(e.copy(aggregateFunction = af), e.toString)() // Get the result of the first aggregate in the last aggregate. val b = AggregateExpression2( - aggregate.First(evalWithinGroup(id, a.toAttribute), Literal(true)), + aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), mode = Complete, isDistinct = false) (e, a, b) } - // Construct the regular aggregate input projection only when we need one. + // Construct the regular aggregate input projection only if we need one. val regularAggProjection = if (regularAggExprs.nonEmpty) { Seq(a.groupingExpressions ++ distinctAggChildren.map(nullify) ++ - Seq(Literal(0)) ++ + Seq(regularGroupId) ++ regularAggChildren) } else { Seq.empty[Seq[Expression]] @@ -353,11 +353,21 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { expand) // Construct the second aggregate - val transformations = (groupByMap ++ - distinctAggOperatorMap.flatMap(_._2) ++ - regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown(transformations).asInstanceOf[NamedExpression] + e.transformDown { + case e: Expression => + // GROUP BY can different in form (name) but must be semantically equal. This makes + // a map lookup tricky. So we do a linear search for a semantically equal group by + // expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } else { @@ -368,7 +378,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = - // We are creating a new reference here instead of reusing a reference in case of a + // We are creating a new reference here instead of reusing the attribute in case of a // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. From ece657b662df2b4145dcf9177aa4d7f275cafec9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 08:48:04 +0100 Subject: [PATCH 6/8] Fixing count default values (1) --- .../sql/catalyst/expressions/aggregate/Utils.scala | 13 ++++++++++--- .../catalyst/expressions/aggregate/interfaces.scala | 6 ++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index f0a3642d61a4..f94fb7ac8843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -244,7 +244,6 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - // TODO try to get the distinct group as small val distinctAggGroups = aggExpressions .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) @@ -316,7 +315,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), mode = Complete, isDistinct = false) - (e, a, b) + + // COUNT has the special property that it can return a result without input rows. This is + // almost impossible to + val c = af match { + case _: Count => Coalesce(Seq(b, Literal(0L))) + case _ => b + } + + (e, a, c) } // Construct the regular aggregate input projection only if we need one. @@ -360,7 +367,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => - // GROUP BY can different in form (name) but must be semantically equal. This makes + // GROUP BY can be different in form () but must be semantically equal. This makes // a map lookup tricky. So we do a linear search for a semantically equal group by // expression. groupByMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a2fab258fcac..5c5b3d1ccd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp */ def supportsPartial: Boolean = true + /** + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. + */ + def defaultResult: Option[Literal] = None + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } From 9be5b9d9e3c9de81473ac93750b687e2de824bb2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 09:17:14 +0100 Subject: [PATCH 7/8] Fixing count default values (2). --- .../sql/catalyst/expressions/aggregate/Count.scala | 2 ++ .../sql/catalyst/expressions/aggregate/Utils.scala | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 54df96cd2446..ec0c8b483a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override val evaluateExpression = Cast(count, LongType) + + override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index f94fb7ac8843..4cad9757fb7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -316,11 +316,11 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { mode = Complete, isDistinct = false) - // COUNT has the special property that it can return a result without input rows. This is - // almost impossible to - val c = af match { - case _: Count => Coalesce(Seq(b, Literal(0L))) - case _ => b + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val c = af.defaultResult match { + case Some(lit) => Coalesce(Seq(b, lit)) + case None => b } (e, a, c) From d3bdb2bb0096663af7d9faf4c0963a3df00065aa Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 11:15:53 +0100 Subject: [PATCH 8/8] Improve docs. Triggering build :P... --- .../spark/sql/catalyst/expressions/aggregate/Utils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 4cad9757fb7b..39010c3be6d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -367,9 +367,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => - // GROUP BY can be different in form () but must be semantically equal. This makes - // a map lookup tricky. So we do a linear search for a semantically equal group by - // expression. + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. groupByMap .find(ge => e.semanticEquals(ge._1)) .map(_._2)