@@ -32,13 +32,14 @@ import org.scalatest.matchers.should.Matchers._
3232import org .apache .spark .SparkException
3333import org .apache .spark .scheduler .{SparkListener , SparkListenerJobEnd }
3434import org .apache .spark .sql .catalyst .TableIdentifier
35+ import org .apache .spark .sql .catalyst .analysis .MultiInstanceRelation
3536import org .apache .spark .sql .catalyst .encoders .{ExpressionEncoder , RowEncoder }
36- import org .apache .spark .sql .catalyst .expressions .Uuid
37+ import org .apache .spark .sql .catalyst .expressions .{ Attribute , AttributeMap , AttributeReference , Uuid }
3738import org .apache .spark .sql .catalyst .optimizer .ConvertToLocalRelation
38- import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , OneRowRelation }
39+ import org .apache .spark .sql .catalyst .plans .logical .{ColumnStat , LeafNode , LocalRelation , LogicalPlan , OneRowRelation , Statistics }
3940import org .apache .spark .sql .catalyst .util .DateTimeUtils
4041import org .apache .spark .sql .connector .FakeV2Provider
41- import org .apache .spark .sql .execution .{FilterExec , QueryExecution , WholeStageCodegenExec }
42+ import org .apache .spark .sql .execution .{FilterExec , LogicalRDD , QueryExecution , WholeStageCodegenExec }
4243import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
4344import org .apache .spark .sql .execution .aggregate .HashAggregateExec
4445import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
@@ -2010,6 +2011,68 @@ class DataFrameSuite extends QueryTest
20102011 }
20112012 }
20122013
2014+ test(" SPARK-39748: build the stats for LogicalRDD based on originLogicalPlan" ) {
2015+ def buildExpectedColumnStats (attrs : Seq [Attribute ]): AttributeMap [ColumnStat ] = {
2016+ AttributeMap (
2017+ attrs.map {
2018+ case attr if attr.dataType == BooleanType =>
2019+ attr -> ColumnStat (
2020+ distinctCount = Some (2 ),
2021+ min = Some (false ),
2022+ max = Some (true ),
2023+ nullCount = Some (0 ),
2024+ avgLen = Some (1 ),
2025+ maxLen = Some (1 ))
2026+
2027+ case attr if attr.dataType == ByteType =>
2028+ attr -> ColumnStat (
2029+ distinctCount = Some (2 ),
2030+ min = Some (1 ),
2031+ max = Some (2 ),
2032+ nullCount = Some (0 ),
2033+ avgLen = Some (1 ),
2034+ maxLen = Some (1 ))
2035+
2036+ case attr => attr -> ColumnStat ()
2037+ }
2038+ )
2039+ }
2040+
2041+ val outputList = Seq (
2042+ AttributeReference (" cbool" , BooleanType )(),
2043+ AttributeReference (" cbyte" , BooleanType )()
2044+ )
2045+
2046+ val expectedSize = 16
2047+ val statsPlan = OutputListAwareStatsTestPlan (
2048+ outputList = outputList,
2049+ rowCount = 2 ,
2050+ size = Some (expectedSize))
2051+
2052+ withSQLConf(SQLConf .CBO_ENABLED .key -> " true" ) {
2053+ val df = Dataset .ofRows(spark, statsPlan)
2054+
2055+ val logicalRDD = LogicalRDD (
2056+ df.logicalPlan.output, spark.sparkContext.emptyRDD, Some (df.queryExecution.analyzed),
2057+ isStreaming = true )(spark)
2058+
2059+ val stats = logicalRDD.computeStats()
2060+ val expectedStats = Statistics (sizeInBytes = expectedSize, rowCount = Some (2 ),
2061+ attributeStats = buildExpectedColumnStats(logicalRDD.output))
2062+ assert(stats === expectedStats)
2063+
2064+ // This method re-issues expression IDs for all outputs. We expect column stats to be
2065+ // reflected as well.
2066+ val newLogicalRDD = logicalRDD.newInstance()
2067+ val newStats = newLogicalRDD.computeStats()
2068+ // LogicalRDD.newInstance adds projection to originLogicalPlan, which triggers estimation
2069+ // on sizeInBytes. We don't intend to check the estimated value.
2070+ val newExpectedStats = Statistics (sizeInBytes = newStats.sizeInBytes, rowCount = Some (2 ),
2071+ attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
2072+ assert(newStats === newExpectedStats)
2073+ }
2074+ }
2075+
20132076 test(" SPARK-10656: completely support special chars" ) {
20142077 val df = Seq (1 -> " a" ).toDF(" i_$.a" , " d^'a." )
20152078 checkAnswer(df.select(df(" *" )), Row (1 , " a" ))
@@ -3249,3 +3312,47 @@ class DataFrameSuite extends QueryTest
32493312case class GroupByKey (a : Int , b : Int )
32503313
32513314case class Bar2 (s : String )
3315+
3316+ /**
3317+ * This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
3318+ */
3319+ case class OutputListAwareStatsTestPlan (
3320+ outputList : Seq [Attribute ],
3321+ rowCount : BigInt ,
3322+ size : Option [BigInt ] = None ) extends LeafNode with MultiInstanceRelation {
3323+ override def output : Seq [Attribute ] = outputList
3324+ override def computeStats (): Statistics = {
3325+ val columnInfo = outputList.map { attr =>
3326+ attr.dataType match {
3327+ case BooleanType =>
3328+ attr -> ColumnStat (
3329+ distinctCount = Some (2 ),
3330+ min = Some (false ),
3331+ max = Some (true ),
3332+ nullCount = Some (0 ),
3333+ avgLen = Some (1 ),
3334+ maxLen = Some (1 ))
3335+
3336+ case ByteType =>
3337+ attr -> ColumnStat (
3338+ distinctCount = Some (2 ),
3339+ min = Some (1 ),
3340+ max = Some (2 ),
3341+ nullCount = Some (0 ),
3342+ avgLen = Some (1 ),
3343+ maxLen = Some (1 ))
3344+
3345+ case _ =>
3346+ attr -> ColumnStat ()
3347+ }
3348+ }
3349+ val attrStats = AttributeMap (columnInfo)
3350+
3351+ Statistics (
3352+ // If sizeInBytes is useless in testing, we just use a fake value
3353+ sizeInBytes = size.getOrElse(Int .MaxValue ),
3354+ rowCount = Some (rowCount),
3355+ attributeStats = attrStats)
3356+ }
3357+ override def newInstance (): LogicalPlan = copy(outputList = outputList.map(_.newInstance()))
3358+ }
0 commit comments