From a82248f50ca1fca33a27fa5f5d306075e4012bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 29 Sep 2023 10:16:12 +0200 Subject: [PATCH] BCode: Track the exact types on the stack, rather than only the height. --- .../tools/backend/jvm/BCodeBodyBuilder.scala | 113 ++++++++++-------- .../tools/backend/jvm/BCodeSkelBuilder.scala | 56 ++++++++- 2 files changed, 115 insertions(+), 54 deletions(-) diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala index e7b5a0dad1bf..974f26f7adeb 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala @@ -79,14 +79,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { tree match { case Assign(lhs @ DesugaredSelect(qual, _), rhs) => - val savedStackHeight = stackHeight + val savedStackSize = stack.recordSize() val isStatic = lhs.symbol.isStaticMember if (!isStatic) { - genLoadQualifier(lhs) - stackHeight += 1 + val qualTK = genLoad(qual) + stack.push(qualTK) } genLoad(rhs, symInfoTK(lhs.symbol)) - stackHeight = savedStackHeight + stack.restoreSize(savedStackSize) lineNumber(tree) // receiverClass is used in the bytecode to access the field. using sym.owner may lead to IllegalAccessError val receiverClass = qual.tpe.typeSymbol @@ -150,9 +150,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } genLoad(larg, resKind) - stackHeight += resKind.size + stack.push(resKind) genLoad(rarg, if (isShift) INT else resKind) - stackHeight -= resKind.size + stack.pop() (code: @switch) match { case ADD => bc add resKind @@ -189,19 +189,19 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { if (isArrayGet(code)) { // load argument on stack assert(args.length == 1, s"Too many arguments for array get operation: $tree"); - stackHeight += 1 + stack.push(k) genLoad(args.head, INT) - stackHeight -= 1 + stack.pop() generatedType = k.asArrayBType.componentType bc.aload(elementType) } else if (isArraySet(code)) { val List(a1, a2) = args - stackHeight += 1 + stack.push(k) genLoad(a1, INT) - stackHeight += 1 + stack.push(INT) genLoad(a2) - stackHeight -= 2 + stack.pop(2) generatedType = UNIT bc.astore(elementType) } else { @@ -235,7 +235,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) val postIf = new asm.Label - genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stackHeight)) + genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stack.recordSize())) markProgramPoint(failure) genLoadTo(elsep, resKind, LoadDestination.FallThrough) markProgramPoint(postIf) @@ -294,8 +294,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { ) } - def genLoad(tree: Tree): Unit = { - genLoad(tree, tpeTK(tree)) + def genLoad(tree: Tree): BType = { + val generatedType = tpeTK(tree) + genLoad(tree, generatedType) + generatedType } /* Generate code for trees that produce values on the stack */ @@ -364,6 +366,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { case t @ Ident(_) => (t, Nil) } + val savedStackSize = stack.recordSize() if (!fun.symbol.isStaticMember) { // load receiver of non-static implementation of lambda @@ -372,10 +375,12 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { // AbstractValidatingLambdaMetafactory.validateMetafactoryArgs val DesugaredSelect(prefix, _) = fun: @unchecked - genLoad(prefix) + val prefixTK = genLoad(prefix) + stack.push(prefixTK) } genLoadArguments(env, fun.symbol.info.firstParamTypes map toTypeKind) + stack.restoreSize(savedStackSize) generatedType = genInvokeDynamicLambda(NoSymbol, fun.symbol, env.size, functionalInterface) case app @ Apply(_, _) => @@ -494,9 +499,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { dest match case LoadDestination.FallThrough => () - case LoadDestination.Jump(label, targetStackHeight) => - if targetStackHeight < stackHeight then - val stackDiff = stackHeight - targetStackHeight + case LoadDestination.Jump(label, targetStackSize) => + val stackDiff = stack.heightDiffWrt(targetStackSize) + if stackDiff != 0 then if expectedType == UNIT then bc dropMany stackDiff else @@ -599,7 +604,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { if dest == LoadDestination.FallThrough then val resKind = tpeTK(tree) val jumpTarget = new asm.Label - registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stackHeight)) + registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stack.recordSize())) genLoad(expr, resKind) markProgramPoint(jumpTarget) resKind @@ -657,7 +662,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { markProgramPoint(loop) if isInfinite then - val dest = LoadDestination.Jump(loop, stackHeight) + val dest = LoadDestination.Jump(loop, stack.recordSize()) genLoadTo(body, UNIT, dest) dest else @@ -672,7 +677,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val failure = new asm.Label genCond(cond, success, failure, targetIfNoJump = success) markProgramPoint(success) - genLoadTo(body, UNIT, LoadDestination.Jump(loop, stackHeight)) + genLoadTo(body, UNIT, LoadDestination.Jump(loop, stack.recordSize())) markProgramPoint(failure) end match LoadDestination.FallThrough @@ -765,10 +770,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { // on the stack (contrary to what the type in the AST says). // scala/bug#10290: qual can be `this.$outer()` (not just `this`), so we call genLoad (not just ALOAD_0) - genLoad(superQual) - stackHeight += 1 + val superQualTK = genLoad(superQual) + stack.push(superQualTK) genLoadArguments(args, paramTKs(app)) - stackHeight -= 1 + stack.pop() generatedType = genCallMethod(fun.symbol, InvokeStyle.Super, app.span) // 'new' constructor call: Note: since constructors are @@ -790,9 +795,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { assert(classBTypeFromSymbol(ctor.owner) == rt, s"Symbol ${ctor.owner.showFullName} is different from $rt") mnode.visitTypeInsn(asm.Opcodes.NEW, rt.internalName) bc dup generatedType - stackHeight += 2 + stack.push(rt) + stack.push(rt) genLoadArguments(args, paramTKs(app)) - stackHeight -= 2 + stack.pop(2) genCallMethod(ctor, InvokeStyle.Special, app.span) case _ => @@ -825,12 +831,11 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { else if (app.hasAttachment(BCodeHelpers.UseInvokeSpecial)) InvokeStyle.Special else InvokeStyle.Virtual - val savedStackHeight = stackHeight + val savedStackSize = stack.recordSize() if invokeStyle.hasInstance then - genLoadQualifier(fun) - stackHeight += 1 + stack.push(genLoadQualifier(fun)) genLoadArguments(args, paramTKs(app)) - stackHeight = savedStackHeight + stack.restoreSize(savedStackSize) val DesugaredSelect(qual, name) = fun: @unchecked // fun is a Select, also checked in genLoadQualifier val isArrayClone = name == nme.clone_ && qual.tpe.widen.isInstanceOf[JavaArrayType] @@ -888,7 +893,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { bc iconst elems.length bc newarray elmKind - stackHeight += 3 // during the genLoad below, there is the result, its dup, and the index + // during the genLoad below, there is the result, its dup, and the index + stack.push(generatedType) + stack.push(generatedType) + stack.push(INT) var i = 0 var rest = elems @@ -901,7 +909,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { i = i + 1 } - stackHeight -= 3 + stack.pop(3) generatedType } @@ -917,7 +925,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val (generatedType, postMatch, postMatchDest) = if dest == LoadDestination.FallThrough then val postMatch = new asm.Label - (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stackHeight)) + (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stack.recordSize())) else (expectedType, null, dest) @@ -1179,7 +1187,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } /* Emit code to Load the qualifier of `tree` on top of the stack. */ - def genLoadQualifier(tree: Tree): Unit = { + def genLoadQualifier(tree: Tree): BType = { lineNumber(tree) tree match { case DesugaredSelect(qualifier, _) => genLoad(qualifier) @@ -1188,6 +1196,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { case Some(sel) => genLoadQualifier(sel) case None => assert(t.symbol.owner == this.claszSymbol) + UNIT } case _ => abort(s"Unknown qualifier $tree") } @@ -1200,14 +1209,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { btpes match case btpe :: btpes1 => genLoad(arg, btpe) - stackHeight += btpe.size + stack.push(btpe) loop(args1, btpes1) case _ => case _ => - val savedStackHeight = stackHeight + val savedStackSize = stack.recordSize() loop(args, btpes) - stackHeight = savedStackHeight + stack.restoreSize(savedStackSize) end genLoadArguments def genLoadModule(tree: Tree): BType = { @@ -1307,13 +1316,13 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { }.sum bc.genNewStringBuilder(approxBuilderSize) - stackHeight += 1 // during the genLoad below, there is a reference to the StringBuilder on the stack + stack.push(jlStringBuilderRef) // during the genLoad below, there is a reference to the StringBuilder on the stack for (elem <- concatArguments) { val elemType = tpeTK(elem) genLoad(elem, elemType) bc.genStringBuilderAppend(elemType) } - stackHeight -= 1 + stack.pop() bc.genStringBuilderEnd } else { @@ -1331,7 +1340,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { var totalArgSlots = 0 var countConcats = 1 // ie. 1 + how many times we spilled - val savedStackHeight = stackHeight + val savedStackSize = stack.recordSize() for (elem <- concatArguments) { val tpe = tpeTK(elem) @@ -1339,7 +1348,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { // Unlikely spill case if (totalArgSlots + elemSlots >= MaxIndySlots) { - stackHeight = savedStackHeight + countConcats + stack.restoreSize(savedStackSize) + for _ <- 0 until countConcats do + stack.push(StringRef) bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result()) countConcats += 1 totalArgSlots = 0 @@ -1364,10 +1375,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val tpe = tpeTK(elem) argTypes += tpe.toASMType genLoad(elem, tpe) - stackHeight += 1 + stack.push(tpe) } } - stackHeight = savedStackHeight + stack.restoreSize(savedStackSize) bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result()) // If we spilled, generate one final concat @@ -1562,9 +1573,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } else { val tk = tpeTK(l).maxType(tpeTK(r)) genLoad(l, tk) - stackHeight += tk.size + stack.push(tk) genLoad(r, tk) - stackHeight -= tk.size + stack.pop() genCJUMP(success, failure, op, tk, targetIfNoJump) } } @@ -1679,9 +1690,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } genLoad(l, ObjectRef) - stackHeight += 1 + stack.push(ObjectRef) genLoad(r, ObjectRef) - stackHeight -= 1 + stack.pop() genCallMethod(equalsMethod, InvokeStyle.Static) genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump) } @@ -1697,9 +1708,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } else if (isNonNullExpr(l)) { // SI-7852 Avoid null check if L is statically non-null. genLoad(l, ObjectRef) - stackHeight += 1 + stack.push(ObjectRef) genLoad(r, ObjectRef) - stackHeight -= 1 + stack.pop() genCallMethod(defn.Any_equals, InvokeStyle.Virtual) genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump) } else { @@ -1709,9 +1720,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val lNonNull = new asm.Label genLoad(l, ObjectRef) - stackHeight += 1 + stack.push(ObjectRef) genLoad(r, ObjectRef) - stackHeight -= 1 + stack.pop() locals.store(eqEqTempLocal) bc dup ObjectRef genCZJUMP(lNull, lNonNull, Primitives.EQ, ObjectRef, targetIfNoJump = lNull) diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala index f892c2bb753c..49ff83aa716a 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala @@ -40,12 +40,62 @@ trait BCodeSkelBuilder extends BCodeHelpers { lazy val NativeAttr: Symbol = requiredClass[scala.native] + final class BTypesStack: + // Anecdotally, growing past 16 to 32 is common; growing past 32 is rare + private var stack = new Array[BType](32) + private var size = 0 + + def push(btype: BType): Unit = + if size == stack.length then + stack = java.util.Arrays.copyOf(stack, stack.length * 2) + stack(size) = btype + size += 1 + + def pop(): Unit = pop(1) + + def pop(count: Int): Unit = + assert(size >= count) + size -= count + + def height: Int = heightBetween(0, size) + + private def heightBetween(start: Int, end: Int): Int = + var result = 0 + var i = start + while i != end do + result += stack(i).size + i += 1 + result + + def recordSize(): BTypesStack.Size = BTypesStack.intToSize(size) + + def restoreSize(targetSize: BTypesStack.Size): Unit = + val targetSize1 = BTypesStack.sizeToInt(targetSize) + assert(size >= targetSize1) + size = targetSize1 + + def heightDiffWrt(targetSize: BTypesStack.Size): Int = + val targetSize1 = BTypesStack.sizeToInt(targetSize) + assert(size >= targetSize1) + heightBetween(targetSize1, size) + + def clear(): Unit = + size = 0 + end BTypesStack + + object BTypesStack: + opaque type Size = Int + + private def intToSize(size: Int): Size = size + private def sizeToInt(size: Size): Int = size + end BTypesStack + /** The destination of a value generated by `genLoadTo`. */ enum LoadDestination: /** The value is put on the stack, and control flows through to the next opcode. */ case FallThrough /** The value is put on the stack, and control flow is transferred to the given `label`. */ - case Jump(label: asm.Label, targetStackHeight: Int) + case Jump(label: asm.Label, targetStackSize: BTypesStack.Size) /** The value is RETURN'ed from the enclosing method. */ case Return /** The value is ATHROW'n. */ @@ -369,7 +419,7 @@ trait BCodeSkelBuilder extends BCodeHelpers { var earlyReturnVar: Symbol = null var shouldEmitCleanup = false // stack tracking - var stackHeight = 0 + val stack = new BTypesStack // line numbers var lastEmittedLineNr = -1 @@ -589,7 +639,7 @@ trait BCodeSkelBuilder extends BCodeHelpers { earlyReturnVar = null shouldEmitCleanup = false - stackHeight = 0 + stack.clear() lastEmittedLineNr = -1 }