diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 87932e09a1d3..760ead42c762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} +import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow @@ -933,7 +934,8 @@ object CodeGenerator extends Logging { classOf[UnsafeMapData].getName, classOf[Expression].getName, classOf[TaskContext].getName, - classOf[TaskKilledException].getName + classOf[TaskKilledException].getName, + classOf[InputMetrics].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 649c21b29467..6255cff24dc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val taskContext = ctx.freshName("taskContext") ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") + val inputMetrics = ctx.freshName("inputMetrics") + ctx.addMutableState("InputMetrics", inputMetrics, + s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | if ($nextBatchTodo == 0) break; | } | $numOutput.add($nextBatchTodo); - | $numGenerated.add($nextBatchTodo); + | $inputMetrics.incRecordsRead($nextBatchTodo); | | $batchEnd += $nextBatchTodo * ${step}L; | } @@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val numGeneratedRows = longMetric("numGeneratedRows") sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) @@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val safePartitionEnd = getSafeMargin(partitionEnd) val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + val taskContext = TaskContext.get() val iter = new Iterator[InternalRow] { private[this] var number: Long = safePartitionStart private[this] var overflow: Boolean = false + private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics override def hasNext = if (!overflow) { @@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } numOutputRows += 1 - numGeneratedRows += 1 + inputMetrics.incRecordsRead(1) unsafeRow.setLong(0, ret) unsafeRow } } - new InterruptibleIterator(TaskContext.get(), iter) + new InterruptibleIterator(taskContext, iter) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala deleted file mode 100644 index ddd7a03e8003..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io.File - -import org.scalatest.concurrent.Eventually - -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.sql.{DataFrame, QueryTest} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils - -class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually { - - test("Range query input/output/generated metrics") { - val numRows = 150L - val numSelectedRows = 100L - val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1). - filter(x => x < numSelectedRows).toDF()) - - assert(res.recordsRead.sum === 0) - assert(res.shuffleRecordsRead.sum === 0) - assert(res.generatedRows === numRows :: Nil) - assert(res.outputRows === numSelectedRows :: numRows :: Nil) - } - - test("Input/output/generated metrics with repartitioning") { - val numRows = 100L - val res = MetricsTestHelper.runAndGetMetrics( - spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF()) - - assert(res.recordsRead.sum === 0) - assert(res.shuffleRecordsRead.sum === numRows) - assert(res.generatedRows === numRows :: Nil) - assert(res.outputRows === 20 :: numRows :: Nil) - } - - test("Input/output/generated metrics with more repartitioning") { - withTempDir { tempDir => - val dir = new File(tempDir, "pqS").getCanonicalPath - - spark.range(10).write.parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("pqS") - - val res = MetricsTestHelper.runAndGetMetrics( - spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2) - .toDF() - ) - - assert(res.recordsRead.sum == 10) - assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150) - assert(res.generatedRows == 30 :: Nil) - assert(res.outputRows == 10 :: 30 :: 300 :: Nil) - } - } -} - -object MetricsTestHelper { - case class AggregatedMetricsResult( - recordsRead: List[Long], - shuffleRecordsRead: List[Long], - generatedRows: List[Long], - outputRows: List[Long]) - - private[this] def extractMetricValues( - df: DataFrame, - metricValues: Map[Long, String], - metricName: String): List[Long] = { - df.queryExecution.executedPlan.collect { - case plan if plan.metrics.contains(metricName) => - metricValues(plan.metrics(metricName).id).toLong - }.toList.sorted - } - - def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false): - AggregatedMetricsResult = { - val spark = df.sparkSession - val sparkContext = spark.sparkContext - - var recordsRead = List[Long]() - var shuffleRecordsRead = List[Long]() - val listener = new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - if (taskEnd.taskMetrics != null) { - recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead :: - recordsRead - shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead :: - shuffleRecordsRead - } - } - } - - val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet - - val prevUseWholeStageCodeGen = - spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) - try { - spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen) - sparkContext.listenerBus.waitUntilEmpty(10000) - sparkContext.addSparkListener(listener) - df.collect() - sparkContext.listenerBus.waitUntilEmpty(10000) - } finally { - spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen) - } - - val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) - val outputRes = extractMetricValues(df, metricValues, "numOutputRows") - val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows") - - AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 229d8814e014..2ce7db6a22c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.execution.metric +import java.io.File + +import scala.collection.mutable.HashMap + import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.SparkPlanInfo @@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } + test("range metrics") { + val res1 = InputOutputMetricsHelper.run( + spark.range(30).filter(x => x % 3 == 0).toDF() + ) + assert(res1 === (30L, 0L, 30L) :: Nil) + + val res2 = InputOutputMetricsHelper.run( + spark.range(150).repartition(4).filter(x => x < 10).toDF() + ) + assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil) + + withTempDir { tempDir => + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + val res3 = InputOutputMetricsHelper.run( + spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() + ) + // The query above is executed in the following stages: + // 1. sql("select * from pqS") => (10, 0, 10) + // 2. range(30) => (30, 0, 30) + // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) + // 4. shuffle & return results => (0, 300, 0) + assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + } + } +} + +object InputOutputMetricsHelper { + private class InputOutputMetricsListener extends SparkListener { + private case class MetricsResult( + var recordsRead: Long = 0L, + var shuffleRecordsRead: Long = 0L, + var sumMaxOutputRows: Long = 0L) + + private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] + + def reset(): Unit = { + stageIdToMetricsResult.clear() + } + + /** + * Return a list of recorded metrics aggregated per stage. + * + * The list is sorted in the ascending order on the stageId. + * For each recorded stage, the following tuple is returned: + * - sum of inputMetrics.recordsRead for all the tasks in the stage + * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage + * - sum of the highest values of "number of output rows" metric for all the tasks in the stage + */ + def getResults(): List[(Long, Long, Long)] = { + stageIdToMetricsResult.keySet.toList.sorted.map { stageId => + val res = stageIdToMetricsResult(stageId) + (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult()) + + res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead + res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead + + var maxOutputRows = 0L + for (accum <- taskEnd.taskMetrics.externalAccums) { + val info = accum.toInfo(Some(accum.value), None) + if (info.name.toString.contains("number of output rows")) { + info.update match { + case Some(n: Number) => + if (n.longValue() > maxOutputRows) { + maxOutputRows = n.longValue() + } + case _ => // Ignore. + } + } + } + res.sumMaxOutputRows += maxOutputRows + } + } + + // Run df.collect() and return aggregated metrics for each stage. + def run(df: DataFrame): List[(Long, Long, Long)] = { + val spark = df.sparkSession + val sparkContext = spark.sparkContext + val listener = new InputOutputMetricsListener() + sparkContext.addSparkListener(listener) + + try { + sparkContext.listenerBus.waitUntilEmpty(5000) + listener.reset() + df.collect() + sparkContext.listenerBus.waitUntilEmpty(5000) + } finally { + sparkContext.removeSparkListener(listener) + } + listener.getResults() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 92d3e9519fa2..5463728ca0c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} -import org.apache.spark.sql.execution.MetricsTestHelper +import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite assert(e2.contains("User specified schema not supported with `jdbc`")) } - test("Input/generated/output metrics on JDBC") { + test("Checking metrics correctness with JDBC") { val foobarCnt = spark.table("foobar").count() - val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF()) - assert(res.recordsRead === foobarCnt :: Nil) - assert(res.shuffleRecordsRead.sum === 0) - assert(res.generatedRows.isEmpty) - assert(res.outputRows === foobarCnt :: Nil) + val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF()) + assert(res === (foobarCnt, 0L, foobarCnt) :: Nil) } test("SPARK-19318: Connection properties keys should be case-sensitive.") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 35c41b531c36..7803ac39e508 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.execution.MetricsTestHelper +import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.hive.test.TestHive /** @@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") - test("Test input/generated/output metrics") { + test("Checking metrics correctness") { import TestHive._ val episodesCnt = sql("select * from episodes").count() - val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF()) - assert(episodesRes.recordsRead === episodesCnt :: Nil) - assert(episodesRes.shuffleRecordsRead.sum === 0) - assert(episodesRes.generatedRows.isEmpty) - assert(episodesRes.outputRows === episodesCnt :: Nil) + val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF()) + assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil) val serdeinsCnt = sql("select * from serdeins").count() - val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF()) - assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil) - assert(serdeinsRes.shuffleRecordsRead.sum === 0) - assert(serdeinsRes.generatedRows.isEmpty) - assert(serdeinsRes.outputRows === serdeinsCnt :: Nil) + val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF()) + assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil) } }