Skip to content

Commit e9a6eea

Browse files
committed
Fix detection of ill-nested awaits
- Adapt detection logic to the post-erasure tree shapes - Lift restriction about || and && in favour of a rewrite to `If` in the ANF transform - import relevant parts of the test from scala-async.
1 parent d425b2f commit e9a6eea

File tree

10 files changed

+165
-34
lines changed

10 files changed

+165
-34
lines changed

src/compiler/scala/tools/nsc/transform/UnCurry.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ package tools.nsc
1515
package transform
1616

1717
import scala.annotation.tailrec
18-
1918
import symtab.Flags._
2019
import scala.collection.mutable
2120
import scala.collection.mutable.ListBuffer
2221
import scala.reflect.internal.util.ListOfNil
23-
2422
import PartialFunction.cond
23+
import scala.reflect.NameTransformer
2524

2625
/*<export> */
2726
/** - uncurry all symbol and tree types (@see UnCurryPhase) -- this includes normalizing all proper types.
@@ -149,7 +148,7 @@ abstract class UnCurry extends InfoTransform
149148
/** Return non-local return key for given method */
150149
private def nonLocalReturnKey(meth: Symbol) =
151150
nonLocalReturnKeys.getOrElseUpdate(meth,
152-
meth.newValue(unit.freshTermName("nonLocalReturnKey"), meth.pos, SYNTHETIC) setInfo ObjectTpe
151+
meth.newValue(unit.freshTermName(nme.NON_LOCAL_RETURN_KEY_STRING), meth.pos, SYNTHETIC) setInfo ObjectTpe
153152
)
154153

155154
/** Generate a non-local return throw with given return expression from given method.

src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ private[async] trait AnfTransform extends TransformUtils {
5252
tree match {
5353
case _: ClassDef | _: ModuleDef | _: Function | _: DefDef =>
5454
tree
55-
case _: RefTree if tree.symbol.hasPackageFlag =>
56-
tree
5755
case _ if !treeContainsAwait =>
5856
tree
59-
case Apply(fun, args) if !isBooleanShortCircuit(fun.symbol) =>
57+
case Apply(sel @ Select(fun, _), arg :: Nil) if isBooleanAnd(sel.symbol) && containsAwait(arg) =>
58+
transform(treeCopy.If(tree, fun, arg, literalBool(false)))
59+
case Apply(sel @ Select(fun, _), arg :: Nil) if isBooleanOr(sel.symbol) && containsAwait(arg) =>
60+
transform(treeCopy.If(tree, fun, literalBool(true), arg))
61+
case Apply(fun, args) =>
6062
val lastAwaitArgIndex: Int = args.lastIndexWhere(containsAwait)
6163
val simpleFun = transform(fun)
6264
var i = 0

src/compiler/scala/tools/nsc/transform/async/AsyncAnalysis.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
package scala.tools.nsc.transform.async
1414

1515
import scala.collection.mutable.ListBuffer
16-
import scala.reflect.internal.Flags
16+
import scala.reflect.NameTransformer
1717

1818
trait AsyncAnalysis extends TransformUtils {
1919
import global._
@@ -35,8 +35,8 @@ trait AsyncAnalysis extends TransformUtils {
3535
reportUnsupportedAwait(classDef, s"nested $kind")
3636
}
3737

38-
override def nestedModule(module: ModuleDef): Unit = {
39-
reportUnsupportedAwait(module, "nested object")
38+
override def nestedModuleClass(moduleClass: ClassDef): Unit = {
39+
reportUnsupportedAwait(moduleClass, "nested object")
4040
}
4141

4242
override def nestedMethod(defDef: DefDef): Unit = {
@@ -50,24 +50,19 @@ trait AsyncAnalysis extends TransformUtils {
5050
override def function(function: Function): Unit = {
5151
reportUnsupportedAwait(function, "nested function")
5252
}
53-
54-
override def patMatFunction(tree: Match): Unit = {
55-
reportUnsupportedAwait(tree, "nested function")
53+
override def function(expandedFunction: ClassDef): Unit = {
54+
reportUnsupportedAwait(expandedFunction, "nested function")
5655
}
5756

5857
override def traverse(tree: Tree): Unit = {
5958
tree match {
6059
case Try(_, _, _) if containsAwait(tree) =>
6160
reportUnsupportedAwait(tree, "try/catch")
6261
super.traverse(tree)
63-
case Return(_) =>
62+
case Throw(Apply(fun, Ident(name) :: _)) if fun.symbol.isConstructor && fun.symbol.owner == definitions.NonLocalReturnControlClass && name.startsWith(nme.NON_LOCAL_RETURN_KEY_STRING) =>
6463
global.reporter.error(tree.pos, "return is illegal within a async block")
65-
case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flags.LAZY) && containsAwait(tree) =>
66-
reportUnsupportedAwait(tree, "lazy val initializer")
67-
case ValDef(mods, _, _, _) if mods.hasFlag(Flags.LAZY) && containsAwait(tree) =>
64+
case DefDef(mods, _, _, _, _, _) if tree.symbol.name.endsWith(nme.LAZY_SLOW_SUFFIX) && containsAwait(tree) =>
6865
reportUnsupportedAwait(tree, "lazy val initializer")
69-
case CaseDef(_, guard, _) if guard exists isAwait =>
70-
reportUnsupportedAwait(tree, "pattern guard")
7166
case _ =>
7267
super.traverse(tree)
7368
}
@@ -91,7 +86,7 @@ trait AsyncAnalysis extends TransformUtils {
9186
traverser(tree)
9287
badAwaits foreach {
9388
tree =>
94-
reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
89+
reportError(tree.pos, s"${currentTransformState.Async_await.decodedName} must not be used under a $whyUnsupported.")
9590
}
9691
badAwaits.nonEmpty
9792
}

src/compiler/scala/tools/nsc/transform/async/LiveVariables.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,14 @@ trait LiveVariables extends ExprBuilder {
8282

8383
override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef)
8484

85-
override def nestedModule(module: ModuleDef): Unit = capturingCheck(module)
85+
override def nestedModuleClass(moduleClass: ClassDef): Unit = capturingCheck(moduleClass)
8686

8787
override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef)
8888

8989
override def byNameArgument(arg: Tree): Unit = capturingCheck(arg)
9090

9191
override def function(function: Function): Unit = capturingCheck(function)
92-
93-
override def patMatFunction(tree: Match): Unit = capturingCheck(tree)
92+
override def function(expandedFunction: ClassDef): Unit = capturingCheck(expandedFunction)
9493
}
9594

9695
val findUses = new FindUseTraverser

src/compiler/scala/tools/nsc/transform/async/TransformUtils.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ package scala.tools.nsc.transform.async
1414

1515
import scala.collection.mutable
1616
import scala.collection.mutable.ListBuffer
17-
import scala.reflect.internal.Mode
1817

1918
/**
2019
* Utilities used in both `ExprBuilder` and `AnfTransform`.
@@ -36,8 +35,10 @@ private[async] trait TransformUtils extends AsyncTransformStates {
3635

3736
def isAwait(fun: Tree): Boolean = fun.symbol == currentTransformState.Async_await
3837

39-
def isBooleanShortCircuit(sym: Symbol): Boolean =
40-
sym.owner == definitions.BooleanClass && (sym == definitions.Boolean_and || sym == definitions.Boolean_or)
38+
def isBooleanAnd(sym: Symbol): Boolean =
39+
sym.owner == definitions.BooleanClass && sym == definitions.Boolean_and
40+
def isBooleanOr(sym: Symbol): Boolean =
41+
sym.owner == definitions.BooleanClass && sym == definitions.Boolean_or
4142

4243
def isLabel(sym: Symbol): Boolean = sym != null && sym.isLabel
4344
def isCaseLabel(sym: Symbol): Boolean = sym != null && sym.isLabel && sym.name.startsWith("case")
@@ -67,6 +68,7 @@ private[async] trait TransformUtils extends AsyncTransformStates {
6768

6869
def literalUnit: Tree = Literal(Constant(())).setType(definitions.UnitTpe) // a def to avoid sharing trees
6970
def literalBoxedUnit: Tree = gen.mkAttributedRef(definitions.BoxedUnit_UNIT)
71+
def literalBool(b: Boolean): Tree = Literal(Constant(b)).setType(definitions.BooleanTpe)
7072

7173
def isLiteralUnit(t: Tree): Boolean = t match {
7274
case Literal(Constant(())) => true
@@ -158,7 +160,7 @@ private[async] trait TransformUtils extends AsyncTransformStates {
158160
def nestedClass(classDef: ClassDef): Unit = {
159161
}
160162

161-
def nestedModule(module: ModuleDef): Unit = {
163+
def nestedModuleClass(moduleClass: ClassDef): Unit = {
162164
}
163165

164166
def nestedMethod(defdef: DefDef): Unit = {
@@ -170,20 +172,17 @@ private[async] trait TransformUtils extends AsyncTransformStates {
170172
def function(function: Function): Unit = {
171173
}
172174

173-
def patMatFunction(tree: Match): Unit = {
175+
def function(expandedFunction: ClassDef): Unit = {
174176
}
175177

176178
override def traverse(tree: Tree): Unit = {
177179
tree match {
178-
case cd: ClassDef => nestedClass(cd)
179-
case md: ModuleDef => nestedModule(md)
180+
case cd: ClassDef =>
181+
if (cd.symbol.isAnonymousClass) function(cd)
182+
else if (cd.symbol.isModuleClass) nestedModuleClass(cd)
183+
else nestedClass(cd)
180184
case dd: DefDef => nestedMethod(dd)
181185
case fun: Function => function(fun)
182-
case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
183-
case Apply(fun, arg1 :: arg2 :: Nil) if isBooleanShortCircuit(fun.symbol) =>
184-
traverse(fun)
185-
traverse(arg1)
186-
byNameArgument(arg2)
187186
case _ => super.traverse(tree)
188187
}
189188
}

src/reflect/scala/reflect/internal/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ trait StdNames {
129129
val NESTED_IN_ANON_FUN: String = NESTED_IN + ANON_FUN_NAME.toString.replace("$", "")
130130
val NESTED_IN_LAMBDA: String = NESTED_IN + DELAMBDAFY_LAMBDA_CLASS_NAME.toString.replace("$", "")
131131

132+
val NON_LOCAL_RETURN_KEY_STRING: String = "nonLocalReturnKey"
133+
132134
/**
133135
* Ensures that name mangling does not accidentally make a class respond `true` to any of
134136
* isAnonymousClass, isAnonymousFunction, isDelambdafyFunction, e.g. by introducing "$anon".
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
ill-nested-await.scala:15: error: await must not be used under a nested method.
2+
async { foo(0)(await(f(0))) }
3+
^
4+
ill-nested-await.scala:20: error: await must not be used under a nested object.
5+
async { object Nested { await(f(false)) } }
6+
^
7+
ill-nested-await.scala:25: error: await must not be used under a nested trait.
8+
async { trait Nested { await(f(false)) } }
9+
^
10+
ill-nested-await.scala:30: error: await must not be used under a nested class.
11+
async { class Nested { await(f(false)) } }
12+
^
13+
ill-nested-await.scala:35: error: await must not be used under a nested method.
14+
async { () => { await(f(false)) } }
15+
^
16+
ill-nested-await.scala:40: error: await must not be used under a nested function.
17+
async { { case 0 => { await(f(false)) } } : PartialFunction[Int, Boolean] }
18+
^
19+
ill-nested-await.scala:45: error: await must not be used under a try/catch.
20+
async { try { await(f(false)) } catch { case _: Throwable => } }
21+
^
22+
ill-nested-await.scala:50: error: await must not be used under a try/catch.
23+
async { try { () } catch { case _: Throwable => await(f(false)) } }
24+
^
25+
ill-nested-await.scala:55: error: await must not be used under a try/catch.
26+
async { try { () } finally { await(f(false)) } }
27+
^
28+
ill-nested-await.scala:60: error: await must not be used under a nested method.
29+
async { def foo = await(f(false)) }
30+
^
31+
ill-nested-await.scala:69: error: await must not be used under a lazy val initializer.
32+
def foo(): Any = async { val x = { lazy val y = await(f(0)); y } }
33+
^
34+
ill-nested-await.scala:9: error: `await` must be enclosed in an `async` block
35+
await[Any](f(null))
36+
^
37+
12 errors found
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-Xasync
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import scala.concurrent._
2+
import ExecutionContext.Implicits.global
3+
import scala.tools.partest.async.Async._
4+
import Future.{successful => f}
5+
6+
7+
class NakedAwait {
8+
def `await only allowed in async neg`(): Unit = {
9+
await[Any](f(null))
10+
}
11+
12+
def `await not allowed in by-name argument`(): Unit = {
13+
// expectError("await must not be used under a by-name argument.") {
14+
def foo(a: Int)(b: => Int) = 0
15+
async { foo(0)(await(f(0))) }
16+
}
17+
18+
def nestedObject(): Unit = {
19+
// expectError("await must not be used under a nested object.") {
20+
async { object Nested { await(f(false)) } }
21+
}
22+
23+
def nestedTrait(): Unit = {
24+
// expectError("await must not be used under a nested trait.") {
25+
async { trait Nested { await(f(false)) } }
26+
}
27+
28+
def nestedClass(): Unit = {
29+
// expectError("await must not be used under a nested class.") {
30+
async { class Nested { await(f(false)) } }
31+
}
32+
33+
def nestedFunction(): Unit = {
34+
// expectError("await must not be used under a nested function.") {
35+
async { () => { await(f(false)) } }
36+
}
37+
38+
def nestedPatMatFunction(): Unit = {
39+
// expectError("await must not be used under a nested class.") { // TODO more specific error message
40+
async { { case 0 => { await(f(false)) } } : PartialFunction[Int, Boolean] }
41+
}
42+
43+
def tryBody(): Unit = {
44+
// expectError("await must not be used under a try/catch.") {
45+
async { try { await(f(false)) } catch { case _: Throwable => } }
46+
}
47+
48+
def catchBody(): Unit = {
49+
// expectError("await must not be used under a try/catch.") {
50+
async { try { () } catch { case _: Throwable => await(f(false)) } }
51+
}
52+
53+
def finallyBody(): Unit = {
54+
// expectError("await must not be used under a try/catch.") {
55+
async { try { () } finally { await(f(false)) } }
56+
}
57+
58+
def nestedMethod(): Unit = {
59+
// expectError("await must not be used under a nested method.") {
60+
async { def foo = await(f(false)) }
61+
}
62+
63+
// def returnIllegal(): Unit = {
64+
// def foo(): Any = async { return false } //!!!
65+
// }
66+
67+
def lazyValIllegal(): Unit = {
68+
//expectError("await must not be used under a lazy val initializer")
69+
def foo(): Any = async { val x = { lazy val y = await(f(0)); y } }
70+
}
71+
}

test/junit/scala/tools/nsc/async/AnnotationDrivenAsync.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ class AnnotationDrivenAsync {
3434
assertEquals(3, run(code))
3535
}
3636

37+
@Test
38+
def testBooleanAndOr(): Unit = {
39+
val code =
40+
"""import scala.concurrent._, duration.Duration, ExecutionContext.Implicits.global
41+
|import scala.tools.partest.async.Async.{async, await}
42+
|import Future.{successful => f}
43+
|
44+
|object Test {
45+
| var counter = 0
46+
| def ordered(i: Int, b: Boolean): Boolean = { assert(counter == i, (counter, i)); counter += 1; b }
47+
| def test: Future[Any] = async {
48+
| counter = 0; assert(!(ordered(0, false) && await(f(ordered(-1, true)))))
49+
| counter = 0; assert(!(ordered(0, false) && await(f(ordered(-1, false)))))
50+
| counter = 0; assert( (ordered(0, true) && await(f(ordered( 1, true)))))
51+
| counter = 0; assert(!(ordered(0, true) && await(f(ordered( 1, false)))))
52+
| counter = 0; assert( (ordered(0, false) || await(f(ordered( 1, true)))))
53+
| counter = 0; assert(!(ordered(0, false) || await(f(ordered( 1, false)))))
54+
| counter = 0; assert( (ordered(0, true) || await(f(ordered(-1, false)))))
55+
| counter = 0; assert( (ordered(0, true) || await(f(ordered(-1, true)))))
56+
| ()
57+
| }
58+
|}
59+
|""".stripMargin
60+
assertEquals((), run(code))
61+
}
62+
3763
@Test
3864
def testBasicScalaConcurrentViaMacroFrontEnd(): Unit = {
3965
val code =

0 commit comments

Comments
 (0)