Skip to content

Commit 3c80ed8

Browse files
committed
[SPARK-39748][SQL][SS][FOLLOWUP] Fix a bug on column stat in LogicalRDD on mismatching exprIDs
### What changes were proposed in this pull request? This PR fixes a bug on #37161 (described the bug in below section) via making sure the output columns in LogicalRDD are always the same with output columns in originLogicalPlan in LogicalRDD, which is needed to inherit the column stats. ### Why are the changes needed? Stats for columns in originLogicalPlan refer to the columns in originLogicalPlan, which could be different from the columns in output of LogicalRDD in terms of expression ID. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT Closes #37187 from HeartSaVioR/SPARK-39748-FOLLOWUP-2. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent c05d0fd commit 3c80ed8

File tree

3 files changed

+126
-6
lines changed

3 files changed

+126
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,20 @@ case class LogicalRDD(
116116
case e: Attribute => rewrite.getOrElse(e, e)
117117
}.asInstanceOf[SortOrder])
118118

119+
val rewrittenOriginLogicalPlan = originLogicalPlan.map { plan =>
120+
assert(output == plan.output, "The output columns are expected to the same for output " +
121+
s"and originLogicalPlan. output: $output / output in originLogicalPlan: ${plan.output}")
122+
123+
val projectList = output.map { attr =>
124+
Alias(attr, attr.name)(exprId = rewrite(attr).exprId)
125+
}
126+
Project(projectList, plan)
127+
}
128+
119129
LogicalRDD(
120130
output.map(rewrite),
121131
rdd,
122-
originLogicalPlan,
132+
rewrittenOriginLogicalPlan,
123133
rewrittenPartitioning,
124134
rewrittenOrdering,
125135
isStreaming

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr
3030
override def addBatch(batchId: Long, data: DataFrame): Unit = {
3131
val rdd = data.queryExecution.toRdd
3232
val executedPlan = data.queryExecution.executedPlan
33+
val analyzedPlanWithoutMarkerNode = eliminateWriteMarkerNode(data.queryExecution.analyzed)
34+
// assertion on precondition
35+
assert(data.logicalPlan.output == analyzedPlanWithoutMarkerNode.output)
3336
val node = LogicalRDD(
34-
data.schema.toAttributes,
37+
data.logicalPlan.output,
3538
rdd,
36-
Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)),
39+
Some(analyzedPlanWithoutMarkerNode),
3740
executedPlan.outputPartitioning,
3841
executedPlan.outputOrdering)(data.sparkSession)
3942
implicit val enc = encoder

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ import org.scalatest.matchers.should.Matchers._
3232
import org.apache.spark.SparkException
3333
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
3434
import org.apache.spark.sql.catalyst.TableIdentifier
35+
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
3536
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
36-
import org.apache.spark.sql.catalyst.expressions.Uuid
37+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Uuid}
3738
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
38-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
39+
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
3940
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4041
import org.apache.spark.sql.connector.FakeV2Provider
41-
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
42+
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, WholeStageCodegenExec}
4243
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
4344
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
4445
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
@@ -2010,6 +2011,68 @@ class DataFrameSuite extends QueryTest
20102011
}
20112012
}
20122013

2014+
test("SPARK-39748: build the stats for LogicalRDD based on originLogicalPlan") {
2015+
def buildExpectedColumnStats(attrs: Seq[Attribute]): AttributeMap[ColumnStat] = {
2016+
AttributeMap(
2017+
attrs.map {
2018+
case attr if attr.dataType == BooleanType =>
2019+
attr -> ColumnStat(
2020+
distinctCount = Some(2),
2021+
min = Some(false),
2022+
max = Some(true),
2023+
nullCount = Some(0),
2024+
avgLen = Some(1),
2025+
maxLen = Some(1))
2026+
2027+
case attr if attr.dataType == ByteType =>
2028+
attr -> ColumnStat(
2029+
distinctCount = Some(2),
2030+
min = Some(1),
2031+
max = Some(2),
2032+
nullCount = Some(0),
2033+
avgLen = Some(1),
2034+
maxLen = Some(1))
2035+
2036+
case attr => attr -> ColumnStat()
2037+
}
2038+
)
2039+
}
2040+
2041+
val outputList = Seq(
2042+
AttributeReference("cbool", BooleanType)(),
2043+
AttributeReference("cbyte", BooleanType)()
2044+
)
2045+
2046+
val expectedSize = 16
2047+
val statsPlan = OutputListAwareStatsTestPlan(
2048+
outputList = outputList,
2049+
rowCount = 2,
2050+
size = Some(expectedSize))
2051+
2052+
withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
2053+
val df = Dataset.ofRows(spark, statsPlan)
2054+
2055+
val logicalRDD = LogicalRDD(
2056+
df.logicalPlan.output, spark.sparkContext.emptyRDD, Some(df.queryExecution.analyzed),
2057+
isStreaming = true)(spark)
2058+
2059+
val stats = logicalRDD.computeStats()
2060+
val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2),
2061+
attributeStats = buildExpectedColumnStats(logicalRDD.output))
2062+
assert(stats === expectedStats)
2063+
2064+
// This method re-issues expression IDs for all outputs. We expect column stats to be
2065+
// reflected as well.
2066+
val newLogicalRDD = logicalRDD.newInstance()
2067+
val newStats = newLogicalRDD.computeStats()
2068+
// LogicalRDD.newInstance adds projection to originLogicalPlan, which triggers estimation
2069+
// on sizeInBytes. We don't intend to check the estimated value.
2070+
val newExpectedStats = Statistics(sizeInBytes = newStats.sizeInBytes, rowCount = Some(2),
2071+
attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
2072+
assert(newStats === newExpectedStats)
2073+
}
2074+
}
2075+
20132076
test("SPARK-10656: completely support special chars") {
20142077
val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.")
20152078
checkAnswer(df.select(df("*")), Row(1, "a"))
@@ -3249,3 +3312,47 @@ class DataFrameSuite extends QueryTest
32493312
case class GroupByKey(a: Int, b: Int)
32503313

32513314
case class Bar2(s: String)
3315+
3316+
/**
3317+
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
3318+
*/
3319+
case class OutputListAwareStatsTestPlan(
3320+
outputList: Seq[Attribute],
3321+
rowCount: BigInt,
3322+
size: Option[BigInt] = None) extends LeafNode with MultiInstanceRelation {
3323+
override def output: Seq[Attribute] = outputList
3324+
override def computeStats(): Statistics = {
3325+
val columnInfo = outputList.map { attr =>
3326+
attr.dataType match {
3327+
case BooleanType =>
3328+
attr -> ColumnStat(
3329+
distinctCount = Some(2),
3330+
min = Some(false),
3331+
max = Some(true),
3332+
nullCount = Some(0),
3333+
avgLen = Some(1),
3334+
maxLen = Some(1))
3335+
3336+
case ByteType =>
3337+
attr -> ColumnStat(
3338+
distinctCount = Some(2),
3339+
min = Some(1),
3340+
max = Some(2),
3341+
nullCount = Some(0),
3342+
avgLen = Some(1),
3343+
maxLen = Some(1))
3344+
3345+
case _ =>
3346+
attr -> ColumnStat()
3347+
}
3348+
}
3349+
val attrStats = AttributeMap(columnInfo)
3350+
3351+
Statistics(
3352+
// If sizeInBytes is useless in testing, we just use a fake value
3353+
sizeInBytes = size.getOrElse(Int.MaxValue),
3354+
rowCount = Some(rowCount),
3355+
attributeStats = attrStats)
3356+
}
3357+
override def newInstance(): LogicalPlan = copy(outputList = outputList.map(_.newInstance()))
3358+
}

0 commit comments

Comments
 (0)