diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index bf9ef6991e3e..149e70e56d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -116,10 +116,20 @@ case class LogicalRDD( case e: Attribute => rewrite.getOrElse(e, e) }.asInstanceOf[SortOrder]) + val rewrittenOriginLogicalPlan = originLogicalPlan.map { plan => + assert(output == plan.output, "The output columns are expected to the same for output " + + s"and originLogicalPlan. output: $output / output in originLogicalPlan: ${plan.output}") + + val projectList = output.map { attr => + Alias(attr, attr.name)(exprId = rewrite(attr).exprId) + } + Project(projectList, plan) + } + LogicalRDD( output.map(rewrite), rdd, - originLogicalPlan, + rewrittenOriginLogicalPlan, rewrittenPartitioning, rewrittenOrdering, isStreaming diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala index 1c6bca241af4..395ed056be28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -30,10 +30,13 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr override def addBatch(batchId: Long, data: DataFrame): Unit = { val rdd = data.queryExecution.toRdd val executedPlan = data.queryExecution.executedPlan + val analyzedPlanWithoutMarkerNode = eliminateWriteMarkerNode(data.queryExecution.analyzed) + // assertion on precondition + assert(data.logicalPlan.output == analyzedPlanWithoutMarkerNode.output) val node = LogicalRDD( - data.schema.toAttributes, + data.logicalPlan.output, rdd, - Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)), + Some(analyzedPlanWithoutMarkerNode), executedPlan.outputPartitioning, executedPlan.outputOrdering)(data.sparkSession) implicit val enc = encoder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 41593c701a7f..e802159f2634 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,13 +32,14 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -2010,6 +2011,68 @@ class DataFrameSuite extends QueryTest } } + test("SPARK-39748: build the stats for LogicalRDD based on originLogicalPlan") { + def buildExpectedColumnStats(attrs: Seq[Attribute]): AttributeMap[ColumnStat] = { + AttributeMap( + attrs.map { + case attr if attr.dataType == BooleanType => + attr -> ColumnStat( + distinctCount = Some(2), + min = Some(false), + max = Some(true), + nullCount = Some(0), + avgLen = Some(1), + maxLen = Some(1)) + + case attr if attr.dataType == ByteType => + attr -> ColumnStat( + distinctCount = Some(2), + min = Some(1), + max = Some(2), + nullCount = Some(0), + avgLen = Some(1), + maxLen = Some(1)) + + case attr => attr -> ColumnStat() + } + ) + } + + val outputList = Seq( + AttributeReference("cbool", BooleanType)(), + AttributeReference("cbyte", BooleanType)() + ) + + val expectedSize = 16 + val statsPlan = OutputListAwareStatsTestPlan( + outputList = outputList, + rowCount = 2, + size = Some(expectedSize)) + + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + val df = Dataset.ofRows(spark, statsPlan) + + val logicalRDD = LogicalRDD( + df.logicalPlan.output, spark.sparkContext.emptyRDD, Some(df.queryExecution.analyzed), + isStreaming = true)(spark) + + val stats = logicalRDD.computeStats() + val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2), + attributeStats = buildExpectedColumnStats(logicalRDD.output)) + assert(stats === expectedStats) + + // This method re-issues expression IDs for all outputs. We expect column stats to be + // reflected as well. + val newLogicalRDD = logicalRDD.newInstance() + val newStats = newLogicalRDD.computeStats() + // LogicalRDD.newInstance adds projection to originLogicalPlan, which triggers estimation + // on sizeInBytes. We don't intend to check the estimated value. + val newExpectedStats = Statistics(sizeInBytes = newStats.sizeInBytes, rowCount = Some(2), + attributeStats = buildExpectedColumnStats(newLogicalRDD.output)) + assert(newStats === newExpectedStats) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) @@ -3249,3 +3312,47 @@ class DataFrameSuite extends QueryTest case class GroupByKey(a: Int, b: Int) case class Bar2(s: String) + +/** + * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. + */ +case class OutputListAwareStatsTestPlan( + outputList: Seq[Attribute], + rowCount: BigInt, + size: Option[BigInt] = None) extends LeafNode with MultiInstanceRelation { + override def output: Seq[Attribute] = outputList + override def computeStats(): Statistics = { + val columnInfo = outputList.map { attr => + attr.dataType match { + case BooleanType => + attr -> ColumnStat( + distinctCount = Some(2), + min = Some(false), + max = Some(true), + nullCount = Some(0), + avgLen = Some(1), + maxLen = Some(1)) + + case ByteType => + attr -> ColumnStat( + distinctCount = Some(2), + min = Some(1), + max = Some(2), + nullCount = Some(0), + avgLen = Some(1), + maxLen = Some(1)) + + case _ => + attr -> ColumnStat() + } + } + val attrStats = AttributeMap(columnInfo) + + Statistics( + // If sizeInBytes is useless in testing, we just use a fake value + sizeInBytes = size.getOrElse(Int.MaxValue), + rowCount = Some(rowCount), + attributeStats = attrStats) + } + override def newInstance(): LogicalPlan = copy(outputList = outputList.map(_.newInstance())) +}