Skip to content

Commit 72cdfeb

Browse files
committed
Porting UnionRDD on parmap
1 parent ad03004 commit 72cdfeb

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ package org.apache.spark.rdd
2020
import java.io.{IOException, ObjectOutputStream}
2121

2222
import scala.collection.mutable.ArrayBuffer
23-
import scala.collection.parallel.ForkJoinTaskSupport
23+
import scala.concurrent.ExecutionContext
2424
import scala.concurrent.forkjoin.ForkJoinPool
2525
import scala.reflect.ClassTag
2626

2727
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.DeveloperApi
29+
import org.apache.spark.util.ThreadUtils.parmap
2930
import org.apache.spark.util.Utils
3031

3132
/**
@@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
5960
}
6061

6162
object UnionRDD {
62-
private[spark] lazy val partitionEvalTaskSupport =
63-
new ForkJoinTaskSupport(new ForkJoinPool(8))
63+
private[spark] lazy val threadPool = new ForkJoinPool(8)
6464
}
6565

6666
@DeveloperApi
@@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
7474
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
7575

7676
override def getPartitions: Array[Partition] = {
77-
val parRDDs = if (isPartitionListingParallel) {
78-
val parArray = rdds.par
79-
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
80-
parArray
77+
val partitionLengths = if (isPartitionListingParallel) {
78+
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
79+
parmap(rdds)(_.partitions.length)
8180
} else {
82-
rdds
81+
rdds.map(_.partitions.length)
8382
}
84-
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
83+
val array = new Array[Partition](partitionLengths.sum)
8584
var pos = 0
8685
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
8786
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)

0 commit comments

Comments
 (0)