Skip to content
Closed
30 changes: 20 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,8 @@ abstract class RDD[T: ClassTag](
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
val cleanF = sc.clean(f)
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter)
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}

Expand All @@ -741,9 +742,11 @@ abstract class RDD[T: ClassTag](
def mapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.map(t => f(t, a))
val a = cleanA(index)
iter.map(t => cleanF(t, a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that mappartitionsWithIndex cleans, and thus anything in that is cleaned. maybe it's an incorrect assumption. cc @andrewor14?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like Reynold is correct.
I can update the PR for collect() and undo the change for other methods touched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think collect might've been cleaned in DAGScheduler's runJob. Double check that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that @andrewor14 added some tests for closure cleaning as part of his recent ClosureCleaner patch; we might check whether that test suite covers these methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite does not currently cover these methods because they are deprecated, but maybe we should just add them. (@JoshRosen is referring to ClosureCleanerSuite)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin @ted-yu actually even though mapPartitionsWithIndex does clean already, it cleans the whole closure but not the ones used in the closure. In this case, I believe it's actually necessary to clean f here since we won't actually clean it from mapPartitionsWithIndex. For the same reason I believe we also need to clean constructA since it's a closure provided by the user.

}, preservesPartitioning)
}

Expand All @@ -756,9 +759,11 @@ abstract class RDD[T: ClassTag](
def flatMapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.flatMap(t => f(t, a))
val a = cleanA(index)
iter.flatMap(t => cleanF(t, a))
}, preservesPartitioning)
}

Expand All @@ -769,9 +774,11 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
val cleanF = sc.clean(f)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex { (index, iter) =>
val a = constructA(index)
iter.map(t => {f(t, a); t})
val a = cleanA(index)
iter.map(t => {cleanF(t, a); t})
}
}

Expand All @@ -782,9 +789,11 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope {
val cleanP = sc.clean(p)
val cleanA = sc.clean(constructA)
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.filter(t => p(t, a))
val a = cleanA(index)
iter.filter(t => cleanP(t, a))
}, preservesPartitioning = true)
}

Expand Down Expand Up @@ -901,7 +910,8 @@ abstract class RDD[T: ClassTag](
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
filter(f.isDefinedAt).map(f)
val cleanF = sc.clean(f)
filter(cleanF.isDefinedAt).map(cleanF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is correct, but I'm actually not 100% sure if it's necessary. I think it is because the filter closure is going to reference f indirectly through f.isDefinedAt, so it seems that we do need to clean it.

In any case, I would recommend that we keep this change since in the worst case we clean a closure twice, which is harmless (we have tests for this).

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import java.io.NotSerializableException
import java.util.Random

import org.scalatest.FunSuite

Expand Down Expand Up @@ -92,6 +93,11 @@ class ClosureCleanerSuite extends FunSuite {
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is missing testForeachWith

expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
Expand Down Expand Up @@ -260,6 +266,21 @@ private object TestUserClosuresActuallyCleaned {
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
}
def testFlatMapWith(rdd: RDD[Int]): Unit = {
rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count()
}
def testMapWith(rdd: RDD[Int]): Unit = {
rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count()
}
def testFilterWith(rdd: RDD[Int]): Unit = {
rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count()
}
def testForEachWith(rdd: RDD[Int]): Unit = {
rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return }
}
def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = {
rdd.mapPartitionsWithContext { (_, it) => return; it }.count()
}
def testZipPartitions2(rdd: RDD[Int]): Unit = {
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
}
Expand Down