diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr new file mode 100644 index 000000000000..459af8d5c898 --- /dev/null +++ b/spec/std/wait_group_spec.cr @@ -0,0 +1,185 @@ +require "spec" +require "wait_group" + +private def block_until_pending_waiter(wg) + while wg.@waiting.empty? + Fiber.yield + end +end + +private def forge_counter(wg, value) + wg.@counter.set(value) +end + +describe WaitGroup do + describe "#add" do + it "can't decrement to a negative counter" do + wg = WaitGroup.new + wg.add(5) + wg.add(-3) + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) } + end + + it "resumes waiters when reaching negative counter" do + wg = WaitGroup.new(1) + spawn do + block_until_pending_waiter(wg) + wg.add(-2) + rescue RuntimeError + end + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end + + it "can't increment after reaching negative counter" do + wg = WaitGroup.new + forge_counter(wg, -1) + + # check twice, to make sure the waitgroup counter wasn't incremented back + # to a positive value! + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) } + end + end + + describe "#done" do + it "can't decrement to a negative counter" do + wg = WaitGroup.new + wg.add(1) + wg.done + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done } + end + + it "resumes waiters when reaching negative counter" do + wg = WaitGroup.new(1) + spawn do + block_until_pending_waiter(wg) + forge_counter(wg, 0) + wg.done + rescue RuntimeError + end + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end + end + + describe "#wait" do + it "immediately returns when counter is zero" do + channel = Channel(Nil).new(1) + + spawn do + wg = WaitGroup.new(0) + wg.wait + channel.send(nil) + end + + select + when channel.receive + # success + when timeout(1.second) + fail "expected #wait to not block the fiber" + end + end + + it "immediately raises when counter is negative" do + wg = WaitGroup.new(0) + expect_raises(RuntimeError) { wg.done } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end + + it "raises when counter is positive after wake up" do + wg = WaitGroup.new(1) + waiter = Fiber.current + + spawn do + block_until_pending_waiter(wg) + waiter.enqueue + end + + expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait } + end + end + + it "waits until concurrent executions are finished" do + wg1 = WaitGroup.new + wg2 = WaitGroup.new + + 8.times do + wg1.add(16) + wg2.add(16) + exited = Channel(Bool).new(16) + + 16.times do + spawn do + wg1.done + wg2.wait + exited.send(true) + end + end + + wg1.wait + + 16.times do + select + when exited.receive + fail "WaitGroup released group too soon" + else + end + wg2.done + end + + 16.times do + select + when x = exited.receive + x.should eq(true) + when timeout(1.millisecond) + fail "Expected channel to receive value" + end + end + end + end + + it "increments the counter from executing fibers" do + wg = WaitGroup.new(16) + extra = Atomic(Int32).new(0) + + 16.times do + spawn do + wg.add(2) + + 2.times do + spawn do + extra.add(1) + wg.done + end + end + + wg.done + end + end + + wg.wait + extra.get.should eq(32) + end + + # the test takes far too much time for the interpreter to complete + {% unless flag?(:interpreted) %} + it "stress add/done/wait" do + wg = WaitGroup.new + + 1000.times do + counter = Atomic(Int32).new(0) + + 2.times do + wg.add(1) + + spawn do + counter.add(1) + wg.done + end + end + + wg.wait + counter.get.should eq(2) + end + end + {% end %} +end diff --git a/src/crystal/pointer_linked_list.cr b/src/crystal/pointer_linked_list.cr index 0ce17b071bd0..03109979d662 100644 --- a/src/crystal/pointer_linked_list.cr +++ b/src/crystal/pointer_linked_list.cr @@ -80,4 +80,10 @@ struct Crystal::PointerLinkedList(T) node = _next end end + + # Iterates the list before clearing it. + def consume_each(&) : Nil + each { |node| yield node } + @head = Pointer(T).null + end end diff --git a/src/wait_group.cr b/src/wait_group.cr new file mode 100644 index 000000000000..d9ae4ae740ac --- /dev/null +++ b/src/wait_group.cr @@ -0,0 +1,120 @@ +require "fiber" +require "crystal/spin_lock" +require "crystal/pointer_linked_list" + +# Suspend execution until a collection of fibers are finished. +# +# The wait group is a declarative counter of how many concurrent fibers have +# been started. Each such fiber is expected to call `#done` to report that they +# are finished doing their work. Whenever the counter reaches zero the waiters +# will be resumed. +# +# This is a simpler and more efficient alternative to using a `Channel(Nil)` +# then looping a number of times until we received N messages to resume +# execution. +# +# Basic example: +# +# ``` +# require "wait_group" +# wg = WaitGroup.new(5) +# +# 5.times do +# spawn do +# do_something +# ensure +# wg.done # the fiber has finished +# end +# end +# +# # suspend the current fiber until the 5 fibers are done +# wg.wait +# ``` +class WaitGroup + private struct Waiting + include Crystal::PointerLinkedList::Node + + def initialize(@fiber : Fiber) + end + + def enqueue : Nil + @fiber.enqueue + end + end + + def initialize(n : Int32 = 0) + @waiting = Crystal::PointerLinkedList(Waiting).new + @lock = Crystal::SpinLock.new + @counter = Atomic(Int32).new(n) + end + + # Increments the counter by how many fibers we want to wait for. + # + # A negative value decrements the counter. When the counter reaches zero, + # all waiting fibers will be resumed. + # Raises `RuntimeError` if the counter reaches a negative value. + # + # Can be called at any time, allowing concurrent fibers to add more fibers to + # wait for, but they must always do so before calling `#done` that would + # decrement the counter, to make sure that the counter may never inadvertently + # reach zero before all fibers are done. + def add(n : Int32 = 1) : Nil + counter = @counter.get(:acquire) + + loop do + raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 + + counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire) + break if success + end + + new_counter = counter + n + return if new_counter > 0 + + @lock.sync do + @waiting.consume_each do |node| + node.value.enqueue + end + end + + raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0 + end + + # Decrements the counter by one. Must be called by concurrent fibers once they + # have finished processing. When the counter reaches zero, all waiting fibers + # will be resumed. + def done : Nil + add(-1) + end + + # Suspends the current fiber until the counter reaches zero, at which point + # the fiber will be resumed. + # + # Can be called from different fibers. + def wait : Nil + return if done? + + waiting = Waiting.new(Fiber.current) + + @lock.sync do + # must check again to avoid a race condition where #done may have + # decremented the counter to zero between the above check and #wait + # acquiring the lock; we'd push the current fiber to the wait list that + # would never be resumed (oops) + return if done? + + @waiting.push(pointerof(waiting)) + end + + Crystal::Scheduler.reschedule + + return if done? + raise RuntimeError.new("Positive WaitGroup counter (early wake up?)") + end + + private def done? + counter = @counter.get(:acquire) + raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 + counter == 0 + end +end