Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions dramatiq_workflow/_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

class AtMostOnceBarrier(dramatiq.rate_limits.Barrier):
"""
The AtMostOnceBarrier is a barrier that ensures that it is released at most
once.

We use this because we want to avoid running callbacks in chains multiple
times. Running callbacks more than once can have compounding effects
especially when groups are involved.

The downside of this is that we cannot guarantee that the barrier will be
released at all. Theoretically a worker could die after releasing the
barrier but just before it has a chance to schedule the callbacks.
A barrier that lets the workflow middleware record completion in two steps.

The regular barrier `wait` semantics are untouched. Once all parties arrive,
callers may optionally invoke :meth:`confirm_release` to permanently record
that the completion callbacks have been scheduled. By deferring this final
confirmation until after expensive work (for example, unserializing a large
workflow) has succeeded, we lower the chance that a worker fails between
releasing the barrier and scheduling follow-up tasks.
"""

def __init__(self, backend, key, *args, ttl=900000):
Expand All @@ -27,18 +25,11 @@ def create(self, parties):
self.backend.add(self.ran_key, -1, self.ttl)
return super().create(parties)

def wait(self, *args, block=True, timeout=None):
if block:
# Blocking with an AtMostOnceBarrier is not supported as it could
# lead to clients waiting indefinitely if the barrier already
# released.
raise ValueError("Blocking is not supported by AtMostOnceBarrier")

released = super().wait(*args, block=False)
if released:
never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl)
if not never_released:
logger.warning("Barrier %s release already recorded; ignoring subsequent release attempt", self.key)
return never_released

return False
def confirm_release(self):
"""
Check and set the flag that ensures callbacks only run once.
"""
never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl)
if not never_released:
logger.warning("Barrier %s release already recorded; ignoring subsequent release attempt", self.key)
return never_released
16 changes: 15 additions & 1 deletion dramatiq_workflow/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,24 @@ def _process_completion_callbacks(

logger.debug("Barrier completed: %s", completion_id)
completion_callbacks.pop()
confirm_release = getattr(barrier, "confirm_release", None) or (lambda: True)

if remaining_workflow is None:
release_confirmed = confirm_release()
if not release_confirmed:
break
continue

workflow = unserialize_workflow(remaining_workflow)
# unserialize_workflow can be expensive for large workflows. By confirming the
# AtMostOnceBarrier only after this step succeeds we reduce the chance that a
# worker crashes between recording the release and doing the heavy work.
release_confirmed = confirm_release()
if not release_confirmed:
break

workflow_with_completion_callbacks(
unserialize_workflow(remaining_workflow),
workflow,
broker,
completion_callbacks,
).run()
Expand Down
26 changes: 5 additions & 21 deletions dramatiq_workflow/tests/test_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,11 @@ def setUp(self):
self.ttl = 900000
self.barrier = AtMostOnceBarrier(self.backend, self.key, ttl=self.ttl)

def test_wait_block_true_raises(self):
with self.assertRaises(ValueError) as context:
self.barrier.wait(block=True)
self.assertEqual(str(context.exception), "Blocking is not supported by AtMostOnceBarrier")

def test_wait_releases_once(self):
def test_confirm_release_only_succeeds_once(self):
self.barrier.create(self.parties)
for _ in range(self.parties - 1):
result = self.barrier.wait(block=False)
self.assertFalse(result)
result = self.barrier.wait(block=False)
self.assertTrue(result)
result = self.barrier.wait(block=False)
self.assertFalse(result)
self.assertFalse(self.barrier.wait(block=False))

def test_wait_does_not_release_when_db_emptied(self):
"""
If the store is emptied, the barrier should not be released.
"""
self.barrier.create(self.parties)
self.backend.db = {}
for _ in range(self.parties):
result = self.barrier.wait(block=False)
self.assertFalse(result)
self.assertTrue(self.barrier.wait(block=False))
self.assertTrue(self.barrier.confirm_release())
self.assertFalse(self.barrier.confirm_release())
24 changes: 24 additions & 0 deletions dramatiq_workflow/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,27 @@ def test_after_process_message_with_lazy_loaded_workflow(self, mock_time):
self.assertIn(workflow_ref, storage.loaded_workflows)

self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None)

@mock.patch("dramatiq_workflow._middleware.workflow_with_completion_callbacks")
@mock.patch("dramatiq_workflow._middleware.unserialize_workflow")
def test_barrier_confirmation_happens_after_unserialize(self, mock_unserialize, mock_workflow_with_callbacks):
call_order: list[str] = []
mock_unserialize.side_effect = lambda workflow: call_order.append("unserialize") or mock.sentinel.workflow

workflow_runner = mock.Mock()
workflow_runner.run.side_effect = lambda: call_order.append("run")
mock_workflow_with_callbacks.return_value = workflow_runner

with mock.patch.object(AtMostOnceBarrier, "confirm_release", autospec=True) as mock_confirm:
mock_confirm.side_effect = lambda *_args, **_kwargs: call_order.append("confirm") or True
barrier_key = "barrier_order"
barrier = AtMostOnceBarrier(self.rate_limiter_backend, barrier_key)
barrier.create(1)
message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier_key, object(), True)]})

self.middleware.after_process_message(self.broker, message)

self.assertEqual(call_order, ["unserialize", "confirm", "run"])
mock_unserialize.assert_called_once()
mock_confirm.assert_called_once()
mock_workflow_with_callbacks.assert_called_once()