Skip to content

Commit c1a1021

Browse files
committed
Add support for discovering inner classes (as opposed to just direct inner closures)
The new unit test in the previous commit would now pass.
1 parent 34504f2 commit c1a1021

File tree

1 file changed

+117
-38
lines changed

1 file changed

+117
-38
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,13 @@ private[spark] object ClosureCleaner extends Logging {
373373

374374
if (needsCleaning) {
375375
// indylambda closures do not reference enclosing closures via an `$outer` chain, so no
376-
// transitive cleaning is needed. Thus clean() shouldn't be recursively called with a
377-
// non-empty accessedFields.
376+
// transitive cleaning on the `$outer` chain is needed.
377+
// Thus clean() shouldn't be recursively called with a non-empty accessedFields.
378378
assert(accessedFields.isEmpty)
379379

380380
initAccessedFields(accessedFields, Seq(capturingClass))
381-
IndylambdaScalaClosures.findAccessedFields(lambdaProxy, classLoader, accessedFields)
381+
IndylambdaScalaClosures.findAccessedFields(
382+
lambdaProxy, classLoader, accessedFields, cleanTransitively)
382383

383384
logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes")
384385
accessedFields.foreach { f => logDebug(" " + f) }
@@ -488,44 +489,109 @@ private[spark] object IndylambdaScalaClosures extends Logging {
488489
writeReplace.invoke(closure).asInstanceOf[SerializedLambda]
489490
}
490491

492+
/**
493+
* Check if the handle represents the LambdaMetafactory that indylambda Scala closures
494+
* use for creating the lambda class and getting a closure instance.
495+
*/
496+
def isLambdaMetafactory(bsmHandle: Handle): Boolean = {
497+
bsmHandle.getOwner == LambdaMetafactoryClassName &&
498+
bsmHandle.getName == LambdaMetafactoryMethodName &&
499+
bsmHandle.getDesc == LambdaMetafactoryMethodDesc
500+
}
501+
502+
/**
503+
* Check if the handle represents a target method that is:
504+
* - a STATIC method that implements a Scala lambda body in the indylambda style
505+
* - captures the enclosing `this`, i.e. the first argument is a reference to the same type as
506+
* the owning class.
507+
* Returns true if both criteria above are met.
508+
*/
509+
def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = {
510+
handle.getTag == H_INVOKESTATIC &&
511+
handle.getName.contains("$anonfun$") &&
512+
handle.getOwner == ownerInternalName &&
513+
handle.getDesc.startsWith(s"(L$ownerInternalName;")
514+
}
515+
516+
/**
517+
* Check if the callee of a call site is a inner class constructor.
518+
* - A constructor has to be invoked via INVOKESPECIAL
519+
* - A constructor's internal name is "<init>" and the return type is "V" (void)
520+
* - An inner class' first argument in the signature has to be a reference to the
521+
* enclosing "this", aka `$outer` in Scala.
522+
*/
523+
def isInnerClassCtorCapturingOuter(
524+
op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = {
525+
op == INVOKESPECIAL && name == "<init>" && desc.startsWith(s"(L$callerInternalName;")
526+
}
527+
491528
/**
492529
* Scans an indylambda Scala closure, along with its lexically nested closures, and populate
493530
* the accessed fields info on which fields on the outer object are accessed.
494531
*/
495532
def findAccessedFields(
496533
lambdaProxy: SerializedLambda,
497534
lambdaClassLoader: ClassLoader,
498-
accessedFields: Map[Class[_], Set[String]]): Unit = {
499-
val implClassInternalName = lambdaProxy.getImplClass
500-
// scalastyle:off classforname
501-
val implClass = Class.forName(
502-
implClassInternalName.replace('/', '.'), false, lambdaClassLoader)
503-
// scalastyle:on classforname
504-
val implClassNode = new ClassNode()
505-
val implClassReader = ClosureCleaner.getClassReader(implClass)
506-
implClassReader.accept(implClassNode, 0)
507-
508-
val methodsByName = Map.empty[MethodIdentifier[_], MethodNode]
509-
for (m <- implClassNode.methods.asScala) {
510-
methodsByName(MethodIdentifier(implClass, m.name, m.desc)) = m
535+
accessedFields: Map[Class[_], Set[String]],
536+
findTransitively: Boolean): Unit = {
537+
538+
// We may need to visit the same class multiple times for different methods on it, and we'll
539+
// need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode.
540+
val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)]
541+
val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode]
542+
def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = {
543+
val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, {
544+
val classExternalName = classInternalName.replace('/', '.')
545+
// scalastyle:off classforname
546+
val clazz = Class.forName(classExternalName, false, lambdaClassLoader)
547+
// scalastyle:on classforname
548+
val classNode = new ClassNode()
549+
val classReader = ClosureCleaner.getClassReader(clazz)
550+
classReader.accept(classNode, 0)
551+
552+
for (m <- classNode.methods.asScala) {
553+
methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
554+
}
555+
556+
(clazz, classNode)
557+
})
558+
classInfo
511559
}
512560

561+
val implClassInternalName = lambdaProxy.getImplClass
562+
val (implClass, _) = getOrUpdateClassInfo(implClassInternalName)
563+
513564
val implMethodId = MethodIdentifier(
514565
implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature)
515566

567+
// The set of classes that we would consider following the calls into.
568+
// Candidates are: known outer class which happens to be the starting closure's impl class,
569+
// and all inner classes discovered below.
570+
val trackedClasses = Set[Class[_]](implClass)
571+
572+
// Depth-first search for inner closures and track the fields that were accessed in them.
573+
// Start from the lambda body's implementation method, follow method invocations
516574
val visited = Set.empty[MethodIdentifier[_]]
517575
val stack = Stack[MethodIdentifier[_]](implMethodId)
576+
def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = {
577+
if (!visited.contains(methodId)) {
578+
stack.push(methodId)
579+
}
580+
}
581+
518582
while (!stack.isEmpty) {
519583
val currentId = stack.pop
520584
visited += currentId
521585

522-
val currentMethodNode = methodsByName(currentId)
523-
logTrace(s" scanning $currentId")
586+
val currentClass = currentId.cls
587+
val currentMethodNode = methodNodeById(currentId)
588+
logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}")
524589
currentMethodNode.accept(new MethodVisitor(ASM7) {
525-
// FIXME: record self class name, get a Class[_] for it
526-
// val selfClassName: String =
527-
// val selfClass: Class[_] =
590+
val currentClassName = currentClass.getName
591+
val currentClassInternalName = currentClassName.replace('.', '/')
528592

593+
// Find and update the accessedFields info. Only fields on known outer classes are tracked.
594+
// This is the FieldAccessFinder equivalent.
529595
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = {
530596
if (op == GETFIELD || op == PUTFIELD) {
531597
val ownerExternalName = owner.replace('/', '.')
@@ -538,43 +604,56 @@ private[spark] object IndylambdaScalaClosures extends Logging {
538604

539605
override def visitMethodInsn(
540606
op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = {
541-
if (owner == implClassInternalName) {
542-
val ownerExternalName = owner.replace('/', '.')
607+
val ownerExternalName = owner.replace('/', '.')
608+
if (owner == currentClassInternalName) {
543609
logTrace(s" found intra class call to $ownerExternalName.$name$desc")
544-
val calleeMethodId = MethodIdentifier(implClass, name, desc)
545-
if (!visited.contains(calleeMethodId)) {
546-
stack.push(calleeMethodId)
610+
// could be invoking a helper method or a field accessor method, just follow it.
611+
pushIfNotVisited(MethodIdentifier(currentClass, name, desc))
612+
} else if (isInnerClassCtorCapturingOuter(
613+
op, owner, name, desc, currentClassInternalName)) {
614+
// Discover inner classes.
615+
// This this the InnerClassFinder equivalent for inner classes, which still use the
616+
// `$outer` chain. So this is NOT controlled by the `findTransitively` flag.
617+
logTrace(s" found inner class $ownerExternalName")
618+
val innerClassInfo = getOrUpdateClassInfo(owner)
619+
val innerClass = innerClassInfo._1
620+
val innerClassNode = innerClassInfo._2
621+
trackedClasses += innerClass
622+
// We need to visit all methods on the inner class so that we don't missing anything.
623+
for (m <- innerClassNode.methods.asScala) {
624+
pushIfNotVisited(MethodIdentifier(innerClass, m.name, m.desc))
547625
}
626+
} else if (findTransitively &&
627+
trackedClasses.find(_.getName == ownerExternalName).isDefined) {
628+
logTrace(s" found call to outer $ownerExternalName.$name$desc")
629+
val (calleeClass, _) = getOrUpdateClassInfo(owner)
630+
pushIfNotVisited(MethodIdentifier(calleeClass, name, desc))
548631
} else {
549-
// FIXME: implement findTransitively
550632
// keep the same behavior as the original ClosureCleaner
551-
logTrace(s" ignoring call to $owner.$name$desc")
633+
logTrace(s" ignoring call to $ownerExternalName.$name$desc")
552634
}
553635
}
554636

555-
// find the lexically nested closures
637+
// Find the lexically nested closures
638+
// This is the InnerClosureFinder equivalent for indylambda nested closures
556639
override def visitInvokeDynamicInsn(
557640
name: String, desc: String, bsmHandle: Handle, bsmArgs: Object*): Unit = {
558641
logTrace(s" invokedynamic: $name$desc, bsmHandle=$bsmHandle, bsmArgs=$bsmArgs")
559642

560643
// fast check: we only care about Scala lambda creation
644+
// TODO: maybe lift this restriction and support other functional interfaces
561645
if (!name.startsWith("apply")) return
562646
if (!Type.getReturnType(desc).getDescriptor.startsWith("Lscala/Function")) return
563647

564-
if (bsmHandle.getOwner == LambdaMetafactoryClassName &&
565-
bsmHandle.getName == LambdaMetafactoryMethodName &&
566-
bsmHandle.getDesc == LambdaMetafactoryMethodDesc) {
648+
if (isLambdaMetafactory(bsmHandle)) {
567649
// OK we're in the right bootstrap method for serializable Java 8 style lambda creation
568650
val targetHandle = bsmArgs(1).asInstanceOf[Handle]
569-
if (targetHandle.getOwner == implClassInternalName &&
570-
targetHandle.getDesc.startsWith(s"(L$implClassInternalName;")) {
651+
if (isLambdaBodyCapturingOuter(targetHandle, currentClassInternalName)) {
571652
// this is a lexically nested closure that also captures the enclosing `this`
572653
logDebug(s" found inner closure $targetHandle")
573654
val calleeMethodId =
574-
MethodIdentifier(implClass, targetHandle.getName, targetHandle.getDesc)
575-
if (!visited.contains(calleeMethodId)) {
576-
stack.push(calleeMethodId)
577-
}
655+
MethodIdentifier(currentClass, targetHandle.getName, targetHandle.getDesc)
656+
pushIfNotVisited(calleeMethodId)
578657
}
579658
}
580659
}

0 commit comments

Comments
 (0)