Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.ExecutionContext
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ThreadUtils.parmap
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag](
}

object UnionRDD {
private[spark] lazy val partitionEvalTaskSupport =
new ForkJoinTaskSupport(new ForkJoinPool(8))
private[spark] lazy val threadPool = new ForkJoinPool(8)
}

@DeveloperApi
Expand All @@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag](
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) {
val parArray = rdds.par
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
parArray
val partitionLengths = if (isPartitionListingParallel) {
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
parmap(rdds)(_.partitions.length)
} else {
rdds
rdds.map(_.partitions.length)
}
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
val array = new Array[Partition](partitionLengths.sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
Expand Down
64 changes: 63 additions & 1 deletion core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ package org.apache.spark.util

import java.util.concurrent._

import scala.collection.TraversableLike
import scala.collection.generic.CanBuildFrom
import scala.language.higherKinds

import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal
Expand Down Expand Up @@ -254,4 +258,62 @@ private[spark] object ThreadUtils {
executor.shutdownNow()
}
}

/**
* Transforms input collection by applying the given function to each element in parallel fashion.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd still include a note in these docs about what this does differently from .par. Just a sentence about it being interruptible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment about this.

* Comparing to the map() method of Scala parallel collections, this method can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
* @param in - the input collection which should be transformed in parallel.
* @param prefix - the prefix assigned to the underlying thread pool.
* @param maxThreads - maximum number of thread can be created during execution.
* @param f - the lambda function will be applied to each element of `in`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
(in: Col[I], prefix: String, maxThreads: Int)
(f: I => O)
(implicit
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence
): Col[O] = {
val pool = newForkJoinPool(prefix, maxThreads)
try {
implicit val ec = ExecutionContext.fromExecutor(pool)

parmap(in)(f)
} finally {
pool.shutdownNow()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ConeyLiu this line interrupts the tasks in the thread pool. Scala par doesn't do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing, thanks very much for your answer.

}
}

/**
* Transforms input collection by applying the given function to each element in parallel fashion.
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
* @param in - the input collection which should be transformed in parallel.
* @param f - the lambda function will be applied to each element of `in`.
* @param ec - an execution context for parallel applying of the given function `f`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
(in: Col[I])
(f: I => O)
(implicit
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
ec: ExecutionContext
): Col[O] = {
val futures = in.map(x => Future(f(x)))
val futureSeq = Future.sequence(futures)

awaitResult(futureSeq, Duration.Inf)
}
}
33 changes: 33 additions & 0 deletions core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite {
"stack trace contains unexpected references to ThreadUtils"
)
}

test("parmap should be interruptible") {
val t = new Thread() {
setDaemon(true)

override def run() {
try {
// "par" is uninterruptible. The following will keep running even if the thread is
// interrupted. We should prefer to use "ThreadUtils.parmap".
//
// (1 to 10).par.flatMap { i =>
// Thread.sleep(100000)
// 1 to i
// }
//
ThreadUtils.parmap(1 to 10, "test", 2) { i =>
Thread.sleep(100000)
1 to i
}.flatten
} catch {
case _: InterruptedException => // excepted
}
}
}
t.start()
eventually(timeout(10.seconds)) {
assert(t.isAlive)
}
t.interrupt()
eventually(timeout(10.seconds)) {
assert(!t.isAlive)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @MaxGekk @zsxwing Could you tell me why this can be interrupted? while (1 to 10).par can't.

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import java.util.Locale

import scala.collection.{GenMap, GenSeq}
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
Expand All @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
Expand All @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
import org.apache.spark.util.ThreadUtils.parmap

// Note: The definition of these commands are based on the ones described in
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
Expand Down Expand Up @@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand(
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
try {
implicit val ec = ExecutionContext.fromExecutor(evalPool)
scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq
spark.sessionState.conf.resolver)
} finally {
evalPool.shutdown()
}
Expand Down Expand Up @@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand(
spec: TablePartitionSpec,
partitionNames: Seq[String],
threshold: Int,
resolver: Resolver,
evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = {
if (partitionNames.isEmpty) {
return Seq(spec -> path)
}

val statuses = fs.listStatus(path, filter)
val statusPar: GenSeq[FileStatus] =
if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
// parallelize the list of partitions here, then we can have better parallelism later.
val parArray = statuses.par
parArray.tasksupport = evalTaskSupport
parArray
} else {
statuses
}
statusPar.flatMap { st =>
val statuses = fs.listStatus(path, filter).toSeq
def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = {
val name = st.getPath.getName
if (st.isDirectory && name.contains("=")) {
val ps = name.split("=", 2)
Expand All @@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand(
val value = ExternalCatalogUtils.unescapePathName(ps(1))
if (resolver(columnName, partitionNames.head)) {
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
partitionNames.drop(1), threshold, resolver, evalTaskSupport)
partitionNames.drop(1), threshold, resolver)
} else {
logWarning(
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
Expand All @@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand(
Seq.empty
}
}
val result = if (partitionNames.length > 1 &&
statuses.length > threshold || partitionNames.length > 2) {
parmap(statuses)(handleStatus _)
} else {
statuses.map(handleStatus)
}

result.flatten
}

private def gatherPartitionStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.net.URI

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.parallel.ForkJoinTaskSupport
import scala.util.{Failure, Try}

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging {
conf: Configuration,
partFiles: Seq[FileStatus],
ignoreCorruptFiles: Boolean): Seq[Footer] = {
val parFiles = partFiles.par
val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8)
parFiles.tasksupport = new ForkJoinTaskSupport(pool)
try {
parFiles.flatMap { currentFile =>
try {
// Skips row group information since we only need the schema.
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
// when it can't read the footer.
Some(new Footer(currentFile.getPath(),
ParquetFileReader.readFooter(
conf, currentFile, SKIP_ROW_GROUPS)))
} catch { case e: RuntimeException =>
if (ignoreCorruptFiles) {
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
None
} else {
throw new IOException(s"Could not read footer for file: $currentFile", e)
}
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
try {
// Skips row group information since we only need the schema.
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
// when it can't read the footer.
Some(new Footer(currentFile.getPath(),
ParquetFileReader.readFooter(
conf, currentFile, SKIP_ROW_GROUPS)))
} catch { case e: RuntimeException =>
if (ignoreCorruptFiles) {
logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
None
} else {
throw new IOException(s"Could not read footer for file: $currentFile", e)
}
}.seq
} finally {
pool.shutdown()
}
}
}.flatten
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo
}

testReadFooters(true)
val exception = intercept[java.io.IOException] {
val exception = intercept[SparkException] {
testReadFooters(false)
}
}.getCause
assert(exception.getMessage().contains("Could not read footer for file"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog {
handler: I => Iterator[O]): Iterator[O] = {
val taskSupport = new ExecutionContextTaskSupport(executionContext)
val groupSize = taskSupport.parallelismLevel.max(8)
implicit val ec = executionContext

source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
parallelCollection.map(handler)
ThreadUtils.parmap(group)(handler)
}.flatten
}
}