diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 316acf02d453..a52502364c4d 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -805,11 +805,11 @@ object PatternMatcher { */ private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[(List[Tree], Plan)] = { def isSwitchableType(tpe: Type): Boolean = - (tpe isRef defn.IntClass) || - (tpe isRef defn.ByteClass) || - (tpe isRef defn.ShortClass) || - (tpe isRef defn.CharClass) || - (tpe isRef defn.StringClass) + (tpe <:< defn.IntType) || + (tpe <:< defn.ByteType) || + (tpe <:< defn.ShortType) || + (tpe <:< defn.CharType) || + (tpe <:< defn.StringType) val seen = mutable.Set[Any]() @@ -859,7 +859,7 @@ object PatternMatcher { (Nil, plan) :: Nil } - if (isSwitchableType(scrutinee.tpe.widen)) recur(plan) + if (isSwitchableType(scrutinee.tpe)) recur(plan) else Nil } @@ -880,8 +880,8 @@ object PatternMatcher { */ val (primScrutinee, scrutineeTpe) = - if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType) - else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType) + if (scrutinee.tpe <:< defn.IntType) (scrutinee, defn.IntType) + else if (scrutinee.tpe <:< defn.StringType) (scrutinee, defn.StringType) else (scrutinee.select(nme.toInt), defn.IntType) def primLiteral(lit: Tree): Tree = diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 415e5f05487f..1a141dbdb978 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -158,6 +158,126 @@ class DottyBytecodeTests extends DottyBytecodeTest { } } + @Test def switchOnUnionOfInts = { + val source = + """ + |object Foo { + | def foo(x: 1 | 2 | 3 | 4 | 5) = x match { + | case 1 => println(3) + | case 2 | 3 => println(2) + | case 4 => println(1) + | case 5 => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfStrings = { + val source = + """ + |object Foo { + | def foo(s: "one" | "two" | "three" | "four" | "five") = s match { + | case "one" => println(3) + | case "two" | "three" => println(2) + | case "four" | "five" => println(1) + | case _ => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfChars = { + val source = + """ + |object Foo { + | def foo(ch: 'a' | 'b' | 'c' | 'd' | 'e'): Int = ch match { + | case 'a' => 1 + | case 'b' => 2 + | case 'c' => 3 + | case 'd' => 4 + | case 'e' => 5 + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfIntSingletons = { + val source = + """ + |object Foo { + | final val One = 1 + | final val Two = 2 + | final val Three = 3 + | final val Four = 4 + | final val Five = 5 + | type Values = One.type | Two.type | Three.type | Four.type | Five.type + | + | def foo(s: Values) = s match { + | case One => println(3) + | case Two | Three => println(2) + | case Four => println(1) + | case Five => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + + @Test def switchOnUnionOfStringSingletons = { + val source = + """ + |object Foo { + | final val One = "one" + | final val Two = "two" + | final val Three = "three" + | final val Four = "four" + | final val Five = "five" + | type Values = One.type | Two.type | Three.type | Four.type | Five.type + | + | def foo(s: Values) = s match { + | case One => println(3) + | case Two | Three => println(2) + | case Four => println(1) + | case Five => println(0) + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val moduleIn = dir.lookupName("Foo$.class", directory = false) + val moduleNode = loadClassNode(moduleIn.input) + val methodNode = getMethod(moduleNode, "foo") + assert(verifySwitch(methodNode)) + } + } + @Test def matchWithDefaultNoThrowMatchError = { val source = """class Test {