diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index ef22c0ab44e4..d8b3cb1e685f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation @@ -184,6 +185,8 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join private def calculatePlanOverhead(plan: LogicalPlan): Float = { val (cached, notCached) = plan.collectLeaves().partition(p => p match { case _: InMemoryRelation => true + case _: LocalRelation => true + case l: LogicalRDD if isLogicalRDDWithStats(l) => true case _ => false }) val scanOverhead = notCached.map(_.stats.sizeInBytes).sum.toFloat @@ -195,10 +198,21 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join m.stats.sizeInBytes.toFloat * 0.2 case m: InMemoryRelation if m.cacheBuilder.storageLevel.useMemory => 0.0 + case _: LocalRelation => 0.0 + case l: LogicalRDD if isLogicalRDDWithStats(l) => 0.0 }.sum.toFloat scanOverhead + cachedOverhead } + /** + * Check if a LogicalRDD has actual statistics (indicating materialized data) + * vs. just the default size estimate. LogicalRDD with rowCount stats indicates + * the data was already computed and stats were collected. + */ + private def isLogicalRDDWithStats(rdd: LogicalRDD): Boolean = { + rdd.stats.rowCount.isDefined + } + /** * Search a filtering predicate in a given logical plan @@ -206,6 +220,8 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join private def hasSelectivePredicate(plan: LogicalPlan): Boolean = { plan.exists { case f: Filter => isLikelySelective(f.condition) + case _: LocalRelation => true + case _: LogicalRDD => true case _ => false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index e1a2fd33c7c9..7336eac34ca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -1821,6 +1821,175 @@ class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite checkAnswer(df, Row(1000, 1) :: Row(1010, 2) :: Row(1020, 2) :: Nil) } } + + test("SPARK-54593: DPP with LocalRelation in broadcast join") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + // Create a LocalRelation using VALUES clause + val filterDF = sql("SELECT * FROM VALUES (1), (9) AS t(store_id)") + filterDF.createOrReplaceTempView("small_stores") + + // Join partitioned table with LocalRelation + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id, f.units_sold + |FROM fact_stats f + |JOIN small_stores s ON f.store_id = s.store_id + """.stripMargin) + + checkPartitionPruningPredicate(df, false, true) + + // Only 2 rows in fact_stats have store_id 1 or 9 + checkAnswer(df, + Row(1000, 1, 1, 10) :: + Row(1150, 1, 9, 20) :: Nil + ) + + // Verify DPP predicates exist in the optimized logical plan + val optimizedPlan = df.queryExecution.optimizedPlan.toString() + assert(optimizedPlan.contains("DynamicPruningSubquery") || + optimizedPlan.contains("dynamicpruning"), + s"Optimized plan should contain DynamicPruningSubquery:\n$optimizedPlan") + } + } + + test("SPARK-54593: DPP with LogicalRDD from cached DataFrame in broadcast join") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + // Create a LogicalRDD from an RDD. DPP should be applied since LogicalRDD + // represents materialized data that's typically small. + import testImplicits._ + val smallStores = Seq((1, "Store1"), (9, "Store9")).toDF("store_id", "store_name") + val collectedDF = spark.createDataFrame( + smallStores.rdd, + smallStores.schema + ) + collectedDF.createOrReplaceTempView("small_stores_rdd") + + // Join partitioned table with LogicalRDD + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id, f.units_sold + |FROM fact_stats f + |JOIN small_stores_rdd s ON f.store_id = s.store_id + """.stripMargin) + + // Broadcast DPP should be applied + checkPartitionPruningPredicate(df, false, true) + + // Only 2 rows in fact_stats have store_id 1 or 9 + checkAnswer(df, + Row(1000, 1, 1, 10) :: + Row(1150, 1, 9, 20) :: Nil + ) + + // Verify DPP predicates exist in the optimized logical plan + val optimizedPlan = df.queryExecution.optimizedPlan.toString() + assert(optimizedPlan.contains("DynamicPruningSubquery") || + optimizedPlan.contains("dynamicpruning"), + s"Optimized plan should contain DynamicPruningSubquery:\n$optimizedPlan") + } + } + + test("SPARK-54593: DPP with empty LocalRelation should not fail") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + // Create an empty LocalRelation + val emptyDF = sql("SELECT * FROM VALUES (1) AS t(store_id) WHERE store_id > 100") + emptyDF.createOrReplaceTempView("empty_stores") + + // Join should return no results but not fail + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id, f.units_sold + |FROM fact_stats f + |JOIN empty_stores s ON f.store_id = s.store_id + """.stripMargin) + + checkAnswer(df, Nil) + + // Verify DPP predicates exist in the optimized logical plan + val optimizedPlan = df.queryExecution.optimizedPlan.toString() + assert(optimizedPlan.contains("DynamicPruningSubquery") || + optimizedPlan.contains("dynamicpruning"), + s"Optimized plan should contain DynamicPruningSubquery:\n$optimizedPlan") + } + } + + test("SPARK-54593: DPP should not trigger for LogicalRDD without originStats") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + // Create an RDD without stats (using parallelize directly) + import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType} + import org.apache.spark.sql.Row + + val rdd = spark.sparkContext.parallelize(Seq(Row(1, "Store1"), Row(9, "Store9"))) + val schema = StructType(Seq( + StructField("store_id", IntegerType, nullable = false), + StructField("store_name", StringType, nullable = true) + )) + val rddDF = spark.createDataFrame(rdd, schema) + rddDF.createOrReplaceTempView("rdd_no_stats") + + // This should still work, but may not use DPP if originStats is missing + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id, f.units_sold + |FROM fact_stats f + |JOIN rdd_no_stats s ON f.store_id = s.store_id + """.stripMargin) + + // Should still produce correct results (only 2 rows match store_ids 1 and 9) + checkAnswer(df, + Row(1000, 1, 1, 10) :: + Row(1150, 1, 9, 20) :: Nil + ) + + // Verify DPP predicates do not exist in the logical RDD + val optimizedPlan = df.queryExecution.optimizedPlan.toString() + assert(!optimizedPlan.contains("DynamicPruningSubquery") && + !optimizedPlan.contains("dynamicpruning"), + s"Optimized plan should contain DynamicPruningSubquery:\n$optimizedPlan") + } + } + + test("SPARK-54593: DPP with large LocalRelation should still work") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + // Create a larger LocalRelation (all store_ids from 1 to 5) + val largeDF = sql( + "SELECT * FROM VALUES (1), (2), (3), (4), (5) AS t(store_id)") + largeDF.createOrReplaceTempView("many_stores") + + val df = sql( + """ + |SELECT f.date_id, f.store_id + |FROM fact_stats f + |JOIN many_stores s ON f.store_id = s.store_id + """.stripMargin) + + checkPartitionPruningPredicate(df, false, true) + + // Exact expected results for store_ids 1-5 + checkAnswer(df, + Row(1000, 1) :: Row(1010, 1) :: Row(1020, 1) :: Row(1150, 1) :: Row(1160, 1) :: + Row(1170, 1) :: Row(1280, 1) :: Row(1290, 1) :: Row(1300, 1) :: + Row(1030, 2) :: Row(1040, 2) :: Row(1050, 2) :: Row(1060, 2) :: Row(1070, 2) :: + Row(1180, 2) :: Row(1190, 2) :: + Row(1080, 3) :: Row(1090, 3) :: Row(1100, 3) :: Row(1110, 3) :: Row(1200, 3) :: + Row(1200, 3) :: + Row(1120, 4) :: Row(1130, 4) :: Row(1140, 4) :: Row(1210, 4) :: Row(1220, 4) :: + Row(1230, 4) :: + Row(1240, 5) :: Row(1250, 5) :: Row(1260, 5) :: Row(1270, 5) :: Nil + ) + + // Verify DPP predicates exist in the optimized logical plan + val optimizedPlan = df.queryExecution.optimizedPlan.toString() + assert(optimizedPlan.contains("DynamicPruningSubquery") || + optimizedPlan.contains("dynamicpruning"), + s"Optimized plan should contain DynamicPruningSubquery:\n$optimizedPlan") + } + } } abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningDataSourceSuiteBase {