diff --git a/faust/transport/producer.py b/faust/transport/producer.py index 06dc379c7..ab90be8b4 100644 --- a/faust/transport/producer.py +++ b/faust/transport/producer.py @@ -8,7 +8,6 @@ """ import asyncio import time -from asyncio import QueueEmpty from typing import Any, Awaitable, Mapping, Optional, cast from mode import Seconds, Service, get_logger @@ -29,6 +28,7 @@ class ProducerBuffer(Service, ProducerBufferT): def __post_init__(self) -> None: self.pending = asyncio.Queue() + self.message_sent = asyncio.Event() def put(self, fut: FutureMessage) -> None: """Add message to buffer. @@ -50,34 +50,22 @@ async def on_stop(self) -> None: async def flush(self) -> None: """Flush all messages (draining the buffer).""" - get_pending = self.pending.get_nowait - send_pending = self._send_pending - - if self.size: - while True: - try: - msg = get_pending() - except QueueEmpty: - break - else: - await send_pending(msg) + await self.flush_atmost(None) - async def flush_atmost(self, n: int) -> int: + async def flush_atmost(self, max_messages: Optional[int]) -> int: """Flush at most ``n`` messages.""" - get_pending = self.pending.get_nowait - send_pending = self._send_pending - - if self.size: - for i in range(n): - try: - msg = get_pending() - except QueueEmpty: - return i - else: - await send_pending(msg) - return n - else: - return 0 + flushed_messages = 0 + while True: + if self.state != "running" and self.size: + raise RuntimeError("Cannot flush: Producer not Running") + if self.size != 0 and ( + (max_messages is None or flushed_messages < max_messages) + ): + self.message_sent.clear() + await self.message_sent.wait() + flushed_messages += 1 + else: + return flushed_messages async def _send_pending(self, fut: FutureMessage) -> None: await fut.message.channel.publish_message(fut, wait=False) @@ -109,6 +97,7 @@ async def _handle_pending(self) -> None: while not self.should_stop: msg = await get_pending() await send_pending(msg) + self.message_sent.set() @property def size(self) -> int: diff --git a/tests/unit/transport/test_producer.py b/tests/unit/transport/test_producer.py index 416137d1b..c958e2011 100644 --- a/tests/unit/transport/test_producer.py +++ b/tests/unit/transport/test_producer.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from mode.utils.mocks import AsyncMock, Mock, call @@ -57,7 +59,22 @@ async def on_send(fut): @pytest.mark.asyncio async def test_wait_until_ebb(self, *, buf): buf.max_messages = 10 - buf._send_pending = AsyncMock() + + def create_send_pending_mock(max_messages): + sent_messages = 0 + + async def _inner(): + nonlocal sent_messages + if sent_messages < max_messages: + sent_messages += 1 + return + else: + await asyncio.Future() + + return create_send_pending_mock + + buf._send_pending = create_send_pending_mock(10) + await buf.start() self._put(buf, range(20)) assert buf.size == 20 @@ -71,7 +88,22 @@ async def test_wait_until_ebb(self, *, buf): @pytest.mark.asyncio async def test_flush(self, *, buf): - buf._send_pending = AsyncMock() + def create_send_pending_mock(max_messages): + sent_messages = 0 + + async def _inner(): + nonlocal sent_messages + if sent_messages < max_messages: + sent_messages += 1 + return + else: + await asyncio.Future() + + return create_send_pending_mock + + buf._send_pending = create_send_pending_mock(10) + await buf.start() + assert not buf.size await buf.flush() @@ -87,7 +119,19 @@ def _put(self, buf, items): @pytest.mark.asyncio async def test_flush_atmost(self, *, buf): - buf._send_pending = AsyncMock() + def create_send_pending_mock(max_messages): + sent_messages = 0 + + async def _inner(): + nonlocal sent_messages + if sent_messages < max_messages: + sent_messages += 1 + return + else: + await asyncio.Future() + + return create_send_pending_mock + assert await buf.flush_atmost(10) == 0 self._put(buf, range(3))