Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(projects, output, child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize my understanding for other reviewers:

This new rule handles the case where you have an expand beneath an aggregate and the expand produces rows with columns which are not referenced by the aggregate operator. In this case, we want to rewrite the expand's projections in order to eliminate the unreferenced column.

if (e.outputSet -- a.references).nonEmpty =>
val newOutput = output.filter(a.references.contains(_))
val newProjects = projects.map { proj =>
proj.zip(output).filter { case (e, a) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The projection has as many expressions as there are output fields, which is why this zip is safe.

The goal of this block is to keep only the expressions whose output columns are referenced by the aggregate.

newOutput.contains(a)
}.unzip._1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is equivalent to .map(_._1), which seems marginally easier to understand (but not much so).

}
a.copy(child = Expand(newProjects, newOutput, child))

case a @ Aggregate(_, _, e @ Expand(_, _, child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule is slightly different: this prunes columns on expand's child if they're not referenced by the aggregate or by the expand.

if (child.outputSet -- e.references -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StringType

Expand Down Expand Up @@ -96,5 +96,34 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Column pruning for Expand") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
val query =
Aggregate(
Seq('aa, 'gid),
Seq(sum('c).as("sum")),
Expand(
Seq(
Seq('a, 'b, 'c, Literal.create(null, StringType), 1),
Seq('a, 'b, 'c, 'a, 2)),
Seq('a, 'b, 'c, 'aa.int, 'gid.int),
input)).analyze
val optimized = Optimize.execute(query)

val expected =
Aggregate(
Seq('aa, 'gid),
Seq(sum('c).as("sum")),
Expand(
Seq(
Seq('c, Literal.create(null, StringType), 1),
Seq('c, 'a, 2)),
Seq('c, 'aa.int, 'gid.int),
Project(Seq('c, 'a),
input))).analyze

comparePlans(optimized, expected)
}

// todo: add more tests for column pruning
}