Skip to content

Commit 54e6fa0

Browse files
tedyuAndrew Or
authored andcommitted
[SPARK-7237] Clean function in several RDD methods
Author: tedyu <[email protected]> Closes #5959 from ted-yu/master and squashes the following commits: f83d445 [tedyu] Move cleaning outside of mapPartitionsWithIndex 56d7c92 [tedyu] Consolidate import of Random f6014c0 [tedyu] Remove cleaning in RDD#filterWith 36feb6c [tedyu] Try to get correct syntax 55d01eb [tedyu] Try to get correct syntax c2786df [tedyu] Correct syntax d92bfcf [tedyu] Correct syntax in test 164d3e4 [tedyu] Correct variable name 8b50d93 [tedyu] Address Andrew's review comments 0c8d47e [tedyu] Add test for mapWith() 6846e40 [tedyu] Add test for flatMapWith() 6c124a9 [tedyu] Clean function in several RDD methods
1 parent bd61f07 commit 54e6fa0

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,8 @@ abstract class RDD[T: ClassTag](
717717
def mapPartitionsWithContext[U: ClassTag](
718718
f: (TaskContext, Iterator[T]) => Iterator[U],
719719
preservesPartitioning: Boolean = false): RDD[U] = withScope {
720-
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
720+
val cleanF = sc.clean(f)
721+
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter)
721722
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
722723
}
723724

@@ -741,9 +742,11 @@ abstract class RDD[T: ClassTag](
741742
def mapWith[A, U: ClassTag]
742743
(constructA: Int => A, preservesPartitioning: Boolean = false)
743744
(f: (T, A) => U): RDD[U] = withScope {
745+
val cleanF = sc.clean(f)
746+
val cleanA = sc.clean(constructA)
744747
mapPartitionsWithIndex((index, iter) => {
745-
val a = constructA(index)
746-
iter.map(t => f(t, a))
748+
val a = cleanA(index)
749+
iter.map(t => cleanF(t, a))
747750
}, preservesPartitioning)
748751
}
749752

@@ -756,9 +759,11 @@ abstract class RDD[T: ClassTag](
756759
def flatMapWith[A, U: ClassTag]
757760
(constructA: Int => A, preservesPartitioning: Boolean = false)
758761
(f: (T, A) => Seq[U]): RDD[U] = withScope {
762+
val cleanF = sc.clean(f)
763+
val cleanA = sc.clean(constructA)
759764
mapPartitionsWithIndex((index, iter) => {
760-
val a = constructA(index)
761-
iter.flatMap(t => f(t, a))
765+
val a = cleanA(index)
766+
iter.flatMap(t => cleanF(t, a))
762767
}, preservesPartitioning)
763768
}
764769

@@ -769,9 +774,11 @@ abstract class RDD[T: ClassTag](
769774
*/
770775
@deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
771776
def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
777+
val cleanF = sc.clean(f)
778+
val cleanA = sc.clean(constructA)
772779
mapPartitionsWithIndex { (index, iter) =>
773-
val a = constructA(index)
774-
iter.map(t => {f(t, a); t})
780+
val a = cleanA(index)
781+
iter.map(t => {cleanF(t, a); t})
775782
}
776783
}
777784

@@ -782,9 +789,11 @@ abstract class RDD[T: ClassTag](
782789
*/
783790
@deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
784791
def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope {
792+
val cleanP = sc.clean(p)
793+
val cleanA = sc.clean(constructA)
785794
mapPartitionsWithIndex((index, iter) => {
786-
val a = constructA(index)
787-
iter.filter(t => p(t, a))
795+
val a = cleanA(index)
796+
iter.filter(t => cleanP(t, a))
788797
}, preservesPartitioning = true)
789798
}
790799

@@ -901,7 +910,8 @@ abstract class RDD[T: ClassTag](
901910
* Return an RDD that contains all matching values by applying `f`.
902911
*/
903912
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
904-
filter(f.isDefinedAt).map(f)
913+
val cleanF = sc.clean(f)
914+
filter(cleanF.isDefinedAt).map(cleanF)
905915
}
906916

907917
/**

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import java.io.NotSerializableException
21+
import java.util.Random
2122

2223
import org.scalatest.FunSuite
2324

@@ -92,6 +93,11 @@ class ClosureCleanerSuite extends FunSuite {
9293
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
9394
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
9495
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
96+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) }
97+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) }
98+
expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) }
99+
expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) }
100+
expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) }
95101
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
96102
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
97103
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
@@ -260,6 +266,21 @@ private object TestUserClosuresActuallyCleaned {
260266
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
261267
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
262268
}
269+
def testFlatMapWith(rdd: RDD[Int]): Unit = {
270+
rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count()
271+
}
272+
def testMapWith(rdd: RDD[Int]): Unit = {
273+
rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count()
274+
}
275+
def testFilterWith(rdd: RDD[Int]): Unit = {
276+
rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count()
277+
}
278+
def testForEachWith(rdd: RDD[Int]): Unit = {
279+
rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return }
280+
}
281+
def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = {
282+
rdd.mapPartitionsWithContext { (_, it) => return; it }.count()
283+
}
263284
def testZipPartitions2(rdd: RDD[Int]): Unit = {
264285
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
265286
}

0 commit comments

Comments
 (0)