diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a2dced57c7153..fe15819bd44a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -264,6 +264,7 @@ abstract class Optimizer(catalogManager: CatalogManager) CheckCartesianProducts), Batch("RewriteSubquery", Once, RewritePredicateSubquery, + NullPropagation, PushPredicateThroughJoin, LimitPushDown, ColumnPruning, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 885512d4d1980..dd3bd0a9e96d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1236,13 +1236,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { // positive not in subquery case var joinExec = assertJoin(( - "select * from testData where key not in (select a from testData2)", + "select * from testData where key not in (select b from testData3)", classOf[BroadcastHashJoinExec])) assert(joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) // negative not in subquery case since multi-column is not supported assertJoin(( - "select * from testData where (key, key + 1) not in (select * from testData2)", + "select * from testData where (key, key + 1) not in (select b, b + 1 from testData3)", classOf[BroadcastNestedLoopJoinExec])) // positive hand-written left anti join @@ -1271,6 +1271,23 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } + test("SPARK-54972: Improve not in subqueries with non-nullable columns") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { + // testData.key nullable false + // testData2.* nullable false + + val joinExec = assertJoin(( + "select * from testData where key not in (select a from testData2)", + classOf[BroadcastHashJoinExec])) + assert(!joinExec.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + + val joinExec2 = assertJoin(( + "select * from testData where (key, key + 1) not in (select * from testData2)", + classOf[BroadcastHashJoinExec])) + assert(!joinExec2.asInstanceOf[BroadcastHashJoinExec].isNullAwareAntiJoin) + } + } + test("SPARK-32399: Full outer shuffled hash join") { val inputDFs = Seq( // Test unique join key