diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 6148fb30783e8..06085497de19a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -360,23 +360,25 @@ case class OneRowRelationExec() extends LeafExecNode override val output: Seq[Attribute] = Nil private val rdd: RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") session .sparkContext .parallelize(Seq(""), 1) .mapPartitionsInternal { _ => val proj = UnsafeProjection.create(Seq.empty[Expression]) - Iterator(proj.apply(InternalRow.empty)).map { r => - numOutputRows += 1 - r - } + Iterator(proj.apply(InternalRow.empty)) } } override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + rdd.map { r => + numOutputRows += 1 + r + } + } override def simpleString(maxFields: Int): String = s"$nodeName[]" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 47efcaaa4e1b4..58457091a4a42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -107,6 +107,25 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } + test("SPARK-54749: OneRowRelation metrics") { + Seq((1L, "false"), (2L, "true")).foreach { case (nodeId, enableWholeStage) => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enableWholeStage) { + val df = spark.sql("select 1 as c1") + val oneRowRelation = df.queryExecution.executedPlan.collect { + case oneRowRelation: OneRowRelationExec => oneRowRelation + } + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + assert(oneRowRelation.size == 1) + + val expected = Map("number of output rows" -> 1L) + testSparkPlanMetrics(df.toDF(), 1, Map( + nodeId -> (("Scan OneRowRelation", expected)))) + } + } + } + + test("Recursive CTEs metrics") { withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { val df = sql(