Skip to content

Commit

Permalink
fix race (out of order) in flushing topics (wrong state is stored in …
Browse files Browse the repository at this point in the history
…changelog) (#112)

* fix race (out of order) in flushing topics

* fix style issues

* fix race in ProducerBuffer.flush: fix tests

since the flush functions now simply
wait for the main producer loop to send the messages, we have to
mock this behaviour in the tests, too

Co-authored-by: Tobias Rauter <[email protected]>
Co-authored-by: Vikram Patki <[email protected]>
  • Loading branch information
3 people authored Feb 27, 2021
1 parent 70e5516 commit a841a0e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
43 changes: 16 additions & 27 deletions faust/transport/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""
import asyncio
import time
from asyncio import QueueEmpty
from typing import Any, Awaitable, Mapping, Optional, cast

from mode import Seconds, Service, get_logger
Expand All @@ -29,6 +28,7 @@ class ProducerBuffer(Service, ProducerBufferT):

def __post_init__(self) -> None:
self.pending = asyncio.Queue()
self.message_sent = asyncio.Event()

def put(self, fut: FutureMessage) -> None:
"""Add message to buffer.
Expand All @@ -50,34 +50,22 @@ async def on_stop(self) -> None:

async def flush(self) -> None:
"""Flush all messages (draining the buffer)."""
get_pending = self.pending.get_nowait
send_pending = self._send_pending

if self.size:
while True:
try:
msg = get_pending()
except QueueEmpty:
break
else:
await send_pending(msg)
await self.flush_atmost(None)

async def flush_atmost(self, n: int) -> int:
async def flush_atmost(self, max_messages: Optional[int]) -> int:
"""Flush at most ``n`` messages."""
get_pending = self.pending.get_nowait
send_pending = self._send_pending

if self.size:
for i in range(n):
try:
msg = get_pending()
except QueueEmpty:
return i
else:
await send_pending(msg)
return n
else:
return 0
flushed_messages = 0
while True:
if self.state != "running" and self.size:
raise RuntimeError("Cannot flush: Producer not Running")
if self.size != 0 and (
(max_messages is None or flushed_messages < max_messages)
):
self.message_sent.clear()
await self.message_sent.wait()
flushed_messages += 1
else:
return flushed_messages

async def _send_pending(self, fut: FutureMessage) -> None:
await fut.message.channel.publish_message(fut, wait=False)
Expand Down Expand Up @@ -109,6 +97,7 @@ async def _handle_pending(self) -> None:
while not self.should_stop:
msg = await get_pending()
await send_pending(msg)
self.message_sent.set()

@property
def size(self) -> int:
Expand Down
50 changes: 47 additions & 3 deletions tests/unit/transport/test_producer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest
from mode.utils.mocks import AsyncMock, Mock, call

Expand Down Expand Up @@ -57,7 +59,22 @@ async def on_send(fut):
@pytest.mark.asyncio
async def test_wait_until_ebb(self, *, buf):
buf.max_messages = 10
buf._send_pending = AsyncMock()

def create_send_pending_mock(max_messages):
sent_messages = 0

async def _inner():
nonlocal sent_messages
if sent_messages < max_messages:
sent_messages += 1
return
else:
await asyncio.Future()

return create_send_pending_mock

buf._send_pending = create_send_pending_mock(10)
await buf.start()
self._put(buf, range(20))
assert buf.size == 20

Expand All @@ -71,7 +88,22 @@ async def test_wait_until_ebb(self, *, buf):

@pytest.mark.asyncio
async def test_flush(self, *, buf):
buf._send_pending = AsyncMock()
def create_send_pending_mock(max_messages):
sent_messages = 0

async def _inner():
nonlocal sent_messages
if sent_messages < max_messages:
sent_messages += 1
return
else:
await asyncio.Future()

return create_send_pending_mock

buf._send_pending = create_send_pending_mock(10)
await buf.start()

assert not buf.size
await buf.flush()

Expand All @@ -87,7 +119,19 @@ def _put(self, buf, items):

@pytest.mark.asyncio
async def test_flush_atmost(self, *, buf):
buf._send_pending = AsyncMock()
def create_send_pending_mock(max_messages):
sent_messages = 0

async def _inner():
nonlocal sent_messages
if sent_messages < max_messages:
sent_messages += 1
return
else:
await asyncio.Future()

return create_send_pending_mock

assert await buf.flush_atmost(10) == 0

self._put(buf, range(3))
Expand Down

0 comments on commit a841a0e

Please sign in to comment.