From 6070ffdbcdddf6675923a4f9276f1ef61e39f183 Mon Sep 17 00:00:00 2001 From: Elena Frank Date: Mon, 5 May 2025 17:41:37 +0200 Subject: [PATCH 1/2] refactor(stream_map): use `Option` to mark exhaused `TaggedStream` --- src/stream_map.rs | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/stream_map.rs b/src/stream_map.rs index 75fc4b9..1934b3d 100644 --- a/src/stream_map.rs +++ b/src/stream_map.rs @@ -1,10 +1,9 @@ -use std::mem; use std::pin::Pin; use std::task::{Context, Poll, Waker}; use std::time::Duration; use futures_util::stream::{BoxStream, SelectAll}; -use futures_util::{stream, FutureExt, Stream, StreamExt}; +use futures_util::{FutureExt, Stream, StreamExt}; use crate::{Delay, PushError, Timeout}; @@ -69,11 +68,8 @@ where pub fn remove(&mut self, id: ID) -> Option> { let tagged = self.inner.iter_mut().find(|s| s.key == id)?; - - let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed()); - tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources. - - Some(inner) + let inner = tagged.inner.take()?; // `TaggedStream` will emit `None` on the next poll and ensure `SelectAll` cleans up the resources. + Some(inner.inner) } pub fn len(&self) -> usize { @@ -137,17 +133,14 @@ where struct TaggedStream { key: K, - inner: S, - - exhausted: bool, + inner: Option, } impl TaggedStream { fn new(key: K, inner: S) -> Self { Self { key, - inner, - exhausted: false, + inner: Some(inner), } } } @@ -160,15 +153,14 @@ where type Item = (K, Option); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.exhausted { + let Some(inner) = self.inner.as_mut() else { return Poll::Ready(None); - } + }; - match futures_util::ready!(self.inner.poll_next_unpin(cx)) { + match futures_util::ready!(inner.poll_next_unpin(cx)) { Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))), None => { - self.exhausted = true; - + self.inner.take(); Poll::Ready(Some((self.key.clone(), None))) } } @@ -237,7 +229,7 @@ mod tests { fn removing_stream() { let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1); - let _ = streams.try_push("ID", stream::once(ready(()))); + let _ = streams.try_push("ID", once(ready(()))); { let cancelled_stream = streams.remove("ID"); From e26a4d7ceb7f037e519364dfff6d40210f341a61 Mon Sep 17 00:00:00 2001 From: Elena Frank Date: Mon, 5 May 2025 17:44:48 +0200 Subject: [PATCH 2/2] feat(stream_map): iterable `StreamMap` Refactor existing `StreamMap` into a `StreamMapInterable` that doesn't box the streams and allows iterating over the inner streams. Re-add type `StreamMap` as a wrapper around `StreamMapIterable` that does the old boxing and thus avoids breaking API. --- src/lib.rs | 2 +- src/stream_map.rs | 147 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 235cb54..f5c307a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ pub use delay::Delay; pub use futures_map::FuturesMap; pub use futures_set::FuturesSet; pub use futures_tuple_set::FuturesTupleSet; -pub use stream_map::StreamMap; +pub use stream_map::{StreamMap, StreamMapIterable}; pub use stream_set::StreamSet; use std::fmt; diff --git a/src/stream_map.rs b/src/stream_map.rs index 1934b3d..f6c24c9 100644 --- a/src/stream_map.rs +++ b/src/stream_map.rs @@ -2,7 +2,7 @@ use std::pin::Pin; use std::task::{Context, Poll, Waker}; use std::time::Duration; -use futures_util::stream::{BoxStream, SelectAll}; +use futures_util::stream::{select_all, BoxStream, SelectAll}; use futures_util::{FutureExt, Stream, StreamExt}; use crate::{Delay, PushError, Timeout}; @@ -10,17 +10,68 @@ use crate::{Delay, PushError, Timeout}; /// Represents a map of [`Stream`]s. /// /// Each stream must finish within the specified time and the map never outgrows its capacity. -pub struct StreamMap { +pub struct StreamMap(StreamMapIterable>); + +impl StreamMap +where + ID: Clone + Unpin, +{ + pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self { + Self(StreamMapIterable::new(make_delay, capacity)) + } +} + +impl StreamMap +where + ID: Clone + PartialEq + Send + Unpin + 'static, + O: Send + 'static, +{ + /// Push a stream into the map. + pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError>> + where + F: Stream + Send + 'static, + { + self.0.try_push(id, stream.boxed()) + } + + pub fn remove(&mut self, id: ID) -> Option> { + self.0.remove(id) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic. + pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.0.poll_ready_unpin(cx) + } + + pub fn poll_next_unpin( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<(ID, Option>)> { + self.0.poll_next_unpin(cx) + } +} + +/// Iterable variant of [`StreamMap`] without boxed streams. +pub struct StreamMapIterable { make_delay: Box Delay + Send + Sync>, capacity: usize, - inner: SelectAll>>>, + inner: SelectAll>>, empty_waker: Option, full_waker: Option, } -impl StreamMap +impl StreamMapIterable where ID: Clone + Unpin, + F: Stream + Unpin, { pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self { Self { @@ -33,18 +84,18 @@ where } } -impl StreamMap +impl StreamMapIterable where ID: Clone + PartialEq + Send + Unpin + 'static, - O: Send + 'static, + F: Stream + Unpin, { /// Push a stream into the map. - pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError>> + pub fn try_push(&mut self, id: ID, stream: F) -> Result<(), PushError> where - F: Stream + Send + 'static, + F: Stream + Send + 'static, { if self.inner.len() >= self.capacity { - return Err(PushError::BeyondCapacity(stream.boxed())); + return Err(PushError::BeyondCapacity(stream)); } if let Some(waker) = self.empty_waker.take() { @@ -55,7 +106,7 @@ where self.inner.push(TaggedStream::new( id, TimeoutStream { - inner: stream.boxed(), + inner: stream, timeout: (self.make_delay)(), }, )); @@ -66,7 +117,7 @@ where } } - pub fn remove(&mut self, id: ID) -> Option> { + pub fn remove(&mut self, id: ID) -> Option { let tagged = self.inner.iter_mut().find(|s| s.key == id)?; let inner = tagged.inner.take()?; // `TaggedStream` will emit `None` on the next poll and ensure `SelectAll` cleans up the resources. Some(inner.inner) @@ -94,7 +145,7 @@ where pub fn poll_next_unpin( &mut self, cx: &mut Context<'_>, - ) -> Poll<(ID, Option>)> { + ) -> Poll<(ID, Option>)> { match futures_util::ready!(self.inner.poll_next_unpin(cx)) { None => { self.empty_waker = Some(cx.waker().clone()); @@ -109,6 +160,14 @@ where Some((id, None)) => Poll::Ready((id, None)), } } + + pub fn iter(&self) -> Iter { + Iter(self.inner.iter()) + } + + pub fn iter_mut(&mut self) -> IterMut { + IterMut(self.inner.iter_mut()) + } } struct TimeoutStream { @@ -167,6 +226,46 @@ where } } +pub struct Iter<'a, ID: Unpin, F: Unpin>(select_all::Iter<'a, TaggedStream>>); + +impl<'a, ID, F> Iterator for Iter<'a, ID, F> +where + ID: Clone + Unpin, + F: Unpin + Stream, +{ + type Item = (ID, &'a F); + + fn next(&mut self) -> Option { + let next = self.0.next()?; + Some((next.key.clone(), &next.inner.as_ref()?.inner)) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +pub struct IterMut<'a, ID: Unpin, F: Unpin>( + select_all::IterMut<'a, TaggedStream>>, +); + +impl<'a, ID, F> Iterator for IterMut<'a, ID, F> +where + ID: Clone + Unpin, + F: Unpin + Stream, +{ + type Item = (ID, &'a mut F); + + fn next(&mut self) -> Option { + let next = self.0.next()?; + Some((next.key.clone(), &mut next.inner.as_mut()?.inner)) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + #[cfg(all(test, feature = "futures-timer"))] mod tests { use futures::channel::mpsc; @@ -248,6 +347,30 @@ mod tests { ); } + #[test] + fn iterating_streams() { + const N: usize = 10; + let mut streams = + StreamMapIterable::new(|| Delay::futures_timer(Duration::from_millis(100)), N); + let mut sender = Vec::with_capacity(N); + for i in 0..N { + let (tx, rx) = mpsc::channel::<()>(1); + let _ = streams.try_push(i, rx); + sender.push(tx); + } + assert_eq!(streams.iter().count(), N); + for (i, (id, _)) in streams.iter().enumerate() { + let expect_id = N - i - 1; // Reverse order. + assert_eq!(id, expect_id); + } + assert!(!sender.iter().any(|tx| tx.is_closed())); + + for (_, rx) in streams.iter_mut() { + rx.close(); + } + assert!(sender.iter().all(|tx| tx.is_closed())); + } + #[tokio::test] async fn replaced_stream_is_still_registered() { let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);