From fcbc5a0abb313c449e90bfcfa4d640d6bec06da7 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 4 Jan 2024 13:06:40 +0100 Subject: [PATCH 01/15] Add WaitGroup synchronization primitive This is more efficient than creating a Channel(Nil) and looping to receive N messages: we don't need a queue, only a counter, and we can avoid spurious wakeups of the main fiber and resume it only once. --- spec/std/wait_group_spec.cr | 20 +++++++ src/crystal/pointer_linked_list.cr | 6 +++ src/wait_group.cr | 87 ++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 spec/std/wait_group_spec.cr create mode 100644 src/wait_group.cr diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr new file mode 100644 index 000000000000..091bcbe2748d --- /dev/null +++ b/spec/std/wait_group_spec.cr @@ -0,0 +1,20 @@ +require "spec" +require "wait_group" + +describe WaitGroup do + it "waits until all concurrent executions are done" do + wg = WaitGroup.new + wg.add(500) + count = Atomic(Int32).new(0) + + 500.times do + ::spawn do + count.add(1) + wg.done + end + end + + wg.wait + count.get.should eq(500) + end +end diff --git a/src/crystal/pointer_linked_list.cr b/src/crystal/pointer_linked_list.cr index 0ce17b071bd0..03109979d662 100644 --- a/src/crystal/pointer_linked_list.cr +++ b/src/crystal/pointer_linked_list.cr @@ -80,4 +80,10 @@ struct Crystal::PointerLinkedList(T) node = _next end end + + # Iterates the list before clearing it. + def consume_each(&) : Nil + each { |node| yield node } + @head = Pointer(T).null + end end diff --git a/src/wait_group.cr b/src/wait_group.cr new file mode 100644 index 000000000000..d034d8bb41c6 --- /dev/null +++ b/src/wait_group.cr @@ -0,0 +1,87 @@ +require "fiber" +require "crystal/spin_lock" +require "crystal/pointer_linked_list" + +# Suspend execution until other fibers are done. +# +# This is a simpler and more efficient alternative to using a `Channel(Nil)` +# then looping a number of times until we received N messages. +# +# Basic example: +# +# ``` +# require "wait_group" +# wg = WaitGroup.new(5) +# +# 5.times do +# spawn do +# do_something +# wg.done # the fiber is done +# end +# end +# +# # suspend the current fiber until the 5 fibers are done +# wg.wait +# ``` +class WaitGroup + private struct Waiting + include Crystal::PointerLinkedList::Node + + def initialize(@fiber : Fiber) + end + + def enqueue : Nil + @fiber.enqueue + end + end + + def initialize(n : Int32 = 0) + @waiting = Crystal::PointerLinkedList(Waiting).new + @lock = Crystal::SpinLock.new + @counter = Atomic(Int32).new(n) + end + + # Increments the counter by how many fibers we want to wait for. + # + # Can be called at any time, allowing concurrent fibers to add more fibers to + # wait for, but they must always do so before calling `#done` that would + # decrement the counter, to make sure that the counter may never inadvertently + # reach zero before all fibers are done. + def add(n : Int32 = 1) : Nil + @counter.add(n) + end + + # Decrements the counter by one. Must be called by concurrent fibers once they + # have finished processing. When the counter reaches zero, all waiting fibers + # will be resumed. + def done : Nil + return unless @counter.sub(1) == 1 + + @lock.sync do + @waiting.consume_each do |node| + node.value.enqueue + end + end + end + + # Suspends the current fiber until the counter reaches zero, at which point + # the fiber will be resumed. + # + # Can be called from different fibers. + def wait : Nil + return if @counter.get == 0 + waiting = Waiting.new(Fiber.current) + + @lock.sync do + # must check again to avoid a race condition where #done may have + # decremented the counter to zero between the above check and #wait + # acquiring the lock; we'd push the current fiber to the wait list that + # would never be resumed (oops) + return if @counter.get == 0 + + @waiting.push(pointerof(waiting)) + end + + Crystal::Scheduler.reschedule + end +end From 5e6dee2f1021a030ee4832bd02d8353762126fa1 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 4 Jan 2024 23:19:53 +0100 Subject: [PATCH 02/15] Improve tests + add(-n) + raise on negative counter --- spec/std/wait_group_spec.cr | 100 +++++++++++++++++++++++++++++++++--- src/wait_group.cr | 21 +++++--- 2 files changed, 105 insertions(+), 16 deletions(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 091bcbe2748d..74362e931609 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -2,19 +2,103 @@ require "spec" require "wait_group" describe WaitGroup do - it "waits until all concurrent executions are done" do - wg = WaitGroup.new - wg.add(500) - count = Atomic(Int32).new(0) + describe "add" do + it "can't decrement to a negative counter" do + wg = WaitGroup.new + wg.add(5) + wg.add(-3) + expect_raises(Exception) { wg.add(-5) } + end + end + + describe "done" do + it "can't decrement to negative value" do + wg = WaitGroup.new + wg.add(1) + wg.done + expect_raises(Exception) { wg.done } + end + end + + it "waits until concurrent executions are done" do + wg1 = WaitGroup.new + wg2 = WaitGroup.new + + 8.times do + wg1.add(16) + wg2.add(16) + exited = Channel(Bool).new(16) + + 16.times do + spawn do + wg1.done + wg2.wait + exited.send(true) + end + end + + wg1.wait + + 16.times do + select + when exited.receive + raise "WaitGroup released group too soon" + else + end + wg2.done + end + + 16.times do + select + when x = exited.receive + x.should eq(true) + when timeout(1.millisecond) + raise "Expected channel to receive value" + end + end + end + end + + it "increments the counter from executing fibers" do + wg = WaitGroup.new(16) + extra = Atomic(Int32).new(0) + + 16.times do + spawn do + wg.add(2) + + 2.times do + spawn do + extra.add(1) + wg.done + end + end - 500.times do - ::spawn do - count.add(1) wg.done end end wg.wait - count.get.should eq(500) + extra.get.should eq(32) + end + + it "stress add/done/wait" do + wg = WaitGroup.new + + 1000.times do + counter = Atomic(Int32).new(0) + + 2.times do + wg.add(1) + + spawn do + counter.add(1) + wg.done + end + end + + wg.wait + counter.get.should eq(2) + end end end diff --git a/src/wait_group.cr b/src/wait_group.cr index d034d8bb41c6..b1eb1393bb8d 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -43,19 +43,17 @@ class WaitGroup # Increments the counter by how many fibers we want to wait for. # + # This can also be used to decrement the counter, in which case the behavior + # is identical to `#done`. + # # Can be called at any time, allowing concurrent fibers to add more fibers to # wait for, but they must always do so before calling `#done` that would # decrement the counter, to make sure that the counter may never inadvertently # reach zero before all fibers are done. def add(n : Int32 = 1) : Nil - @counter.add(n) - end - - # Decrements the counter by one. Must be called by concurrent fibers once they - # have finished processing. When the counter reaches zero, all waiting fibers - # will be resumed. - def done : Nil - return unless @counter.sub(1) == 1 + new_value = @counter.add(n) + n + raise "Negative WaitGroup counter" if new_value < 0 + return unless new_value == 0 @lock.sync do @waiting.consume_each do |node| @@ -64,6 +62,13 @@ class WaitGroup end end + # Decrements the counter by one. Must be called by concurrent fibers once they + # have finished processing. When the counter reaches zero, all waiting fibers + # will be resumed. + def done : Nil + add(-1) + end + # Suspends the current fiber until the counter reaches zero, at which point # the fiber will be resumed. # From c76bc4c750f220b90bb725a14c1291e9e18eaf52 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Fri, 5 Jan 2024 12:05:31 +0100 Subject: [PATCH 03/15] Fix example to call wg.done inside ensure block --- src/wait_group.cr | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wait_group.cr b/src/wait_group.cr index b1eb1393bb8d..36094f4f3ead 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -16,6 +16,7 @@ require "crystal/pointer_linked_list" # 5.times do # spawn do # do_something +# ensure # wg.done # the fiber is done # end # end From 900649a34b563bce0ddd0b01c47b7a58b0db34aa Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Mon, 8 Jan 2024 10:10:47 +0100 Subject: [PATCH 04/15] Leverage fail + improve documentation --- spec/std/wait_group_spec.cr | 8 ++++---- src/wait_group.cr | 12 +++++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 74362e931609..eb16c1d7797d 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -12,7 +12,7 @@ describe WaitGroup do end describe "done" do - it "can't decrement to negative value" do + it "can't decrement to a negative counter" do wg = WaitGroup.new wg.add(1) wg.done @@ -20,7 +20,7 @@ describe WaitGroup do end end - it "waits until concurrent executions are done" do + it "waits until concurrent executions are finished" do wg1 = WaitGroup.new wg2 = WaitGroup.new @@ -42,7 +42,7 @@ describe WaitGroup do 16.times do select when exited.receive - raise "WaitGroup released group too soon" + fail "WaitGroup released group too soon" else end wg2.done @@ -53,7 +53,7 @@ describe WaitGroup do when x = exited.receive x.should eq(true) when timeout(1.millisecond) - raise "Expected channel to receive value" + fail "Expected channel to receive value" end end end diff --git a/src/wait_group.cr b/src/wait_group.cr index 36094f4f3ead..b4f978138561 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -2,10 +2,16 @@ require "fiber" require "crystal/spin_lock" require "crystal/pointer_linked_list" -# Suspend execution until other fibers are done. +# Suspend execution until a collection of fibers are finished. +# +# The wait group is a declarative counter of how many concurrent fibers have +# been started. Each such fiber is expected to call `#done` to report that hey +# are finished doing their work. Whenever the counter reaches zero the waiters +# will be resumed. # # This is a simpler and more efficient alternative to using a `Channel(Nil)` -# then looping a number of times until we received N messages. +# then looping a number of times until we received N messages to resume +# execution. # # Basic example: # @@ -17,7 +23,7 @@ require "crystal/pointer_linked_list" # spawn do # do_something # ensure -# wg.done # the fiber is done +# wg.done # the fiber has finished # end # end # From 65646834ce1c96a577d62656f25c5cd4db529602 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Wed, 24 Jan 2024 14:12:00 +0100 Subject: [PATCH 05/15] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Müller --- spec/std/wait_group_spec.cr | 4 ++-- src/wait_group.cr | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index eb16c1d7797d..617ddcccc1e5 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -7,7 +7,7 @@ describe WaitGroup do wg = WaitGroup.new wg.add(5) wg.add(-3) - expect_raises(Exception) { wg.add(-5) } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) } end end @@ -16,7 +16,7 @@ describe WaitGroup do wg = WaitGroup.new wg.add(1) wg.done - expect_raises(Exception) { wg.done } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done } end end diff --git a/src/wait_group.cr b/src/wait_group.cr index b4f978138561..a09aab16a9c9 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -50,8 +50,9 @@ class WaitGroup # Increments the counter by how many fibers we want to wait for. # - # This can also be used to decrement the counter, in which case the behavior - # is identical to `#done`. + # A negative value decrements the counter. When the counter reaches zero, + # all waiting fibers will be resumed. + # Raises `RuntimeError` if the counter reaches a negative value. # # Can be called at any time, allowing concurrent fibers to add more fibers to # wait for, but they must always do so before calling `#done` that would @@ -59,7 +60,7 @@ class WaitGroup # reach zero before all fibers are done. def add(n : Int32 = 1) : Nil new_value = @counter.add(n) + n - raise "Negative WaitGroup counter" if new_value < 0 + raise RuntimeError.new("Negative WaitGroup counter") if new_value < 0 return unless new_value == 0 @lock.sync do From 720f22e94cf980cee4e6dfbf2cb14fc6fb7a87a5 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 25 Jan 2024 11:39:10 +0100 Subject: [PATCH 06/15] Fix: yet another typo... Co-authored-by: Jason Frey --- src/wait_group.cr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wait_group.cr b/src/wait_group.cr index a09aab16a9c9..6d642d1617cf 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -5,7 +5,7 @@ require "crystal/pointer_linked_list" # Suspend execution until a collection of fibers are finished. # # The wait group is a declarative counter of how many concurrent fibers have -# been started. Each such fiber is expected to call `#done` to report that hey +# been started. Each such fiber is expected to call `#done` to report that they # are finished doing their work. Whenever the counter reaches zero the waiters # will be resumed. # From 463c0685e708f0056207e8e80ce2a8d5a8f91947 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 25 Jan 2024 11:30:26 +0100 Subject: [PATCH 07/15] Test WaitGroup with the interpreter Disables the stress test when interpreted as it takes forever to complete. --- spec/interpreter_std_spec.cr | 1 + spec/std/wait_group_spec.cr | 29 ++++++++++++++++------------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/spec/interpreter_std_spec.cr b/spec/interpreter_std_spec.cr index 3cdcc55cbd61..e97159540297 100644 --- a/spec/interpreter_std_spec.cr +++ b/spec/interpreter_std_spec.cr @@ -251,6 +251,7 @@ require "./std/uuid/json_spec.cr" require "./std/uuid_spec.cr" require "./std/uuid/yaml_spec.cr" # require "./std/va_list_spec.cr" (failed to run) +require "./std/wait_group_spec.cr" require "./std/weak_ref_spec.cr" require "./std/winerror_spec.cr" require "./std/xml/builder_spec.cr" diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 617ddcccc1e5..c66113268728 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -82,23 +82,26 @@ describe WaitGroup do extra.get.should eq(32) end - it "stress add/done/wait" do - wg = WaitGroup.new + # the test takes far too much time for the interpreter to complete + {% unless flag?(:interpreted) %} + it "stress add/done/wait" do + wg = WaitGroup.new - 1000.times do - counter = Atomic(Int32).new(0) + 1000.times do + counter = Atomic(Int32).new(0) - 2.times do - wg.add(1) + 2.times do + wg.add(1) - spawn do - counter.add(1) - wg.done + spawn do + counter.add(1) + wg.done + end end - end - wg.wait - counter.get.should eq(2) + wg.wait + counter.get.should eq(2) + end end - end + {% end %} end From 8358b01f8c4e2866ec3bc8c3b8db447f3141b94f Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Tue, 30 Jan 2024 16:13:20 +0100 Subject: [PATCH 08/15] Fix: disable thread specs for the interpreter I can't reproduce the "can't resuming running fiber" anymore when I disable the thread specs. I think we likely need actual support from the interpreter to start threads in the interpreted code. --- spec/interpreter_std_spec.cr | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spec/interpreter_std_spec.cr b/spec/interpreter_std_spec.cr index e97159540297..05d4e9f636af 100644 --- a/spec/interpreter_std_spec.cr +++ b/spec/interpreter_std_spec.cr @@ -234,9 +234,9 @@ require "./std/system_error_spec.cr" require "./std/system/group_spec.cr" # require "./std/system_spec.cr" (failed to run) require "./std/system/user_spec.cr" -require "./std/thread/condition_variable_spec.cr" -require "./std/thread/mutex_spec.cr" -# require "./std/thread_spec.cr" (failed to run) +# require "./std/thread/condition_variable_spec.cr" (not supported) +# require "./std/thread/mutex_spec.cr" (not supported) +# require "./std/thread_spec.cr" (not supported) require "./std/time/custom_formats_spec.cr" require "./std/time/format_spec.cr" require "./std/time/location_spec.cr" From fe867f45a8b88d132514c66abb1d9a7ce1e612dc Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Mon, 25 Mar 2024 12:57:15 +0100 Subject: [PATCH 09/15] Fix: resume waiting fibers + raise on negative counter --- spec/std/wait_group_spec.cr | 55 +++++++++++++++++++++++++++++++++++-- src/none.cr | 0 src/wait_group.cr | 34 ++++++++++++++--------- 3 files changed, 74 insertions(+), 15 deletions(-) create mode 100644 src/none.cr diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index c66113268728..1bed2ff37966 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -1,23 +1,74 @@ require "spec" require "wait_group" +private def block_until_pending_waiter(wg) + while wg.@waiting.empty? + Fiber.yield + end +end + describe WaitGroup do - describe "add" do + describe "#add" do it "can't decrement to a negative counter" do wg = WaitGroup.new wg.add(5) wg.add(-3) expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) } end + + it "resumes waiters when reaching negative counter" do + wg = WaitGroup.new(1) + spawn do + block_until_pending_waiter + wg.add(-2) + rescue RuntimeError + end + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end end - describe "done" do + describe "#done" do it "can't decrement to a negative counter" do wg = WaitGroup.new wg.add(1) wg.done expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done } end + + it "resumes waiters when reaching negative counter" do + wg = WaitGroup.new(1) + spawn do + block_until_pending_waiter + wg.add(-2) + rescue RuntimeError + end + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end + end + + describe "#wait" do + it "immediately returns when counter is zero" do + channel = Channel(Nil).new(1) + + spawn do + wg = WaitGroup.new(0) + wg.wait + channel.send(nil) + end + + select + when channel.receive + # success + when timeout(1.second) + fail "expected #wait to not block the fiber" + end + end + + it "immediately raises when counter is negative" do + wg = WaitGroup.new(0) + expect_raises(RuntimeError) { wg.done } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } + end end it "waits until concurrent executions are finished" do diff --git a/src/none.cr b/src/none.cr new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/wait_group.cr b/src/wait_group.cr index 6d642d1617cf..19afe781c2e9 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -60,14 +60,15 @@ class WaitGroup # reach zero before all fibers are done. def add(n : Int32 = 1) : Nil new_value = @counter.add(n) + n - raise RuntimeError.new("Negative WaitGroup counter") if new_value < 0 - return unless new_value == 0 + return if new_value > 0 @lock.sync do @waiting.consume_each do |node| node.value.enqueue end end + + raise RuntimeError.new("Negative WaitGroup counter") if new_value < 0 end # Decrements the counter by one. Must be called by concurrent fibers once they @@ -82,19 +83,26 @@ class WaitGroup # # Can be called from different fibers. def wait : Nil - return if @counter.get == 0 - waiting = Waiting.new(Fiber.current) + case @counter.get <=> 0 + when -1 + raise RuntimeError.new("Negative WaitGroup counter") + when 0 + return + when 1 + waiting = Waiting.new(Fiber.current) - @lock.sync do - # must check again to avoid a race condition where #done may have - # decremented the counter to zero between the above check and #wait - # acquiring the lock; we'd push the current fiber to the wait list that - # would never be resumed (oops) - return if @counter.get == 0 + @lock.sync do + # must check again to avoid a race condition where #done may have + # decremented the counter to zero between the above check and #wait + # acquiring the lock; we'd push the current fiber to the wait list that + # would never be resumed (oops) + return if @counter.get == 0 - @waiting.push(pointerof(waiting)) - end + @waiting.push(pointerof(waiting)) + end - Crystal::Scheduler.reschedule + Crystal::Scheduler.reschedule + raise RuntimeError.new("Negative WaitGroup counter") if @counter.get < 0 + end end end From 8e8756e1446f65c23e711af77c644d6394207d91 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Mon, 25 Mar 2024 13:09:24 +0100 Subject: [PATCH 10/15] fixup! Fix: resume waiting fibers + raise on negative counter --- spec/std/wait_group_spec.cr | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 1bed2ff37966..d7d5cf1e8839 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -19,7 +19,7 @@ describe WaitGroup do it "resumes waiters when reaching negative counter" do wg = WaitGroup.new(1) spawn do - block_until_pending_waiter + block_until_pending_waiter(wg) wg.add(-2) rescue RuntimeError end @@ -38,7 +38,7 @@ describe WaitGroup do it "resumes waiters when reaching negative counter" do wg = WaitGroup.new(1) spawn do - block_until_pending_waiter + block_until_pending_waiter(wg) wg.add(-2) rescue RuntimeError end From 8c06bbdb88aa9d69766b4b36ed8fa10d77a1f4d4 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Mon, 25 Mar 2024 14:37:33 +0100 Subject: [PATCH 11/15] fix: remove src/once + use done in spec (not add) --- spec/std/wait_group_spec.cr | 3 ++- src/none.cr | 0 2 files changed, 2 insertions(+), 1 deletion(-) delete mode 100644 src/none.cr diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index d7d5cf1e8839..ea3938fa16cf 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -39,7 +39,8 @@ describe WaitGroup do wg = WaitGroup.new(1) spawn do block_until_pending_waiter(wg) - wg.add(-2) + wg.done + wg.done rescue RuntimeError end expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } diff --git a/src/none.cr b/src/none.cr deleted file mode 100644 index e69de29bb2d1..000000000000 From 1bb92245aa1157f806c0c8e9a1889b3f8ef796e8 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Mon, 25 Mar 2024 18:32:58 +0100 Subject: [PATCH 12/15] Improve readability of WaitGroup#wait Co-authored-by: Sijawusz Pur Rahnama --- src/wait_group.cr | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wait_group.cr b/src/wait_group.cr index 19afe781c2e9..3d6f4f0bffe5 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -83,12 +83,12 @@ class WaitGroup # # Can be called from different fibers. def wait : Nil - case @counter.get <=> 0 - when -1 + case @counter.get + when .negative? raise RuntimeError.new("Negative WaitGroup counter") - when 0 + when .zero? return - when 1 + when .positive? waiting = Waiting.new(Fiber.current) @lock.sync do From bc7c1749e6871703f07e0eda0ae57b235b229c98 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 28 Mar 2024 12:27:08 +0100 Subject: [PATCH 13/15] Fix: race conditions It's now impossible for `#add` to increment a negative counter back into a positive one. `#wait` now checks for negative counter in addition to zero counter right after grabbing the lock. --- spec/std/wait_group_spec.cr | 16 ++++++++++- src/wait_group.cr | 56 ++++++++++++++++++++++--------------- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index ea3938fa16cf..1956ebfbe25b 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -7,6 +7,10 @@ private def block_until_pending_waiter(wg) end end +private def forge_counter(wg, value) + wg.@counter.set(value) +end + describe WaitGroup do describe "#add" do it "can't decrement to a negative counter" do @@ -25,6 +29,16 @@ describe WaitGroup do end expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } end + + it "can't increment after reaching negative counter" do + wg = WaitGroup.new + forge_counter(wg, -1) + + # check twice, to make sure the waitgroup counter wasn't incremented back + # to a positive value! + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) } + expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) } + end end describe "#done" do @@ -39,7 +53,7 @@ describe WaitGroup do wg = WaitGroup.new(1) spawn do block_until_pending_waiter(wg) - wg.done + forge_counter(wg, 0) wg.done rescue RuntimeError end diff --git a/src/wait_group.cr b/src/wait_group.cr index 3d6f4f0bffe5..3c51f8c40c30 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -59,8 +59,18 @@ class WaitGroup # decrement the counter, to make sure that the counter may never inadvertently # reach zero before all fibers are done. def add(n : Int32 = 1) : Nil - new_value = @counter.add(n) + n - return if new_value > 0 + counter = @counter.get(:acquire) + new_counter = uninitialized Int32 + + loop do + raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 + + new_counter = counter + n + counter, success = @counter.compare_and_set(counter, new_counter, :acquire_release, :acquire) + break if success + end + + return if new_counter > 0 @lock.sync do @waiting.consume_each do |node| @@ -68,7 +78,7 @@ class WaitGroup end end - raise RuntimeError.new("Negative WaitGroup counter") if new_value < 0 + raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0 end # Decrements the counter by one. Must be called by concurrent fibers once they @@ -83,26 +93,28 @@ class WaitGroup # # Can be called from different fibers. def wait : Nil - case @counter.get - when .negative? - raise RuntimeError.new("Negative WaitGroup counter") - when .zero? - return - when .positive? - waiting = Waiting.new(Fiber.current) - - @lock.sync do - # must check again to avoid a race condition where #done may have - # decremented the counter to zero between the above check and #wait - # acquiring the lock; we'd push the current fiber to the wait list that - # would never be resumed (oops) - return if @counter.get == 0 - - @waiting.push(pointerof(waiting)) - end + return if done? + + waiting = Waiting.new(Fiber.current) - Crystal::Scheduler.reschedule - raise RuntimeError.new("Negative WaitGroup counter") if @counter.get < 0 + @lock.sync do + # must check again to avoid a race condition where #done may have + # decremented the counter to zero between the above check and #wait + # acquiring the lock; we'd push the current fiber to the wait list that + # would never be resumed (oops) + return if done? + + @waiting.push(pointerof(waiting)) end + + Crystal::Scheduler.reschedule + + done? + end + + private def done? + counter = @counter.get(:acquire) + raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 + counter == 0 end end From 0deae7c49d78909594b7546fa2dc6e4185bc7684 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Thu, 28 Mar 2024 14:02:43 +0100 Subject: [PATCH 14/15] Fix: raise on early wake up --- spec/std/wait_group_spec.cr | 12 ++++++++++++ src/wait_group.cr | 3 ++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/spec/std/wait_group_spec.cr b/spec/std/wait_group_spec.cr index 1956ebfbe25b..459af8d5c898 100644 --- a/spec/std/wait_group_spec.cr +++ b/spec/std/wait_group_spec.cr @@ -84,6 +84,18 @@ describe WaitGroup do expect_raises(RuntimeError) { wg.done } expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait } end + + it "raises when counter is positive after wake up" do + wg = WaitGroup.new(1) + waiter = Fiber.current + + spawn do + block_until_pending_waiter(wg) + waiter.enqueue + end + + expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait } + end end it "waits until concurrent executions are finished" do diff --git a/src/wait_group.cr b/src/wait_group.cr index 3c51f8c40c30..2af83d8343a9 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -109,7 +109,8 @@ class WaitGroup Crystal::Scheduler.reschedule - done? + return if done? + raise RuntimeError.new("Positive WaitGroup counter (early wake up?)") end private def done? From cd6ae904da9ff4555255b17757fbfdcc6cb8a0f1 Mon Sep 17 00:00:00 2001 From: Julien Portalier Date: Fri, 29 Mar 2024 11:11:06 +0100 Subject: [PATCH 15/15] Fix: avoid uninitialized (and avoid a nilable) --- src/wait_group.cr | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wait_group.cr b/src/wait_group.cr index 2af83d8343a9..d9ae4ae740ac 100644 --- a/src/wait_group.cr +++ b/src/wait_group.cr @@ -60,16 +60,15 @@ class WaitGroup # reach zero before all fibers are done. def add(n : Int32 = 1) : Nil counter = @counter.get(:acquire) - new_counter = uninitialized Int32 loop do raise RuntimeError.new("Negative WaitGroup counter") if counter < 0 - new_counter = counter + n - counter, success = @counter.compare_and_set(counter, new_counter, :acquire_release, :acquire) + counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire) break if success end + new_counter = counter + n return if new_counter > 0 @lock.sync do