diff --git a/faust/transport/_cython/conductor.pyx b/faust/transport/_cython/conductor.pyx index 68ad89d61..38e6c75b9 100644 --- a/faust/transport/_cython/conductor.pyx +++ b/faust/transport/_cython/conductor.pyx @@ -17,6 +17,9 @@ cdef class ConductorHandler: object acquire_flow_control object consumer object wait_until_producer_ebb + object consumer_on_buffer_full + object consumer_on_buffer_drop + def __init__(self, object conductor, object tp, object channels): self.conductor = conductor @@ -29,6 +32,21 @@ cdef class ConductorHandler: self.acquire_flow_control = self.app.flow_control.acquire self.wait_until_producer_ebb = self.app.producer.buffer.wait_until_ebb self.consumer = self.app.consumer + # We divide `stream_buffer_maxsize` with Queue.pressure_ratio + # find a limit to the number of messages we will buffer + # before considering the buffer to be under high pressure. + # When the buffer is under high pressure, we call + # Consumer.on_buffer_full(tp) to remove this topic partition + # from the fetcher. + # We still accept anything that's currently in the fetcher (it's + # already in memory so we are just moving the data) without blocking, + # but signal the fetcher to stop retrieving any more data for this + # partition. + self.consumer_on_buffer_full = self.app.consumer.on_buffer_full + + # when the buffer drops down to half we re-enable fetching + # from the partition. + self.consumer_on_buffer_drop = self.app.consumer.on_buffer_drop async def __call__(self, object message): cdef: @@ -78,6 +96,18 @@ cdef class ConductorHandler: await chan.put(event) delivered.add(chan) + # callback called when the queue is under high pressure/ + # about to become full. + def on_pressure_high(self) -> None: + self.on_topic_buffer_full(self.tp) + self.consumer_on_buffer_full(self.tp) + + # callback used when pressure drops. + # added to Queue._pending_pressure_drop_callbacks + # when the buffer is under high pressure/full. + def on_pressure_drop(self) -> None: + self.consumer_on_buffer_drop(self.tp) + cdef object _decode(self, object event, object channel, object event_keyid): keyid = channel.key_type, channel.value_type if event_keyid is None or event is None: @@ -96,5 +126,9 @@ cdef class ConductorHandler: full.append((event, channel)) return False else: - queue.put_nowait(event) + queue.put_nowait_enhanced( + value=event, + on_pressure_high=self.on_pressure_high, + on_pressure_drop=self.on_pressure_drop, + ) return True diff --git a/faust/transport/conductor.py b/faust/transport/conductor.py index ba24bee5f..e55ddff81 100644 --- a/faust/transport/conductor.py +++ b/faust/transport/conductor.py @@ -137,7 +137,11 @@ async def on_message(message: Message) -> None: if queue.full(): full.append((event, chan)) continue - queue.put_nowait(event) + queue.put_nowait_enhanced( + event, + on_pressure_high=on_pressure_high, + on_pressure_drop=on_pressure_drop, + ) else: # subsequent channels may have a different # key/value type pair, meaning they all can @@ -153,7 +157,11 @@ async def on_message(message: Message) -> None: if queue.full(): full.append((dest_event, chan)) continue - queue.put_nowait(dest_event) + queue.put_nowait_enhanced( + dest_event, + on_pressure_high=on_pressure_high, + on_pressure_drop=on_pressure_drop, + ) delivered.add(chan) if full: for _, dest_chan in full: @@ -165,6 +173,7 @@ async def on_message(message: Message) -> None: ], return_when=asyncio.ALL_COMPLETED, ) + except KeyDecodeError as exc: remaining = channels - delivered message.ack(app.consumer, n=len(remaining)) diff --git a/faust/transport/consumer.py b/faust/transport/consumer.py index 3337542f5..3c8260a30 100644 --- a/faust/transport/consumer.py +++ b/faust/transport/consumer.py @@ -518,9 +518,11 @@ def _set_active_tps(self, tps: Set[TP]) -> Set[TP]: return xtps def on_buffer_full(self, tp: TP) -> None: - active_partitions = self._get_active_partitions() - active_partitions.discard(tp) - self._buffered_partitions.add(tp) + # do not remove the partition when in recovery + if not self.app.rebalancing: + active_partitions = self._get_active_partitions() + active_partitions.discard(tp) + self._buffered_partitions.add(tp) def on_buffer_drop(self, tp: TP) -> None: buffered_partitions = self._buffered_partitions