1818package org .apache .spark .util
1919
2020import java .io .{ByteArrayInputStream , ByteArrayOutputStream }
21- import java .lang .invoke .SerializedLambda
21+ import java .lang .invoke .{ MethodHandleInfo , SerializedLambda }
2222
23+ import scala .collection .JavaConverters ._
2324import scala .collection .mutable .{Map , Set , Stack }
2425
25- import org .apache .xbean .asm7 .{ClassReader , ClassVisitor , MethodVisitor , Type }
26+ import org .apache .commons .lang3 .ClassUtils
27+ import org .apache .xbean .asm7 .{ClassReader , ClassVisitor , Handle , MethodVisitor , Type }
2628import org .apache .xbean .asm7 .Opcodes ._
29+ import org .apache .xbean .asm7 .tree .{ClassNode , MethodNode }
2730
2831import org .apache .spark .{SparkEnv , SparkException }
2932import org .apache .spark .internal .Logging
@@ -159,39 +162,6 @@ private[spark] object ClosureCleaner extends Logging {
159162 clean(closure, checkSerializable, cleanTransitively, Map .empty)
160163 }
161164
162- /**
163- * Try to get a serialized Lambda from the closure.
164- *
165- * @param closure the closure to check.
166- */
167- private def getSerializedLambda (closure : AnyRef ): Option [SerializedLambda ] = {
168- val isClosureCandidate =
169- closure.getClass.isSynthetic &&
170- closure
171- .getClass
172- .getInterfaces.exists(_.getName == " scala.Serializable" )
173-
174- if (isClosureCandidate) {
175- try {
176- Option (inspect(closure))
177- } catch {
178- case e : Exception =>
179- // no need to check if debug is enabled here the Spark
180- // logging api covers this.
181- logDebug(" Closure is not a serialized lambda." , e)
182- None
183- }
184- } else {
185- None
186- }
187- }
188-
189- private def inspect (closure : AnyRef ): SerializedLambda = {
190- val writeReplace = closure.getClass.getDeclaredMethod(" writeReplace" )
191- writeReplace.setAccessible(true )
192- writeReplace.invoke(closure).asInstanceOf [java.lang.invoke.SerializedLambda ]
193- }
194-
195165 /**
196166 * Helper method to clean the given closure in place.
197167 *
@@ -239,12 +209,12 @@ private[spark] object ClosureCleaner extends Logging {
239209 cleanTransitively : Boolean ,
240210 accessedFields : Map [Class [_], Set [String ]]): Unit = {
241211
242- // most likely to be the case with 2.12, 2.13
212+ // indylambda check. Most likely to be the case with 2.12, 2.13
243213 // so we check first
244214 // non LMF-closures should be less frequent from now on
245- val lambdaFunc = getSerializedLambda (func)
215+ val maybeIndylambdaProxy = IndylambdaScalaClosures .getSerializationProxy (func)
246216
247- if (! isClosure(func.getClass) && lambdaFunc .isEmpty) {
217+ if (! isClosure(func.getClass) && maybeIndylambdaProxy .isEmpty) {
248218 logDebug(s " Expected a closure; got ${func.getClass.getName}" )
249219 return
250220 }
@@ -256,7 +226,7 @@ private[spark] object ClosureCleaner extends Logging {
256226 return
257227 }
258228
259- if (lambdaFunc .isEmpty) {
229+ if (maybeIndylambdaProxy .isEmpty) {
260230 logDebug(s " +++ Cleaning closure $func ( ${func.getClass.getName}) +++ " )
261231
262232 // A list of classes that represents closures enclosed in the given one
@@ -372,14 +342,60 @@ private[spark] object ClosureCleaner extends Logging {
372342
373343 logDebug(s " +++ closure $func ( ${func.getClass.getName}) is now cleaned +++ " )
374344 } else {
375- logDebug(s " Cleaning lambda: ${lambdaFunc.get.getImplMethodName}" )
345+ val lambdaProxy = maybeIndylambdaProxy.get
346+ val implMethodName = lambdaProxy.getImplMethodName
347+
348+ logDebug(s " Cleaning indylambda closure: $implMethodName" )
349+
350+ // capturing class is the class that declared this lambda
351+ val capturingClassName = lambdaProxy.getCapturingClass.replace('/' , '.' )
352+ val classLoader = func.getClass.getClassLoader // this is the safest option
353+ // scalastyle:off classforname
354+ val capturingClass = Class .forName(capturingClassName, false , classLoader)
355+ // scalastyle:on classforname
376356
377- val captClass = Utils .classForName(lambdaFunc.get.getCapturingClass.replace('/' , '.' ),
378- initialize = false , noSparkClassLoader = true )
379357 // Fail fast if we detect return statements in closures
380- getClassReader(captClass)
381- .accept(new ReturnStatementFinder (Some (lambdaFunc.get.getImplMethodName)), 0 )
382- logDebug(s " +++ Lambda closure ( ${lambdaFunc.get.getImplMethodName}) is now cleaned +++ " )
358+ val capturingClassReader = getClassReader(capturingClass)
359+ capturingClassReader.accept(new ReturnStatementFinder (Option (implMethodName)), 0 )
360+
361+ val isClosureDeclaredInScalaRepl = capturingClassName.startsWith(" $line" ) &&
362+ capturingClassName.endsWith(" $iw" )
363+ val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0 ) {
364+ Option (lambdaProxy.getCapturedArg(0 ))
365+ } else {
366+ None
367+ }
368+
369+ // only need to clean when there is an enclosing "this" captured by the closure, and it
370+ // should be something cleanable, i.e. a Scala REPL line object
371+ val needsCleaning = isClosureDeclaredInScalaRepl &&
372+ outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName
373+
374+ if (needsCleaning) {
375+ assert(accessedFields.isEmpty)
376+
377+ initAccessedFields(accessedFields, Seq (capturingClass))
378+ IndylambdaScalaClosures .findAccessedFields(lambdaProxy, classLoader, accessedFields)
379+
380+ logDebug(s " + fields accessed by starting closure: " + accessedFields.size)
381+ accessedFields.foreach { f => logDebug(" " + f) }
382+
383+ if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) {
384+ // clone and clean the enclosing `this` only when there are fields to null out
385+
386+ val outerThis = outerThisOpt.get
387+
388+ logDebug(s " + cloning instance of REPL class $capturingClassName" )
389+ val clonedOuterThis = cloneAndSetFields(
390+ parent = null , outerThis, capturingClass, accessedFields)
391+
392+ val outerField = func.getClass.getDeclaredField(" arg$1" )
393+ outerField.setAccessible(true )
394+ outerField.set(func, clonedOuterThis)
395+ }
396+ }
397+
398+ logDebug(s " +++ indylambda closure ( $implMethodName) is now cleaned +++ " )
383399 }
384400
385401 if (checkSerializable) {
@@ -414,6 +430,139 @@ private[spark] object ClosureCleaner extends Logging {
414430 }
415431}
416432
433+ private [spark] object IndylambdaScalaClosures extends Logging {
434+ // internal name of java.lang.invoke.LambdaMetafactory
435+ val LambdaMetafactoryClassName = " java/lang/invoke/LambdaMetafactory"
436+ // the method that Scala indylambda use for bootstrap method
437+ val LambdaMetafactoryMethodName = " altMetafactory"
438+ val LambdaMetafactoryMethodDesc = " (Ljava/lang/invoke/MethodHandles$Lookup;" +
439+ " Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" +
440+ " Ljava/lang/invoke/CallSite;"
441+
442+ /**
443+ * Check if the given reference is a indylambda style Scala closure.
444+ * If so, return a non-empty serialization proxy (SerializedLambda) of the closure;
445+ * otherwise return None.
446+ *
447+ * @param maybeClosure the closure to check.
448+ */
449+ def getSerializationProxy (maybeClosure : AnyRef ): Option [SerializedLambda ] = {
450+ val maybeClosureClass = maybeClosure.getClass
451+
452+ // shortcut the fast check:
453+ // indylambda closure classes are generated by Java's LambdaMetafactory, and they're always
454+ // synthetic.
455+ if (! maybeClosureClass.isSynthetic) return None
456+
457+ val implementedInterfaces = ClassUtils .getAllInterfaces(maybeClosureClass).asScala
458+ val isClosureCandidate = implementedInterfaces.exists(_.getName == " scala.Serializable" ) &&
459+ implementedInterfaces.exists(_.getName.startsWith(" scala.Function" ))
460+
461+ if (isClosureCandidate) {
462+ try {
463+ val lambdaProxy = inspect(maybeClosure)
464+ if (isIndylambdaScalaClosure(lambdaProxy)) Option (lambdaProxy)
465+ else None
466+ } catch {
467+ case e : Exception =>
468+ // no need to check if debug is enabled here the Spark logging api covers this.
469+ logDebug(" The given reference is not an indylambda Scala closure." , e)
470+ None
471+ }
472+ } else {
473+ None
474+ }
475+ }
476+
477+ def isIndylambdaScalaClosure (lambdaProxy : SerializedLambda ): Boolean = {
478+ lambdaProxy.getImplMethodKind == MethodHandleInfo .REF_invokeStatic &&
479+ lambdaProxy.getImplMethodName.contains(" $anonfun$" )
480+ // && implements a scala.runtime.java8 functional interface
481+ }
482+
483+ def inspect (closure : AnyRef ): SerializedLambda = {
484+ val writeReplace = closure.getClass.getDeclaredMethod(" writeReplace" )
485+ writeReplace.setAccessible(true )
486+ writeReplace.invoke(closure).asInstanceOf [SerializedLambda ]
487+ }
488+
489+ def findAccessedFields (
490+ lambdaProxy : SerializedLambda ,
491+ lambdaClassLoader : ClassLoader ,
492+ accessedFields : Map [Class [_], Set [String ]]): Unit = {
493+ val implClassInternalName = lambdaProxy.getImplClass
494+ // scalastyle:off classforname
495+ val implClass = Class .forName(
496+ implClassInternalName.replace('/' , '.' ), false , lambdaClassLoader)
497+ // scalastyle:on classforname
498+ val implClassNode = new ClassNode ()
499+ val implClassReader = ClosureCleaner .getClassReader(implClass)
500+ implClassReader.accept(implClassNode, 0 )
501+
502+ val methodsByName = Map .empty[MethodIdentifier [_], MethodNode ]
503+ for (m <- implClassNode.methods.asScala) {
504+ methodsByName(MethodIdentifier (implClass, m.name, m.desc)) = m
505+ }
506+
507+ val implMethodId = MethodIdentifier (
508+ implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature)
509+ val implMethodNode = methodsByName(implMethodId)
510+
511+ val visited = Set [MethodIdentifier [_]](implMethodId)
512+ val stack = Stack [MethodIdentifier [_]](implMethodId)
513+ while (! stack.isEmpty) {
514+ val currentId = stack.pop
515+ val currentMethodNode = methodsByName(currentId)
516+ logTrace(s " scanning $currentId" )
517+ currentMethodNode.accept(new MethodVisitor (ASM7 ) {
518+ override def visitFieldInsn (op : Int , owner : String , name : String , desc : String ): Unit = {
519+ if (op == GETFIELD || op == PUTFIELD ) {
520+ val ownerExternalName = owner.replace('/' , '.' )
521+ for (cl <- accessedFields.keys if cl.getName == ownerExternalName) {
522+ logTrace(s " found field access $name on $owner" )
523+ accessedFields(cl) += name
524+ }
525+ }
526+ }
527+
528+ override def visitMethodInsn (
529+ op : Int , owner : String , name : String , desc : String , itf : Boolean ): Unit = {
530+ if (owner == implClassInternalName) {
531+ logTrace(s " found intra class call to $owner. $name$desc" )
532+ stack.push(MethodIdentifier (implClass, name, desc))
533+ } else {
534+ // keep the same behavior as the original ClosureCleaner
535+ logTrace(s " ignoring call to $owner. $name$desc" )
536+ }
537+ }
538+
539+ // find the lexically nested closures
540+ override def visitInvokeDynamicInsn (
541+ name : String , desc : String , bsmHandle : Handle , bsmArgs : Object * ): Unit = {
542+ logTrace(s " invokedynamic: $name$desc, bsmHandle= $bsmHandle, bsmArgs= $bsmArgs" )
543+
544+ // fast check: we only care about Scala lambda creation
545+ if (! name.startsWith(" apply" )) return
546+ if (! Type .getReturnType(desc).getDescriptor.startsWith(" Lscala/Function" )) return
547+
548+ if (bsmHandle.getOwner == LambdaMetafactoryClassName &&
549+ bsmHandle.getName == LambdaMetafactoryMethodName &&
550+ bsmHandle.getDesc == LambdaMetafactoryMethodDesc ) {
551+ // OK we're in the right bootstrap method for serializable Java 8 style lambda creation
552+ val targetHandle = bsmArgs(1 ).asInstanceOf [Handle ]
553+ if (targetHandle.getOwner == implClassInternalName &&
554+ targetHandle.getDesc.startsWith(s " (L $implClassInternalName; " )) {
555+ // this is a lexically nested closure that also captures the enclosing `this`
556+ logDebug(s " found inner closure $targetHandle" )
557+ stack.push(MethodIdentifier (implClass, targetHandle.getName, targetHandle.getDesc))
558+ }
559+ }
560+ }
561+ })
562+ }
563+ }
564+ }
565+
417566private [spark] class ReturnStatementInClosureException
418567 extends SparkException (" Return statements aren't allowed in Spark closures" )
419568
0 commit comments