Skip to content

Commit

Permalink
Use a different scheme for deciding when to use the default
Browse files Browse the repository at this point in the history
We now don't try to instantiate selected type variables. Instead, we use a default as fallback if
the expected type is underspecified according to the definition in Implicits. This is
simpler and more expressive.
  • Loading branch information
odersky committed Nov 25, 2024
1 parent 56f849b commit a728906
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 37 deletions.
26 changes: 13 additions & 13 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,19 @@ object Implicits:
def msg(using Context): Message =
em"${errors.map(_.msg).mkString("\n")}"
}

private def isUnderSpecifiedArgument(tp: Type)(using Context): Boolean =
tp.isRef(defn.NothingClass) || tp.isRef(defn.NullClass) || (tp eq NoPrefix)

def isUnderspecified(tp: Type)(using Context): Boolean = tp.stripTypeVar match
case tp: WildcardType =>
!tp.optBounds.exists || isUnderspecified(tp.optBounds.hiBound)
case tp: ViewProto =>
isUnderspecified(tp.resType)
|| tp.resType.isRef(defn.UnitClass)
|| isUnderSpecifiedArgument(tp.argType.widen)
case _ =>
tp.isAny || tp.isAnyRef
end Implicits

import Implicits.*
Expand Down Expand Up @@ -1665,19 +1678,6 @@ trait Implicits:
res
end searchImplicit

def isUnderSpecifiedArgument(tp: Type): Boolean =
tp.isRef(defn.NothingClass) || tp.isRef(defn.NullClass) || (tp eq NoPrefix)

private def isUnderspecified(tp: Type): Boolean = tp.stripTypeVar match
case tp: WildcardType =>
!tp.optBounds.exists || isUnderspecified(tp.optBounds.hiBound)
case tp: ViewProto =>
isUnderspecified(tp.resType)
|| tp.resType.isRef(defn.UnitClass)
|| isUnderSpecifiedArgument(tp.argType.widen)
case _ =>
tp.isAny || tp.isAnyRef

/** Search implicit in context `ctxImplicits` or else in implicit scope
* of expected type if `ctxImplicits == null`.
*/
Expand Down
33 changes: 13 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2420,32 +2420,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// is either a given instance of type ExpressibleAsCollectionLiteralClass
// or a default instance. The default instance is either Seq or Map,
// depending on the forms of `tree.elems`. We search for a type class if
// the expected type is a value type that is not an uninstantiated type variable.
// the expected type is a value type that is not underspeficied for implicit search.
val maker = pt match
case pt: TypeVar if !pt.isInstantiated =>
defaultMaker
case pt: ValueType =>
case pt: ValueType if !Implicits.isUnderspecified(wildApprox(pt)) =>
val tc = defn.ExpressibleAsCollectionLiteralClass.typeRef.appliedTo(pt)
val nestedCtx = ctx.fresh.setNewTyperState()
val maker = inContext(nestedCtx):
// Find given instance `witness` of type `ExpressibleAsCollectionLiteral[<pt>]`
val witness = inferImplicitArg(tc, tree.span.startPos)
if witness.tpe.isInstanceOf[SearchFailureType] then
val msg = missingArgMsg(witness, pt, "")
if isAmbiguousGiven(witness) then report.error(msg, tree.srcPos)
else typr.println(i"failed collection literal witness: ${msg.toString}")
// Find given instance `witness` of type `ExpressibleAsCollectionLiteral[<pt>]`
val witness = inferImplicitArg(tc, tree.span.startPos)
def errMsg = missingArgMsg(witness, pt, "")
typr.println(i"infer for $tree with $tc = $witness, ${ctx.typerState.constraint}")
witness.tpe match
case _: AmbiguousImplicits =>
report.error(errMsg, tree.srcPos)
defaultMaker
else
// Instantiate local type variables in witness.tpe, so that nested
// SeqLiterals don't get typed as default Seq due to first case above
def isLocal(tv: TypeVar) =
val state = tv.owningState
state != null && (state.get eq ctx.typerState)
instantiateSelected(witness.tpe, isLocal, minimize = false)
case _: SearchFailureType =>
typr.println(i"failed collection literal witness: ${errMsg.toString}")
defaultMaker
case _ =>
// Continue with typing `witness.fromLiteral` as the constructor
untpd.TypedSplice(witness.select(nme.fromLiteral))
nestedCtx.typerState.commit()
maker
case _ =>
defaultMaker
typed(
Expand Down
20 changes: 20 additions & 0 deletions library/src/scala/compiletime/ExpressibleAsCollectionLiteral.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ import reflect.ClassTag
type Elem = T
inline def fromLiteral(inline xs: T*) = IArray(xs*)

given arrayBufferFromLiteral: [T: ClassTag] => ExpressibleAsCollectionLiteral[collection.mutable.ArrayBuffer[T]]:
type Elem = T
inline def fromLiteral(inline xs: T*) = collection.mutable.ArrayBuffer(xs*)

given setFromLiteral: [T] => ExpressibleAsCollectionLiteral[Set[T]]:
type Elem = T
inline def fromLiteral(inline xs: T*) = Set(xs*)

given hashSetFromLiteral: [T] => ExpressibleAsCollectionLiteral[collection.mutable.HashSet[T]]:
type Elem = T
inline def fromLiteral(inline xs: T*) = collection.mutable.HashSet(xs*)

given bitsetFromLiteral: ExpressibleAsCollectionLiteral[collection.immutable.BitSet]:
type Elem = Int
inline def fromLiteral(inline xs: Int*) = collection.immutable.BitSet(xs*)

given mapFromLiteral: [K, V] => ExpressibleAsCollectionLiteral[Map[K, V]]:
type Elem = (K, V)
inline def fromLiteral(inline xs: (K, V)*) = Map(xs*)

given hashMapFromLiteral: [K, V] => ExpressibleAsCollectionLiteral[collection.mutable.HashMap[K, V]]:
type Elem = (K, V)
inline def fromLiteral(inline xs: (K, V)*) = collection.mutable.HashMap(xs*)
29 changes: 25 additions & 4 deletions tests/neg/seqlits.check
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
-- [E172] Type Error: tests/neg/seqlits.scala:20:14 --------------------------------------------------------------------
20 | val x: A = [1, 2, 3] // error: ambiguous
-- [E172] Type Error: tests/neg/seqlits.scala:23:14 --------------------------------------------------------------------
23 | val x: A = [1, 2, 3] // error: ambiguous
| ^^^^^^^
|Ambiguous given instances: both given instance given_ExpressibleAsCollectionLiteral_B in object SeqLits and given instance given_ExpressibleAsCollectionLiteral_C in object SeqLits match type scala.compiletime.ExpressibleAsCollectionLiteral[A]
-- [E007] Type Mismatch Error: tests/neg/seqlits.scala:21:14 -----------------------------------------------------------
21 | val y: D = [1, 2, 3] // error: type mismatch
-- [E007] Type Mismatch Error: tests/neg/seqlits.scala:24:14 -----------------------------------------------------------
24 | val y: D = [1, 2, 3] // error: type mismatch
| ^^^^^^^
| Found: Seq[Int]
| Required: D
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/seqlits.scala:26:39 -----------------------------------------------------------
26 | val mbss: Map[BitSet, Seq[Int]] = [[1] -> [1], [0, 2] -> [1, 2], [0] -> []] // error: type mismatch // error // error
| ^^^^^^^^
| Found: (Seq[Int], Seq[Int])
| Required: (scala.collection.immutable.BitSet, Seq[Int])
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/seqlits.scala:26:51 -----------------------------------------------------------
26 | val mbss: Map[BitSet, Seq[Int]] = [[1] -> [1], [0, 2] -> [1, 2], [0] -> []] // error: type mismatch // error // error
| ^^^^^^^^^^^^^^
| Found: (Seq[Int], Seq[Int])
| Required: (scala.collection.immutable.BitSet, Seq[Int])
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/seqlits.scala:26:69 -----------------------------------------------------------
26 | val mbss: Map[BitSet, Seq[Int]] = [[1] -> [1], [0, 2] -> [1, 2], [0] -> []] // error: type mismatch // error // error
| ^^^^^
| Found: (Seq[Int], Seq[Nothing])
| Required: (scala.collection.immutable.BitSet, Seq[Int])
|
| longer explanation available when compiling with `-explain`
5 changes: 5 additions & 0 deletions tests/neg/seqlits.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import language.`3.7`
import compiletime.ExpressibleAsCollectionLiteral
import language.experimental.collectionLiterals
import collection.immutable.BitSet

class A

Expand All @@ -19,3 +22,5 @@ object SeqLits:

val x: A = [1, 2, 3] // error: ambiguous
val y: D = [1, 2, 3] // error: type mismatch

val mbss: Map[BitSet, Seq[Int]] = [[1] -> [1], [0, 2] -> [1, 2], [0] -> []] // error: type mismatch // error // error
5 changes: 5 additions & 0 deletions tests/run/seqlits.check
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ Map(1 -> one, 2 -> two)
List(List(1), List(2, 3), List())
List(List(1), List(2, 3), List())
Vector(Vector(1), Vector(2, 3), Vector())
ArrayBuffer(Set(hello, world), Set())
Set(BitSet(1), BitSet(2, 3), BitSet())
Map(1 -> BitSet(1), 2 -> BitSet(1, 2), 0 -> BitSet())
HashMap(0 -> List(), 1 -> List(BitSet(1), BitSet(2, 3)), 2 -> List(BitSet()))
Map(BitSet(1) -> List(1), BitSet(0, 2) -> List(1, 2), BitSet(0) -> List())
13 changes: 13 additions & 0 deletions tests/run/seqlits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import language.`3.7`
import reflect.ClassTag
import compiletime.ExpressibleAsCollectionLiteral
import collection.immutable.BitSet
import collection.mutable.{ArrayBuffer, HashMap}
import language.experimental.collectionLiterals

/** Some delayed computation like a Mill Task */
case class Task[T](body: () => T)
Expand Down Expand Up @@ -33,6 +35,12 @@ object SeqLits:
val ss2 = [[1], [2, 3], []]
val _: Seq[Seq[Int]] = ss2
val vs: Vector[Vector[Int]] = [[1], [2, 3], []]
val ab: ArrayBuffer[Set[String]] = [["hello", "world"], []]
val sbs: Set[BitSet] = [[1], [2, 3], []]
val mbs: Map[Int, BitSet] = [1 -> [1], 2 -> [1, 2], 0 -> []]
val hbs: HashMap[Int, Seq[BitSet]] = [1 -> [[1], [2, 3]], 2 -> [[]], 0 -> []]
// val mbss: Map[BitSet, Seq[Int]] = [[1] -> [1], [0, 2] -> [1, 2], [0] -> []] // error: keys get default value Seq
val mbss: Map[BitSet, Seq[Int]] = [([1], [1]), ([0, 2], [1, 2]), ([0], [])] // ok

println(s"Seq $s")
println(s"Vector $v")
Expand All @@ -48,3 +56,8 @@ object SeqLits:
println(ss)
println(ss2)
println(vs)
println(ab)
println(sbs)
println(mbs)
println(hbs)
println(mbss)

0 comments on commit a728906

Please sign in to comment.