Skip to content

Commit 4dca831

Browse files
committed
[EXT][SPARK-24648][SQL] SqlMetrics should be threadsafe apache#21634
Use LongAdder to make SQLMetrics thread safe.
1 parent 2b7fc7a commit 4dca831

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
1919

2020
import java.text.NumberFormat
2121
import java.util.Locale
22+
import java.util.concurrent.atomic.LongAdder
2223

2324
import org.apache.spark.SparkContext
2425
import org.apache.spark.scheduler.AccumulableInfo
@@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
3233
* on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]].
3334
*/
3435
class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
36+
3537
// This is a workaround for SPARK-11013.
3638
// We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
3739
// update it at the end of task and the value will be at least 0. Then we can filter out the -1
3840
// values before calculate max, min, etc.
39-
private[this] var _value = initValue
40-
private var _zeroValue = initValue
41+
private[this] val _value = new LongAdder
42+
private val _zeroValue = initValue
43+
_value.add(initValue)
4144

4245
override def copy(): SQLMetric = {
43-
val newAcc = new SQLMetric(metricType, _value)
44-
newAcc._zeroValue = initValue
46+
val newAcc = new SQLMetric(metricType, initValue)
47+
newAcc.add(_value.sum())
4548
newAcc
4649
}
4750

48-
override def reset(): Unit = _value = _zeroValue
51+
override def reset(): Unit = this.set(_zeroValue)
4952

5053
override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
51-
case o: SQLMetric => _value += o.value
54+
case o: SQLMetric => _value.add(o.value)
5255
case _ => throw new UnsupportedOperationException(
5356
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
5457
}
5558

56-
override def isZero(): Boolean = _value == _zeroValue
59+
override def isZero(): Boolean = _value.sum() == _zeroValue
5760

58-
override def add(v: Long): Unit = _value += v
61+
override def add(v: Long): Unit = _value.add(v)
5962

6063
// We can set a double value to `SQLMetric` which stores only long value, if it is
6164
// average metrics.
6265
def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v)
6366

64-
def set(v: Long): Unit = _value = v
67+
def set(v: Long): Unit = {
68+
_value.reset()
69+
_value.add(v)
70+
}
6571

66-
def +=(v: Long): Unit = _value += v
72+
def +=(v: Long): Unit = _value.add(v)
6773

68-
override def value: Long = _value
74+
override def value: Long = _value.sum()
6975

7076
// Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
7177
override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
@@ -153,7 +159,7 @@ object SQLMetrics {
153159
Seq.fill(3)(0L)
154160
} else {
155161
val sorted = validValues.sorted
156-
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
162+
Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1))
157163
}
158164
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
159165
}
@@ -173,7 +179,8 @@ object SQLMetrics {
173179
Seq.fill(4)(0L)
174180
} else {
175181
val sorted = validValues.sorted
176-
Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
182+
Seq(sorted.sum, sorted.head, sorted(validValues.length / 2),
183+
sorted(validValues.length - 1))
177184
}
178185
metric.map(strFormat)
179186
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric
1919

2020
import java.io.File
2121

22+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
2223
import scala.util.Random
2324

2425
import org.apache.spark.SparkFunSuite
2526
import org.apache.spark.sql._
2627
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
27-
import org.apache.spark.sql.execution.ui.SQLAppStatusStore
2828
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
@@ -413,6 +413,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
413413
testMetricsDynamicPartition("parquet", "parquet", "t1")
414414
}
415415

416+
<<<<<<< HEAD
416417
test("SPARK-26327: FileSourceScanExec metrics") {
417418
withTable("testDataForScan") {
418419
spark.range(10).selectExpr("id", "id % 3 as p")
@@ -427,4 +428,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
427428
)
428429
}
429430
}
431+
432+
test("writing metrics from single thread") {
433+
val nAdds = 10
434+
val acc = new SQLMetric("test", -10)
435+
assert(acc.isZero())
436+
acc.set(0)
437+
for (i <- 1 to nAdds) acc.add(1)
438+
assert(!acc.isZero())
439+
assert(nAdds === acc.value)
440+
acc.reset()
441+
assert(acc.isZero())
442+
}
443+
444+
test("writing metrics from multiple threads") {
445+
implicit val ec: ExecutionContextExecutor = ExecutionContext.global
446+
val nFutures = 1000
447+
val nAdds = 100
448+
val acc = new SQLMetric("test", -10)
449+
assert(acc.isZero() === true)
450+
acc.set(0)
451+
val l = for ( i <- 1 to nFutures ) yield {
452+
Future {
453+
for (j <- 1 to nAdds) acc.add(1)
454+
i
455+
}
456+
}
457+
for (futures <- Future.sequence(l)) {
458+
assert(nFutures === futures.length)
459+
assert(!acc.isZero())
460+
assert(nFutures * nAdds === acc.value)
461+
acc.reset()
462+
assert(acc.isZero())
463+
}
464+
}
430465
}

0 commit comments

Comments
 (0)