diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 2263538a11676..0c2ceb1a02c7b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -217,11 +217,11 @@ class BarrierTaskContext private[spark] ( */ @Experimental @Since("3.0.0") - def allGather(message: String): ArrayBuffer[String] = { + def allGather(message: String): Array[String] = { val json = runBarrier(RequestMethod.ALL_GATHER, message) val jsonArray = parse(json) implicit val formats = DefaultFormats - ArrayBuffer(jsonArray.extract[Array[String]]: _*) + jsonArray.extract[Array[String]] } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index fa8bf0fc06358..06c9446c7534e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -420,7 +420,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( context.asInstanceOf[BarrierTaskContext].barrier() result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( + val messages: Array[String] = context.asInstanceOf[BarrierTaskContext].allGather( message ) result = compact(render(JArray( diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 33594c0a50d14..0dd8be72dc904 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -65,7 +65,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(Random.nextInt(1000)) // Pass partitionId message in val message: String = context.partitionId().toString - val messages: ArrayBuffer[String] = context.allGather(message) + val messages: Array[String] = context.allGather(message) messages.toList.iterator } // Take a sorted list of all the partitionId messages