@@ -19,6 +19,8 @@ package org.apache.spark.util
1919
2020import java .io .NotSerializableException
2121
22+ import scala .collection .mutable
23+
2224import org .scalatest .{BeforeAndAfterAll , FunSuite }
2325
2426import org .apache .spark .{SparkContext , SparkException }
@@ -93,6 +95,250 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
9395 assertSerializable(closure, serializableAfter)
9496 }
9597
98+ /**
99+ * Return the fields accessed by the given closure by class.
100+ * This also optionally finds the fields transitively referenced through methods
101+ * that belong to other classes.
102+ */
103+ private def findAccessedFields (
104+ closure : AnyRef ,
105+ outerClasses : Seq [Class [_]],
106+ findTransitively : Boolean ): Map [Class [_], Set [String ]] = {
107+ val fields = new mutable.HashMap [Class [_], mutable.Set [String ]]
108+ outerClasses.foreach { c => fields(c) = new mutable.HashSet [String ] }
109+ ClosureCleaner .getClassReader(closure.getClass)
110+ .accept(new FieldAccessFinder (fields, findTransitively), 0 )
111+ fields.mapValues(_.toSet).toMap
112+ }
113+
114+ test(" get inner classes" ) {
115+ val closure1 = () => 1
116+ val closure2 = () => { () => 1 }
117+ val closure3 = (i : Int ) => {
118+ (1 to i).map { x => x + 1 }.filter { x => x > 5 }
119+ }
120+ val closure4 = (j : Int ) => {
121+ (1 to j).flatMap { x =>
122+ (1 to x).flatMap { y =>
123+ (1 to y).map { z => z + 1 }
124+ }
125+ }
126+ }
127+ val inner1 = ClosureCleaner .getInnerClasses(closure1)
128+ val inner2 = ClosureCleaner .getInnerClasses(closure2)
129+ val inner3 = ClosureCleaner .getInnerClasses(closure3)
130+ val inner4 = ClosureCleaner .getInnerClasses(closure4)
131+ assert(inner1.isEmpty)
132+ assert(inner2.size === 1 )
133+ assert(inner3.size === 2 )
134+ assert(inner4.size === 3 )
135+ assert(inner2.forall(ClosureCleaner .isClosure))
136+ assert(inner3.forall(ClosureCleaner .isClosure))
137+ assert(inner4.forall(ClosureCleaner .isClosure))
138+ }
139+
140+ test(" get outer classes and objects" ) {
141+ val localValue = someSerializableValue
142+ val closure1 = () => 1
143+ val closure2 = () => localValue
144+ val closure3 = () => someSerializableValue
145+ val closure4 = () => someSerializableMethod()
146+ val outerClasses1 = ClosureCleaner .getOuterClasses(closure1)
147+ val outerClasses2 = ClosureCleaner .getOuterClasses(closure2)
148+ val outerClasses3 = ClosureCleaner .getOuterClasses(closure3)
149+ val outerClasses4 = ClosureCleaner .getOuterClasses(closure4)
150+ val outerObjects1 = ClosureCleaner .getOuterObjects(closure1)
151+ val outerObjects2 = ClosureCleaner .getOuterObjects(closure2)
152+ val outerObjects3 = ClosureCleaner .getOuterObjects(closure3)
153+ val outerObjects4 = ClosureCleaner .getOuterObjects(closure4)
154+
155+ // The classes and objects should have the same size
156+ assert(outerClasses1.size === outerObjects1.size)
157+ assert(outerClasses2.size === outerObjects2.size)
158+ assert(outerClasses3.size === outerObjects3.size)
159+ assert(outerClasses4.size === outerObjects4.size)
160+
161+ // These do not have $outer pointers because they reference only local variables
162+ assert(outerClasses1.isEmpty)
163+ assert(outerClasses2.isEmpty)
164+
165+ // These closures do have $outer pointers because they ultimately reference `this`
166+ // The first $outer pointer refers to the closure defines this test (see FunSuite#test)
167+ // The second $outer pointer refers to ClosureCleanerSuite2
168+ assert(outerClasses3.size === 2 )
169+ assert(outerClasses4.size === 2 )
170+ assert(ClosureCleaner .isClosure(outerClasses3(0 )))
171+ assert(ClosureCleaner .isClosure(outerClasses4(0 )))
172+ assert(outerClasses3(0 ) === outerClasses4(0 )) // part of the same "FunSuite#test" scope
173+ assert(outerClasses3(1 ) === this .getClass)
174+ assert(outerClasses4(1 ) === this .getClass)
175+ assert(outerObjects3(1 ) === this )
176+ assert(outerObjects4(1 ) === this )
177+ }
178+
179+ test(" get outer classes and objects with nesting" ) {
180+ val localValue = someSerializableValue
181+
182+ val test1 = () => {
183+ val x = 1
184+ val closure1 = () => 1
185+ val closure2 = () => x
186+ val outerClasses1 = ClosureCleaner .getOuterClasses(closure1)
187+ val outerClasses2 = ClosureCleaner .getOuterClasses(closure2)
188+ val outerObjects1 = ClosureCleaner .getOuterObjects(closure1)
189+ val outerObjects2 = ClosureCleaner .getOuterObjects(closure2)
190+ assert(outerClasses1.size === outerObjects1.size)
191+ assert(outerClasses2.size === outerObjects2.size)
192+ // These inner closures only reference local variables, and so do not have $outer pointer
193+ assert(outerClasses1.isEmpty)
194+ assert(outerClasses2.isEmpty)
195+ }
196+
197+ val test2 = () => {
198+ def y = 1
199+ val closure1 = () => 1
200+ val closure2 = () => y
201+ val closure3 = () => localValue
202+ val outerClasses1 = ClosureCleaner .getOuterClasses(closure1)
203+ val outerClasses2 = ClosureCleaner .getOuterClasses(closure2)
204+ val outerClasses3 = ClosureCleaner .getOuterClasses(closure3)
205+ val outerObjects1 = ClosureCleaner .getOuterObjects(closure1)
206+ val outerObjects2 = ClosureCleaner .getOuterObjects(closure2)
207+ val outerObjects3 = ClosureCleaner .getOuterObjects(closure3)
208+ assert(outerClasses1.size === outerObjects1.size)
209+ assert(outerClasses2.size === outerObjects2.size)
210+ assert(outerClasses3.size === outerObjects3.size)
211+ // Same as above, this closure only references local variables
212+ assert(outerClasses1.isEmpty)
213+ // This closure references the "test2" scope because it needs to find the method `y`
214+ // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
215+ assert(outerClasses2.size === 3 )
216+ // This closure references the "test2" scope because it needs to find the
217+ // `localValue` defined outside of this scope
218+ assert(outerClasses3.size === 3 )
219+ assert(ClosureCleaner .isClosure(outerClasses2(0 )))
220+ assert(ClosureCleaner .isClosure(outerClasses3(0 )))
221+ assert(ClosureCleaner .isClosure(outerClasses2(1 )))
222+ assert(ClosureCleaner .isClosure(outerClasses3(1 )))
223+ assert(outerClasses2(0 ) === outerClasses3(0 )) // part of the same "test2" scope
224+ assert(outerClasses2(1 ) === outerClasses3(1 )) // part of the same "FunSuite#test" scope
225+ assert(outerClasses2(2 ) === this .getClass)
226+ assert(outerClasses3(2 ) === this .getClass)
227+ assert(outerObjects2(2 ) === this )
228+ assert(outerObjects3(2 ) === this )
229+ }
230+
231+ test1()
232+ test2()
233+ }
234+
235+ test(" find accessed fields" ) {
236+ val localValue = someSerializableValue
237+ val closure1 = () => 1
238+ val closure2 = () => localValue
239+ val closure3 = () => someSerializableValue
240+ val outerClasses1 = ClosureCleaner .getOuterClasses(closure1)
241+ val outerClasses2 = ClosureCleaner .getOuterClasses(closure2)
242+ val outerClasses3 = ClosureCleaner .getOuterClasses(closure3)
243+
244+ val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false )
245+ val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false )
246+ val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false )
247+ assert(fields1.isEmpty)
248+ assert(fields2.isEmpty)
249+ assert(fields3.size === 2 )
250+ // This corresponds to the "FunSuite#test" closure. This is empty because the
251+ // field `closure3` references belongs to its parent (i.e. ClosureCleanerSuite2)
252+ assert(fields3(outerClasses3(0 )).isEmpty)
253+ // This corresponds to the ClosureCleanerSuite2. This is also empty, however,
254+ // because we did not find fields transitively (i.e. beyond 1 enclosing scope)
255+ assert(fields3(outerClasses3(1 )).isEmpty)
256+
257+ val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true )
258+ val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true )
259+ val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true )
260+ assert(fields1t.isEmpty)
261+ assert(fields2t.isEmpty)
262+ assert(fields3t.size === 2 )
263+ // Because we find fields transitively now, we are able to detect that we need the
264+ // $outer pointer to get the field from the ClosureCleanerSuite2.
265+ assert(fields3t(outerClasses3(0 )).size === 1 )
266+ assert(fields3t(outerClasses3(0 )).head === " $outer" )
267+ assert(fields3t(outerClasses3(1 )).size === 1 )
268+ assert(fields3t(outerClasses3(1 )).head.contains(" someSerializableValue" ))
269+ }
270+
271+ test(" find accessed fields with nesting" ) {
272+ val localValue = someSerializableValue
273+
274+ val test1 = () => {
275+ def a = localValue + 1
276+ val closure1 = () => 1
277+ val closure2 = () => a
278+ val closure3 = () => localValue
279+ val closure4 = () => someSerializableValue
280+ val outerClasses1 = ClosureCleaner .getOuterClasses(closure1)
281+ val outerClasses2 = ClosureCleaner .getOuterClasses(closure2)
282+ val outerClasses3 = ClosureCleaner .getOuterClasses(closure3)
283+ val outerClasses4 = ClosureCleaner .getOuterClasses(closure4)
284+
285+ // First, find only fields the closures directly access
286+ val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false )
287+ val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false )
288+ val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false )
289+ val fields4 = findAccessedFields(closure4, outerClasses4, findTransitively = false )
290+ assert(fields1.isEmpty)
291+ // "test1" < "FunSuite#test" < ClosureCleanerSuite2
292+ assert(fields2.size === 3 )
293+ assert(fields2(outerClasses2(0 )).isEmpty) // `def a` is not a field
294+ assert(fields2(outerClasses2(1 )).isEmpty)
295+ assert(fields2(outerClasses2(2 )).isEmpty)
296+ assert(fields3.size === 3 )
297+ // Note that `localValue` is a field of the "test1" closure because `def a` needs it
298+ // Further note that it is NOT a field of the "FunSuite#test" closure but a local variable
299+ assert(fields3(outerClasses3(0 )).size === 1 )
300+ assert(fields3(outerClasses3(0 )).head.contains(" localValue" ))
301+ assert(fields3(outerClasses3(1 )).isEmpty)
302+ assert(fields3(outerClasses3(2 )).isEmpty)
303+ assert(fields4.size === 3 )
304+ assert(fields4(outerClasses4(0 )).isEmpty)
305+ assert(fields4(outerClasses4(1 )).isEmpty)
306+ // Because `someSerializableValue` is a val, even an explicit reference here actually
307+ // involves a method call to access the underlying value of the variable. Because we are
308+ // not finding fields transitively here, we do not consider the fields accessed by this
309+ // "method" (i.e. the val's accessor).
310+ assert(fields4(outerClasses4(2 )).isEmpty)
311+
312+ // Now do the same, but find fields that the closures transitively reference
313+ val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true )
314+ val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true )
315+ val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true )
316+ val fields4t = findAccessedFields(closure4, outerClasses4, findTransitively = true )
317+ assert(fields1t.isEmpty)
318+ assert(fields2t.size === 3 )
319+ // This closure transitively references `localValue` because `def a` uses it
320+ assert(fields2t(outerClasses2(0 )).size === 1 )
321+ assert(fields2t(outerClasses2(0 )).head.contains(" localValue" ))
322+ assert(fields2t(outerClasses2(1 )).isEmpty)
323+ assert(fields2t(outerClasses2(2 )).isEmpty)
324+ assert(fields3t.size === 3 )
325+ assert(fields3t(outerClasses3(0 )).size === 1 ) // as before
326+ assert(fields3t(outerClasses3(0 )).head.contains(" localValue" ))
327+ assert(fields3t(outerClasses3(1 )).isEmpty)
328+ assert(fields3t(outerClasses3(2 )).isEmpty)
329+ assert(fields4t.size === 3 )
330+ // Through a series of method calls, we are able to detect that we ultimately access
331+ // ClosureCleanerSuite2's field `someSerializableValue`. Along the way, we also accessed
332+ // a few $outer parent pointers to get to the outermost object.
333+ assert(fields4t(outerClasses4(0 )) === Set (" $outer" ))
334+ assert(fields4t(outerClasses4(1 )) === Set (" $outer" ))
335+ assert(fields4t(outerClasses4(2 )).size === 1 )
336+ assert(fields4t(outerClasses4(2 )).head.contains(" someSerializableValue" ))
337+ }
338+
339+ test1()
340+ }
341+
96342 test(" clean basic serializable closures" ) {
97343 val localSerializableVal = someSerializableValue
98344 val closure1 = () => 1
0 commit comments