@@ -373,12 +373,13 @@ private[spark] object ClosureCleaner extends Logging {
373373
374374 if (needsCleaning) {
375375 // indylambda closures do not reference enclosing closures via an `$outer` chain, so no
376- // transitive cleaning is needed. Thus clean() shouldn't be recursively called with a
377- // non-empty accessedFields.
376+ // transitive cleaning on the `$outer` chain is needed.
377+ // Thus clean() shouldn't be recursively called with a non-empty accessedFields.
378378 assert(accessedFields.isEmpty)
379379
380380 initAccessedFields(accessedFields, Seq (capturingClass))
381- IndylambdaScalaClosures .findAccessedFields(lambdaProxy, classLoader, accessedFields)
381+ IndylambdaScalaClosures .findAccessedFields(
382+ lambdaProxy, classLoader, accessedFields, cleanTransitively)
382383
383384 logDebug(s " + fields accessed by starting closure: ${accessedFields.size} classes " )
384385 accessedFields.foreach { f => logDebug(" " + f) }
@@ -488,44 +489,109 @@ private[spark] object IndylambdaScalaClosures extends Logging {
488489 writeReplace.invoke(closure).asInstanceOf [SerializedLambda ]
489490 }
490491
492+ /**
493+ * Check if the handle represents the LambdaMetafactory that indylambda Scala closures
494+ * use for creating the lambda class and getting a closure instance.
495+ */
496+ def isLambdaMetafactory (bsmHandle : Handle ): Boolean = {
497+ bsmHandle.getOwner == LambdaMetafactoryClassName &&
498+ bsmHandle.getName == LambdaMetafactoryMethodName &&
499+ bsmHandle.getDesc == LambdaMetafactoryMethodDesc
500+ }
501+
502+ /**
503+ * Check if the handle represents a target method that is:
504+ * - a STATIC method that implements a Scala lambda body in the indylambda style
505+ * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as
506+ * the owning class.
507+ * Returns true if both criteria above are met.
508+ */
509+ def isLambdaBodyCapturingOuter (handle : Handle , ownerInternalName : String ): Boolean = {
510+ handle.getTag == H_INVOKESTATIC &&
511+ handle.getName.contains(" $anonfun$" ) &&
512+ handle.getOwner == ownerInternalName &&
513+ handle.getDesc.startsWith(s " (L $ownerInternalName; " )
514+ }
515+
516+ /**
517+ * Check if the callee of a call site is a inner class constructor.
518+ * - A constructor has to be invoked via INVOKESPECIAL
519+ * - A constructor's internal name is "<init>" and the return type is "V" (void)
520+ * - An inner class' first argument in the signature has to be a reference to the
521+ * enclosing "this", aka `$outer` in Scala.
522+ */
523+ def isInnerClassCtorCapturingOuter (
524+ op : Int , owner : String , name : String , desc : String , callerInternalName : String ): Boolean = {
525+ op == INVOKESPECIAL && name == " <init>" && desc.startsWith(s " (L $callerInternalName; " )
526+ }
527+
491528 /**
492529 * Scans an indylambda Scala closure, along with its lexically nested closures, and populate
493530 * the accessed fields info on which fields on the outer object are accessed.
494531 */
495532 def findAccessedFields (
496533 lambdaProxy : SerializedLambda ,
497534 lambdaClassLoader : ClassLoader ,
498- accessedFields : Map [Class [_], Set [String ]]): Unit = {
499- val implClassInternalName = lambdaProxy.getImplClass
500- // scalastyle:off classforname
501- val implClass = Class .forName(
502- implClassInternalName.replace('/' , '.' ), false , lambdaClassLoader)
503- // scalastyle:on classforname
504- val implClassNode = new ClassNode ()
505- val implClassReader = ClosureCleaner .getClassReader(implClass)
506- implClassReader.accept(implClassNode, 0 )
507-
508- val methodsByName = Map .empty[MethodIdentifier [_], MethodNode ]
509- for (m <- implClassNode.methods.asScala) {
510- methodsByName(MethodIdentifier (implClass, m.name, m.desc)) = m
535+ accessedFields : Map [Class [_], Set [String ]],
536+ findTransitively : Boolean ): Unit = {
537+
538+ // We may need to visit the same class multiple times for different methods on it, and we'll
539+ // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode.
540+ val classInfoByInternalName = Map .empty[String , (Class [_], ClassNode )]
541+ val methodNodeById = Map .empty[MethodIdentifier [_], MethodNode ]
542+ def getOrUpdateClassInfo (classInternalName : String ): (Class [_], ClassNode ) = {
543+ val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, {
544+ val classExternalName = classInternalName.replace('/' , '.' )
545+ // scalastyle:off classforname
546+ val clazz = Class .forName(classExternalName, false , lambdaClassLoader)
547+ // scalastyle:on classforname
548+ val classNode = new ClassNode ()
549+ val classReader = ClosureCleaner .getClassReader(clazz)
550+ classReader.accept(classNode, 0 )
551+
552+ for (m <- classNode.methods.asScala) {
553+ methodNodeById(MethodIdentifier (clazz, m.name, m.desc)) = m
554+ }
555+
556+ (clazz, classNode)
557+ })
558+ classInfo
511559 }
512560
561+ val implClassInternalName = lambdaProxy.getImplClass
562+ val (implClass, _) = getOrUpdateClassInfo(implClassInternalName)
563+
513564 val implMethodId = MethodIdentifier (
514565 implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature)
515566
567+ // The set of classes that we would consider following the calls into.
568+ // Candidates are: known outer class which happens to be the starting closure's impl class,
569+ // and all inner classes discovered below.
570+ val trackedClasses = Set [Class [_]](implClass)
571+
572+ // Depth-first search for inner closures and track the fields that were accessed in them.
573+ // Start from the lambda body's implementation method, follow method invocations
516574 val visited = Set .empty[MethodIdentifier [_]]
517575 val stack = Stack [MethodIdentifier [_]](implMethodId)
576+ def pushIfNotVisited (methodId : MethodIdentifier [_]): Unit = {
577+ if (! visited.contains(methodId)) {
578+ stack.push(methodId)
579+ }
580+ }
581+
518582 while (! stack.isEmpty) {
519583 val currentId = stack.pop
520584 visited += currentId
521585
522- val currentMethodNode = methodsByName(currentId)
523- logTrace(s " scanning $currentId" )
586+ val currentClass = currentId.cls
587+ val currentMethodNode = methodNodeById(currentId)
588+ logTrace(s " scanning ${currentId.cls.getName}. ${currentId.name}${currentId.desc}" )
524589 currentMethodNode.accept(new MethodVisitor (ASM7 ) {
525- // FIXME: record self class name, get a Class[_] for it
526- // val selfClassName: String =
527- // val selfClass: Class[_] =
590+ val currentClassName = currentClass.getName
591+ val currentClassInternalName = currentClassName.replace('.' , '/' )
528592
593+ // Find and update the accessedFields info. Only fields on known outer classes are tracked.
594+ // This is the FieldAccessFinder equivalent.
529595 override def visitFieldInsn (op : Int , owner : String , name : String , desc : String ): Unit = {
530596 if (op == GETFIELD || op == PUTFIELD ) {
531597 val ownerExternalName = owner.replace('/' , '.' )
@@ -538,43 +604,56 @@ private[spark] object IndylambdaScalaClosures extends Logging {
538604
539605 override def visitMethodInsn (
540606 op : Int , owner : String , name : String , desc : String , itf : Boolean ): Unit = {
541- if (owner == implClassInternalName) {
542- val ownerExternalName = owner.replace( '/' , '.' )
607+ val ownerExternalName = owner.replace( '/' , '.' )
608+ if (owner == currentClassInternalName) {
543609 logTrace(s " found intra class call to $ownerExternalName. $name$desc" )
544- val calleeMethodId = MethodIdentifier (implClass, name, desc)
545- if (! visited.contains(calleeMethodId)) {
546- stack.push(calleeMethodId)
610+ // could be invoking a helper method or a field accessor method, just follow it.
611+ pushIfNotVisited(MethodIdentifier (currentClass, name, desc))
612+ } else if (isInnerClassCtorCapturingOuter(
613+ op, owner, name, desc, currentClassInternalName)) {
614+ // Discover inner classes.
615+ // This this the InnerClassFinder equivalent for inner classes, which still use the
616+ // `$outer` chain. So this is NOT controlled by the `findTransitively` flag.
617+ logTrace(s " found inner class $ownerExternalName" )
618+ val innerClassInfo = getOrUpdateClassInfo(owner)
619+ val innerClass = innerClassInfo._1
620+ val innerClassNode = innerClassInfo._2
621+ trackedClasses += innerClass
622+ // We need to visit all methods on the inner class so that we don't missing anything.
623+ for (m <- innerClassNode.methods.asScala) {
624+ pushIfNotVisited(MethodIdentifier (innerClass, m.name, m.desc))
547625 }
626+ } else if (findTransitively &&
627+ trackedClasses.find(_.getName == ownerExternalName).isDefined) {
628+ logTrace(s " found call to outer $ownerExternalName. $name$desc" )
629+ val (calleeClass, _) = getOrUpdateClassInfo(owner)
630+ pushIfNotVisited(MethodIdentifier (calleeClass, name, desc))
548631 } else {
549- // FIXME: implement findTransitively
550632 // keep the same behavior as the original ClosureCleaner
551- logTrace(s " ignoring call to $owner . $name$desc" )
633+ logTrace(s " ignoring call to $ownerExternalName . $name$desc" )
552634 }
553635 }
554636
555- // find the lexically nested closures
637+ // Find the lexically nested closures
638+ // This is the InnerClosureFinder equivalent for indylambda nested closures
556639 override def visitInvokeDynamicInsn (
557640 name : String , desc : String , bsmHandle : Handle , bsmArgs : Object * ): Unit = {
558641 logTrace(s " invokedynamic: $name$desc, bsmHandle= $bsmHandle, bsmArgs= $bsmArgs" )
559642
560643 // fast check: we only care about Scala lambda creation
644+ // TODO: maybe lift this restriction and support other functional interfaces
561645 if (! name.startsWith(" apply" )) return
562646 if (! Type .getReturnType(desc).getDescriptor.startsWith(" Lscala/Function" )) return
563647
564- if (bsmHandle.getOwner == LambdaMetafactoryClassName &&
565- bsmHandle.getName == LambdaMetafactoryMethodName &&
566- bsmHandle.getDesc == LambdaMetafactoryMethodDesc ) {
648+ if (isLambdaMetafactory(bsmHandle)) {
567649 // OK we're in the right bootstrap method for serializable Java 8 style lambda creation
568650 val targetHandle = bsmArgs(1 ).asInstanceOf [Handle ]
569- if (targetHandle.getOwner == implClassInternalName &&
570- targetHandle.getDesc.startsWith(s " (L $implClassInternalName; " )) {
651+ if (isLambdaBodyCapturingOuter(targetHandle, currentClassInternalName)) {
571652 // this is a lexically nested closure that also captures the enclosing `this`
572653 logDebug(s " found inner closure $targetHandle" )
573654 val calleeMethodId =
574- MethodIdentifier (implClass, targetHandle.getName, targetHandle.getDesc)
575- if (! visited.contains(calleeMethodId)) {
576- stack.push(calleeMethodId)
577- }
655+ MethodIdentifier (currentClass, targetHandle.getName, targetHandle.getDesc)
656+ pushIfNotVisited(calleeMethodId)
578657 }
579658 }
580659 }
0 commit comments