Skip to content

Commit

Permalink
Some fixes for AnnotatedTypes mapping (#19957)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel authored Sep 13, 2024
2 parents 8e9ded0 + ac76938 commit c9b9aea
Show file tree
Hide file tree
Showing 23 changed files with 242 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dotty.tools.benchmarks

import org.openjdk.jmh.annotations.{Benchmark, BenchmarkMode, Fork, Level, Measurement, Mode as JMHMode, Param, Scope, Setup, State, Warmup}
import java.util.concurrent.TimeUnit.SECONDS

import dotty.tools.dotc.{Driver, Run, Compiler}
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}, tpd.{Apply, Block, Tree, TreeAccumulator, TypeApply}
import dotty.tools.dotc.core.Annotations.{Annotation, ConcreteAnnotation, EmptyAnnotation}
import dotty.tools.dotc.core.Contexts.{ContextBase, Context, ctx, withMode}
import dotty.tools.dotc.core.Mode
import dotty.tools.dotc.core.Phases.Phase
import dotty.tools.dotc.core.Symbols.{defn, mapSymbols, Symbol}
import dotty.tools.dotc.core.Types.{AnnotatedType, NoType, SkolemType, TermRef, Type, TypeMap}
import dotty.tools.dotc.parsing.Parser
import dotty.tools.dotc.typer.TyperPhase

/** Measures the performance of mapping over annotated types.
*
* Run with: scala3-bench-micro / Jmh / run AnnotationsMappingBenchmark
*/
@Fork(value = 4)
@Warmup(iterations = 4, time = 1, timeUnit = SECONDS)
@Measurement(iterations = 4, time = 1, timeUnit = SECONDS)
@BenchmarkMode(Array(JMHMode.Throughput))
@State(Scope.Thread)
class AnnotationsMappingBenchmark:
var tp: Type = null
var specialIntTp: Type = null
var context: Context = null
var typeFunction: Context ?=> Type => Type = null
var typeMap: TypeMap = null

@Param(Array("v1", "v2", "v3", "v4"))
var valName: String = null

@Param(Array("id", "mapInts"))
var typeFunctionName: String = null

@Setup(Level.Iteration)
def setup(): Unit =
val testPhase =
new Phase:
final override def phaseName = "testPhase"
final override def run(using ctx: Context): Unit =
val pkg = ctx.compilationUnit.tpdTree.symbol
tp = pkg.requiredClass("Test").requiredValueRef(valName).underlying
specialIntTp = pkg.requiredClass("Test").requiredType("SpecialInt").typeRef
context = ctx

val compiler =
new Compiler:
private final val baseCompiler = new Compiler()
final override def phases = List(List(Parser()), List(TyperPhase()), List(testPhase))

val driver =
new Driver:
final override def newCompiler(using Context): Compiler = compiler

driver.process(Array("-classpath", System.getProperty("BENCH_CLASS_PATH"), "tests/someAnnotatedTypes.scala"))

typeFunction =
typeFunctionName match
case "id" => tp => tp
case "mapInts" => tp => (if tp frozen_=:= defn.IntType then specialIntTp else tp)
case _ => throw new IllegalArgumentException(s"Unknown type function: $typeFunctionName")

typeMap =
new TypeMap(using context):
final override def apply(tp: Type): Type = typeFunction(mapOver(tp))

@Benchmark def applyTypeMap() = typeMap.apply(tp)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -16,12 +17,12 @@ import java.util.concurrent.{Executors, ExecutorService}
class ContendedInitialization {

@Param(Array("2000000", "5000000"))
var size: Int = _
var size: Int = uninitialized

@Param(Array("2", "4", "8"))
var nThreads: Int = _
var nThreads: Int = uninitialized

var executor: ExecutorService = _
var executor: ExecutorService = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccess {

var holder: LazyHolder = _
var holder: LazyHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyAnyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessAny {

var holder: LazyAnyHolder = _
var holder: LazyAnyHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyGenericHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessGeneric {

var holder: LazyGenericHolder[String] = _
var holder: LazyGenericHolder[String] = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations.*
import org.openjdk.jmh.infra.Blackhole
import LazyVals.LazyIntHolder
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessInt {

var holder: LazyIntHolder = _
var holder: LazyIntHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessMultiple {

var holders: Array[LazyHolder] = _
var holders: Array[LazyHolder] = uninitialized

@Setup
def prepare: Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dotty.tools.benchmarks.lazyvals

import compiletime.uninitialized
import org.openjdk.jmh.annotations._
import LazyVals.LazyStringHolder
import org.openjdk.jmh.infra.Blackhole
Expand All @@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
@State(Scope.Benchmark)
class InitializedAccessString {

var holder: LazyStringHolder = _
var holder: LazyStringHolder = uninitialized

@Setup
def prepare: Unit = {
Expand Down
28 changes: 28 additions & 0 deletions bench-micro/tests/someAnnotatedTypes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
class Test:
class FlagAnnot extends annotation.StaticAnnotation
class StringAnnot(val s: String) extends annotation.StaticAnnotation
class LambdaAnnot(val f: Int => Boolean) extends annotation.StaticAnnotation

type SpecialInt <: Int

val v1: Int @FlagAnnot = 42

val v2: Int @StringAnnot("hello") = 42

val v3: Int @LambdaAnnot(it => it == 42) = 42

val v4: Int @LambdaAnnot(it => {
def g(x: Int, y: Int) = x - y + 5
g(it, 7) * 2 == 80
}) = 42

/*val v5: Int @LambdaAnnot(it => {
class Foo(x: Int):
def xPlus10 = x + 10
def xPlus20 = x + 20
def xPlus(y: Int) = x + y
val foo = Foo(it)
foo.xPlus10 - foo.xPlus20 + foo.xPlus(30) == 62
}) = 42*/

def main(args: Array[String]): Unit = ???
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
loop(tree, Nil)

/** All term arguments of an application in a single flattened list */
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, args) => allArguments(fn)
case Block(_, expr) => allArguments(expr)
case _ => Nil
}

/** All type and term arguments of an application in a single flattened list */
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, _) => allArguments(fn)
case TypeApply(fn, args) => allArguments(fn) ::: args
case Block(_, expr) => allArguments(expr)
case _ => Nil
}
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ object Annotations {
def derivedAnnotation(tree: Tree)(using Context): Annotation =
if (tree eq this.tree) this else Annotation(tree)

/** All arguments to this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
/** All term arguments of this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allTermArguments(tree)

def argument(i: Int)(using Context): Option[Tree] = {
val args = arguments
Expand All @@ -54,15 +54,18 @@ object Annotations {
* type, since ranges cannot be types of trees.
*/
def mapWith(tm: TypeMap)(using Context) =
val args = arguments
val args = tpd.allArguments(tree)
if args.isEmpty then this
else
// Checks if `tm` would result in any change by applying it to types
// inside the annotations' arguments and checking if the resulting types
// are different.
val findDiff = new TreeAccumulator[Type]:
def apply(x: Type, tree: Tree)(using Context): Type =
if tm.isRange(x) then x
else
val tp1 = tm(tree.tpe)
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree)
val diff = findDiff(NoType, args)
if tm.isRange(diff) then EmptyAnnotation
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object PositionPickler:
pickler: TastyPickler,
addrOfTree: TreeToAddr,
treeAnnots: untpd.MemberDef => List[tpd.Tree],
typeAnnots: List[tpd.Tree],
relativePathReference: String,
source: SourceFile,
roots: List[Tree],
Expand Down Expand Up @@ -136,6 +137,9 @@ object PositionPickler:
}
for (root <- roots)
traverse(root, NoSource)

for annotTree <- typeAnnots do
traverse(annotTree, NoSource)
end picklePositions
end PositionPickler

7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
*/
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()

/** A set of annotation trees appearing in annotated types.
*/
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()

/** A map from member definitions to their doc comments, so that later
* parallel comment pickling does not need to access symbols of trees (which
* would involve accessing symbols of named types and possibly changing phases
Expand All @@ -57,6 +61,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
val ts = annotTrees.lookup(tree)
if ts == null then Nil else ts.toList

def typeAnnots: List[Tree] = annotatedTypeTrees.toList

def docString(tree: untpd.MemberDef): Option[Comment] =
Option(docStrings.lookup(tree))

Expand Down Expand Up @@ -278,6 +284,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
case tpe: AnnotatedType =>
writeByte(ANNOTATEDtype)
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
annotatedTypeTrees += tpe.annot.tree
case tpe: AndType =>
writeByte(ANDtype)
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ object PickledQuotes {
if tree.span.exists then
val positionWarnings = new mutable.ListBuffer[Message]()
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
positionWarnings.foreach(report.warning(_))

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Pickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class Pickler extends Phase {
if tree.span.exists then
val reference = ctx.settings.sourceroot.value
PositionPickler.picklePositions(
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
unit.source, tree :: Nil, positionWarnings,
scratch.positionBuffer, scratch.pickledIndices)

Expand Down
6 changes: 4 additions & 2 deletions compiler/test/dotty/tools/dotc/printing/PrintingTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scala.language.unsafeNulls

import vulpix.FileDiff
import vulpix.TestConfiguration
import vulpix.TestConfiguration
import vulpix.ParallelTesting
import reporting.TestReporter

import java.io._
Expand All @@ -25,7 +25,9 @@ import java.io.File
class PrintingTest {

def options(phase: String, flags: List[String]) =
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-classpath", TestConfiguration.basicClasspath) ::: flags
val outDir = ParallelTesting.defaultOutputDir + "printing" + File.pathSeparator
File(outDir).mkdirs()
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-d", outDir, "-classpath", TestConfiguration.basicClasspath) ::: flags

private def compileFile(path: JPath, phase: String): Boolean = {
val baseFilePath = path.toString.stripSuffix(".scala")
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/annot-17939b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.annotation.Annotation
class myRefined(f: ? => Boolean) extends Annotation

def test(axes: Int) = true

trait Tensor:
def mean(axes: Int): Int @myRefined(_ => test(axes))

class TensorImpl() extends Tensor:
def mean(axes: Int) = ???
9 changes: 9 additions & 0 deletions tests/pos/annot-18064.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//> using options "-Xprint:typer"

class myAnnot[T]() extends annotation.Annotation

trait Tensor[T]:
def add: Tensor[T] @myAnnot[T]()

class TensorImpl[A]() extends Tensor[A]:
def add /* : Tensor[A] @myAnnot[A] */ = this
10 changes: 10 additions & 0 deletions tests/pos/annot-5789.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Annot[T] extends scala.annotation.Annotation

class D[T](val f: Int@Annot[T])

object A{
def main(a:Array[String]) = {
val c = new D[Int](1)
c.f
}
}
Loading

0 comments on commit c9b9aea

Please sign in to comment.