@@ -23,6 +23,8 @@ trait ExprBuilder {
2323 val labelDefStates = collection.mutable.Map [Symbol , Int ]()
2424
2525 trait AsyncState {
26+ var switchId : Int = - 1
27+
2628 def state : Int
2729
2830 def nextStates : Array [Int ]
@@ -76,13 +78,13 @@ trait ExprBuilder {
7678 mkHandlerCase(state, stats)
7779
7880 override val toString : String =
79- s " AsyncStateWithoutAwait # $state, nextStates = $nextStates"
81+ s " AsyncStateWithoutAwait # $state, nextStates = ${ nextStates.toList} "
8082 }
8183
8284 /** A sequence of statements that concludes with an `await` call. The `onComplete`
8385 * handler will unconditionally transition to `nextState`.
8486 */
85- final class AsyncStateWithAwait (var stats : List [Tree ], val state : Int , onCompleteState : Int , nextState : Int ,
87+ final class AsyncStateWithAwait (var stats : List [Tree ], val state : Int , val onCompleteState : Int , nextState : Int ,
8688 val awaitable : Awaitable , symLookup : SymLookup )
8789 extends AsyncState {
8890
@@ -268,11 +270,11 @@ trait ExprBuilder {
268270 }
269271
270272 // populate asyncStates
271- def add (stat : Tree ): Unit = stat match {
273+ def add (stat : Tree , afterState : Option [ Int ] = None ): Unit = stat match {
272274 // the val name = await(..) pattern
273275 case vd @ ValDef (mods, name, tpt, Apply (fun, arg :: Nil )) if isAwait(fun) =>
274276 val onCompleteState = nextState()
275- val afterAwaitState = nextState()
277+ val afterAwaitState = afterState.getOrElse( nextState() )
276278 val awaitable = Awaitable (arg, stat.symbol, tpt.tpe, vd)
277279 asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
278280 currState = afterAwaitState
@@ -283,7 +285,7 @@ trait ExprBuilder {
283285
284286 val thenStartState = nextState()
285287 val elseStartState = nextState()
286- val afterIfState = nextState()
288+ val afterIfState = afterState.getOrElse( nextState() )
287289
288290 asyncStates +=
289291 // the two Int arguments are the start state of the then branch and the else branch, respectively
@@ -305,7 +307,7 @@ trait ExprBuilder {
305307 java.util.Arrays .setAll(caseStates, new IntUnaryOperator {
306308 override def applyAsInt (operand : Int ): Int = nextState()
307309 })
308- val afterMatchState = nextState()
310+ val afterMatchState = afterState.getOrElse( nextState() )
309311
310312 asyncStates +=
311313 stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
@@ -323,15 +325,16 @@ trait ExprBuilder {
323325 if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
324326
325327 val startLabelState = stateIdForLabel(ld.symbol)
326- val afterLabelState = nextState()
328+ val afterLabelState = afterState.getOrElse( nextState() )
327329 asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
328330 labelDefStates(ld.symbol) = startLabelState
329331 val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
330332 asyncStates ++= builder.asyncStates
331333 currState = afterLabelState
332334 stateBuilder = new AsyncStateBuilder (currState, symLookup)
333335 case b @ Block (stats, expr) =>
334- (stats :+ expr) foreach (add)
336+ for (stat <- stats) add(stat)
337+ add(expr, afterState = Some (endState))
335338 case _ =>
336339 checkForUnsupportedAwait(stat)
337340 stateBuilder += stat
@@ -346,7 +349,7 @@ trait ExprBuilder {
346349
347350 def onCompleteHandler [T : WeakTypeTag ]: Tree
348351
349- def toDot ( afterDSE : Boolean ) : String
352+ def toDot : String
350353 }
351354
352355 case class SymLookup (stateMachineClass : Symbol , applyTrParam : Symbol ) {
@@ -371,11 +374,10 @@ trait ExprBuilder {
371374 val blockBuilder = new AsyncBlockBuilder (stats, expr, startState, endState, symLookup)
372375
373376 new AsyncBlock {
374- val liveStates = mutable.AnyRefMap [Integer , Integer ]()
375- val deadStates = mutable.AnyRefMap [Integer , Integer ]()
377+ val switchIds = mutable.AnyRefMap [Integer , Integer ]()
376378
377379 // render with http://graphviz.it/#/new
378- def toDot ( afterDSE : Boolean ) : String = {
380+ def toDot : String = {
379381 val states = asyncStates
380382 def toHtmlLabel (label : String , preText : String , builder : StringBuilder ): Unit = {
381383 builder.append(" <b>" ).append(label).append(" </b>" ).append(" <br/>" )
@@ -390,39 +392,60 @@ trait ExprBuilder {
390392 val dotBuilder = new StringBuilder ()
391393 dotBuilder.append(" digraph {\n " )
392394 def stateLabel (s : Int ) = {
393- val beforeDseLabel = if (s == 0 ) " INITIAL" else if (s == Int .MaxValue ) " TERMINAL" else if (s > 0 ) " S" + s else " C" + Math .abs(s)
394- if (afterDSE) {
395- " \" S" + liveStates.getOrElse(s, s) + " (" + beforeDseLabel + " )\" "
396- } else {
397- beforeDseLabel
398- }
399-
395+ if (s == 0 ) " INITIAL" else if (s == Int .MaxValue ) " TERMINAL" else switchIds.getOrElse(s, s).toString
400396 }
401397 val length = asyncStates.size
402398 for ((state, i) <- asyncStates.zipWithIndex) {
403- val liveStateIdOpt : Option [Int ] = if (afterDSE) {
404- liveStates.get(state.state).map(_.intValue())
405- } else Some (state.state)
406- for (_ <- liveStateIdOpt) {
407- dotBuilder.append(s """ ${stateLabel(state.state)} [label= """ ).append(" <" )
408- if (i != length - 1 ) {
409- val CaseDef (_, _, body) = state.mkHandlerCaseForState
410- toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
411- } else {
412- toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString(" \n " ), dotBuilder)
413- }
414- dotBuilder.append(" > ]\n " )
399+ dotBuilder.append(s """ ${stateLabel(state.state)} [label= """ ).append(" <" )
400+ if (i != length - 1 ) {
401+ val CaseDef (_, _, body) = state.mkHandlerCaseForState
402+ toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
403+ } else {
404+ toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString(" \n " ), dotBuilder)
415405 }
406+ dotBuilder.append(" > ]\n " )
416407 }
417- for (state <- states; if liveStates.contains(state.state); succ <- state.nextStates) {
408+ for (state <- states; succ <- state.nextStates) {
418409 dotBuilder.append(s """ ${stateLabel(state.state)} -> ${stateLabel(succ)}""" )
419410 dotBuilder.append(" \n " )
420411 }
421412 dotBuilder.append(" }\n " )
422413 dotBuilder.toString
423414 }
424415
425- def asyncStates = blockBuilder.asyncStates.toList
416+ lazy val asyncStates : List [AsyncState ] = filterStates
417+
418+ def filterStates = {
419+ val all = blockBuilder.asyncStates.toList
420+ val (initial :: rest) = all
421+ val map = all.iterator.map(x => (x.state, x)).toMap
422+ var seen = mutable.HashSet [Int ]()
423+ def loop (state : AsyncState ): Unit = {
424+ seen.add(state.state)
425+ for (i <- state.nextStates) {
426+ if (i != Int .MaxValue && ! seen.contains(i)) {
427+ loop(map(i))
428+ }
429+ }
430+ }
431+ loop(initial)
432+ val live = rest.filter(state => seen(state.state))
433+ var nextSwitchId = 1
434+ (initial :: live).foreach { state =>
435+ val switchId = nextSwitchId
436+ switchIds(state.state) = switchId
437+ nextSwitchId += 1
438+ state match {
439+ case state : AsyncStateWithAwait =>
440+ val switchId = nextSwitchId
441+ switchIds(state.onCompleteState) = switchId
442+ nextSwitchId += 1
443+ case _ =>
444+ }
445+ }
446+ initial :: live
447+
448+ }
426449
427450 def mkCombinedHandlerCases [T : WeakTypeTag ]: List [CaseDef ] = {
428451 val caseForLastState : CaseDef = {
@@ -488,43 +511,14 @@ trait ExprBuilder {
488511 // Identify dead states: `case <id> => { state = nextId; (); (); ... }, eliminated, and compact state ids to
489512 // enable emission of a tableswitch.
490513 private def eliminateDeadStates (m : Match ): Tree = {
491- object DeadState {
492- private var compactedStateId = 1
493- for (CaseDef (Literal (Constant (stateId : Integer )), EmptyTree , body) <- m.cases) {
494- body match {
495- case _ if (stateId == 0 ) => liveStates(stateId) = stateId
496- case Block (Assign (_, Literal (Constant (nextState : Integer ))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) =>
497- deadStates(stateId) = nextState
498- case _ =>
499- liveStates(stateId) = compactedStateId
500- compactedStateId += 1
501- }
502- }
503- if (deadStates.nonEmpty)
504- AsyncUtils .vprintln(s " ${deadStates.size} dead states eliminated " )
505- def isDead (i : Integer ) = deadStates.contains(i)
506- def translatedStateId (i : Integer , tree : Tree ): Integer = {
507- def chaseDead (i : Integer ): Integer = {
508- val replacement = deadStates.getOrNull(i)
509- if (replacement == null ) i
510- else chaseDead(replacement)
511- }
512-
513- val live = chaseDead(i)
514- liveStates.get(live) match {
515- case Some (x) => x
516- case None => sys.error(s " $live, $liveStates \n $deadStates\n $m\n\n ==== \n $tree" )
517- }
518- }
519- }
520514 val stateMemberSymbol = symLookup.stateMachineMember(name.state)
521515 // - remove CaseDef-s for dead states
522516 // - rewrite state transitions to dead states to instead transition to the
523517 // non-dead successor.
524518 val elimDeadStateTransform = new Transformer {
525519 override def transform (tree : Tree ): Tree = tree match {
526520 case as @ Assign (lhs, Literal (Constant (i : Integer ))) if lhs.symbol == stateMemberSymbol =>
527- val replacement = DeadState .translatedStateId(i, as )
521+ val replacement = switchIds(i )
528522 treeCopy.Assign (tree, lhs, Literal (Constant (replacement)))
529523 case _ : Match | _ : CaseDef | _ : Block | _ : If =>
530524 super .transform(tree)
@@ -533,12 +527,9 @@ trait ExprBuilder {
533527 }
534528 val cases1 = m.cases.flatMap {
535529 case cd @ CaseDef (Literal (Constant (i : Integer )), EmptyTree , rhs) =>
536- if (DeadState .isDead(i)) Nil
537- else {
538- val replacement = DeadState .translatedStateId(i, cd)
539- val rhs1 = elimDeadStateTransform.transform(rhs)
540- treeCopy.CaseDef (cd, Literal (Constant (replacement)), EmptyTree , rhs1) :: Nil
541- }
530+ val replacement = switchIds(i)
531+ val rhs1 = elimDeadStateTransform.transform(rhs)
532+ treeCopy.CaseDef (cd, Literal (Constant (replacement)), EmptyTree , rhs1) :: Nil
542533 case x => x :: Nil
543534 }
544535 treeCopy.Match (m, m.selector, cases1)
0 commit comments