Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf

/**
* Collapse plans consisting empty local relations generated by [[PruneFilters]].
Expand All @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport {
private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
Expand All @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {

// Construct a project list from plan's output, while the value is always NULL.
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) }
plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }

override def conf: SQLConf = SQLConf.get

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{IntegerType, StructType}

class PropagateEmptyRelationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand All @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters,
PropagateEmptyRelation) :: Nil
PropagateEmptyRelation,
CollapseProject) :: Nil
}

object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] {
Expand All @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters) :: Nil
PruneFilters,
CollapseProject) :: Nil
}

val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1)))
Expand Down Expand Up @@ -79,18 +81,21 @@ class PropagateEmptyRelationSuite extends PlanTest {

(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
(true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
(true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
(true, false, LeftOuter,
Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
(true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
(true, false, FullOuter,
Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, LeftAnti, Some(testRelation1)),
(true, false, LeftSemi, Some(LocalRelation('a.int))),

(false, true, Inner, Some(LocalRelation('a.int, 'b.int))),
(false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, true, RightOuter,
Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
(false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
(false, true, FullOuter,
Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
(false, true, LeftAnti, Some(LocalRelation('a.int))),
(false, true, LeftSemi, Some(LocalRelation('a.int))),

Expand Down Expand Up @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("propagate empty relation keeps the plan resolved") {
val query = testRelation1.join(
LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None)
val optimized = Optimize.execute(query.analyze)
assert(optimized.resolved)
}
}