Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,76 @@ object Utils {
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
*
* TODO Expression cannocalization
* TODO Eliminate foldable expressions from distinct clauses.
* TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
* operator. Perhaps this is a good thing? It is much simpler to plan later on...
* For example (in scala):
* {{{
* val data = Seq(
* ("a", "ca1", "cb1", 10),
* ("a", "ca1", "cb2", 5),
* ("b", "ca1", "cb1", 13))
* .toDF("key", "cat1", "cat2", "value")
* data.registerTempTable("data")
*
* val agg = data.groupBy($"key")
* .agg(
* countDistinct($"cat1").as("cat1_cnt"),
* countDistinct($"cat2").as("cat2_cnt"),
* sum($"value").as("total"))
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2),
* sum('value)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [sum('value)]
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint)),
* ('key, 'cat1, null, 1, null),
* ('key, null, 'cat2, 2, null)]
* output = ['key, 'cat1, 'cat2, 'gid, 'value])
* LocalTableScan [...]
* }}}
*
* The rule does the following things here:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat2 group.
* An expand operator is inserted to expand the child data for each group. The expand will null
* out all unused columns for the given group; this must be done in order to ensure correctness
* later on. Groups can by identified by a group id (gid) column added by the expand operator.
* 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
* this aggregate consists of the original group by clause, all the requested distinct columns
* and the group id. Both de-duplication of distinct column and the aggregation of the
* non-distinct group take advantage of the fact that we group by the group id (gid) and that we
* have nulled out all non-relevant columns for the the given group.
* 3. Aggregating the distinct groups and combining this with the results of the non-distinct
* aggregation. In this step we use the group id to filter the inputs for the aggregate
* functions. The result of the non-distinct group are 'aggregated' by using the first operator,
* it might be more elegant to use the native UDAF merge mechanism for this in the future.
*
* This rule duplicates the input data by two or more times (# distinct groups + an optional
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
* exchange operators. Keeping the number of distinct groups as low a possible should be priority,
* we could improve this in the current rule by applying more advanced expression cannocalization
* techniques.
*/
object MultipleDistinctRewriter extends Rule[LogicalPlan] {

Expand Down Expand Up @@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction2,
id: Literal,
attrs: Map[Expression, Expression]): AggregateFunction2 = {
af.withNewChildren(af.children.map { case afc =>
evalWithinGroup(id, attrs(afc))
af: AggregateFunction2)(
attrs: Expression => Expression): AggregateFunction2 = {
af.withNewChildren(af.children.map {
case afc => attrs(afc)
}).asInstanceOf[AggregateFunction2]
}

Expand All @@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Final aggregate
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrMap(x))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}

Expand All @@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val regularGroupId = Literal(0)
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(
e.aggregateFunction,
regularGroupId,
regularAggChildAttrMap)
val a = Alias(e.copy(aggregateFunction = af), e.toString)()

// Get the result of the first aggregate in the last aggregate.
val b = AggregateExpression2(
aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
val operator = Alias(e.copy(aggregateFunction = af), e.toString)()

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression2(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)

// Some aggregate functions (COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
val c = af.defaultResult match {
case Some(lit) => Coalesce(Seq(b, lit))
case None => b
val resultWithDefault = af.defaultResult match {
case Some(lit) => Coalesce(Seq(result, lit))
case None => result
}

(e, a, c)
// Return a Tuple3 containing:
// i. The original aggregate expression (used for look ups).
// ii. The actual aggregation operator (used in the first aggregate).
// iii. The operator that selects and returns the result (used in the second aggregate).
(e, operator, resultWithDefault)
}

// Construct the regular aggregate input projection only if we need one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}

test("multiple distinct column sets") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| key,
| count(distinct value1),
| count(distinct value2)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3, 3) ::
Row(1, 2, 3) ::
Row(2, 2, 1) ::
Row(3, 0, 1) :: Nil)
}

test("test count") {
checkAnswer(
sqlContext.sql(
Expand Down