88
99package scala .concurrent
1010
11+ import java .util .ArrayDeque
1112import java .util .concurrent .Executor
1213import scala .annotation .tailrec
14+ import scala .util .control .NonFatal
15+
16+ /**
17+ * Marker trait to indicate that a Runnable is Batchable by BatchingExecutors
18+ */
19+ trait Batchable {
20+ self : Runnable =>
21+ }
1322
1423/**
1524 * Mixin trait for an Executor
@@ -39,79 +48,90 @@ import scala.annotation.tailrec
3948 * WARNING: The underlying Executor's execute-method must not execute the submitted Runnable
4049 * in the calling thread synchronously. It must enqueue/handoff the Runnable.
4150 */
42- private [concurrent] trait BatchingExecutor extends Executor {
51+ private [concurrent] trait BatchingExecutor extends Executor {
52+ private [this ] final val _tasksLocal = new ThreadLocal [Batch ]()
53+
54+ private [this ] final class Batch (capacity : Int ) extends ArrayDeque [Runnable ](capacity) with Runnable with BlockContext with (BlockContext => Unit ) {
55+ private [this ] final var parentBlockContext : BlockContext = _
4356
44- // invariant: if "_tasksLocal.get ne null" then we are inside BatchingRunnable.run; if it is null, we are outside
45- private [this ] val _tasksLocal = new ThreadLocal [List [Runnable ]]()
57+ def this (r : Runnable , capacity : Int ) = {
58+ this (capacity)
59+ addLast(r)
60+ }
61+
62+ final def executor : BatchingExecutor = BatchingExecutor .this
4663
47- private class Batch (val initial : List [Runnable ]) extends Runnable with BlockContext {
48- private [this ] var parentBlockContext : BlockContext = _
4964 // this method runs in the delegate ExecutionContext's thread
50- override def run (): Unit = {
51- require(_tasksLocal.get eq null )
52-
53- val prevBlockContext = BlockContext .current
54- BlockContext .withBlockContext(this ) {
55- try {
56- parentBlockContext = prevBlockContext
57-
58- @ tailrec def processBatch (batch : List [Runnable ]): Unit = batch match {
59- case Nil => ()
60- case head :: tail =>
61- _tasksLocal set tail
62- try {
63- head.run()
64- } catch {
65- case t : Throwable =>
66- // if one task throws, move the
67- // remaining tasks to another thread
68- // so we can throw the exception
69- // up to the invoking executor
70- val remaining = _tasksLocal.get
71- _tasksLocal set Nil
72- unbatchedExecute(new Batch (remaining)) // TODO what if this submission fails?
73- throw t // rethrow
74- }
75- processBatch(_tasksLocal.get) // since head.run() can add entries, always do _tasksLocal.get here
76- }
77-
78- processBatch(initial)
79- } finally {
80- _tasksLocal.remove()
81- parentBlockContext = null
82- }
65+ override final def run (): Unit = BlockContext .usingBlockContext(this )(this )
66+
67+ override final def apply (prevBlockContext : BlockContext ): Unit = {
68+ // This invariant needs to hold: require(_tasksLocal.get eq null)
69+ parentBlockContext = prevBlockContext
70+ try {
71+ _tasksLocal.set(this )
72+ runAll()
73+ _tasksLocal.remove() // Will be cleared in the throwing-case by runAll()
74+ } finally {
75+ parentBlockContext = null
8376 }
8477 }
8578
79+ @ tailrec private [this ] final def runAll (): Unit = {
80+ val next = pollLast()
81+ if (next ne null ) {
82+ try next.run() catch {
83+ case t : Throwable =>
84+ parentBlockContext = null // Need to reset this before re-submitting it
85+ _tasksLocal.remove() // If unbatchedExecute runs synchronously
86+ handleRunFailure(t)
87+ }
88+ runAll()
89+ }
90+ }
91+
92+ private [this ] final def handleRunFailure (cause : Throwable ): Nothing =
93+ if (NonFatal (cause) || cause.isInstanceOf [InterruptedException ]) {
94+ try unbatchedExecute(this ) catch {
95+ case inner : Throwable =>
96+ if (NonFatal (inner)) {
97+ val e = new ExecutionException (" Non-fatal error occurred and resubmission failed, see suppressed exception." , cause)
98+ e.addSuppressed(inner)
99+ throw e
100+ } else throw inner
101+ }
102+ throw cause
103+ } else throw cause
104+
86105 override def blockOn [T ](thunk : => T )(implicit permission : CanAwait ): T = {
87- // if we know there will be blocking, we don't want to keep tasks queued up because it could deadlock.
88- {
89- val tasks = _tasksLocal.get
90- _tasksLocal set Nil
91- if ((tasks ne null ) && tasks.nonEmpty )
92- unbatchedExecute(new Batch (tasks) )
106+ val pbc = parentBlockContext
107+ if ( ! isEmpty) { // if we know there will be blocking, we don't want to keep tasks queued up because it could deadlock.
108+ val b = new Batch (math.max( 4 , this .size))
109+ b.addAll( this )
110+ this .clear( )
111+ unbatchedExecute(b )
93112 }
94113
95- // now delegate the blocking to the previous BC
96- require(parentBlockContext ne null )
97- parentBlockContext.blockOn(thunk)
114+ if (pbc ne null ) pbc.blockOn(thunk) // now delegate the blocking to the previous BC
115+ else {
116+ try thunk finally throw new IllegalStateException (" BUG in BatchingExecutor.Batch: parentBlockContext is null" )
117+ }
98118 }
99119 }
100120
101121 protected def unbatchedExecute (r : Runnable ): Unit
102122
103- override def execute (runnable : Runnable ): Unit = {
104- if (batchable(runnable)) { // If we can batch the runnable
105- _tasksLocal.get match {
106- case null => unbatchedExecute(new Batch (runnable :: Nil )) // If we aren't in batching mode yet, enqueue batch
107- case some => _tasksLocal.set(runnable :: some) // If we are already in batching mode, add to batch
108- }
109- } else unbatchedExecute(runnable) // If not batchable, just delegate to underlying
123+ private [this ] final def batchedExecute (runnable : Runnable ): Unit = {
124+ val b = _tasksLocal.get
125+ if (b ne null ) b.addLast(runnable)
126+ else unbatchedExecute(new Batch (runnable, 4 ))
110127 }
111128
112- /** Override this to define which runnables will be batched. */
113- def batchable (runnable : Runnable ): Boolean = runnable match {
114- case _ : OnCompleteRunnable => true
115- case _ => false
116- }
117- }
129+ override def execute (runnable : Runnable ): Unit =
130+ if (batchable(runnable)) batchedExecute(runnable)
131+ else unbatchedExecute(runnable)
132+
133+ /** Override this to define which runnables will be batched.
134+ * By default it tests the Runnable for being an instance of [Batchable].
135+ **/
136+ protected def batchable (runnable : Runnable ): Boolean = runnable.isInstanceOf [Batchable ]
137+ }
0 commit comments