Skip to content
25 changes: 14 additions & 11 deletions src/main/scala/scala/async/internal/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[async] trait AnfTransform {
stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil))))
expr match {
case Apply(fun, args) if isAwait(fun) =>
val valDef = defineVal(name.await, expr, tree.pos)
val valDef = defineVal(name.await(), expr, tree.pos)
val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe)
val ref1 = if (ref.tpe =:= definitions.UnitTpe)
// https://github.com/scala/async/issues/74
Expand Down Expand Up @@ -109,7 +109,7 @@ private[async] trait AnfTransform {
} else if (expr.tpe =:= definitions.NothingTpe) {
statsExprThrow
} else {
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
val varDef = defineVar(name.ifRes(), expr.tpe, tree.pos)
def typedAssign(lhs: Tree) =
api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))

Expand Down Expand Up @@ -140,7 +140,7 @@ private[async] trait AnfTransform {
} else if (expr.tpe =:= definitions.NothingTpe) {
statsExprThrow
} else {
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
val varDef = defineVar(name.matchRes(), expr.tpe, tree.pos)
def typedAssign(lhs: Tree) =
api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))
val casesWithAssign = cases map {
Expand All @@ -163,14 +163,14 @@ private[async] trait AnfTransform {
}
}

def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
def defineVar(name: TermName, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name, pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
}

def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
def defineVal(name: TermName, lhs: Tree, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name, pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos)
}

Expand Down Expand Up @@ -212,7 +212,7 @@ private[async] trait AnfTransform {
case Arg(expr, _, argName) =>
linearize.transformToList(expr) match {
case stats :+ expr1 =>
val valDef = defineVal(argName, expr1, expr1.pos)
val valDef = defineVal(name.freshen(argName), expr1, expr1.pos)
require(valDef.tpe != null, valDef)
val stats1 = stats :+ valDef
(stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol))))
Expand Down Expand Up @@ -279,8 +279,9 @@ private[async] trait AnfTransform {
// TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = linearize.transformToBlock(body)
val (valDefs, mappings) = (pat collect {
case b@Bind(name, _) =>
val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos)
case b@Bind(bindName, _) =>
val vd = defineVal(name.freshen(bindName.toTermName), gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos)
vd.symbol.updateAttachment(SyntheticBindVal)
(vd, (b.symbol, vd.symbol))
}).unzip
val (from, to) = mappings.unzip
Expand Down Expand Up @@ -333,7 +334,7 @@ private[async] trait AnfTransform {
// Otherwise, create the matchres var. We'll callers of the label def below.
// Remember: we're iterating through the statement sequence in reverse, so we'll get
// to the LabelDef and mutate `matchResults` before we'll get to its callers.
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
val matchResult = linearize.defineVar(name.matchRes(), param.tpe, ld.pos)
matchResults += matchResult
caseDefToMatchResult(ld.symbol) = matchResult.symbol
val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)
Expand Down Expand Up @@ -408,3 +409,5 @@ private[async] trait AnfTransform {
}).asInstanceOf[Block]
}
}

object SyntheticBindVal
11 changes: 11 additions & 0 deletions src/main/scala/scala/async/internal/AsyncMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@ package scala.async.internal
object AsyncMacro {
def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
import language.reflectiveCalls

// Use an attachment on RootClass as a sneaky place for a per-Global cache
val att = c0.internal.attachments(c0.universe.rootMirror.RootClass)
val names = att.get[AsyncNames[_]].getOrElse {
val names = new AsyncNames[c0.universe.type](c0.universe)
att.update(names)
names
}

new AsyncMacro { self =>
val c: c0.type = c0
val asyncNames: AsyncNames[c.universe.type] = names.asInstanceOf[AsyncNames[c.universe.type]]
val body: c.Tree = body0
// This member is required by `AsyncTransform`:
val asyncBase: AsyncBase = base
Expand All @@ -23,6 +33,7 @@ private[async] trait AsyncMacro
val c: scala.reflect.macros.Context
val body: c.Tree
var containsAwait: c.Tree => Boolean
val asyncNames: AsyncNames[c.universe.type]

lazy val macroPos: c.universe.Position = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree): c.Tree = c.universe.atPos(macroPos)(t)
Expand Down
109 changes: 109 additions & 0 deletions src/main/scala/scala/async/internal/AsyncNames.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package scala.async.internal

import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.api.Names

/**
* A per-global cache of names needed by the Async macro.
*/
final class AsyncNames[U <: Names with Singleton](val u: U) {
self =>
import u._

abstract class NameCache[N <: U#Name](base: String) {
val cached = new ArrayBuffer[N]()
protected def newName(s: String): N
def apply(i: Int): N = {
if (cached.isDefinedAt(i)) cached(i)
else {
assert(cached.length == i)
val name = newName(freshenString(base, i))
cached += name
name
}
}
}

final class TermNameCache(base: String) extends NameCache[U#TermName](base) {
override protected def newName(s: String): U#TermName = newTermName(s)
}
final class TypeNameCache(base: String) extends NameCache[U#TypeName](base) {
override protected def newName(s: String): U#TypeName = newTypeName(s)
}
private val matchRes: TermNameCache = new TermNameCache("match")
private val ifRes: TermNameCache = new TermNameCache("if")
private val await: TermNameCache = new TermNameCache("await")

private val result = newTermName("result$async")
private val completed: TermName = newTermName("completed$async")
private val apply = newTermName("apply")
private val stateMachine = newTermName("stateMachine$async")
private val stateMachineT = stateMachine.toTypeName
private val state: u.TermName = newTermName("state$async")
private val execContext = newTermName("execContext$async")
private val tr: u.TermName = newTermName("tr$async")
private val t: u.TermName = newTermName("throwable$async")

final class NameSource[N <: U#Name](cache: NameCache[N]) {
private val count = new AtomicInteger(0)
def apply(): N = cache(count.getAndIncrement())
}

class AsyncName {
final val matchRes = new NameSource[U#TermName](self.matchRes)
final val ifRes = new NameSource[U#TermName](self.matchRes)
final val await = new NameSource[U#TermName](self.await)
final val completed = self.completed
final val result = self.result
final val apply = self.apply
final val stateMachine = self.stateMachine
final val stateMachineT = self.stateMachineT
final val state: u.TermName = self.state
final val execContext = self.execContext
final val tr: u.TermName = self.tr
final val t: u.TermName = self.t

private val seenPrefixes = mutable.AnyRefMap[Name, AtomicInteger]()
private val freshened = mutable.HashSet[Name]()

final def freshenIfNeeded(name: TermName): TermName = {
seenPrefixes.getOrNull(name) match {
case null =>
seenPrefixes.put(name, new AtomicInteger())
name
case counter =>
freshen(name, counter)
}
}
final def freshenIfNeeded(name: TypeName): TypeName = {
seenPrefixes.getOrNull(name) match {
case null =>
seenPrefixes.put(name, new AtomicInteger())
name
case counter =>
freshen(name, counter)
}
}
final def freshen(name: TermName): TermName = {
val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger())
freshen(name, counter)
}
final def freshen(name: TypeName): TypeName = {
val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger())
freshen(name, counter)
}
private def freshen(name: TermName, counter: AtomicInteger): TermName = {
if (freshened.contains(name)) name
else TermName(freshenString(name.toString, counter.incrementAndGet()))
}
private def freshen(name: TypeName, counter: AtomicInteger): TypeName = {
if (freshened.contains(name)) name
else TypeName(freshenString(name.toString, counter.incrementAndGet()))
}
}

private def freshenString(name: String, counter: Int): String = name.toString + "$async$" + counter
}
12 changes: 8 additions & 4 deletions src/main/scala/scala/async/internal/AsyncTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ trait AsyncTransform {
buildAsyncBlock(anfTree, symLookup)
}

if(AsyncUtils.verbose)
logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))

val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates)

// live variables analysis
Expand Down Expand Up @@ -114,10 +111,15 @@ trait AsyncTransform {
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
else
startStateMachine

if(AsyncUtils.verbose) {
logDiagnostics(anfTree, asyncBlock, asyncBlock.asyncStates.map(_.toString))
}
futureSystemOps.dot(enclosingOwner, body).foreach(f => f(asyncBlock.toDot))
cleanupContainsAwaitAttachments(result)
}

def logDiagnostics(anfTree: Tree, states: Seq[String]): Unit = {
def logDiagnostics(anfTree: Tree, block: AsyncBlock, states: Seq[String]): Unit = {
def location = try {
macroPos.source.path
} catch {
Expand All @@ -129,6 +131,8 @@ trait AsyncTransform {
AsyncUtils.vprintln(s"${c.macroApplication}")
AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
states foreach (s => AsyncUtils.vprintln(s))
AsyncUtils.vprintln("===== DOT =====")
AsyncUtils.vprintln(block.toDot)
}

/**
Expand Down
Loading