From 240ae7a976f03ce58bcdeed34f8022ebfea23a4c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 2 Apr 2018 17:07:52 +0900 Subject: [PATCH] Fix --- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 6 ++ .../optimizer/CostBasedJoinReorder.scala | 12 +-- .../sql/catalyst/optimizer/Optimizer.scala | 79 +++++++++++++------ .../optimizer/PropagateEmptyRelation.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../spark/sql/catalyst/optimizer/joins.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 10 +-- .../plans/logical/basicLogicalOperators.scala | 18 +++-- .../optimizer/JoinOptimizationSuite.scala | 2 +- ...ullabilityInAttributeReferencesSuite.scala | 56 ++++++++++--- .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 10 +-- .../execution/basicPhysicalOperators.scala | 6 -- .../apache/spark/sql/DataFrameJoinSuite.scala | 10 +-- 18 files changed, 148 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7..04f7849d98ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -856,7 +856,7 @@ class Analyzer( failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") // To resolve duplicate expression IDs for Join and Intersect - case j @ Join(left, right, _, _) if !j.duplicateResolved => + case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) @@ -2087,10 +2087,10 @@ class Analyzer( */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + case j @ Join(left, right, UsingJoin(joinType, usingCols), _, _) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) - case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + case j @ Join(left, right, NaturalJoin(joinType), condition, _) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad8..e2af2283af31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -147,7 +147,7 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + s"conditions: $condition") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") @@ -583,7 +583,7 @@ trait CheckAnalysis extends PredicateHelper { failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => + case j @ Join(left, right, joinType, _, _) => joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dde..b620fd803b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -228,7 +228,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bc..033106726fed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -63,6 +63,12 @@ trait PredicateHelper { } } + // If one expression and its children are null intolerant, it is null intolerant. + protected def isNullIntolerant(expr: Expression): Boolean = expr match { + case e: NullIntolerant => e.children.forall(isNullIntolerant) + case _ => false + } + // Substitute any known alias from a map. protected def replaceAlias( condition: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 064ca68b7a62..d391aa882397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -42,9 +42,9 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } else { val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. - case j @ Join(_, _, _: InnerLike, Some(cond)) => + case j @ Join(_, _, _: InnerLike, Some(cond), _) => reorder(j, j.output) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => reorder(p, p.output) } @@ -76,12 +76,12 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond)) => + case Join(left, right, _: InnerLike, Some(cond), _) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => extractInnerJoins(j) case _ => @@ -90,11 +90,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond)) => + case j @ Join(left, right, jt: InnerLike, Some(cond), _) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a8..7cd977cb63ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -254,7 +254,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child // attribute is not on the black list. - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, notNullAttrs) => val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) val mapping = AttributeMap( @@ -263,7 +263,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { val newCondition = condition.map(_.transform { case a: Attribute => mapping.getOrElse(a, a) }) - Join(newLeft, newRight, joinType, newCondition) + Join(newLeft, newRight, joinType, newCondition, notNullAttrs) case _ => // Remove redundant aliases in the subtree(s). @@ -354,7 +354,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - case LocalLimit(exp, join @ Join(left, right, joinType, _)) => + case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) @@ -468,7 +468,7 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // Eliminate unneeded attributes from right side of a Left Existence Join. - case j @ Join(_, right, LeftExistence(_), _) => + case j @ Join(_, right, LeftExistence(_), _, _) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -661,27 +661,38 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] filter } - case join @ Join(left, right, joinType, conditionOpt) => + case join @ Join(left, right, joinType, conditionOpt, notNullAttrs) => joinType match { // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an // inner join, it just drops the right side in the final output. case _: InnerLike | LeftSemi => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) - val newRight = inferNewFilter(right, allConstraints) - join.copy(left = newLeft, right = newRight) + val newLeftPredicates = inferNewPredicate(left, allConstraints) + val newRightPredicates = inferNewPredicate(right, allConstraints) + val newNotNullAttrs = getNotNullAttributes( + newLeftPredicates ++ newRightPredicates, notNullAttrs) + join.copy( + left = addFilterIfNeeded(left, newLeftPredicates), + right = addFilterIfNeeded(right, newRightPredicates), + notNullAttributes = newNotNullAttrs) // For right outer join, we can only infer additional filters for left side. case RightOuter => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) - join.copy(left = newLeft) + val newLeftPredicates = inferNewPredicate(left, allConstraints) + val newNotNullAttrs = getNotNullAttributes(newLeftPredicates, notNullAttrs) + join.copy( + left = addFilterIfNeeded(left, newLeftPredicates), + notNullAttributes = newNotNullAttrs) // For left join, we can only infer additional filters for right side. case LeftOuter | LeftAnti => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newRight = inferNewFilter(right, allConstraints) - join.copy(right = newRight) + val newRightPredicates = inferNewPredicate(right, allConstraints) + val newNotNullAttrs = getNotNullAttributes(newRightPredicates, notNullAttrs) + join.copy( + right = addFilterIfNeeded(right, newRightPredicates), + notNullAttributes = newNotNullAttrs) case _ => join } @@ -696,16 +707,32 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] baseConstraints.union(inferAdditionalConstraints(baseConstraints)) } - private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { - val newPredicates = constraints + private def inferNewPredicate( + plan: LogicalPlan, constraints: Set[Expression]): Set[Expression] = { + constraints .union(constructIsNotNullConstraints(constraints, plan.output)) .filter { c => c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic } -- plan.constraints - if (newPredicates.isEmpty) { - plan - } else { + } + + private def getNotNullAttributes( + constraints: Set[Expression], + curNotNullAttrs: Set[ExprId]): Set[ExprId] = { + + // Split out all the IsNotNulls from the `constraints` + val (notNullPreds, _) = constraints.partition { + case IsNotNull(a) => isNullIntolerant(a) + case _ => false + } + notNullPreds.flatMap(_.references.map(_.exprId)) ++ curNotNullAttrs + } + + private def addFilterIfNeeded(plan: LogicalPlan, newPredicates: Set[Expression]): LogicalPlan = { + if (newPredicates.nonEmpty) { Filter(newPredicates.reduce(And), plan) + } else { + plan } } } @@ -1048,7 +1075,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter - case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, notNullAttrs)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -1062,7 +1089,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) - val join = Join(newLeft, newRight, joinType, newJoinCond) + val join = Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -1074,7 +1101,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, notNullAttrs) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1084,7 +1111,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + val newJoin = Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs) (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1094,7 +1121,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case j @ Join(left, right, joinType, joinCondition) => + case j @ Join(left, right, joinType, joinCondition, notNullAttrs) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -1107,7 +1134,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -1115,7 +1142,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, RightOuter, newJoinCond) + Join(newLeft, newRight, RightOuter, newJoinCond, notNullAttrs) case LeftOuter | LeftAnti | ExistenceJoin(_) => // push down the right side only join filter for right sub query val newLeft = left @@ -1123,7 +1150,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, notNullAttrs) case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") @@ -1179,7 +1206,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected cartesian product for ${j.joinType.sql} join between logical plans diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index c3fdb924243d..b19e13870aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than // just processing the input data. - case p @ Join(_, _, joinType, _) + case p @ Join(_, _, joinType, _, _) if !p.children.exists(_.isStreaming) => val isLeftEmpty = isEmptyLocalRelation(p.left) val isRightEmpty = isEmptyLocalRelation(p.right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd80680..1094f84ef256 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -544,7 +544,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty => val newJoin = j.transformExpressions(replaceFoldable) val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { case _: InnerLike | LeftExistence(_) => Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf273fd6..999ceec99505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -147,7 +147,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7..673168fc1064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -54,7 +54,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // the produced join then becomes unresolved and break structural integrity. We should // de-duplicate conflicting attributes. We don't use transformation here because we only // care about the most top join converted from correlated predicate subquery. - case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) => + case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond, _) => val duplicates = right.outputSet.intersect(left.outputSet) if (duplicates.nonEmpty) { val aliasMap = AttributeMap(duplicates.map { dup => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f90570719..cd04e02148ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -102,7 +102,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, condition, _) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -165,11 +165,11 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { */ def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond) => + case Join(left, right, joinType: InnerLike, cond, _) => val (plans, conditions) = flattenJoin(left, joinType) (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => + case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition, _)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) @@ -178,9 +178,9 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) => + case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) => Some(flattenJoin(f)) - case j @ Join(_, _, joinType, _) => + case j @ Join(_, _, joinType, _, _) => Some(flattenJoin(j)) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df50479543..e65b88861b51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -280,23 +280,29 @@ case class Join( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) + condition: Option[Expression], + notNullAttributes: Set[ExprId] = Set.empty) extends BinaryNode with PredicateHelper { + private def updateNullabilityOf(attrs: Seq[Attribute]): Seq[Attribute] = attrs.map { + case a if a.nullable && notNullAttributes.contains(a.exprId) => a.withNullability(false) + case a => a + } + override def output: Seq[Attribute] = { joinType match { case j: ExistenceJoin => - left.output :+ j.exists + updateNullabilityOf(left.output :+ j.exists) case LeftExistence(_) => - left.output + updateNullabilityOf(left.output) case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) + updateNullabilityOf(left.output) ++ right.output.map(_.withNullability(true)) case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output + left.output.map(_.withNullability(true)) ++ updateNullabilityOf(right.output) case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case _ => - left.output ++ right.output + updateNullabilityOf(left.output ++ right.output) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index ccd9d8dd4d21..f88cb176aed8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -141,7 +141,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r + case Join(_, r, _, _, _) if r.stats.sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala index 09b11f5aba2a..916e9786c59f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -19,28 +19,37 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { - object Optimizer extends RuleExecutor[LogicalPlan] { + object Optimizer1 extends RuleExecutor[LogicalPlan] { val batches = Batch("Constant Folding", FixedPoint(10), - NullPropagation, - ConstantFolding, - BooleanSimplification, - SimplifyConditionals, - SimplifyBinaryComparison, - SimplifyExtractValueOps) :: + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: Batch("UpdateAttributeReferences", Once, UpdateNullabilityInAttributeReferences) :: Nil } - test("update nullability in AttributeReference") { + object Optimizer2 extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Predicate Handling", FixedPoint(10), + PushPredicateThroughJoin, + InferFiltersFromConstraints) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { val rel = LocalRelation('a.long.notNull) // In the 'original' plans below, the Aggregate node produced by groupBy() has a // nullable AttributeReference to `b`, because both array indexing and map lookup are @@ -51,7 +60,30 @@ class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") .groupBy($"b")("1") val expected = rel.select('a as "b").groupBy($"b")("1").analyze - val optimized = Optimizer.execute(original.analyze) + val optimized = Optimizer1.execute(original.analyze) + comparePlans(optimized, expected) + } + + test("SPARK-XXXXX update nullability in Join output") { + val r1 = LocalRelation('k1.int, 'v1.int) + val r2 = LocalRelation('k2.int, 'v2.int) + val joined = r1 + .join(r2, Inner, Some($"k1" === $"k2")) + .where($"v1" + $"v2" > 0) + .select($"v1", $"v2") + val optimized = Optimizer2.execute(joined.analyze) + val expected = r1 + .where($"k1".isNotNull && $"v1".isNotNull) + .join(r2.where($"k2".isNotNull && $"v2".isNotNull), Inner, + Some($"k1" === $"k2" && $"v1" + $"v2" > 0)) + .select($"v1", $"v2") + .analyze match { + // `projectList` should have the not-null `v1` and `v2` + case p @ Project(projectList, _) => + p.copy(projectList = projectList.map { + case a: Attribute => a.withNullability(false) + }) + } comparePlans(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d2..863f62447190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -78,7 +78,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => .reduce(And), child) case sample: Sample => sample.copy(seed = 0L) - case Join(left, right, joinType, condition) if condition.isDefined => + case Join(left, right, joinType, condition, _) if condition.isDefined => val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba24..edcc5462c8ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -279,23 +279,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition) + case j @ logical.Join(left, right, joinType, condition, _) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) + case j @ logical.Join(left, right, joinType, condition, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, _: InnerLike, condition) => + case logical.Join(left, right, _: InnerLike, condition, _) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil - case logical.Join(left, right, joinType, condition) => + case logical.Join(left, right, joinType, condition, _) => val buildSide = broadcastSide( left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM @@ -359,7 +359,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { new StreamingSymmetricHashJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case Join(left, right, _, _) if left.isStreaming && right.isStreaming => + case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( "Stream stream joins without equality predicate is not supported", plan = Some(plan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd..470ce2870ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -91,12 +91,6 @@ case class FilterExec(condition: Expression, child: SparkPlan) case _ => false } - // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a..792dfee7c250 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -198,7 +198,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> left val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + case j @ Join(_, _, LeftOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Left, Row(3, 4, "3", null, null, null) :: Nil) @@ -206,7 +206,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> right val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, RightOuter, _) => j }.size === 1) + case j @ Join(_, _, RightOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Right, Row(null, null, null, 5, 6, "5") :: Nil) @@ -215,7 +215,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). where($"a.int" === 1 && $"b.int2" === 3) assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( outerJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -223,7 +223,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // right -> inner val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( rightJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -231,7 +231,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // left -> inner val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil)