Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 2 additions & 28 deletions src/gateway/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,35 +392,9 @@ impl IntoFuture for ClientBuilder {
/// [`Event::MessageCreate`]: crate::model::event::Event::MessageCreate
pub struct Client {
data: Arc<dyn std::any::Any + Send + Sync>,
/// A HashMap of all shards instantiated by the Client.
/// The shard manager for the client.
///
/// The key is the shard ID and the value is the shard itself.
///
/// # Examples
///
/// If you call [`client.start_shard(3, 5)`][`Client::start_shard`], this HashMap will only
/// ever contain a single key of `3`, as that's the only Shard the client is responsible for.
///
/// If you call [`client.start_shards(10)`][`Client::start_shards`], this HashMap will contain
/// keys 0 through 9, one for each shard handled by the client.
///
/// Printing the number of shards currently instantiated by the client every 5 seconds:
///
/// ```rust,no_run
/// # use serenity::prelude::*;
/// # use std::time::Duration;
/// #
/// # fn run(client: Client) {
/// tokio::spawn(async move {
/// loop {
/// let count = client.shard_manager.shards_instantiated().len();
/// println!("Shard count instantiated: {}", count);
///
/// tokio::time::sleep(Duration::from_millis(5000)).await;
/// }
/// });
/// # }
/// ```
/// This is the brains, managing shards (websocket connections) and bot lifecycle.
pub shard_manager: ShardManager,
/// The voice manager for the client.
///
Expand Down
61 changes: 10 additions & 51 deletions src/gateway/sharding/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ impl ShardManager {
}
}

/// Retrieves a function which can be used to shut down the ShardManager later.
///
/// This function will return `true` if the ShardManager has successfully been
/// notified to shut down, or false if it has already shut down and been dropped.
pub fn get_shutdown_trigger(&self) -> impl FnOnce() -> bool + Send + use<> {
let manager_tx = self.manager_tx.clone();
move || manager_tx.unbounded_send(ShardManagerMessage::Quit(Ok(()))).is_ok()
}

Comment on lines +121 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think exposing self.manager_tx is maybe a more flexible solution? This function kinda sticks out, and I don't really like the name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more flexible, but it's exposing too much API to public consumers imo, leading to defensive or defective programming inside the currently quite simple ShardManager. This is the bare minimum interface that users (at least I) need, and future refactors/reworks cannot be merged if they cannot replicate this behavior.

/// The main interface for starting the management of shards. Initializes the shards by
/// queueing them for starting, and then listens for [`ShardManagerMessage`]s in a loop.
///
Expand Down Expand Up @@ -161,7 +170,7 @@ impl ShardManager {
/// Note that this queues all shards but does not actually start them. To start the manager's
/// event loop and dispatch [`ShardRunner`]s as they get queued, call [`Self::run`] instead.
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
pub fn initialize(&mut self, shard_index: u16, shard_init: u16, shard_total: NonZeroU16) {
fn initialize(&mut self, shard_index: u16, shard_init: u16, shard_total: NonZeroU16) {
let shard_to = shard_index + shard_init;

self.shard_total = shard_total;
Expand All @@ -177,39 +186,6 @@ impl ShardManager {
self.queue.push_back(shard_id);
}

/// Restarts a shard runner.
///
/// Sends a shutdown signal to a shard's associated [`ShardRunner`], and then queues an
/// initialization of a new shard runner for the same shard.
///
/// [`ShardRunner`]: super::ShardRunner
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
pub fn restart(&mut self, shard_id: ShardId) {
info!("Restarting shard {shard_id}");

if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Restart) {
warn!("Failed to send restart signal to shard {shard_id}: {why:?}");
}
}
}

/// Attempts to shut down the shard runner by Id.
///
/// **Note**: If the receiving end of an mpsc channel - owned by the shard runner - no longer
/// exists, then the shard runner will not know it should shut down. This _should never happen_.
/// It may already be stopped.
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
pub fn shutdown(&mut self, shard_id: ShardId, code: u16) {
info!("Shutting down shard {}", shard_id);

if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(code)) {
warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}");
}
}
}

// This function assumes that each of the shard ids are bucketed separately according to
// `max_concurrency`. If this assumption is violated, you will likely get ratelimited.
//
Expand Down Expand Up @@ -291,23 +267,6 @@ impl ShardManager {
Ok(())
}

/// Returns whether the shard manager contains an active instance of a shard runner responsible
/// for the given ID.
///
/// If a shard has been queued but has not yet been initiated, then this will return `false`.
#[must_use]
pub fn has(&self, shard_id: ShardId) -> bool {
self.runners.contains_key(&shard_id)
}

/// Returns the [`ShardId`]s of the shards that have been instantiated and currently have a
/// valid [`ShardRunner`].
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
#[must_use]
pub fn shards_instantiated(&self) -> Vec<ShardId> {
self.runners.iter().map(|entries| *entries.key()).collect()
}

/// Returns the gateway intents used for this gateway connection.
#[must_use]
pub fn intents(&self) -> GatewayIntents {
Expand Down