Skip to content

Commit b486ffc

Browse files
alarxin
authored andcommitted
[SPARK-19447] Make Range operator generate "recordsRead" metric
## What changes were proposed in this pull request? The Range was modified to produce "recordsRead" metric instead of "generated rows". The tests were updated and partially moved to SQLMetricsSuite. ## How was this patch tested? Unit tests. Author: Ala Luszczak <[email protected]> Closes #16960 from ala/range-records-read.
1 parent 729ce37 commit b486ffc

File tree

6 files changed

+125
-155
lines changed

6 files changed

+125
-155
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp
3131
import org.codehaus.janino.util.ClassFile
3232

3333
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
34+
import org.apache.spark.executor.InputMetrics
3435
import org.apache.spark.internal.Logging
3536
import org.apache.spark.metrics.source.CodegenMetrics
3637
import org.apache.spark.sql.catalyst.InternalRow
@@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
933934
classOf[UnsafeMapData].getName,
934935
classOf[Expression].getName,
935936
classOf[TaskContext].getName,
936-
classOf[TaskKilledException].getName
937+
classOf[TaskKilledException].getName,
938+
classOf[InputMetrics].getName
937939
))
938940
evaluator.setExtendedClass(classOf[GeneratedClass])
939941

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
365365

366366
val taskContext = ctx.freshName("taskContext")
367367
ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
368+
val inputMetrics = ctx.freshName("inputMetrics")
369+
ctx.addMutableState("InputMetrics", inputMetrics,
370+
s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
368371

369372
// In order to periodically update the metrics without inflicting performance penalty, this
370373
// operator produces elements in batches. After a batch is complete, the metrics are updated
@@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
460463
| if ($nextBatchTodo == 0) break;
461464
| }
462465
| $numOutput.add($nextBatchTodo);
463-
| $numGenerated.add($nextBatchTodo);
466+
| $inputMetrics.incRecordsRead($nextBatchTodo);
464467
|
465468
| $batchEnd += $nextBatchTodo * ${step}L;
466469
| }
@@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
469472

470473
protected override def doExecute(): RDD[InternalRow] = {
471474
val numOutputRows = longMetric("numOutputRows")
472-
val numGeneratedRows = longMetric("numGeneratedRows")
473475
sqlContext
474476
.sparkContext
475477
.parallelize(0 until numSlices, numSlices)
@@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
488490
val safePartitionEnd = getSafeMargin(partitionEnd)
489491
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
490492
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
493+
val taskContext = TaskContext.get()
491494

492495
val iter = new Iterator[InternalRow] {
493496
private[this] var number: Long = safePartitionStart
494497
private[this] var overflow: Boolean = false
498+
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
495499

496500
override def hasNext =
497501
if (!overflow) {
@@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
513517
}
514518

515519
numOutputRows += 1
516-
numGeneratedRows += 1
520+
inputMetrics.incRecordsRead(1)
517521
unsafeRow.setLong(0, ret)
518522
unsafeRow
519523
}
520524
}
521-
new InterruptibleIterator(TaskContext.get(), iter)
525+
new InterruptibleIterator(taskContext, iter)
522526
}
523527
}
524528

sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala

Lines changed: 0 additions & 131 deletions
This file was deleted.

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.metric
1919

20+
import java.io.File
21+
22+
import scala.collection.mutable.HashMap
23+
2024
import org.apache.spark.SparkFunSuite
25+
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
2126
import org.apache.spark.sql._
2227
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2328
import org.apache.spark.sql.execution.SparkPlanInfo
@@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
309314
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
310315
}
311316

317+
test("range metrics") {
318+
val res1 = InputOutputMetricsHelper.run(
319+
spark.range(30).filter(x => x % 3 == 0).toDF()
320+
)
321+
assert(res1 === (30L, 0L, 30L) :: Nil)
322+
323+
val res2 = InputOutputMetricsHelper.run(
324+
spark.range(150).repartition(4).filter(x => x < 10).toDF()
325+
)
326+
assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
327+
328+
withTempDir { tempDir =>
329+
val dir = new File(tempDir, "pqS").getCanonicalPath
330+
331+
spark.range(10).write.parquet(dir)
332+
spark.read.parquet(dir).createOrReplaceTempView("pqS")
333+
334+
val res3 = InputOutputMetricsHelper.run(
335+
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
336+
)
337+
// The query above is executed in the following stages:
338+
// 1. sql("select * from pqS") => (10, 0, 10)
339+
// 2. range(30) => (30, 0, 30)
340+
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
341+
// 4. shuffle & return results => (0, 300, 0)
342+
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
343+
}
344+
}
345+
}
346+
347+
object InputOutputMetricsHelper {
348+
private class InputOutputMetricsListener extends SparkListener {
349+
private case class MetricsResult(
350+
var recordsRead: Long = 0L,
351+
var shuffleRecordsRead: Long = 0L,
352+
var sumMaxOutputRows: Long = 0L)
353+
354+
private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
355+
356+
def reset(): Unit = {
357+
stageIdToMetricsResult.clear()
358+
}
359+
360+
/**
361+
* Return a list of recorded metrics aggregated per stage.
362+
*
363+
* The list is sorted in the ascending order on the stageId.
364+
* For each recorded stage, the following tuple is returned:
365+
* - sum of inputMetrics.recordsRead for all the tasks in the stage
366+
* - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
367+
* - sum of the highest values of "number of output rows" metric for all the tasks in the stage
368+
*/
369+
def getResults(): List[(Long, Long, Long)] = {
370+
stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
371+
val res = stageIdToMetricsResult(stageId)
372+
(res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
373+
}
374+
}
375+
376+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
377+
val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
378+
379+
res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
380+
res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
381+
382+
var maxOutputRows = 0L
383+
for (accum <- taskEnd.taskMetrics.externalAccums) {
384+
val info = accum.toInfo(Some(accum.value), None)
385+
if (info.name.toString.contains("number of output rows")) {
386+
info.update match {
387+
case Some(n: Number) =>
388+
if (n.longValue() > maxOutputRows) {
389+
maxOutputRows = n.longValue()
390+
}
391+
case _ => // Ignore.
392+
}
393+
}
394+
}
395+
res.sumMaxOutputRows += maxOutputRows
396+
}
397+
}
398+
399+
// Run df.collect() and return aggregated metrics for each stage.
400+
def run(df: DataFrame): List[(Long, Long, Long)] = {
401+
val spark = df.sparkSession
402+
val sparkContext = spark.sparkContext
403+
val listener = new InputOutputMetricsListener()
404+
sparkContext.addSparkListener(listener)
405+
406+
try {
407+
sparkContext.listenerBus.waitUntilEmpty(5000)
408+
listener.reset()
409+
df.collect()
410+
sparkContext.listenerBus.waitUntilEmpty(5000)
411+
} finally {
412+
sparkContext.removeSparkListener(listener)
413+
}
414+
listener.getResults()
415+
}
312416
}

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
3131
import org.apache.spark.sql.execution.command.ExplainCommand
3232
import org.apache.spark.sql.execution.datasources.LogicalRelation
3333
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
34-
import org.apache.spark.sql.execution.MetricsTestHelper
34+
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
3535
import org.apache.spark.sql.sources._
3636
import org.apache.spark.sql.test.SharedSQLContext
3737
import org.apache.spark.sql.types._
@@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
917917
assert(e2.contains("User specified schema not supported with `jdbc`"))
918918
}
919919

920-
test("Input/generated/output metrics on JDBC") {
920+
test("Checking metrics correctness with JDBC") {
921921
val foobarCnt = spark.table("foobar").count()
922-
val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
923-
assert(res.recordsRead === foobarCnt :: Nil)
924-
assert(res.shuffleRecordsRead.sum === 0)
925-
assert(res.generatedRows.isEmpty)
926-
assert(res.outputRows === foobarCnt :: Nil)
922+
val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
923+
assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
927924
}
928925

929926
test("SPARK-19318: Connection properties keys should be case-sensitive.") {

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
1919

2020
import org.scalatest.BeforeAndAfterAll
2121

22-
import org.apache.spark.sql.execution.MetricsTestHelper
22+
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
2323
import org.apache.spark.sql.hive.test.TestHive
2424

2525
/**
@@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
4949

5050
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
5151

52-
test("Test input/generated/output metrics") {
52+
test("Checking metrics correctness") {
5353
import TestHive._
5454

5555
val episodesCnt = sql("select * from episodes").count()
56-
val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
57-
assert(episodesRes.recordsRead === episodesCnt :: Nil)
58-
assert(episodesRes.shuffleRecordsRead.sum === 0)
59-
assert(episodesRes.generatedRows.isEmpty)
60-
assert(episodesRes.outputRows === episodesCnt :: Nil)
56+
val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF())
57+
assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)
6158

6259
val serdeinsCnt = sql("select * from serdeins").count()
63-
val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
64-
assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
65-
assert(serdeinsRes.shuffleRecordsRead.sum === 0)
66-
assert(serdeinsRes.generatedRows.isEmpty)
67-
assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
60+
val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
61+
assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
6862
}
6963
}

0 commit comments

Comments
 (0)