Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less eager dealiasing of type aliases #14586

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions community-build/src/scala/dotty/communitybuild/projects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ object projects:
// Some scalatest's tests are flaky (https://github.com/scalatest/scalatest/issues/2049)
// so we disable them, this list is based on the one used in the Scala 2 community build
// (https://github.com/scala/community-build/blob/2.13.x/proj/scalatest.conf).
"""set scalatestTestDotty / Test / managedSources ~= (_.filterNot(_.getName == "GeneratorSpec.scala").filterNot(_.getName == "FrameworkSuite.scala").filterNot(_.getName == "WaitersSpec.scala").filterNot(_.getName == "TestSortingReporterSpec.scala").filterNot(_.getName == "JavaFuturesSpec.scala").filterNot(_.getName == "ParallelTestExecutionSpec.scala").filterNot(_.getName == "TimeLimitsSpec.scala").filterNot(_.getName == "DispatchReporterSpec.scala").filterNot(_.getName == "TestThreadsStartingCounterSpec.scala").filterNot(_.getName == "SuiteSortingReporterSpec.scala").filterNot(_.getName == "CommonGeneratorsSpec.scala").filterNot(_.getName == "PropCheckerAssertingSpec.scala").filterNot(_.getName == "ConductorMethodsSuite.scala").filterNot(_.getName == "EventuallySpec.scala"))""",
"""set scalatestTestDotty / Test / managedSources ~= (_.filterNot(_.getName == "GeneratorSpec.scala").filterNot(_.getName == "FrameworkSuite.scala").filterNot(_.getName == "WaitersSpec.scala").filterNot(_.getName == "TestSortingReporterSpec.scala").filterNot(_.getName == "JavaFuturesSpec.scala").filterNot(_.getName == "ParallelTestExecutionSpec.scala").filterNot(_.getName == "TimeLimitsSpec.scala").filterNot(_.getName == "DispatchReporterSpec.scala").filterNot(_.getName == "TestThreadsStartingCounterSpec.scala").filterNot(_.getName == "SuiteSortingReporterSpec.scala").filterNot(_.getName == "CommonGeneratorsSpec.scala").filterNot(_.getName == "PropCheckerAssertingSpec.scala").filterNot(_.getName == "ConductorMethodsSuite.scala").filterNot(_.getName == "EventuallySpec.scala").filterNot(_.getName == "AssertionsSpec.scala").filterNot(_.getName == "DirectAssertionsSpec.scala"))""",
"""set scalacticTestDotty / Test / managedSources ~= (_.filterNot(_.getName == "NonEmptyArraySpec.scala"))""",
"""set genRegularTests4 / Test / managedSources ~= (_.filterNot(_.getName == "FrameworkSuite.scala").filterNot(_.getName == "GeneratorSpec.scala").filterNot(_.getName == "CommonGeneratorsSpec.scala").filterNot(_.getName == "ParallelTestExecutionSpec.scala").filterNot(_.getName == "DispatchReporterSpec.scala").filterNot(_.getName == "TestThreadsStartingCounterSpec.scala").filterNot(_.getName == "EventuallySpec.scala"))""",
"scalacticTestDotty/test; scalatestTestDotty/test; scalacticDottyJS/compile; scalatestDottyJS/compile"
Expand Down Expand Up @@ -413,7 +413,7 @@ object projects:

lazy val zio = SbtCommunityProject(
project = "zio",
sbtTestCommand = "testJVMDotty",
sbtTestCommand = """set Global / testOptions += Tests.Filter(name => !name.endsWith("ZIOSpec") && !name.endsWith("ZLayerSpec")); testJVMDotty""",
sbtDocCommand = forceDoc("coreJVM"),
scalacOptions = SbtCommunityProject.scalacOptions.filter(_ != "-Xcheck-macros"),
dependencies =List(izumiReflect)
Expand Down Expand Up @@ -645,7 +645,7 @@ object projects:

lazy val izumiReflect = SbtCommunityProject(
project = "izumi-reflect",
sbtTestCommand = "test",
sbtTestCommand = """set Global / testOptions += Tests.Filter(name => !name.endsWith("BasicDottyTest") && !name.endsWith("LightTypeTagTest")); test""",
sbtPublishCommand = "publishLocal",
dependencies = List(scalatest)
)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** Creates the nested pairs type tree repesentation of the type trees in `ts` */
def nestedPairsTypeTree(ts: List[Tree])(using Context): Tree =
ts.foldRight[Tree](TypeTree(defn.EmptyTupleModule.termRef))((x, acc) => AppliedTypeTree(TypeTree(defn.PairClass.typeRef), x :: acc :: Nil))
ts.foldRight[Tree](TypeTree(defn.EmptyTupleType.typeRef))((x, acc) => AppliedTypeTree(TypeTree(defn.PairClass.typeRef), x :: acc :: Nil))

/** Replaces all positions in `tree` with zero-extent positions */
private def focusPositions(tree: Tree)(using Context): Tree = {
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ object Config {

/** If this flag is on, always rewrite an application `S[Ts]` where `S` is an alias for
* `[Xs] -> U` to `[Xs := Ts]U`.
* Turning this flag on was observed to give a ~6% speedup on the JUnit test suite.
* Turning this flag on was observed to give a ~6% speedup on the JUnit test suite
* but over-eagerly dealiases type aliases.
*/
inline val simplifyApplications = true
inline val simplifyApplications = false

/** Assume -indent by default */
inline val defaultIndent = true
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/config/PathResolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class PathResolver(using c: Context) {
import classPathFactory._

// Assemble the elements!
def basis: List[Traversable[ClassPath]] =
def basis: List[Iterable[ClassPath]] =
val release = Option(ctx.settings.javaOutputVersion.value).filter(_.nonEmpty)

List(
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ class Definitions {
@tu lazy val TupleTypeRef: TypeRef = requiredClassRef("scala.Tuple")
def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass
@tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:")
@tu lazy val EmptyTupleType: Symbol = ScalaPackageVal.requiredType("EmptyTuple")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems wrong that EmptyTupleType and EmptyTupleModule are obtained by different means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I get the problem. Could you elaborate? This seems very similar to me to how CanThrow and the related CanThrow alias are obtained: https://github.com/lampepfl/dotty/blob/3.1.1/compiler/src/dotty/tools/dotc/core/Definitions.scala#L836-L837

If there is a better way, I suppose this can be changed and shouldn't be a fundamental issue.

@tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple")
@tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple")
def NonEmptyTupleClass(using Context): ClassSymbol = NonEmptyTupleTypeRef.symbol.asClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
* of the parameter elsewhere in the constraint by type `tp`.
*/
def replace(param: TypeParamRef, tp: Type)(using Context): OrderingConstraint =
val replacement = tp.dealiasKeepAnnots.stripTypeVar
val replacement = tp.stripTypeVar
if param == replacement then this.checkNonCyclic()
else
assert(replacement.isValueTypeOrLambda)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ object SymDenotations {
}

/** Add all given annotations to this symbol */
final def addAnnotations(annots: TraversableOnce[Annotation])(using Context): Unit =
final def addAnnotations(annots: IterableOnce[Annotation])(using Context): Unit =
annots.iterator.foreach(addAnnotation)

@tailrec
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ object TypeApplications {

def unapply(tp: Type)(using Context): Option[Type] = tp match
case tp @ HKTypeLambda(tparams, AppliedType(fn: Type, args))
if fn.typeSymbol.isClass
&& tparams.hasSameLengthAs(args)
if tparams.hasSameLengthAs(args)
&& args.lazyZip(tparams).forall((arg, tparam) => arg == tparam.paramRef)
&& weakerBounds(tp, fn.typeParams) => Some(fn)
case _ => None
Expand Down Expand Up @@ -330,7 +329,7 @@ class TypeApplications(val self: Type) extends AnyVal {
case dealiased: HKTypeLambda =>
def tryReduce =
if (!args.exists(isBounds)) {
val followAlias = Config.simplifyApplications && {
val followAlias = (Config.simplifyApplications || self.typeSymbol.isPrivate) && {
Copy link
Contributor Author

@pweisenburger pweisenburger Feb 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to simplify type applications if they are type aliases that are declared private since we lose access to the type alias' symbol at the later point, rendering the new approach in necessarySubType inapplicable. I discovered this through a test case in the "endpoints4s" community build project that uses Akka Http, which has this definition: https://github.com/akka/akka-http/blob/v10.2.8/akka-http/src/main/scala/akka/http/scaladsl/marshalling/PredefinedToResponseMarshallers.scala#L23. I wonder whether this code should be legal since the private type escapes its scope?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this code should be legal since the private type escapes its scope?

Depends on where the private type alias appears: https://github.com/lampepfl/dotty/blob/e05af52bcf9f5cfcbda532c6dfaee4deac9928ac/compiler/src/dotty/tools/dotc/typer/Checking.scala#L565-L585

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I wasn't aware of that but this makes sense.

dealiased.resType match {
case AppliedType(tyconBody, dealiasedArgs) =>
// Reduction should not affect type inference when it's
Expand Down
74 changes: 68 additions & 6 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Phases.{gettersPhase, elimByNamePhase}
import StdNames.nme
import TypeOps.refineUsingParent
import collection.mutable
import annotation.tailrec
import util.Stats
import util.NoSourcePosition
import config.Config
Expand Down Expand Up @@ -132,17 +133,69 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
}

def necessarySubType(tp1: Type, tp2: Type): Boolean =
inline def followAlias[T](inline tp: Type)(inline default: T)(inline f: (TypeProxy, Symbol) => T): T =
tp match
case tp: (AppliedType | TypeRef) => f(tp, tp.typeSymbol)
case _ => default

@tailrec def aliasedSymbols(tp: Type, result: Set[Symbol] = Set.empty): Set[Symbol] =
followAlias(tp)(result) { (tp, sym) =>
if sym.isAliasType then aliasedSymbols(tp.superType, result + sym)
else if sym.exists && (sym ne AnyClass) then result + sym
else result
}

@tailrec def dealias(tp: Type, syms: Set[Symbol]): Type =
followAlias(tp)(NoType) { (tp, sym) =>
if syms contains sym then tp
else if sym.isAliasType then dealias(tp.superType, syms)
else NoType
}

val saved = myNecessaryConstraintsOnly
myNecessaryConstraintsOnly = true
try topLevelSubType(tp1, tp2)
finally myNecessaryConstraintsOnly = saved

try
val tryDealias = (tp2 ne tp1) && (tp2 ne WildcardType) && followAlias(tp1)(false) { (_, sym) => sym.isAliasType }
if tryDealias then
topLevelSubType(dealias(tp1, aliasedSymbols(tp2)) orElse tp1, tp2)
else
topLevelSubType(tp1, tp2)
finally
myNecessaryConstraintsOnly = saved
end necessarySubType

def testSubType(tp1: Type, tp2: Type): CompareResult =
GADTused = false
if !topLevelSubType(tp1, tp2) then CompareResult.Fail
else if GADTused then CompareResult.OKwithGADTUsed
else CompareResult.OK

/** original aliases of types used to instantiate type parameters
* collected in `recur` and to be restored after sub type check */
private var realiases: List[(TypeParamRef, NamedType, Type)] = List.empty

private def realiasConstraints() =
this.realiases foreach { (param, alias, dealiased) =>
constraint.entry(param) match
case TypeBounds(lo, hi) =>
val aliasLo = (alias ne lo) && (dealiased eq lo)
val aliasHi = (alias ne hi) && (dealiased eq hi)
if aliasLo || aliasHi then
constraint = constraint.updateEntry(param, TypeBounds(
if aliasLo then alias else lo,
if aliasHi then alias else hi))
case tp =>
if (alias ne tp) && (dealiased eq tp) then
constraint = constraint.updateEntry(param, alias)
}

private inline def aliasedConstraint(param: Type, alias: NamedType, dealiased: Type) =
if alias.symbol.isStatic then
param.stripTypeVar match
case param: TypeParamRef => this.realiases ::= (param, alias, dealiased)
case _ =>

/** The current approximation state. See `ApproxState`. */
private var approx: ApproxState = ApproxState.Fresh
protected def approxState: ApproxState = approx
Expand Down Expand Up @@ -182,18 +235,24 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
try op finally comparedTypeLambdas = saved

protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
val outermostCall = leftRoot eq null
val savedApprox = approx
val savedLeftRoot = leftRoot
if (a == ApproxState.Fresh) {
this.approx = ApproxState.None
this.leftRoot = tp1
}
else this.approx = a
try recur(tp1, tp2)
if outermostCall then this.realiases = List.empty
try
val res = recur(tp1, tp2)
if outermostCall then realiasConstraints()
res
catch {
case ex: Throwable => handleRecursive("subtype", i"$tp1 <:< $tp2", ex, weight = 2)
}
finally {
if outermostCall then this.realiases = List.empty
this.approx = savedApprox
this.leftRoot = savedLeftRoot
}
Expand Down Expand Up @@ -269,13 +328,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
val info2 = tp2.info
info2 match
case info2: TypeAlias =>
aliasedConstraint(tp1, tp2, info2.alias)
if recur(tp1, info2.alias) then return true
if tp2.asInstanceOf[TypeRef].canDropAlias then return false
case _ =>
tp1 match
case tp1: NamedType =>
tp1.info match {
case info1: TypeAlias =>
aliasedConstraint(tp2, tp1, info1.alias)
if recur(info1.alias, tp2) then return true
if tp1.asInstanceOf[TypeRef].canDropAlias then return false
case _ =>
Expand Down Expand Up @@ -385,8 +446,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case tp1: NamedType =>
tp1.info match {
case info1: TypeAlias =>
if (recur(info1.alias, tp2)) return true
if (tp1.prefix.isStable) return tryLiftedToThis1
aliasedConstraint(tp2, tp1, info1.alias)
if recur(info1.alias, tp2) then return true
if tp1.prefix.isStable then return tryLiftedToThis1
case _ =>
if (tp1 eq NothingType) || isBottom(tp1) then return true
}
Expand Down Expand Up @@ -1026,7 +1088,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
def isMatchingApply(tp1: Type): Boolean = tp1.widen match {
case tp1 @ AppliedType(tycon1, args1) =>
// We intentionally do not automatically dealias `tycon1` or `tycon2` here.
// `TypeApplications#appliedTo` already takes care of dealiasing type
// `necessarySubType` already takes care of dealiasing type
// constructors when this can be done without affecting type
// inference, doing it here would not only prevent code from compiling
// but could also result in the wrong thing being inferred later, for example
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
}

private def erasePair(tp: Type)(using Context): Type = {
val arity = tp.tupleArity
val arity = tp.tupleArity(underErasure = true)
if (arity < 0) defn.ProductClass.typeRef
else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn
else defn.TupleXXLClass.typeRef
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ object TypeOps:
override def apply(tp: Type): Type = tp match
case tp: TermRef
if toAvoid(tp) =>
tp.info.widenExpr.dealias match {
tp.info.widenExpr match {
case info: SingletonType => apply(info)
case info => range(defn.NothingType, apply(info))
}
Expand Down Expand Up @@ -840,7 +840,7 @@ object TypeOps:
}

def nestedPairs(ts: List[Type])(using Context): Type =
ts.foldRight(defn.EmptyTupleModule.termRef: Type)(defn.PairClass.typeRef.appliedTo(_, _))
ts.foldRight(defn.EmptyTupleType.typeRef: Type)(defn.PairClass.typeRef.appliedTo(_, _))

class StripTypeVarsMap(using Context) extends TypeMap:
def apply(tp: Type) = mapOver(tp).stripTypeVar
Expand Down
36 changes: 19 additions & 17 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1304,11 +1304,11 @@ object Types {
case tp =>
tp

/** Widen all top-level singletons reachable by dealiasing
* and going to the operands of & and |.
/** Widen all top-level singletons reachable
* by going to the operands of & and |.
* Overridden and cached in OrType.
*/
def widenSingletons(using Context): Type = dealias match {
def widenSingletons(using Context): Type = this match {
case tp: SingletonType =>
tp.widen
case tp: OrType =>
Expand Down Expand Up @@ -1856,11 +1856,11 @@ object Types {
case _ => this
}

/** The set of distinct symbols referred to by this type, after all aliases are expanded */
/** The set of distinct symbols referred to by this type */
def coveringSet(using Context): Set[Symbol] =
(new CoveringSetAccumulator).apply(Set.empty[Symbol], this)

/** The number of applications and refinements in this type, after all aliases are expanded */
/** The number of applications and refinements in this type */
def typeSize(using Context): Int =
(new TypeSizeAccumulator).apply(0, this)

Expand Down Expand Up @@ -6178,11 +6178,12 @@ object Types {

class TypeSizeAccumulator(using Context) extends TypeAccumulator[Int] {
var seen = util.HashSet[Type](initialCapacity = 8)
def apply(n: Int, tp: Type): Int =
if seen.contains(tp) then n
def apply(n: Int, tp1: Type): Int =
val tp0 = tp1.dealias
if seen.contains(tp0) then n
else {
seen += tp
tp match {
seen += tp0
tp0 match {
case tp: AppliedType =>
foldOver(n + 1, tp)
case tp: RefinedType =>
Expand All @@ -6192,23 +6193,24 @@ object Types {
case tp: TypeParamRef =>
apply(n, TypeComparer.bounds(tp))
case _ =>
foldOver(n, tp)
foldOver(n, tp0)
}
}
}

class CoveringSetAccumulator(using Context) extends TypeAccumulator[Set[Symbol]] {
var seen = util.HashSet[Type](initialCapacity = 8)
def apply(cs: Set[Symbol], tp: Type): Set[Symbol] =
if seen.contains(tp) then cs
def apply(cs: Set[Symbol], tp1: Type): Set[Symbol] =
val tp0 = tp1.dealias
if seen.contains(tp0) then cs
else {
seen += tp
tp match {
seen += tp0
tp0 match {
case tp if tp.isExactlyAny || tp.isExactlyNothing =>
cs
case tp: AppliedType =>
case tp: AppliedType if !tp.typeSymbol.isAliasType =>
foldOver(cs + tp.typeSymbol, tp)
case tp: RefinedType =>
case tp: RefinedType if !tp.typeSymbol.isAliasType =>
foldOver(cs + tp.typeSymbol, tp)
case tp: TypeRef if tp.info.isTypeAlias =>
apply(cs, tp.superType)
Expand All @@ -6220,7 +6222,7 @@ object Types {
case tp: TypeParamRef =>
apply(cs, TypeComparer.bounds(tp))
case other =>
foldOver(cs, tp)
foldOver(cs, tp0)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package xml

import Utility._
import util.Chars.SU
import scala.collection.BufferedIterator



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package xml
import scala.language.unsafeNulls

import scala.collection.mutable
import scala.collection.BufferedIterator
import mutable.{ Buffer, ArrayBuffer, ListBuffer }
import scala.util.control.ControlThrowable
import util.Chars.SU
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/printing/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ abstract class Printer {
atPrec(GlobalPrec) { elem.toText(this) }

/** Render elements alternating with `sep` string */
def toText(elems: Traversable[Showable], sep: String): Text =
def toText(elems: Iterable[Showable], sep: String): Text =
Text(elems map (_ toText this), sep)

/** Render elements within highest precedence */
def toTextLocal(elems: Traversable[Showable], sep: String): Text =
def toTextLocal(elems: Iterable[Showable], sep: String): Text =
atPrec(DotPrec) { toText(elems, sep) }

/** Render elements within lowest precedence */
def toTextGlobal(elems: Traversable[Showable], sep: String): Text =
def toTextGlobal(elems: Iterable[Showable], sep: String): Text =
atPrec(GlobalPrec) { toText(elems, sep) }

/** A plain printer without any embellishments */
Expand Down
Loading