Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion twilight-standby/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
};
use tokio::sync::{mpsc, oneshot};

/// Future canceled due to Standby being dropped.
/// Future canceled due to Standby being dropped or shutdown.
#[derive(Debug)]
pub struct Canceled(());

Expand Down
166 changes: 126 additions & 40 deletions twilight-standby/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use dashmap::DashMap;
use std::{
fmt,
hash::Hash,
sync::atomic::{AtomicUsize, Ordering},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
use tokio::sync::{mpsc, oneshot};
use twilight_model::{
Expand Down Expand Up @@ -125,6 +125,8 @@ pub struct Standby {
/// List of reaction bystanders where the ID of the message is known
/// beforehand.
reactions: DashMap<Id<MessageMarker>, Vec<Bystander<ReactionAdd>>>,
/// Whether the standby is shutdown or not.
shutdown: AtomicBool,
}

impl Standby {
Expand Down Expand Up @@ -186,6 +188,47 @@ impl Standby {
completions
}

/// Returns whether or not the bystander is shutdown.
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::Relaxed)
}

/// Cancels all this instance's [`WaitForFuture`] and [`WaitForStream`].
///
/// # Example
///
/// ```no_run
/// # use twilight_gateway::{Intents, Shard, ShardId};
/// # use twilight_standby::Standby;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
/// # let standby = Standby::new();
/// use tokio_stream::StreamExt as _;
/// use twilight_gateway::{CloseFrame, Message};
///
/// shard.close(CloseFrame::NORMAL);
///
/// while let Some(item) = shard.next().await {
/// match item {
/// Ok(Message::Close(_)) => break,
/// Ok(Message::Text(_)) => unimplemented!(),
/// Err(source) => unimplemented!(),
/// }
/// }
///
/// // Cancel event handlers waiting for new events.
/// standby.shutdown();
///
/// unimplemented!("await all event handlers");
///
/// # Ok(()) }
/// ```
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
self.clear();
}

/// Wait for an event in a certain guild.
///
/// To wait for multiple guild events matching the given predicate use
Expand Down Expand Up @@ -217,7 +260,7 @@ impl Standby {
/// # Errors
///
/// The returned future resolves to a [`Canceled`] error if the associated
/// [`Standby`] instance is dropped.
/// [`Standby`] instance is dropped or shutdown.
///
/// [`BanAdd`]: twilight_model::gateway::payload::incoming::BanAdd
/// [`wait_for_stream`]: Self::wait_for_stream
Expand All @@ -226,7 +269,7 @@ impl Standby {
guild_id: Id<GuildMarker>,
check: F,
) -> WaitForFuture<Event> {
Self::wait_for_inner(&self.guilds, guild_id, Box::new(check))
self.wait_for_inner(&self.guilds, guild_id, Box::new(check))
}

/// Wait for a stream of events in a certain guild.
Expand Down Expand Up @@ -265,7 +308,7 @@ impl Standby {
/// # Errors
///
/// The returned stream ends when the associated [`Standby`] instance is
/// dropped.
/// dropped or shutdown.
///
/// [`BanAdd`]: twilight_model::gateway::payload::incoming::BanAdd
/// [`wait_for`]: Self::wait_for
Expand All @@ -274,7 +317,7 @@ impl Standby {
guild_id: Id<GuildMarker>,
check: F,
) -> WaitForStream<Event> {
Self::wait_for_stream_inner(&self.guilds, guild_id, Box::new(check))
self.wait_for_stream_inner(&self.guilds, guild_id, Box::new(check))
}

/// Wait for an event not in a certain guild. This must be filtered by an
Expand Down Expand Up @@ -306,7 +349,7 @@ impl Standby {
/// # Errors
///
/// The returned future resolves to a [`Canceled`] error if the associated
/// [`Standby`] instance is dropped.
/// [`Standby`] instance is dropped or shutdown.
///
/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
/// [`wait_for_event_stream`]: Self::wait_for_event_stream
Expand Down Expand Up @@ -349,7 +392,7 @@ impl Standby {
/// # Errors
///
/// The returned stream ends when the associated [`Standby`] instance is
/// dropped.
/// dropped or shutdown.
///
/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
/// [`wait_for_event`]: Self::wait_for_event
Expand Down Expand Up @@ -391,15 +434,15 @@ impl Standby {
/// # Errors
///
/// The returned future resolves to a [`Canceled`] error if the associated
/// [`Standby`] instance is dropped.
/// [`Standby`] instance is dropped or shutdown.
///
/// [`wait_for_message_stream`]: Self::wait_for_message_stream
pub fn wait_for_message<F: Fn(&MessageCreate) -> bool + Send + Sync + 'static>(
&self,
channel_id: Id<ChannelMarker>,
check: F,
) -> WaitForFuture<MessageCreate> {
Self::wait_for_inner(&self.messages, channel_id, Box::new(check))
self.wait_for_inner(&self.messages, channel_id, Box::new(check))
}

/// Wait for a stream of message in a certain channel.
Expand Down Expand Up @@ -437,15 +480,15 @@ impl Standby {
/// # Errors
///
/// The returned stream ends when the associated [`Standby`] instance is
/// dropped.
/// dropped or shutdown.
///
/// [`wait_for_message`]: Self::wait_for_message
pub fn wait_for_message_stream<F: Fn(&MessageCreate) -> bool + Send + Sync + 'static>(
&self,
channel_id: Id<ChannelMarker>,
check: F,
) -> WaitForStream<MessageCreate> {
Self::wait_for_stream_inner(&self.messages, channel_id, Box::new(check))
self.wait_for_stream_inner(&self.messages, channel_id, Box::new(check))
}

/// Wait for a reaction on a certain message.
Expand Down Expand Up @@ -477,15 +520,15 @@ impl Standby {
/// # Errors
///
/// The returned future resolves to a [`Canceled`] error if the associated
/// [`Standby`] instance is dropped.
/// [`Standby`] instance is dropped or shutdown.
///
/// [`wait_for_reaction_stream`]: Self::wait_for_reaction_stream
pub fn wait_for_reaction<F: Fn(&ReactionAdd) -> bool + Send + Sync + 'static>(
&self,
message_id: Id<MessageMarker>,
check: F,
) -> WaitForFuture<ReactionAdd> {
Self::wait_for_inner(&self.reactions, message_id, Box::new(check))
self.wait_for_inner(&self.reactions, message_id, Box::new(check))
}

/// Wait for a stream of reactions on a certain message.
Expand Down Expand Up @@ -525,15 +568,15 @@ impl Standby {
/// # Errors
///
/// The returned stream ends when the associated [`Standby`] instance is
/// dropped.
/// dropped or shutdown.
///
/// [`wait_for_reaction`]: Self::wait_for_reaction
pub fn wait_for_reaction_stream<F: Fn(&ReactionAdd) -> bool + Send + Sync + 'static>(
&self,
message_id: Id<MessageMarker>,
check: F,
) -> WaitForStream<ReactionAdd> {
Self::wait_for_stream_inner(&self.reactions, message_id, Box::new(check))
self.wait_for_stream_inner(&self.reactions, message_id, Box::new(check))
}

/// Wait for a component on a certain message.
Expand Down Expand Up @@ -570,7 +613,7 @@ impl Standby {
message_id: Id<MessageMarker>,
check: F,
) -> WaitForFuture<InteractionCreate> {
Self::wait_for_inner(&self.components, message_id, Box::new(check))
self.wait_for_inner(&self.components, message_id, Box::new(check))
}

/// Wait for a stream of components on a certain message.
Expand Down Expand Up @@ -608,51 +651,78 @@ impl Standby {
/// # Errors
///
/// The returned stream ends when the associated [`Standby`] instance is
/// dropped.
/// dropped or shutdown.
///
/// [`wait_for_component`]: Self::wait_for_component
pub fn wait_for_component_stream<F: Fn(&InteractionCreate) -> bool + Send + Sync + 'static>(
&self,
message_id: Id<MessageMarker>,
check: F,
) -> WaitForStream<InteractionCreate> {
Self::wait_for_stream_inner(&self.components, message_id, Box::new(check))
self.wait_for_stream_inner(&self.components, message_id, Box::new(check))
}

/// Clears all bystanders.
fn clear(&self) {
self.components.clear();
self.events.clear();
self.guilds.clear();
self.messages.clear();
self.reactions.clear();
}

/// Next event ID in [`Standby::event_counter`].
fn next_event_id(&self) -> usize {
self.event_counter.fetch_add(1, Ordering::Relaxed)
}

/// Run a specified action if not shutdown.
///
/// If shutdown during invocation, `action`'s added bystanders are removed.
fn cancellable(&self, action: impl FnOnce()) {
if !self.is_shutdown() {
action();
if self.is_shutdown() {
self.clear();
}
}
}

/// Wait for a `T`.
fn wait_for_inner<IdKind, T>(
&self,
map: &DashMap<Id<IdKind>, Vec<Bystander<T>>>,
id: Id<IdKind>,
check: Box<dyn Fn(&T) -> bool + Send + Sync + 'static>,
) -> WaitForFuture<T> {
let (tx, rx) = oneshot::channel();

let mut entry = map.entry(id).or_default();
entry.push(Bystander {
func: check,
sender: Some(Sender::Future(tx)),
self.cancellable(|| {
let mut entry = map.entry(id).or_default();
entry.push(Bystander {
func: check,
sender: Some(Sender::Future(tx)),
});
});

WaitForFuture { rx }
}

/// Wait for a stream of `T`.
fn wait_for_stream_inner<IdKind, T>(
&self,
map: &DashMap<Id<IdKind>, Vec<Bystander<T>>>,
id: Id<IdKind>,
check: Box<dyn Fn(&T) -> bool + Send + Sync + 'static>,
) -> WaitForStream<T> {
let (tx, rx) = mpsc::unbounded_channel();

let mut entry = map.entry(id).or_default();
entry.push(Bystander {
func: check,
sender: Some(Sender::Stream(tx)),
self.cancellable(|| {
let mut entry = map.entry(id).or_default();
entry.push(Bystander {
func: check,
sender: Some(Sender::Stream(tx)),
});
});

WaitForStream { rx }
Expand All @@ -665,13 +735,15 @@ impl Standby {
) -> WaitForFuture<Event> {
let (tx, rx) = oneshot::channel();

self.events.insert(
self.next_event_id(),
Bystander {
func: check,
sender: Some(Sender::Future(tx)),
},
);
self.cancellable(|| {
self.events.insert(
self.next_event_id(),
Bystander {
func: check,
sender: Some(Sender::Future(tx)),
},
);
});

WaitForFuture { rx }
}
Expand All @@ -683,13 +755,15 @@ impl Standby {
) -> WaitForStream<Event> {
let (tx, rx) = mpsc::unbounded_channel();

self.events.insert(
self.next_event_id(),
Bystander {
func: check,
sender: Some(Sender::Stream(tx)),
},
);
self.cancellable(|| {
self.events.insert(
self.next_event_id(),
Bystander {
func: check,
sender: Some(Sender::Stream(tx)),
},
);
});

WaitForStream { rx }
}
Expand Down Expand Up @@ -1416,4 +1490,16 @@ mod tests {
standby.process(&Event::ReactionAdd(Box::new(ReactionAdd(reaction()))));
assert!(matches!(wait.await, Ok(Event::ReactionAdd(_))));
}

#[tokio::test]
async fn test_shutdown() {
let standby = Standby::new();

let wait = standby.wait_for_event(|event| event.kind() == EventType::MessageCreate);
standby.shutdown();
assert!(wait.await.is_err());

let wait = standby.wait_for_event(|event| event.kind() == EventType::MessageCreate);
assert!(wait.await.is_err());
}
}