diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8bc0ff7936da..8c2ce883093c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -45,7 +45,8 @@ import org.apache.spark.util._ private[spark] class PythonRDD( parent: RDD[_], func: PythonFunction, - preservePartitoning: Boolean) + preservePartitoning: Boolean, + isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -63,6 +64,9 @@ private[spark] class PythonRDD( val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 978e7c004e5e..b399bf9febae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -19,7 +19,6 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.BarrierTaskContext import org.apache.spark.TaskContext import org.apache.spark.annotation.{Experimental, Since} diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 951851804b1d..d17a8eb76ad4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2406,6 +2406,22 @@ def toLocalIterator(self): sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) return _load_from_socket(sock_info, self._jrdd_deserializer) + def barrier(self): + """ + .. note:: Experimental + + Indicates that Spark must launch the tasks together for the current stage. + + .. versionadded:: 2.4.0 + """ + return RDDBarrier(self) + + def _is_barrier(self): + """ + Whether this RDD is in a barrier stage. + """ + return self._jrdd.rdd().isBarrier() + def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast @@ -2429,6 +2445,33 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class RDDBarrier(object): + + """ + .. note:: Experimental + + An RDDBarrier turns an RDD into a barrier RDD, which forces Spark to launch tasks of the stage + contains this RDD together. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, rdd): + self.rdd = rdd + + def mapPartitions(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Return a new RDD by applying a function to each partition of this RDD. + + .. versionadded:: 2.4.0 + """ + def func(s, iterator): + return f(iterator) + return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + + class PipelinedRDD(RDD): """ @@ -2448,7 +2491,7 @@ class PipelinedRDD(RDD): 20 """ - def __init__(self, prev, func, preservesPartitioning=False): + def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: self.func = func @@ -2474,6 +2517,7 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None + self.is_barrier = prev._is_barrier() or isFromBarrier def getNumPartitions(self): return self._prev_jrdd.partitions().size() @@ -2493,7 +2537,7 @@ def _jrdd(self): wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, - self.preservesPartitioning) + self.preservesPartitioning, self.is_barrier) self._jrdd_val = python_rdd.asJavaRDD() if profiler: @@ -2509,6 +2553,9 @@ def id(self): def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) + def _is_barrier(self): + return self.is_barrier + def _test(): import doctest