diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index b072a7f5d914..ca0cfb6834f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit._ +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -171,6 +172,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // This is very perf sensitive. // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) + val extraIsNotNullAttrs = mutable.Set[Attribute]() val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} @@ -178,6 +180,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) + } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { + extraIsNotNullAttrs += r + genPredicate(IsNotNull(r), input, child.output) } else { "" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b5d021549c7a..676e10fe59dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3233,6 +3233,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-29213: FilterExec should not throw NPE") { + withTempView("t1", "t2", "t3") { + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1") + sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)") + .as[java.lang.Long] + .map(identity) + .toDF("x") + .createOrReplaceTempView("t2") + sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3") + sql( + """ + |SELECT t1.x + |FROM t1 + |LEFT JOIN ( + | SELECT x FROM ( + | SELECT x FROM t2 + | UNION ALL + | SELECT SUBSTR(x,5) x FROM t3 + | ) a + | WHERE LENGTH(x)>0 + |) t3 + |ON t1.x=t3.x + """.stripMargin).collect() + } + } } case class Foo(bar: Option[String])