@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
1919
2020import java .text .NumberFormat
2121import java .util .Locale
22+ import java .util .concurrent .atomic .LongAdder
2223
2324import org .apache .spark .SparkContext
2425import org .apache .spark .scheduler .AccumulableInfo
@@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
3233 * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates() ]].
3334 */
3435class SQLMetric (val metricType : String , initValue : Long = 0L ) extends AccumulatorV2 [Long , Long ] {
36+
3537 // This is a workaround for SPARK-11013.
3638 // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
3739 // update it at the end of task and the value will be at least 0. Then we can filter out the -1
3840 // values before calculate max, min, etc.
39- private [this ] var _value = initValue
40- private var _zeroValue = initValue
41+ private [this ] val _value = new LongAdder
42+ private val _zeroValue = initValue
43+ _value.add(initValue)
4144
4245 override def copy (): SQLMetric = {
43- val newAcc = new SQLMetric (metricType, _value )
44- newAcc._zeroValue = initValue
46+ val newAcc = new SQLMetric (metricType, initValue )
47+ newAcc.add(_value.sum())
4548 newAcc
4649 }
4750
48- override def reset (): Unit = _value = _zeroValue
51+ override def reset (): Unit = this .set( _zeroValue)
4952
5053 override def merge (other : AccumulatorV2 [Long , Long ]): Unit = other match {
51- case o : SQLMetric => _value += o.value
54+ case o : SQLMetric => _value.add( o.value)
5255 case _ => throw new UnsupportedOperationException (
5356 s " Cannot merge ${this .getClass.getName} with ${other.getClass.getName}" )
5457 }
5558
56- override def isZero (): Boolean = _value == _zeroValue
59+ override def isZero (): Boolean = _value.sum() == _zeroValue
5760
58- override def add (v : Long ): Unit = _value += v
61+ override def add (v : Long ): Unit = _value.add(v)
5962
6063 // We can set a double value to `SQLMetric` which stores only long value, if it is
6164 // average metrics.
6265 def set (v : Double ): Unit = SQLMetrics .setDoubleForAverageMetrics(this , v)
6366
64- def set (v : Long ): Unit = _value = v
67+ def set (v : Long ): Unit = {
68+ _value.reset()
69+ _value.add(v)
70+ }
6571
66- def += (v : Long ): Unit = _value += v
72+ def += (v : Long ): Unit = _value.add(v)
6773
68- override def value : Long = _value
74+ override def value : Long = _value.sum()
6975
7076 // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
7177 override def toInfo (update : Option [Any ], value : Option [Any ]): AccumulableInfo = {
@@ -153,7 +159,7 @@ object SQLMetrics {
153159 Seq .fill(3 )(0L )
154160 } else {
155161 val sorted = validValues.sorted
156- Seq (sorted( 0 ) , sorted(validValues.length / 2 ), sorted(validValues.length - 1 ))
162+ Seq (sorted.head , sorted(validValues.length / 2 ), sorted(validValues.length - 1 ))
157163 }
158164 metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
159165 }
@@ -173,7 +179,8 @@ object SQLMetrics {
173179 Seq .fill(4 )(0L )
174180 } else {
175181 val sorted = validValues.sorted
176- Seq (sorted.sum, sorted(0 ), sorted(validValues.length / 2 ), sorted(validValues.length - 1 ))
182+ Seq (sorted.sum, sorted.head, sorted(validValues.length / 2 ),
183+ sorted(validValues.length - 1 ))
177184 }
178185 metric.map(strFormat)
179186 }
0 commit comments