Skip to content

Commit 252417f

Browse files
HyukjinKwonsrowen
authored andcommitted
[SPARK-15322][SQL][FOLLOWUP] Use the new long accumulator for old int accumulators.
## What changes were proposed in this pull request? This PR corrects the remaining cases for using old accumulators. This does not change some old accumulator usages below: - `ImplicitSuite.scala` - Tests dedicated to old accumulator, for implicits with `AccumulatorParam` - `AccumulatorSuite.scala` - Tests dedicated to old accumulator - `JavaSparkContext.scala` - For supporting old accumulators for Java API. - `debug.package.scala` - Usage with `HashSet[String]`. Currently, it seems no implementation for this. I might be able to write an anonymous class for this but I didn't because I think it is not worth writing a lot of codes only for this. - `SQLMetricsSuite.scala` - This uses the old accumulator for checking type boxing. It seems new accumulator does not require type boxing for this case whereas the old one requires (due to the use of generic). ## How was this patch tested? Existing tests cover this. Author: hyukjinkwon <[email protected]> Closes #13434 from HyukjinKwon/accum.
1 parent b85d18f commit 252417f

File tree

7 files changed

+22
-23
lines changed

7 files changed

+22
-23
lines changed

core/src/test/scala/org/apache/spark/DistributedSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
9292

9393
test("accumulators") {
9494
sc = new SparkContext(clusterUrl, "test")
95-
val accum = sc.accumulator(0)
96-
sc.parallelize(1 to 10, 10).foreach(x => accum += x)
95+
val accum = sc.longAccumulator
96+
sc.parallelize(1 to 10, 10).foreach(x => accum.add(x))
9797
assert(accum.value === 55)
9898
}
9999

@@ -109,7 +109,6 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
109109

110110
test("repeatedly failing task") {
111111
sc = new SparkContext(clusterUrl, "test")
112-
val accum = sc.accumulator(0)
113112
val thrown = intercept[SparkException] {
114113
// scalastyle:off println
115114
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))

repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ class ReplSuite extends SparkFunSuite {
107107
test("simple foreach with accumulator") {
108108
val output = runInterpreter("local",
109109
"""
110-
|val accum = sc.accumulator(0)
111-
|sc.parallelize(1 to 10).foreach(x => accum += x)
110+
|val accum = sc.longAccumulator
111+
|sc.parallelize(1 to 10).foreach(x => accum.add(x))
112112
|accum.value
113113
""".stripMargin)
114114
assertDoesNotContain("error:", output)
115115
assertDoesNotContain("Exception", output)
116-
assertContains("res1: Int = 55", output)
116+
assertContains("res1: Long = 55", output)
117117
}
118118

119119
test("external vars") {

repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ class ReplSuite extends SparkFunSuite {
150150
test("simple foreach with accumulator") {
151151
val output = runInterpreter("local",
152152
"""
153-
|val accum = sc.accumulator(0)
154-
|sc.parallelize(1 to 10).foreach(x => accum += x)
153+
|val accum = sc.longAccumulator
154+
|sc.parallelize(1 to 10).foreach(x => accum.add(x))
155155
|accum.value
156156
""".stripMargin)
157157
assertDoesNotContain("error:", output)
158158
assertDoesNotContain("Exception", output)
159-
assertContains("res1: Int = 55", output)
159+
assertContains("res1: Long = 55", output)
160160
}
161161

162162
test("external vars") {

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.collection.JavaConverters._
2121

2222
import org.apache.commons.lang.StringUtils
2323

24-
import org.apache.spark.Accumulator
2524
import org.apache.spark.network.util.JavaUtils
2625
import org.apache.spark.rdd.RDD
2726
import org.apache.spark.sql.catalyst.InternalRow
@@ -36,7 +35,7 @@ import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
3635
import org.apache.spark.sql.execution.metric.SQLMetrics
3736
import org.apache.spark.sql.types.UserDefinedType
3837
import org.apache.spark.storage.StorageLevel
39-
import org.apache.spark.util.{AccumulatorContext, ListAccumulator}
38+
import org.apache.spark.util.{AccumulatorContext, ListAccumulator, LongAccumulator}
4039

4140

4241
private[sql] object InMemoryRelation {
@@ -294,8 +293,8 @@ private[sql] case class InMemoryTableScanExec(
294293
sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
295294

296295
// Accumulators used for testing purposes
297-
lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
298-
lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
296+
lazy val readPartitions = sparkContext.longAccumulator
297+
lazy val readBatches = sparkContext.longAccumulator
299298

300299
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
301300

@@ -339,7 +338,7 @@ private[sql] case class InMemoryTableScanExec(
339338
false
340339
} else {
341340
if (enableAccumulators) {
342-
readBatches += 1
341+
readBatches.add(1)
343342
}
344343
true
345344
}
@@ -361,7 +360,7 @@ private[sql] case class InMemoryTableScanExec(
361360
val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
362361
columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
363362
if (enableAccumulators && columnarIterator.hasNext) {
364-
readPartitions += 1
363+
readPartitions.add(1)
365364
}
366365
columnarIterator
367366
}

sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
2828
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
2929
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
3030
import org.apache.spark.sql.internal.SQLConf
31+
import org.apache.spark.util.LongAccumulator
3132

3233
/**
3334
* Contains methods for debugging query execution.
@@ -122,13 +123,13 @@ package object debug {
122123
/**
123124
* A collection of metrics for each column of output.
124125
*
125-
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
126+
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
126127
* causing the wrong data to be projected.
127128
*/
128129
case class ColumnMetrics(
129130
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
130131

131-
val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
132+
val tupleCount: LongAccumulator = sparkContext.longAccumulator
132133

133134
val numColumns: Int = child.output.size
134135
val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics())
@@ -149,7 +150,7 @@ package object debug {
149150

150151
def next(): InternalRow = {
151152
val currentRow = iter.next()
152-
tupleCount += 1
153+
tupleCount.add(1)
153154
var i = 0
154155
while (i < numColumns) {
155156
val value = currentRow.get(i, output(i).dataType)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,9 +2067,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20672067
checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
20682068

20692069
// Identity udf that tracks the number of times it is called.
2070-
val countAcc = sparkContext.accumulator(0, "CallCount")
2070+
val countAcc = sparkContext.longAccumulator("CallCount")
20712071
spark.udf.register("testUdf", (x: Int) => {
2072-
countAcc.++=(1)
2072+
countAcc.add(1)
20732073
x
20742074
})
20752075

@@ -2092,7 +2092,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20922092
df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
20932093

20942094
val testUdf = functions.udf((x: Int) => {
2095-
countAcc.++=(1)
2095+
countAcc.add(1)
20962096
x
20972097
})
20982098
verifyCallCount(

sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
365365
// This task has both accumulators that are SQL metrics and accumulators that are not.
366366
// The listener should only track the ones that are actually SQL metrics.
367367
val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
368-
val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
368+
val nonSqlMetric = sparkContext.longAccumulator("baseball")
369369
val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None)
370-
val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
370+
val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None)
371371
val taskInfo = createTaskInfo(0, 0)
372372
taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo)
373373
val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null)

0 commit comments

Comments
 (0)