Skip to content

Commit 131ca14

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-24005][CORE] Remove usage of Scala’s parallel collection
## What changes were proposed in this pull request? In the PR, I propose to replace Scala parallel collections by new methods `parmap()`. The methods use futures to transform a sequential collection by applying a lambda function to each element in parallel. The result of `parmap` is another regular (sequential) collection. The proposed `parmap` method aims to solve the problem of impossibility to interrupt parallel Scala collection. This possibility is needed for reliable task preemption. ## How was this patch tested? A test was added to `ThreadUtilsSuite` Closes #21913 from MaxGekk/par-map. Authored-by: Maxim Gekk <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 88e0c7b commit 131ca14

File tree

7 files changed

+142
-56
lines changed

7 files changed

+142
-56
lines changed

core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ package org.apache.spark.rdd
2020
import java.io.{IOException, ObjectOutputStream}
2121

2222
import scala.collection.mutable.ArrayBuffer
23-
import scala.collection.parallel.ForkJoinTaskSupport
23+
import scala.concurrent.ExecutionContext
2424
import scala.concurrent.forkjoin.ForkJoinPool
2525
import scala.reflect.ClassTag
2626

2727
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.DeveloperApi
29+
import org.apache.spark.util.ThreadUtils.parmap
2930
import org.apache.spark.util.Utils
3031

3132
/**
@@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
5960
}
6061

6162
object UnionRDD {
62-
private[spark] lazy val partitionEvalTaskSupport =
63-
new ForkJoinTaskSupport(new ForkJoinPool(8))
63+
private[spark] lazy val threadPool = new ForkJoinPool(8)
6464
}
6565

6666
@DeveloperApi
@@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
7474
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
7575

7676
override def getPartitions: Array[Partition] = {
77-
val parRDDs = if (isPartitionListingParallel) {
78-
val parArray = rdds.par
79-
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
80-
parArray
77+
val partitionLengths = if (isPartitionListingParallel) {
78+
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
79+
parmap(rdds)(_.partitions.length)
8180
} else {
82-
rdds
81+
rdds.map(_.partitions.length)
8382
}
84-
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
83+
val array = new Array[Partition](partitionLengths.sum)
8584
var pos = 0
8685
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
8786
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)

core/src/main/scala/org/apache/spark/util/ThreadUtils.scala

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ package org.apache.spark.util
1919

2020
import java.util.concurrent._
2121

22+
import scala.collection.TraversableLike
23+
import scala.collection.generic.CanBuildFrom
24+
import scala.language.higherKinds
25+
2226
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
23-
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
27+
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
2428
import scala.concurrent.duration.{Duration, FiniteDuration}
2529
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
2630
import scala.util.control.NonFatal
@@ -254,4 +258,62 @@ private[spark] object ThreadUtils {
254258
executor.shutdownNow()
255259
}
256260
}
261+
262+
/**
263+
* Transforms input collection by applying the given function to each element in parallel fashion.
264+
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
265+
* at any time. This is useful on canceling of task execution, for example.
266+
*
267+
* @param in - the input collection which should be transformed in parallel.
268+
* @param prefix - the prefix assigned to the underlying thread pool.
269+
* @param maxThreads - maximum number of thread can be created during execution.
270+
* @param f - the lambda function will be applied to each element of `in`.
271+
* @tparam I - the type of elements in the input collection.
272+
* @tparam O - the type of elements in resulted collection.
273+
* @return new collection in which each element was given from the input collection `in` by
274+
* applying the lambda function `f`.
275+
*/
276+
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
277+
(in: Col[I], prefix: String, maxThreads: Int)
278+
(f: I => O)
279+
(implicit
280+
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
281+
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence
282+
): Col[O] = {
283+
val pool = newForkJoinPool(prefix, maxThreads)
284+
try {
285+
implicit val ec = ExecutionContext.fromExecutor(pool)
286+
287+
parmap(in)(f)
288+
} finally {
289+
pool.shutdownNow()
290+
}
291+
}
292+
293+
/**
294+
* Transforms input collection by applying the given function to each element in parallel fashion.
295+
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
296+
* at any time. This is useful on canceling of task execution, for example.
297+
*
298+
* @param in - the input collection which should be transformed in parallel.
299+
* @param f - the lambda function will be applied to each element of `in`.
300+
* @param ec - an execution context for parallel applying of the given function `f`.
301+
* @tparam I - the type of elements in the input collection.
302+
* @tparam O - the type of elements in resulted collection.
303+
* @return new collection in which each element was given from the input collection `in` by
304+
* applying the lambda function `f`.
305+
*/
306+
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
307+
(in: Col[I])
308+
(f: I => O)
309+
(implicit
310+
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
311+
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
312+
ec: ExecutionContext
313+
): Col[O] = {
314+
val futures = in.map(x => Future(f(x)))
315+
val futureSeq = Future.sequence(futures)
316+
317+
awaitResult(futureSeq, Duration.Inf)
318+
}
257319
}

core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite {
133133
"stack trace contains unexpected references to ThreadUtils"
134134
)
135135
}
136+
137+
test("parmap should be interruptible") {
138+
val t = new Thread() {
139+
setDaemon(true)
140+
141+
override def run() {
142+
try {
143+
// "par" is uninterruptible. The following will keep running even if the thread is
144+
// interrupted. We should prefer to use "ThreadUtils.parmap".
145+
//
146+
// (1 to 10).par.flatMap { i =>
147+
// Thread.sleep(100000)
148+
// 1 to i
149+
// }
150+
//
151+
ThreadUtils.parmap(1 to 10, "test", 2) { i =>
152+
Thread.sleep(100000)
153+
1 to i
154+
}.flatten
155+
} catch {
156+
case _: InterruptedException => // excepted
157+
}
158+
}
159+
}
160+
t.start()
161+
eventually(timeout(10.seconds)) {
162+
assert(t.isAlive)
163+
}
164+
t.interrupt()
165+
eventually(timeout(10.seconds)) {
166+
assert(!t.isAlive)
167+
}
168+
}
136169
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
2020
import java.util.Locale
2121

2222
import scala.collection.{GenMap, GenSeq}
23-
import scala.collection.parallel.ForkJoinTaskSupport
23+
import scala.concurrent.ExecutionContext
2424
import scala.util.control.NonFatal
2525

2626
import org.apache.hadoop.conf.Configuration
@@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
2929

3030
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
3131
import org.apache.spark.sql.catalyst.TableIdentifier
32-
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
32+
import org.apache.spark.sql.catalyst.analysis.Resolver
3333
import org.apache.spark.sql.catalyst.catalog._
3434
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3535
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
4040
import org.apache.spark.sql.internal.HiveSerDe
4141
import org.apache.spark.sql.types._
4242
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
43+
import org.apache.spark.util.ThreadUtils.parmap
4344

4445
// Note: The definition of these commands are based on the ones described in
4546
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand(
621622
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
622623
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
623624
try {
625+
implicit val ec = ExecutionContext.fromExecutor(evalPool)
624626
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
625-
spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq
627+
spark.sessionState.conf.resolver)
626628
} finally {
627629
evalPool.shutdown()
628630
}
@@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand(
654656
spec: TablePartitionSpec,
655657
partitionNames: Seq[String],
656658
threshold: Int,
657-
resolver: Resolver,
658-
evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
659+
resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = {
659660
if (partitionNames.isEmpty) {
660661
return Seq(spec -> path)
661662
}
662663

663-
val statuses = fs.listStatus(path, filter)
664-
val statusPar: GenSeq[FileStatus] =
665-
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
666-
// parallelize the list of partitions here, then we can have better parallelism later.
667-
val parArray = statuses.par
668-
parArray.tasksupport = evalTaskSupport
669-
parArray
670-
} else {
671-
statuses
672-
}
673-
statusPar.flatMap { st =>
664+
val statuses = fs.listStatus(path, filter).toSeq
665+
def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = {
674666
val name = st.getPath.getName
675667
if (st.isDirectory && name.contains("=")) {
676668
val ps = name.split("=", 2)
@@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand(
679671
val value = ExternalCatalogUtils.unescapePathName(ps(1))
680672
if (resolver(columnName, partitionNames.head)) {
681673
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
682-
partitionNames.drop(1), threshold, resolver, evalTaskSupport)
674+
partitionNames.drop(1), threshold, resolver)
683675
} else {
684676
logWarning(
685677
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
@@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand(
690682
Seq.empty
691683
}
692684
}
685+
val result = if (partitionNames.length > 1 &&
686+
statuses.length > threshold || partitionNames.length > 2) {
687+
parmap(statuses)(handleStatus _)
688+
} else {
689+
statuses.map(handleStatus)
690+
}
691+
692+
result.flatten
693693
}
694694

695695
private def gatherPartitionStats(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.net.URI
2222

2323
import scala.collection.JavaConverters._
2424
import scala.collection.mutable
25-
import scala.collection.parallel.ForkJoinTaskSupport
2625
import scala.util.{Failure, Try}
2726

2827
import org.apache.hadoop.conf.Configuration
@@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging {
532531
conf: Configuration,
533532
partFiles: Seq[FileStatus],
534533
ignoreCorruptFiles: Boolean): Seq[Footer] = {
535-
val parFiles = partFiles.par
536-
val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8)
537-
parFiles.tasksupport = new ForkJoinTaskSupport(pool)
538-
try {
539-
parFiles.flatMap { currentFile =>
540-
try {
541-
// Skips row group information since we only need the schema.
542-
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
543-
// when it can't read the footer.
544-
Some(new Footer(currentFile.getPath(),
545-
ParquetFileReader.readFooter(
546-
conf, currentFile, SKIP_ROW_GROUPS)))
547-
} catch { case e: RuntimeException =>
548-
if (ignoreCorruptFiles) {
549-
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
550-
None
551-
} else {
552-
throw new IOException(s"Could not read footer for file: $currentFile", e)
553-
}
534+
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
535+
try {
536+
// Skips row group information since we only need the schema.
537+
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
538+
// when it can't read the footer.
539+
Some(new Footer(currentFile.getPath(),
540+
ParquetFileReader.readFooter(
541+
conf, currentFile, SKIP_ROW_GROUPS)))
542+
} catch { case e: RuntimeException =>
543+
if (ignoreCorruptFiles) {
544+
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
545+
None
546+
} else {
547+
throw new IOException(s"Could not read footer for file: $currentFile", e)
554548
}
555-
}.seq
556-
} finally {
557-
pool.shutdown()
558-
}
549+
}
550+
}.flatten
559551
}
560552

561553
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo
5151
}
5252

5353
testReadFooters(true)
54-
val exception = intercept[java.io.IOException] {
54+
val exception = intercept[SparkException] {
5555
testReadFooters(false)
56-
}
56+
}.getCause
5757
assert(exception.getMessage().contains("Could not read footer for file"))
5858
}
5959
}

streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog {
312312
handler: I => Iterator[O]): Iterator[O] = {
313313
val taskSupport = new ExecutionContextTaskSupport(executionContext)
314314
val groupSize = taskSupport.parallelismLevel.max(8)
315+
implicit val ec = executionContext
316+
315317
source.grouped(groupSize).flatMap { group =>
316-
val parallelCollection = group.par
317-
parallelCollection.tasksupport = taskSupport
318-
parallelCollection.map(handler)
318+
ThreadUtils.parmap(group)(handler)
319319
}.flatten
320320
}
321321
}

0 commit comments

Comments
 (0)