@@ -20,12 +20,13 @@ package org.apache.spark.rdd
2020import java .io .{IOException , ObjectOutputStream }
2121
2222import scala .collection .mutable .ArrayBuffer
23- import scala .collection . parallel . ForkJoinTaskSupport
23+ import scala .concurrent . ExecutionContext
2424import scala .concurrent .forkjoin .ForkJoinPool
2525import scala .reflect .ClassTag
2626
2727import org .apache .spark .{Dependency , Partition , RangeDependency , SparkContext , TaskContext }
2828import org .apache .spark .annotation .DeveloperApi
29+ import org .apache .spark .util .ThreadUtils .parmap
2930import org .apache .spark .util .Utils
3031
3132/**
@@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
5960}
6061
6162object UnionRDD {
62- private [spark] lazy val partitionEvalTaskSupport =
63- new ForkJoinTaskSupport (new ForkJoinPool (8 ))
63+ private [spark] lazy val threadPool = new ForkJoinPool (8 )
6464}
6565
6666@ DeveloperApi
@@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
7474 rdds.length > conf.getInt(" spark.rdd.parallelListingThreshold" , 10 )
7575
7676 override def getPartitions : Array [Partition ] = {
77- val parRDDs = if (isPartitionListingParallel) {
78- val parArray = rdds.par
79- parArray.tasksupport = UnionRDD .partitionEvalTaskSupport
80- parArray
77+ val partitionLengths = if (isPartitionListingParallel) {
78+ implicit val ec = ExecutionContext .fromExecutor(UnionRDD .threadPool)
79+ parmap(rdds)(_.partitions.length)
8180 } else {
82- rdds
81+ rdds.map(_.partitions.length)
8382 }
84- val array = new Array [Partition ](parRDDs.map(_.partitions.length).seq .sum)
83+ val array = new Array [Partition ](partitionLengths .sum)
8584 var pos = 0
8685 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
8786 array(pos) = new UnionPartition (pos, rdd, rddIndex, split.index)
0 commit comments