-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Add WaitGroup synchronization primitive
#14167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fcbc5a0
5e6dee2
c76bc4c
900649a
6564683
720f22e
463c068
8358b01
fe867f4
efdc0bd
8e8756e
8c06bbd
1bb9224
bc7c174
0deae7c
3ecac1c
cd6ae90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| require "spec" | ||
| require "wait_group" | ||
|
|
||
| private def block_until_pending_waiter(wg) | ||
| while wg.@waiting.empty? | ||
| Fiber.yield | ||
| end | ||
| end | ||
|
|
||
| private def forge_counter(wg, value) | ||
| wg.@counter.set(value) | ||
| end | ||
|
|
||
| describe WaitGroup do | ||
| describe "#add" do | ||
| it "can't decrement to a negative counter" do | ||
| wg = WaitGroup.new | ||
| wg.add(5) | ||
| wg.add(-3) | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) } | ||
| end | ||
|
|
||
| it "resumes waiters when reaching negative counter" do | ||
| wg = WaitGroup.new(1) | ||
| spawn do | ||
| block_until_pending_waiter(wg) | ||
| wg.add(-2) | ||
| rescue RuntimeError | ||
| end | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
| end | ||
|
|
||
| it "can't increment after reaching negative counter" do | ||
| wg = WaitGroup.new | ||
| forge_counter(wg, -1) | ||
|
|
||
| # check twice, to make sure the waitgroup counter wasn't incremented back | ||
| # to a positive value! | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) } | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) } | ||
| end | ||
| end | ||
|
|
||
| describe "#done" do | ||
| it "can't decrement to a negative counter" do | ||
| wg = WaitGroup.new | ||
| wg.add(1) | ||
| wg.done | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done } | ||
| end | ||
|
|
||
| it "resumes waiters when reaching negative counter" do | ||
| wg = WaitGroup.new(1) | ||
| spawn do | ||
| block_until_pending_waiter(wg) | ||
| forge_counter(wg, 0) | ||
| wg.done | ||
| rescue RuntimeError | ||
| end | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
| end | ||
| end | ||
|
|
||
| describe "#wait" do | ||
| it "immediately returns when counter is zero" do | ||
| channel = Channel(Nil).new(1) | ||
|
|
||
| spawn do | ||
| wg = WaitGroup.new(0) | ||
| wg.wait | ||
| channel.send(nil) | ||
| end | ||
|
|
||
| select | ||
| when channel.receive | ||
| # success | ||
| when timeout(1.second) | ||
| fail "expected #wait to not block the fiber" | ||
| end | ||
| end | ||
|
|
||
| it "immediately raises when counter is negative" do | ||
| wg = WaitGroup.new(0) | ||
| expect_raises(RuntimeError) { wg.done } | ||
| expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } | ||
| end | ||
|
|
||
| it "raises when counter is positive after wake up" do | ||
| wg = WaitGroup.new(1) | ||
| waiter = Fiber.current | ||
|
|
||
| spawn do | ||
| block_until_pending_waiter(wg) | ||
| waiter.enqueue | ||
| end | ||
|
|
||
| expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait } | ||
| end | ||
| end | ||
|
|
||
| it "waits until concurrent executions are finished" do | ||
| wg1 = WaitGroup.new | ||
| wg2 = WaitGroup.new | ||
|
|
||
| 8.times do | ||
| wg1.add(16) | ||
| wg2.add(16) | ||
| exited = Channel(Bool).new(16) | ||
|
|
||
| 16.times do | ||
| spawn do | ||
| wg1.done | ||
| wg2.wait | ||
| exited.send(true) | ||
| end | ||
| end | ||
|
|
||
| wg1.wait | ||
|
|
||
| 16.times do | ||
| select | ||
| when exited.receive | ||
| fail "WaitGroup released group too soon" | ||
| else | ||
| end | ||
| wg2.done | ||
| end | ||
ysbaddaden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| 16.times do | ||
| select | ||
| when x = exited.receive | ||
| x.should eq(true) | ||
| when timeout(1.millisecond) | ||
| fail "Expected channel to receive value" | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| it "increments the counter from executing fibers" do | ||
| wg = WaitGroup.new(16) | ||
| extra = Atomic(Int32).new(0) | ||
|
|
||
| 16.times do | ||
| spawn do | ||
| wg.add(2) | ||
|
|
||
| 2.times do | ||
| spawn do | ||
| extra.add(1) | ||
| wg.done | ||
| end | ||
| end | ||
|
|
||
| wg.done | ||
| end | ||
| end | ||
|
|
||
| wg.wait | ||
| extra.get.should eq(32) | ||
| end | ||
|
|
||
| # the test takes far too much time for the interpreter to complete | ||
| {% unless flag?(:interpreted) %} | ||
| it "stress add/done/wait" do | ||
| wg = WaitGroup.new | ||
|
|
||
| 1000.times do | ||
| counter = Atomic(Int32).new(0) | ||
|
|
||
| 2.times do | ||
| wg.add(1) | ||
|
|
||
| spawn do | ||
| counter.add(1) | ||
| wg.done | ||
| end | ||
| end | ||
|
|
||
| wg.wait | ||
| counter.get.should eq(2) | ||
| end | ||
| end | ||
| {% end %} | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| require "fiber" | ||
| require "crystal/spin_lock" | ||
| require "crystal/pointer_linked_list" | ||
|
|
||
| # Suspend execution until a collection of fibers are finished. | ||
| # | ||
| # The wait group is a declarative counter of how many concurrent fibers have | ||
| # been started. Each such fiber is expected to call `#done` to report that they | ||
| # are finished doing their work. Whenever the counter reaches zero the waiters | ||
| # will be resumed. | ||
| # | ||
| # This is a simpler and more efficient alternative to using a `Channel(Nil)` | ||
| # then looping a number of times until we received N messages to resume | ||
| # execution. | ||
| # | ||
| # Basic example: | ||
| # | ||
| # ``` | ||
| # require "wait_group" | ||
| # wg = WaitGroup.new(5) | ||
| # | ||
| # 5.times do | ||
| # spawn do | ||
| # do_something | ||
| # ensure | ||
| # wg.done # the fiber has finished | ||
| # end | ||
| # end | ||
| # | ||
| # # suspend the current fiber until the 5 fibers are done | ||
| # wg.wait | ||
| # ``` | ||
| class WaitGroup | ||
| private struct Waiting | ||
| include Crystal::PointerLinkedList::Node | ||
|
|
||
| def initialize(@fiber : Fiber) | ||
| end | ||
|
|
||
| def enqueue : Nil | ||
| @fiber.enqueue | ||
| end | ||
| end | ||
|
|
||
| def initialize(n : Int32 = 0) | ||
| @waiting = Crystal::PointerLinkedList(Waiting).new | ||
| @lock = Crystal::SpinLock.new | ||
| @counter = Atomic(Int32).new(n) | ||
| end | ||
|
|
||
| # Increments the counter by how many fibers we want to wait for. | ||
| # | ||
| # A negative value decrements the counter. When the counter reaches zero, | ||
ysbaddaden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # all waiting fibers will be resumed. | ||
| # Raises `RuntimeError` if the counter reaches a negative value. | ||
| # | ||
| # Can be called at any time, allowing concurrent fibers to add more fibers to | ||
| # wait for, but they must always do so before calling `#done` that would | ||
| # decrement the counter, to make sure that the counter may never inadvertently | ||
| # reach zero before all fibers are done. | ||
| def add(n : Int32 = 1) : Nil | ||
ysbaddaden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| counter = @counter.get(:acquire) | ||
|
|
||
| loop do | ||
| raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 | ||
|
|
||
| counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire) | ||
| break if success | ||
| end | ||
|
|
||
| new_counter = counter + n | ||
| return if new_counter > 0 | ||
|
|
||
| @lock.sync do | ||
| @waiting.consume_each do |node| | ||
| node.value.enqueue | ||
| end | ||
straight-shoota marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0 | ||
| end | ||
|
|
||
| # Decrements the counter by one. Must be called by concurrent fibers once they | ||
| # have finished processing. When the counter reaches zero, all waiting fibers | ||
| # will be resumed. | ||
ysbaddaden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def done : Nil | ||
| add(-1) | ||
| end | ||
|
|
||
| # Suspends the current fiber until the counter reaches zero, at which point | ||
| # the fiber will be resumed. | ||
| # | ||
| # Can be called from different fibers. | ||
| def wait : Nil | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this donation very much! It will be very useful in many cases. One proposal that probably can be done later on as a separate improvement - is to make the select
when wg.wait
puts "All fibers done"
when timeout(X.seconds)
puts "Some fiber stuck"
end
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe just have
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bararchy We're missing a generic mechanism for timeouts... but we could abstract how it's implemented for That doesn't mean we can't also integrate with
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ysbaddaden I think that @alexkutsan's idea is better, because then we don't need to handle a Timeout Exception in case that the Timeout happen, and instead handle it in select context which seems cleaner, like how So I think my idea is less clean tbh 😅
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a commit to support WaitGroup in The integration wasn't complex after I understood how SelectAction and SelectContext are working, but the current implementation is very isolated to Channel (on purpose). Maybe the integration is not a good idea, but if proves to be a good idea, we might want to extract the I'll open a pull request after this one is merged, so we can have a proper discussion.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ysbaddaden now that it's in and merged, are you planning to make the followup PR? 👁️ |
||
| return if done? | ||
|
|
||
| waiting = Waiting.new(Fiber.current) | ||
|
|
||
| @lock.sync do | ||
| # must check again to avoid a race condition where #done may have | ||
| # decremented the counter to zero between the above check and #wait | ||
| # acquiring the lock; we'd push the current fiber to the wait list that | ||
| # would never be resumed (oops) | ||
| return if done? | ||
|
|
||
| @waiting.push(pointerof(waiting)) | ||
| end | ||
|
|
||
| Crystal::Scheduler.reschedule | ||
|
|
||
| return if done? | ||
| raise RuntimeError.new("Positive WaitGroup counter (early wake up?)") | ||
| end | ||
|
|
||
| private def done? | ||
| counter = @counter.get(:acquire) | ||
| raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 | ||
| counter == 0 | ||
| end | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.