diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1b141572cc7f9..d7e8571f6ce43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -161,7 +161,10 @@ abstract class Optimizer(catalogManager: CatalogManager) // LocalRelation and does not trigger many rules. Batch("LocalRelation early", fixedPoint, ConvertToLocalRelation, - PropagateEmptyRelation) :: + PropagateEmptyRelation, + // PropagateEmptyRelation can change the nullability of an attribute from nullable to + // non-nullable when an empty relation child of a Union is removed + UpdateAttributeNullability) :: Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: // Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense @@ -198,7 +201,10 @@ abstract class Optimizer(catalogManager: CatalogManager) ReassignLambdaVariableID) :+ Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, - PropagateEmptyRelation) :+ + PropagateEmptyRelation, + // PropagateEmptyRelation can change the nullability of an attribute from nullable to + // non-nullable when an empty relation child of a Union is removed + UpdateAttributeNullability) :+ // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index b19e13870aa65..0299646150ff3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -50,8 +50,26 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: Union if p.children.forall(isEmptyLocalRelation) => - empty(p) + case p @ Union(children) if children.exists(isEmptyLocalRelation) => + val newChildren = children.filterNot(isEmptyLocalRelation) + if (newChildren.isEmpty) { + empty(p) + } else { + val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head + val outputs = newPlan.output.zip(p.output) + // the original Union may produce different output attributes than the new one so we alias + // them if needed + if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) { + newPlan + } else { + val outputAliases = outputs.map { case (newAttr, oldAttr) => + val newExplicitMetadata = + if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None + Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata) + } + Project(outputAliases, newPlan) + } + } // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 9c7d4c7d8d233..dc323d4e5c77c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -55,6 +55,9 @@ class PropagateEmptyRelationSuite extends PlanTest { val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) val testRelation2 = LocalRelation.fromExternalRows(Seq('b.int), data = Seq(Row(1))) + val metadata = new MetadataBuilder().putLong("test", 1).build() + val testRelation3 = + LocalRelation.fromExternalRows(Seq('c.int.notNull.withMetadata(metadata)), data = Seq(Row(1))) test("propagate empty relation through Union") { val query = testRelation1 @@ -67,6 +70,39 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-32241: remove empty relation children from Union") { + val query = testRelation1.union(testRelation2.where(false)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation1 + comparePlans(optimized, correctAnswer) + + val query2 = testRelation1.where(false).union(testRelation2) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = testRelation2.select('b.as('a)).analyze + comparePlans(optimized2, correctAnswer2) + + val query3 = testRelation1.union(testRelation2.where(false)).union(testRelation3) + val optimized3 = Optimize.execute(query3.analyze) + val correctAnswer3 = testRelation1.union(testRelation3) + comparePlans(optimized3, correctAnswer3) + + val query4 = testRelation1.where(false).union(testRelation2).union(testRelation3) + val optimized4 = Optimize.execute(query4.analyze) + val correctAnswer4 = testRelation2.union(testRelation3).select('b.as('a)).analyze + comparePlans(optimized4, correctAnswer4) + + // Nullability can change from nullable to non-nullable + val query5 = testRelation1.where(false).union(testRelation3) + val optimized5 = Optimize.execute(query5.analyze) + assert(query5.output.head.nullable, "Original output should be nullable") + assert(!optimized5.output.head.nullable, "New output should be non-nullable") + + // Keep metadata + val query6 = testRelation3.where(false).union(testRelation1) + val optimized6 = Optimize.execute(query6.analyze) + assert(optimized6.output.head.metadata == metadata, "New output should keep metadata") + } + test("propagate empty relation through Join") { // Testcases are tuples of (left predicate, right predicate, joinType, correct answer) // Note that `None` is used to compare with OptimizeWithoutPropagateEmptyRelation.