diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 60e383afadf1..4b6f73235a57 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,12 +20,13 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.ExecutionContext import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.ThreadUtils.parmap import org.apache.spark.util.Utils /** @@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag]( } object UnionRDD { - private[spark] lazy val partitionEvalTaskSupport = - new ForkJoinTaskSupport(new ForkJoinPool(8)) + private[spark] lazy val threadPool = new ForkJoinPool(8) } @DeveloperApi @@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag]( rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) override def getPartitions: Array[Partition] = { - val parRDDs = if (isPartitionListingParallel) { - val parArray = rdds.par - parArray.tasksupport = UnionRDD.partitionEvalTaskSupport - parArray + val partitionLengths = if (isPartitionListingParallel) { + implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool) + parmap(rdds)(_.partitions.length) } else { - rdds + rdds.map(_.partitions.length) } - val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) + val array = new Array[Partition](partitionLengths.sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 0f08a2b0ad89..f0e5addbe5b5 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,8 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import scala.collection.TraversableLike +import scala.collection.generic.CanBuildFrom +import scala.language.higherKinds + import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -254,4 +258,62 @@ private[spark] object ThreadUtils { executor.shutdownNow() } } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param prefix - the prefix assigned to the underlying thread pool. + * @param maxThreads - maximum number of thread can be created during execution. + * @param f - the lambda function will be applied to each element of `in`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I], prefix: String, maxThreads: Int) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence + ): Col[O] = { + val pool = newForkJoinPool(prefix, maxThreads) + try { + implicit val ec = ExecutionContext.fromExecutor(pool) + + parmap(in)(f) + } finally { + pool.shutdownNow() + } + } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param f - the lambda function will be applied to each element of `in`. + * @param ec - an execution context for parallel applying of the given function `f`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I]) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence + ec: ExecutionContext + ): Col[O] = { + val futures = in.map(x => Future(f(x))) + val futureSeq = Future.sequence(futures) + + awaitResult(futureSeq, Duration.Inf) + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index ae3b3d829f1b..604f1e1ca310 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite { "stack trace contains unexpected references to ThreadUtils" ) } + + test("parmap should be interruptible") { + val t = new Thread() { + setDaemon(true) + + override def run() { + try { + // "par" is uninterruptible. The following will keep running even if the thread is + // interrupted. We should prefer to use "ThreadUtils.parmap". + // + // (1 to 10).par.flatMap { i => + // Thread.sleep(100000) + // 1 to i + // } + // + ThreadUtils.parmap(1 to 10, "test", 2) { i => + Thread.sleep(100000) + 1 to i + }.flatten + } catch { + case _: InterruptedException => // excepted + } + } + } + t.start() + eventually(timeout(10.seconds)) { + assert(t.isAlive) + } + t.interrupt() + eventually(timeout(10.seconds)) { + assert(!t.isAlive) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e1faecedd20e..7a6f5741862c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import java.util.Locale import scala.collection.{GenMap, GenSeq} -import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.ExecutionContext import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} +import org.apache.spark.util.ThreadUtils.parmap // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand( val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = try { + implicit val ec = ExecutionContext.fromExecutor(evalPool) scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, - spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + spark.sessionState.conf.resolver) } finally { evalPool.shutdown() } @@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver, - evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { + resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } - val statuses = fs.listStatus(path, filter) - val statusPar: GenSeq[FileStatus] = - if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { - // parallelize the list of partitions here, then we can have better parallelism later. - val parArray = statuses.par - parArray.tasksupport = evalTaskSupport - parArray - } else { - statuses - } - statusPar.flatMap { st => + val statuses = fs.listStatus(path, filter).toSeq + def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = { val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) @@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver, evalTaskSupport) + partitionNames.drop(1), threshold, resolver) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") @@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty } } + val result = if (partitionNames.length > 1 && + statuses.length > threshold || partitionNames.length > 2) { + parmap(statuses)(handleStatus _) + } else { + statuses.map(handleStatus) + } + + result.flatten } private def gatherPartitionStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 283d7761d22d..b2409f3470e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -22,7 +22,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.parallel.ForkJoinTaskSupport import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging { conf: Configuration, partFiles: Seq[FileStatus], ignoreCorruptFiles: Boolean): Seq[Footer] = { - val parFiles = partFiles.par - val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) - parFiles.tasksupport = new ForkJoinTaskSupport(pool) - try { - parFiles.flatMap { currentFile => - try { - // Skips row group information since we only need the schema. - // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, - // when it can't read the footer. - Some(new Footer(currentFile.getPath(), - ParquetFileReader.readFooter( - conf, currentFile, SKIP_ROW_GROUPS))) - } catch { case e: RuntimeException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) - None - } else { - throw new IOException(s"Could not read footer for file: $currentFile", e) - } + ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile => + try { + // Skips row group information since we only need the schema. + // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, + // when it can't read the footer. + Some(new Footer(currentFile.getPath(), + ParquetFileReader.readFooter( + conf, currentFile, SKIP_ROW_GROUPS))) + } catch { case e: RuntimeException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) + None + } else { + throw new IOException(s"Could not read footer for file: $currentFile", e) } - }.seq - } finally { - pool.shutdown() - } + } + }.flatten } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index 3a0867fd2b78..94abf115cef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo } testReadFooters(true) - val exception = intercept[java.io.IOException] { + val exception = intercept[SparkException] { testReadFooters(false) - } + }.getCause assert(exception.getMessage().contains("Could not read footer for file")) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 2e8599026ea1..bba071e80c0e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog { handler: I => Iterator[O]): Iterator[O] = { val taskSupport = new ExecutionContextTaskSupport(executionContext) val groupSize = taskSupport.parallelismLevel.max(8) + implicit val ec = executionContext + source.grouped(groupSize).flatMap { group => - val parallelCollection = group.par - parallelCollection.tasksupport = taskSupport - parallelCollection.map(handler) + ThreadUtils.parmap(group)(handler) }.flatten } }