Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp
import org.codehaus.janino.util.ClassFile

import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
classOf[UnsafeMapData].getName,
classOf[Expression].getName,
classOf[TaskContext].getName,
classOf[TaskKilledException].getName
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)

val taskContext = ctx.freshName("taskContext")
ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
val inputMetrics = ctx.freshName("inputMetrics")
ctx.addMutableState("InputMetrics", inputMetrics,
s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")

// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
Expand Down Expand Up @@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| if ($nextBatchTodo == 0) break;
| }
| $numOutput.add($nextBatchTodo);
| $numGenerated.add($nextBatchTodo);
| $inputMetrics.incRecordsRead($nextBatchTodo);
|
| $batchEnd += $nextBatchTodo * ${step}L;
| }
Expand All @@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val numGeneratedRows = longMetric("numGeneratedRows")
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
Expand All @@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val safePartitionEnd = getSafeMargin(partitionEnd)
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
val taskContext = TaskContext.get()

val iter = new Iterator[InternalRow] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics

override def hasNext =
if (!overflow) {
Expand All @@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}

numOutputRows += 1
numGeneratedRows += 1
inputMetrics.incRecordsRead(1)
unsafeRow.setLong(0, ret)
unsafeRow
}
}
new InterruptibleIterator(TaskContext.get(), iter)
new InterruptibleIterator(taskContext, iter)
}
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

package org.apache.spark.sql.execution.metric

import java.io.File

import scala.collection.mutable.HashMap

import org.apache.spark.SparkFunSuite
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.SparkPlanInfo
Expand Down Expand Up @@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
}

test("range metrics") {
val res1 = InputOutputMetricsHelper.run(
spark.range(30).filter(x => x % 3 == 0).toDF()
)
assert(res1 === (30L, 0L, 30L) :: Nil)

val res2 = InputOutputMetricsHelper.run(
spark.range(150).repartition(4).filter(x => x < 10).toDF()
)
assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)

withTempDir { tempDir =>
val dir = new File(tempDir, "pqS").getCanonicalPath

spark.range(10).write.parquet(dir)
spark.read.parquet(dir).createOrReplaceTempView("pqS")

val res3 = InputOutputMetricsHelper.run(
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
)
// The query above is executed in the following stages:
// 1. sql("select * from pqS") => (10, 0, 10)
// 2. range(30) => (30, 0, 30)
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
// 4. shuffle & return results => (0, 300, 0)
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
}
}
}

object InputOutputMetricsHelper {
private class InputOutputMetricsListener extends SparkListener {
private case class MetricsResult(
var recordsRead: Long = 0L,
var shuffleRecordsRead: Long = 0L,
var sumMaxOutputRows: Long = 0L)

private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]

def reset(): Unit = {
stageIdToMetricsResult.clear()
}

/**
* Return a list of recorded metrics aggregated per stage.
*
* The list is sorted in the ascending order on the stageId.
* For each recorded stage, the following tuple is returned:
* - sum of inputMetrics.recordsRead for all the tasks in the stage
* - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
* - sum of the highest values of "number of output rows" metric for all the tasks in the stage
*/
def getResults(): List[(Long, Long, Long)] = {
stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
val res = stageIdToMetricsResult(stageId)
(res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())

res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead

var maxOutputRows = 0L
for (accum <- taskEnd.taskMetrics.externalAccums) {
val info = accum.toInfo(Some(accum.value), None)
if (info.name.toString.contains("number of output rows")) {
info.update match {
case Some(n: Number) =>
if (n.longValue() > maxOutputRows) {
maxOutputRows = n.longValue()
}
case _ => // Ignore.
}
}
}
res.sumMaxOutputRows += maxOutputRows
}
}

// Run df.collect() and return aggregated metrics for each stage.
def run(df: DataFrame): List[(Long, Long, Long)] = {
val spark = df.sparkSession
val sparkContext = spark.sparkContext
val listener = new InputOutputMetricsListener()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use try...finally here

sparkContext.addSparkListener(listener)

try {
sparkContext.listenerBus.waitUntilEmpty(5000)
listener.reset()
df.collect()
sparkContext.listenerBus.waitUntilEmpty(5000)
} finally {
sparkContext.removeSparkListener(listener)
}
listener.getResults()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}

test("Input/generated/output metrics on JDBC") {
test("Checking metrics correctness with JDBC") {
val foobarCnt = spark.table("foobar").count()
val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
assert(res.recordsRead === foobarCnt :: Nil)
assert(res.shuffleRecordsRead.sum === 0)
assert(res.generatedRows.isEmpty)
assert(res.outputRows === foobarCnt :: Nil)
val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
}

test("SPARK-19318: Connection properties keys should be case-sensitive.") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.hive.test.TestHive

/**
Expand Down Expand Up @@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {

createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")

test("Test input/generated/output metrics") {
test("Checking metrics correctness") {
import TestHive._

val episodesCnt = sql("select * from episodes").count()
val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
assert(episodesRes.recordsRead === episodesCnt :: Nil)
assert(episodesRes.shuffleRecordsRead.sum === 0)
assert(episodesRes.generatedRows.isEmpty)
assert(episodesRes.outputRows === episodesCnt :: Nil)
val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF())
assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)

val serdeinsCnt = sql("select * from serdeins").count()
val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
assert(serdeinsRes.shuffleRecordsRead.sum === 0)
assert(serdeinsRes.generatedRows.isEmpty)
assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
}
}