Skip to content

Commit a3aa465

Browse files
author
Andrew Or
committed
Add more tests for individual closure cleaner operations
1 parent e672170 commit a3aa465

File tree

2 files changed

+256
-8
lines changed

2 files changed

+256
-8
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException}
3232
private[spark] object ClosureCleaner extends Logging {
3333

3434
// Get an ASM class reader for a given class from the JAR that loaded it
35-
def getClassReader(cls: Class[_]): ClassReader = {
35+
private[util] def getClassReader(cls: Class[_]): ClassReader = {
3636
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
3737
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
3838
val resourceStream = cls.getResourceAsStream(className)
@@ -45,7 +45,7 @@ private[spark] object ClosureCleaner extends Logging {
4545
}
4646

4747
// Check whether a class represents a Scala closure
48-
private def isClosure(cls: Class[_]): Boolean = {
48+
private[util] def isClosure(cls: Class[_]): Boolean = {
4949
cls.getName.contains("$anonfun$")
5050
}
5151

@@ -55,10 +55,11 @@ private[spark] object ClosureCleaner extends Logging {
5555
// for outer objects beyond that because cloning the user's object is probably
5656
// not a good idea (whereas we can clone closure objects just fine since we
5757
// understand how all their fields are used).
58-
private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
58+
private[util] def getOuterClasses(obj: AnyRef): List[Class[_]] = {
5959
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
6060
f.setAccessible(true)
6161
val outer = f.get(obj)
62+
// The outer pointer may be null if we have cleaned this closure before
6263
if (outer != null) {
6364
if (isClosure(f.getType)) {
6465
return f.getType :: getOuterClasses(f.get(obj))
@@ -71,10 +72,11 @@ private[spark] object ClosureCleaner extends Logging {
7172
}
7273

7374
// Get a list of the outer objects for a given closure object.
74-
private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
75+
private[util] def getOuterObjects(obj: AnyRef): List[AnyRef] = {
7576
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
7677
f.setAccessible(true)
7778
val outer = f.get(obj)
79+
// The outer pointer may be null if we have cleaned this closure before
7880
if (outer != null) {
7981
if (isClosure(f.getType)) {
8082
return f.get(obj) :: getOuterObjects(f.get(obj))
@@ -89,7 +91,7 @@ private[spark] object ClosureCleaner extends Logging {
8991
/**
9092
* Return a list of classes that represent closures enclosed in the given closure object.
9193
*/
92-
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
94+
private[util] def getInnerClasses(obj: AnyRef): List[Class[_]] = {
9395
val seen = Set[Class[_]](obj.getClass)
9496
var stack = List[Class[_]](obj.getClass)
9597
while (!stack.isEmpty) {
@@ -372,7 +374,7 @@ private case class MethodIdentifier(cls: Class[_], name: String, desc: String)
372374
* @param specificMethod if not empty, visit only this method
373375
* @param visitedMethods a list of visited methods to avoid cycles
374376
*/
375-
private class FieldAccessFinder(
377+
private[util] class FieldAccessFinder(
376378
fields: Map[Class[_], Set[String]],
377379
findTransitively: Boolean,
378380
specificMethod: Option[MethodIdentifier] = None,
@@ -387,8 +389,8 @@ private class FieldAccessFinder(
387389
exceptions: Array[String]): MethodVisitor = {
388390

389391
// Ignore this method unless we are told to visit it
390-
if (specificMethod.nonEmpty &&
391-
specificMethod.get.name != name || specificMethod.get.desc != desc) {
392+
if (specificMethod.isDefined &&
393+
(specificMethod.get.name != name || specificMethod.get.desc != desc)) {
392394
return null
393395
}
394396

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.util
1919

2020
import java.io.NotSerializableException
2121

22+
import scala.collection.mutable
23+
2224
import org.scalatest.{BeforeAndAfterAll, FunSuite}
2325

2426
import org.apache.spark.{SparkContext, SparkException}
@@ -93,6 +95,250 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
9395
assertSerializable(closure, serializableAfter)
9496
}
9597

98+
/**
99+
* Return the fields accessed by the given closure by class.
100+
* This also optionally finds the fields transitively referenced through methods
101+
* that belong to other classes.
102+
*/
103+
private def findAccessedFields(
104+
closure: AnyRef,
105+
outerClasses: Seq[Class[_]],
106+
findTransitively: Boolean): Map[Class[_], Set[String]] = {
107+
val fields = new mutable.HashMap[Class[_], mutable.Set[String]]
108+
outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] }
109+
ClosureCleaner.getClassReader(closure.getClass)
110+
.accept(new FieldAccessFinder(fields, findTransitively), 0)
111+
fields.mapValues(_.toSet).toMap
112+
}
113+
114+
test("get inner classes") {
115+
val closure1 = () => 1
116+
val closure2 = () => { () => 1 }
117+
val closure3 = (i: Int) => {
118+
(1 to i).map { x => x + 1 }.filter { x => x > 5 }
119+
}
120+
val closure4 = (j: Int) => {
121+
(1 to j).flatMap { x =>
122+
(1 to x).flatMap { y =>
123+
(1 to y).map { z => z + 1 }
124+
}
125+
}
126+
}
127+
val inner1 = ClosureCleaner.getInnerClasses(closure1)
128+
val inner2 = ClosureCleaner.getInnerClasses(closure2)
129+
val inner3 = ClosureCleaner.getInnerClasses(closure3)
130+
val inner4 = ClosureCleaner.getInnerClasses(closure4)
131+
assert(inner1.isEmpty)
132+
assert(inner2.size === 1)
133+
assert(inner3.size === 2)
134+
assert(inner4.size === 3)
135+
assert(inner2.forall(ClosureCleaner.isClosure))
136+
assert(inner3.forall(ClosureCleaner.isClosure))
137+
assert(inner4.forall(ClosureCleaner.isClosure))
138+
}
139+
140+
test("get outer classes and objects") {
141+
val localValue = someSerializableValue
142+
val closure1 = () => 1
143+
val closure2 = () => localValue
144+
val closure3 = () => someSerializableValue
145+
val closure4 = () => someSerializableMethod()
146+
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
147+
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
148+
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
149+
val outerClasses4 = ClosureCleaner.getOuterClasses(closure4)
150+
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
151+
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
152+
val outerObjects3 = ClosureCleaner.getOuterObjects(closure3)
153+
val outerObjects4 = ClosureCleaner.getOuterObjects(closure4)
154+
155+
// The classes and objects should have the same size
156+
assert(outerClasses1.size === outerObjects1.size)
157+
assert(outerClasses2.size === outerObjects2.size)
158+
assert(outerClasses3.size === outerObjects3.size)
159+
assert(outerClasses4.size === outerObjects4.size)
160+
161+
// These do not have $outer pointers because they reference only local variables
162+
assert(outerClasses1.isEmpty)
163+
assert(outerClasses2.isEmpty)
164+
165+
// These closures do have $outer pointers because they ultimately reference `this`
166+
// The first $outer pointer refers to the closure defines this test (see FunSuite#test)
167+
// The second $outer pointer refers to ClosureCleanerSuite2
168+
assert(outerClasses3.size === 2)
169+
assert(outerClasses4.size === 2)
170+
assert(ClosureCleaner.isClosure(outerClasses3(0)))
171+
assert(ClosureCleaner.isClosure(outerClasses4(0)))
172+
assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope
173+
assert(outerClasses3(1) === this.getClass)
174+
assert(outerClasses4(1) === this.getClass)
175+
assert(outerObjects3(1) === this)
176+
assert(outerObjects4(1) === this)
177+
}
178+
179+
test("get outer classes and objects with nesting") {
180+
val localValue = someSerializableValue
181+
182+
val test1 = () => {
183+
val x = 1
184+
val closure1 = () => 1
185+
val closure2 = () => x
186+
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
187+
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
188+
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
189+
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
190+
assert(outerClasses1.size === outerObjects1.size)
191+
assert(outerClasses2.size === outerObjects2.size)
192+
// These inner closures only reference local variables, and so do not have $outer pointer
193+
assert(outerClasses1.isEmpty)
194+
assert(outerClasses2.isEmpty)
195+
}
196+
197+
val test2 = () => {
198+
def y = 1
199+
val closure1 = () => 1
200+
val closure2 = () => y
201+
val closure3 = () => localValue
202+
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
203+
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
204+
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
205+
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
206+
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
207+
val outerObjects3 = ClosureCleaner.getOuterObjects(closure3)
208+
assert(outerClasses1.size === outerObjects1.size)
209+
assert(outerClasses2.size === outerObjects2.size)
210+
assert(outerClasses3.size === outerObjects3.size)
211+
// Same as above, this closure only references local variables
212+
assert(outerClasses1.isEmpty)
213+
// This closure references the "test2" scope because it needs to find the method `y`
214+
// Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
215+
assert(outerClasses2.size === 3)
216+
// This closure references the "test2" scope because it needs to find the
217+
// `localValue` defined outside of this scope
218+
assert(outerClasses3.size === 3)
219+
assert(ClosureCleaner.isClosure(outerClasses2(0)))
220+
assert(ClosureCleaner.isClosure(outerClasses3(0)))
221+
assert(ClosureCleaner.isClosure(outerClasses2(1)))
222+
assert(ClosureCleaner.isClosure(outerClasses3(1)))
223+
assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope
224+
assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope
225+
assert(outerClasses2(2) === this.getClass)
226+
assert(outerClasses3(2) === this.getClass)
227+
assert(outerObjects2(2) === this)
228+
assert(outerObjects3(2) === this)
229+
}
230+
231+
test1()
232+
test2()
233+
}
234+
235+
test("find accessed fields") {
236+
val localValue = someSerializableValue
237+
val closure1 = () => 1
238+
val closure2 = () => localValue
239+
val closure3 = () => someSerializableValue
240+
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
241+
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
242+
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
243+
244+
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
245+
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
246+
val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false)
247+
assert(fields1.isEmpty)
248+
assert(fields2.isEmpty)
249+
assert(fields3.size === 2)
250+
// This corresponds to the "FunSuite#test" closure. This is empty because the
251+
// field `closure3` references belongs to its parent (i.e. ClosureCleanerSuite2)
252+
assert(fields3(outerClasses3(0)).isEmpty)
253+
// This corresponds to the ClosureCleanerSuite2. This is also empty, however,
254+
// because we did not find fields transitively (i.e. beyond 1 enclosing scope)
255+
assert(fields3(outerClasses3(1)).isEmpty)
256+
257+
val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true)
258+
val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true)
259+
val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true)
260+
assert(fields1t.isEmpty)
261+
assert(fields2t.isEmpty)
262+
assert(fields3t.size === 2)
263+
// Because we find fields transitively now, we are able to detect that we need the
264+
// $outer pointer to get the field from the ClosureCleanerSuite2.
265+
assert(fields3t(outerClasses3(0)).size === 1)
266+
assert(fields3t(outerClasses3(0)).head === "$outer")
267+
assert(fields3t(outerClasses3(1)).size === 1)
268+
assert(fields3t(outerClasses3(1)).head.contains("someSerializableValue"))
269+
}
270+
271+
test("find accessed fields with nesting") {
272+
val localValue = someSerializableValue
273+
274+
val test1 = () => {
275+
def a = localValue + 1
276+
val closure1 = () => 1
277+
val closure2 = () => a
278+
val closure3 = () => localValue
279+
val closure4 = () => someSerializableValue
280+
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
281+
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
282+
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
283+
val outerClasses4 = ClosureCleaner.getOuterClasses(closure4)
284+
285+
// First, find only fields the closures directly access
286+
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
287+
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
288+
val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false)
289+
val fields4 = findAccessedFields(closure4, outerClasses4, findTransitively = false)
290+
assert(fields1.isEmpty)
291+
// "test1" < "FunSuite#test" < ClosureCleanerSuite2
292+
assert(fields2.size === 3)
293+
assert(fields2(outerClasses2(0)).isEmpty) // `def a` is not a field
294+
assert(fields2(outerClasses2(1)).isEmpty)
295+
assert(fields2(outerClasses2(2)).isEmpty)
296+
assert(fields3.size === 3)
297+
// Note that `localValue` is a field of the "test1" closure because `def a` needs it
298+
// Further note that it is NOT a field of the "FunSuite#test" closure but a local variable
299+
assert(fields3(outerClasses3(0)).size === 1)
300+
assert(fields3(outerClasses3(0)).head.contains("localValue"))
301+
assert(fields3(outerClasses3(1)).isEmpty)
302+
assert(fields3(outerClasses3(2)).isEmpty)
303+
assert(fields4.size === 3)
304+
assert(fields4(outerClasses4(0)).isEmpty)
305+
assert(fields4(outerClasses4(1)).isEmpty)
306+
// Because `someSerializableValue` is a val, even an explicit reference here actually
307+
// involves a method call to access the underlying value of the variable. Because we are
308+
// not finding fields transitively here, we do not consider the fields accessed by this
309+
// "method" (i.e. the val's accessor).
310+
assert(fields4(outerClasses4(2)).isEmpty)
311+
312+
// Now do the same, but find fields that the closures transitively reference
313+
val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true)
314+
val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true)
315+
val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true)
316+
val fields4t = findAccessedFields(closure4, outerClasses4, findTransitively = true)
317+
assert(fields1t.isEmpty)
318+
assert(fields2t.size === 3)
319+
// This closure transitively references `localValue` because `def a` uses it
320+
assert(fields2t(outerClasses2(0)).size === 1)
321+
assert(fields2t(outerClasses2(0)).head.contains("localValue"))
322+
assert(fields2t(outerClasses2(1)).isEmpty)
323+
assert(fields2t(outerClasses2(2)).isEmpty)
324+
assert(fields3t.size === 3)
325+
assert(fields3t(outerClasses3(0)).size === 1) // as before
326+
assert(fields3t(outerClasses3(0)).head.contains("localValue"))
327+
assert(fields3t(outerClasses3(1)).isEmpty)
328+
assert(fields3t(outerClasses3(2)).isEmpty)
329+
assert(fields4t.size === 3)
330+
// Through a series of method calls, we are able to detect that we ultimately access
331+
// ClosureCleanerSuite2's field `someSerializableValue`. Along the way, we also accessed
332+
// a few $outer parent pointers to get to the outermost object.
333+
assert(fields4t(outerClasses4(0)) === Set("$outer"))
334+
assert(fields4t(outerClasses4(1)) === Set("$outer"))
335+
assert(fields4t(outerClasses4(2)).size === 1)
336+
assert(fields4t(outerClasses4(2)).head.contains("someSerializableValue"))
337+
}
338+
339+
test1()
340+
}
341+
96342
test("clean basic serializable closures") {
97343
val localSerializableVal = someSerializableValue
98344
val closure1 = () => 1

0 commit comments

Comments
 (0)