-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19447] Make Range operator generate "recordsRead" metric #16960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,12 @@ | |
|
|
||
| package org.apache.spark.sql.execution.metric | ||
|
|
||
| import java.io.File | ||
|
|
||
| import scala.collection.mutable.HashMap | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.execution.SparkPlanInfo | ||
|
|
@@ -309,4 +314,94 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { | |
| assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) | ||
| } | ||
|
|
||
| test("range metrics") { | ||
| val res1 = InputOutputMetricsHelper.run( | ||
| spark.range(30).filter(x => x % 3 == 0).toDF() | ||
| ) | ||
| assert(res1 === (30L, 0L, 30L) :: Nil) | ||
|
|
||
| val res2 = InputOutputMetricsHelper.run( | ||
| spark.range(150).repartition(4).filter(x => x < 10).toDF() | ||
| ) | ||
| assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil) | ||
|
|
||
| withTempDir { tempDir => | ||
| val dir = new File(tempDir, "pqS").getCanonicalPath | ||
|
|
||
| spark.range(10).write.parquet(dir) | ||
| spark.read.parquet(dir).createOrReplaceTempView("pqS") | ||
|
|
||
| val res3 = InputOutputMetricsHelper.run( | ||
| spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() | ||
| ) | ||
| assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| object InputOutputMetricsHelper { | ||
| private class InputOutputMetricsListener extends SparkListener { | ||
| private case class MetricsResult( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: add space |
||
| var recordsRead: Long = 0L, | ||
| var shuffleRecordsRead: Long = 0L, | ||
| var sumMaxOutputRows: Long = 0L) | ||
|
|
||
| private[this] var stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] | ||
|
||
|
|
||
| def reset(): Unit = { | ||
| stageIdToMetricsResult = HashMap.empty[Int, MetricsResult] | ||
| } | ||
|
|
||
| /** | ||
| * Return a list of recorded metrics aggregated per stage. | ||
| * | ||
| * The list is sorted in the ascending order on the stageId. | ||
| * For each recorded stage, the following tuple is returned: | ||
| * - sum of inputMetrics.recordsRead for all the tasks in the stage | ||
| * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage | ||
| * - sum of the highest values of "number of output rows" metric for all the tasks in the stage | ||
| */ | ||
| def getResults(): List[(Long, Long, Long)] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too long long long |
||
| stageIdToMetricsResult.keySet.toList.sorted.map({ stageId => | ||
| val res = stageIdToMetricsResult(stageId) | ||
| (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)}) | ||
| } | ||
|
|
||
| override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { | ||
| val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, { MetricsResult() }) | ||
|
||
|
|
||
| res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead | ||
| res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead | ||
|
|
||
| var maxOutputRows = 0L | ||
| for (accum <- taskEnd.taskMetrics.externalAccums) { | ||
| val info = accum.toInfo(Some(accum.value), None) | ||
| if (info.name.toString.contains("number of output rows")) { | ||
| info.update match { | ||
| case Some(n: Number) => | ||
| if (n.longValue() > maxOutputRows) { | ||
| maxOutputRows = n.longValue() | ||
| } | ||
| case _ => // Ignore. | ||
| } | ||
| } | ||
| } | ||
| res.sumMaxOutputRows += maxOutputRows | ||
| } | ||
| } | ||
|
|
||
| // Run df.collect() and return aggregated metrics for each stage. | ||
| def run(df: DataFrame): List[(Long, Long, Long)] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. document what hte long long long are for? |
||
| val spark = df.sparkSession | ||
| val sparkContext = spark.sparkContext | ||
| val listener = new InputOutputMetricsListener() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use try...finally here |
||
| sparkContext.addSparkListener(listener) | ||
|
|
||
| sparkContext.listenerBus.waitUntilEmpty(5000) | ||
| listener.reset() | ||
| df.collect() | ||
| sparkContext.listenerBus.waitUntilEmpty(5000) | ||
| sparkContext.removeSparkListener(listener) | ||
| listener.getResults() | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hard to reason about. Could you add a few lines of documentation?