Skip to content

Commit

Permalink
Backport "Patmat: Use less type variables in prefix inference" (#17440)
Browse files Browse the repository at this point in the history
Backports #16827

Solves #17368
  • Loading branch information
Kordyjan authored May 9, 2023
2 parents 597144e + 3b9b83d commit 752ad2f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
53 changes: 38 additions & 15 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package dotty.tools
package dotc
package core

import Contexts._, Types._, Symbols._, Names._, Flags._
import Contexts._, Types._, Symbols._, Names._, NameKinds.*, Flags._
import SymDenotations._
import util.Spans._
import util.Stats
Expand Down Expand Up @@ -839,40 +839,63 @@ object TypeOps:
}
}

// Prefix inference, replace `p.C.this.Child` with `X.Child` where `X <: p.C`
// Note: we need to strip ThisType in `p` recursively.
/** Gather GADT symbols and `ThisType`s found in `tp2`, ie. the scrutinee. */
object TraverseTp2 extends TypeTraverser:
val thisTypes = util.HashSet[ThisType]()
val gadtSyms = new mutable.ListBuffer[Symbol]

def traverse(tp: Type) = {
val tpd = tp.dealias
if tpd ne tp then traverse(tpd)
else tp match
case tp: ThisType if !tp.tref.symbol.isStaticOwner && !thisTypes.contains(tp) =>
thisTypes += tp
traverseChildren(tp.tref)
case tp: TypeRef if tp.symbol.isAbstractOrParamType =>
gadtSyms += tp.symbol
traverseChildren(tp)
case _ =>
traverseChildren(tp)
}
TraverseTp2.traverse(tp2)
val thisTypes = TraverseTp2.thisTypes
val gadtSyms = TraverseTp2.gadtSyms.toList

// Prefix inference, given `p.C.this.Child`:
// 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or
// 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively.
//
// See tests/patmat/i3938.scala
// See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala
class InferPrefixMap extends TypeMap {
var prefixTVar: Type | Null = null
def apply(tp: Type): Type = tp match {
case ThisType(tref: TypeRef) if !tref.symbol.isStaticOwner =>
case tp @ ThisType(tref) if !tref.symbol.isStaticOwner =>
val symbol = tref.symbol
if (symbol.is(Module))
if thisTypes.contains(tp) then
prefixTVar = tp // e.g. tests/pos/i16785.scala, keep Outer.this
prefixTVar.uncheckedNN
else if symbol.is(Module) then
TermRef(this(tref.prefix), symbol.sourceModule)
else if (prefixTVar != null)
this(tref)
else {
prefixTVar = WildcardType // prevent recursive call from assigning it
val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds) }
// e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints`
val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) }
val tref2 = this(tref.applyIfParameterized(tvars))
prefixTVar = newTypeVar(TypeBounds.upper(tref2))
prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name))
prefixTVar.uncheckedNN
}
case tp => mapOver(tp)
}
}

val inferThisMap = new InferPrefixMap
val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds) }
val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) }
val protoTp1 = inferThisMap.apply(tp1).appliedTo(tvars)

val getAbstractSymbols = new TypeAccumulator[List[Symbol]]:
def apply(xs: List[Symbol], tp: Type) = tp.dealias match
case tp: TypeRef if tp.symbol.exists && !tp.symbol.isClass => foldOver(tp.symbol :: xs, tp)
case tp => foldOver(xs, tp)
val syms2 = getAbstractSymbols(Nil, tp2).reverse
if syms2.nonEmpty then ctx.gadtState.addToConstraint(syms2)
if gadtSyms.nonEmpty then
ctx.gadtState.addToConstraint(gadtSyms)

// If parent contains a reference to an abstract type, then we should
// refine subtype checking to eliminate abstract types according to
Expand Down
2 changes: 1 addition & 1 deletion tests/patmat/i12408.check
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
13: Pattern Match Exhaustivity: X[<?>] & (X.this : X[T]).A(_), X[<?>] & (X.this : X[T]).C(_)
13: Pattern Match Exhaustivity: A(_), C(_)
21: Pattern Match
11 changes: 11 additions & 0 deletions tests/pos/i16785.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class VarImpl[Lbl, A]

class Outer[|*|[_, _], Lbl1]:
type Var[A1] = VarImpl[Lbl1, A1]

sealed trait Foo[G]
case class Bar[T, U]()
extends Foo[Var[T] |*| Var[U]]

def go[X](scr: Foo[Var[X]]): Unit = scr match // was: compile hang
case Bar() => ()

0 comments on commit 752ad2f

Please sign in to comment.