diff --git a/twilight-standby/src/future.rs b/twilight-standby/src/future.rs index 2e6e418f72..b28e9b8008 100644 --- a/twilight-standby/src/future.rs +++ b/twilight-standby/src/future.rs @@ -12,7 +12,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; -/// Future canceled due to Standby being dropped. +/// Future canceled due to Standby being dropped or shutdown. #[derive(Debug)] pub struct Canceled(()); diff --git a/twilight-standby/src/lib.rs b/twilight-standby/src/lib.rs index 7087ffa5d8..4fcdba37d6 100644 --- a/twilight-standby/src/lib.rs +++ b/twilight-standby/src/lib.rs @@ -19,7 +19,7 @@ use dashmap::DashMap; use std::{ fmt, hash::Hash, - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::{AtomicBool, AtomicUsize, Ordering}, }; use tokio::sync::{mpsc, oneshot}; use twilight_model::{ @@ -125,6 +125,8 @@ pub struct Standby { /// List of reaction bystanders where the ID of the message is known /// beforehand. reactions: DashMap, Vec>>, + /// Whether the standby is shutdown or not. + shutdown: AtomicBool, } impl Standby { @@ -186,6 +188,47 @@ impl Standby { completions } + /// Returns whether or not the bystander is shutdown. + pub fn is_shutdown(&self) -> bool { + self.shutdown.load(Ordering::Relaxed) + } + + /// Cancels all this instance's [`WaitForFuture`] and [`WaitForStream`]. + /// + /// # Example + /// + /// ```no_run + /// # use twilight_gateway::{Intents, Shard, ShardId}; + /// # use twilight_standby::Standby; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty()); + /// # let standby = Standby::new(); + /// use tokio_stream::StreamExt as _; + /// use twilight_gateway::{CloseFrame, Message}; + /// + /// shard.close(CloseFrame::NORMAL); + /// + /// while let Some(item) = shard.next().await { + /// match item { + /// Ok(Message::Close(_)) => break, + /// Ok(Message::Text(_)) => unimplemented!(), + /// Err(source) => unimplemented!(), + /// } + /// } + /// + /// // Cancel event handlers waiting for new events. + /// standby.shutdown(); + /// + /// unimplemented!("await all event handlers"); + /// + /// # Ok(()) } + /// ``` + pub fn shutdown(&self) { + self.shutdown.store(true, Ordering::Relaxed); + self.clear(); + } + /// Wait for an event in a certain guild. /// /// To wait for multiple guild events matching the given predicate use @@ -217,7 +260,7 @@ impl Standby { /// # Errors /// /// The returned future resolves to a [`Canceled`] error if the associated - /// [`Standby`] instance is dropped. + /// [`Standby`] instance is dropped or shutdown. /// /// [`BanAdd`]: twilight_model::gateway::payload::incoming::BanAdd /// [`wait_for_stream`]: Self::wait_for_stream @@ -226,7 +269,7 @@ impl Standby { guild_id: Id, check: F, ) -> WaitForFuture { - Self::wait_for_inner(&self.guilds, guild_id, Box::new(check)) + self.wait_for_inner(&self.guilds, guild_id, Box::new(check)) } /// Wait for a stream of events in a certain guild. @@ -265,7 +308,7 @@ impl Standby { /// # Errors /// /// The returned stream ends when the associated [`Standby`] instance is - /// dropped. + /// dropped or shutdown. /// /// [`BanAdd`]: twilight_model::gateway::payload::incoming::BanAdd /// [`wait_for`]: Self::wait_for @@ -274,7 +317,7 @@ impl Standby { guild_id: Id, check: F, ) -> WaitForStream { - Self::wait_for_stream_inner(&self.guilds, guild_id, Box::new(check)) + self.wait_for_stream_inner(&self.guilds, guild_id, Box::new(check)) } /// Wait for an event not in a certain guild. This must be filtered by an @@ -306,7 +349,7 @@ impl Standby { /// # Errors /// /// The returned future resolves to a [`Canceled`] error if the associated - /// [`Standby`] instance is dropped. + /// [`Standby`] instance is dropped or shutdown. /// /// [`Ready`]: twilight_model::gateway::payload::incoming::Ready /// [`wait_for_event_stream`]: Self::wait_for_event_stream @@ -349,7 +392,7 @@ impl Standby { /// # Errors /// /// The returned stream ends when the associated [`Standby`] instance is - /// dropped. + /// dropped or shutdown. /// /// [`Ready`]: twilight_model::gateway::payload::incoming::Ready /// [`wait_for_event`]: Self::wait_for_event @@ -391,7 +434,7 @@ impl Standby { /// # Errors /// /// The returned future resolves to a [`Canceled`] error if the associated - /// [`Standby`] instance is dropped. + /// [`Standby`] instance is dropped or shutdown. /// /// [`wait_for_message_stream`]: Self::wait_for_message_stream pub fn wait_for_message bool + Send + Sync + 'static>( @@ -399,7 +442,7 @@ impl Standby { channel_id: Id, check: F, ) -> WaitForFuture { - Self::wait_for_inner(&self.messages, channel_id, Box::new(check)) + self.wait_for_inner(&self.messages, channel_id, Box::new(check)) } /// Wait for a stream of message in a certain channel. @@ -437,7 +480,7 @@ impl Standby { /// # Errors /// /// The returned stream ends when the associated [`Standby`] instance is - /// dropped. + /// dropped or shutdown. /// /// [`wait_for_message`]: Self::wait_for_message pub fn wait_for_message_stream bool + Send + Sync + 'static>( @@ -445,7 +488,7 @@ impl Standby { channel_id: Id, check: F, ) -> WaitForStream { - Self::wait_for_stream_inner(&self.messages, channel_id, Box::new(check)) + self.wait_for_stream_inner(&self.messages, channel_id, Box::new(check)) } /// Wait for a reaction on a certain message. @@ -477,7 +520,7 @@ impl Standby { /// # Errors /// /// The returned future resolves to a [`Canceled`] error if the associated - /// [`Standby`] instance is dropped. + /// [`Standby`] instance is dropped or shutdown. /// /// [`wait_for_reaction_stream`]: Self::wait_for_reaction_stream pub fn wait_for_reaction bool + Send + Sync + 'static>( @@ -485,7 +528,7 @@ impl Standby { message_id: Id, check: F, ) -> WaitForFuture { - Self::wait_for_inner(&self.reactions, message_id, Box::new(check)) + self.wait_for_inner(&self.reactions, message_id, Box::new(check)) } /// Wait for a stream of reactions on a certain message. @@ -525,7 +568,7 @@ impl Standby { /// # Errors /// /// The returned stream ends when the associated [`Standby`] instance is - /// dropped. + /// dropped or shutdown. /// /// [`wait_for_reaction`]: Self::wait_for_reaction pub fn wait_for_reaction_stream bool + Send + Sync + 'static>( @@ -533,7 +576,7 @@ impl Standby { message_id: Id, check: F, ) -> WaitForStream { - Self::wait_for_stream_inner(&self.reactions, message_id, Box::new(check)) + self.wait_for_stream_inner(&self.reactions, message_id, Box::new(check)) } /// Wait for a component on a certain message. @@ -570,7 +613,7 @@ impl Standby { message_id: Id, check: F, ) -> WaitForFuture { - Self::wait_for_inner(&self.components, message_id, Box::new(check)) + self.wait_for_inner(&self.components, message_id, Box::new(check)) } /// Wait for a stream of components on a certain message. @@ -608,7 +651,7 @@ impl Standby { /// # Errors /// /// The returned stream ends when the associated [`Standby`] instance is - /// dropped. + /// dropped or shutdown. /// /// [`wait_for_component`]: Self::wait_for_component pub fn wait_for_component_stream bool + Send + Sync + 'static>( @@ -616,7 +659,16 @@ impl Standby { message_id: Id, check: F, ) -> WaitForStream { - Self::wait_for_stream_inner(&self.components, message_id, Box::new(check)) + self.wait_for_stream_inner(&self.components, message_id, Box::new(check)) + } + + /// Clears all bystanders. + fn clear(&self) { + self.components.clear(); + self.events.clear(); + self.guilds.clear(); + self.messages.clear(); + self.reactions.clear(); } /// Next event ID in [`Standby::event_counter`]. @@ -624,18 +676,33 @@ impl Standby { self.event_counter.fetch_add(1, Ordering::Relaxed) } + /// Run a specified action if not shutdown. + /// + /// If shutdown during invocation, `action`'s added bystanders are removed. + fn cancellable(&self, action: impl FnOnce()) { + if !self.is_shutdown() { + action(); + if self.is_shutdown() { + self.clear(); + } + } + } + /// Wait for a `T`. fn wait_for_inner( + &self, map: &DashMap, Vec>>, id: Id, check: Box bool + Send + Sync + 'static>, ) -> WaitForFuture { let (tx, rx) = oneshot::channel(); - let mut entry = map.entry(id).or_default(); - entry.push(Bystander { - func: check, - sender: Some(Sender::Future(tx)), + self.cancellable(|| { + let mut entry = map.entry(id).or_default(); + entry.push(Bystander { + func: check, + sender: Some(Sender::Future(tx)), + }); }); WaitForFuture { rx } @@ -643,16 +710,19 @@ impl Standby { /// Wait for a stream of `T`. fn wait_for_stream_inner( + &self, map: &DashMap, Vec>>, id: Id, check: Box bool + Send + Sync + 'static>, ) -> WaitForStream { let (tx, rx) = mpsc::unbounded_channel(); - let mut entry = map.entry(id).or_default(); - entry.push(Bystander { - func: check, - sender: Some(Sender::Stream(tx)), + self.cancellable(|| { + let mut entry = map.entry(id).or_default(); + entry.push(Bystander { + func: check, + sender: Some(Sender::Stream(tx)), + }); }); WaitForStream { rx } @@ -665,13 +735,15 @@ impl Standby { ) -> WaitForFuture { let (tx, rx) = oneshot::channel(); - self.events.insert( - self.next_event_id(), - Bystander { - func: check, - sender: Some(Sender::Future(tx)), - }, - ); + self.cancellable(|| { + self.events.insert( + self.next_event_id(), + Bystander { + func: check, + sender: Some(Sender::Future(tx)), + }, + ); + }); WaitForFuture { rx } } @@ -683,13 +755,15 @@ impl Standby { ) -> WaitForStream { let (tx, rx) = mpsc::unbounded_channel(); - self.events.insert( - self.next_event_id(), - Bystander { - func: check, - sender: Some(Sender::Stream(tx)), - }, - ); + self.cancellable(|| { + self.events.insert( + self.next_event_id(), + Bystander { + func: check, + sender: Some(Sender::Stream(tx)), + }, + ); + }); WaitForStream { rx } } @@ -1416,4 +1490,16 @@ mod tests { standby.process(&Event::ReactionAdd(Box::new(ReactionAdd(reaction())))); assert!(matches!(wait.await, Ok(Event::ReactionAdd(_)))); } + + #[tokio::test] + async fn test_shutdown() { + let standby = Standby::new(); + + let wait = standby.wait_for_event(|event| event.kind() == EventType::MessageCreate); + standby.shutdown(); + assert!(wait.await.is_err()); + + let wait = standby.wait_for_event(|event| event.kind() == EventType::MessageCreate); + assert!(wait.await.is_err()); + } }