diff --git a/Cargo.toml b/Cargo.toml index 281fc0bf252..4b3aa92bb92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,7 +98,7 @@ collector = ["gateway"] # Enables the Framework trait which is an abstraction for old-style text commands. framework = ["gateway"] # Enables gateway support, which allows bots to listen for Discord events. -gateway = ["model", "flate2"] +gateway = ["model", "flate2", "dashmap"] # Enables HTTP, which enables bots to execute actions on Discord. http = ["dashmap", "mime_guess", "percent-encoding"] # Enables wrapper methods around HTTP requests on model types. diff --git a/examples/e07_shard_manager/src/main.rs b/examples/e07_shard_manager/src/main.rs index 059b396f6e9..7c816782e27 100644 --- a/examples/e07_shard_manager/src/main.rs +++ b/examples/e07_shard_manager/src/main.rs @@ -7,8 +7,7 @@ //! //! This isn't particularly useful for small bots, but is useful for large bots that may need to //! split load on separate VPSs or dedicated servers. Additionally, Discord requires that there be -//! at least one shard for every -//! 2500 guilds that a bot is on. +//! at least one shard for every 2500 guilds that a bot is on. //! //! For the purposes of this example, we'll print the current statuses of the two shards to the //! terminal every 30 seconds. This includes the ID of the shard, the current connection stage, @@ -60,22 +59,19 @@ async fn main() { let mut client = Client::builder(token, intents).event_handler(Handler).await.expect("Err creating client"); - // Here we get a HashMap of of the shards' status that we move into a new thread. A separate - // tokio task holds the ownership to each entry, so each one will require acquiring a lock - // before reading. - let runners = client.shard_manager.runner_info(); + // Here we get a DashMap of of the shards' status that we move into a new thread. + let runners = client.shard_manager.runners.clone(); tokio::spawn(async move { loop { sleep(Duration::from_secs(30)).await; - for (id, runner) in &runners { - if let Ok(runner) = runner.lock() { - println!( - "Shard ID {} is {} with a latency of {:?}", - id, runner.stage, runner.latency, - ); - } + for entry in runners.iter() { + let (id, (runner, _)) = entry.pair(); + println!( + "Shard ID {} is {} with a latency of {:?}", + id, runner.stage, runner.latency, + ); } } }); diff --git a/src/gateway/client/context.rs b/src/gateway/client/context.rs index 6c66cb157b8..cc61bf50e8a 100644 --- a/src/gateway/client/context.rs +++ b/src/gateway/client/context.rs @@ -1,5 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use dashmap::DashMap; use futures::channel::mpsc::UnboundedSender as Sender; #[cfg(feature = "cache")] @@ -46,7 +47,8 @@ pub struct Context { pub http: Arc, #[cfg(feature = "cache")] pub cache: Arc, - pub runner_info: Arc>, + /// Metadata about the initialised shards, and their control channels. + pub runners: Arc)>>, #[cfg(feature = "collector")] pub(crate) collectors: Arc>>, } diff --git a/src/gateway/sharding/shard_manager.rs b/src/gateway/sharding/shard_manager.rs index 57e8a637bc9..a9c4fc285e3 100644 --- a/src/gateway/sharding/shard_manager.rs +++ b/src/gateway/sharding/shard_manager.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; use std::num::NonZeroU16; +use std::sync::Arc; #[cfg(feature = "framework")] use std::sync::OnceLock; -use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; +use dashmap::DashMap; use futures::StreamExt; use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender}; use tokio::time::{sleep, timeout}; @@ -67,7 +67,7 @@ pub struct ShardManager { /// /// **Note**: It is highly recommended to not mutate this yourself unless you need to. Instead /// prefer to use methods on this struct that are provided where possible. - pub runners: HashMap>, Sender)>, + pub runners: Arc)>>, /// A copy of the client's voice manager. #[cfg(feature = "voice")] pub voice_manager: Option>, @@ -103,7 +103,7 @@ impl ShardManager { framework: opt.framework, last_start: None, queue: ShardQueue::new(opt.max_concurrency), - runners: HashMap::new(), + runners: Arc::new(DashMap::new()), #[cfg(feature = "voice")] voice_manager: opt.voice_manager, ws_url: opt.ws_url, @@ -187,7 +187,7 @@ impl ShardManager { pub fn restart(&mut self, shard_id: ShardId) { info!("Restarting shard {shard_id}"); - if let Some((_, tx)) = self.runners.remove(&shard_id) { + if let Some((_, (_, tx))) = self.runners.remove(&shard_id) { if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Restart) { warn!("Failed to send restart signal to shard {shard_id}: {why:?}"); } @@ -203,7 +203,7 @@ impl ShardManager { pub fn shutdown(&mut self, shard_id: ShardId, code: u16) { info!("Shutting down shard {}", shard_id); - if let Some((_, tx)) = self.runners.remove(&shard_id) { + if let Some((_, (_, tx))) = self.runners.remove(&shard_id) { if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(code)) { warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}"); } @@ -263,18 +263,13 @@ impl ShardManager { let cloned_http = Arc::clone(&self.http); shard.set_application_id_callback(move |id| cloned_http.set_application_id(id)); - let runner_info = Arc::new(Mutex::new(ShardRunnerInfo { - latency: None, - stage: ConnectionStage::Disconnected, - })); - let mut runner = ShardRunner::new(ShardRunnerOptions { data: Arc::clone(&self.data), event_handler: self.event_handler.clone(), raw_event_handler: self.raw_event_handler.clone(), #[cfg(feature = "framework")] framework: self.framework.get().cloned(), - runner_info: Arc::clone(&runner_info), + runners: Arc::clone(&self.runners), manager_tx: self.manager_tx.clone(), #[cfg(feature = "voice")] voice_manager: self.voice_manager.clone(), @@ -284,6 +279,11 @@ impl ShardManager { http: Arc::clone(&self.http), }); + let runner_info = ShardRunnerInfo { + latency: None, + stage: ConnectionStage::Disconnected, + }; + self.runners.insert(shard_id, (runner_info, runner.runner_tx())); spawn_named("shard_runner::run", async move { runner.run().await }); @@ -305,17 +305,7 @@ impl ShardManager { #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] #[must_use] pub fn shards_instantiated(&self) -> Vec { - self.runners.keys().copied().collect() - } - - /// Returns the [`ShardRunnerInfo`] corresponding to each running shard. - /// - /// Note that the shard runner also holds a copy of its info, which is why each entry is - /// wrapped in `Arc>`. - #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - #[must_use] - pub fn runner_info(&self) -> HashMap>> { - self.runners.iter().map(|(&id, (runner, _))| (id, Arc::clone(runner))).collect() + self.runners.iter().map(|entries| *entries.key()).collect() } /// Returns the gateway intents used for this gateway connection. @@ -334,7 +324,8 @@ impl Drop for ShardManager { fn drop(&mut self) { info!("Shutting down all shards"); - for (shard_id, (_, tx)) in self.runners.drain() { + for entry in self.runners.iter() { + let (shard_id, (_, tx)) = entry.pair(); info!("Shutting down shard {}", shard_id); if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(1000)) { warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}"); diff --git a/src/gateway/sharding/shard_runner.rs b/src/gateway/sharding/shard_runner.rs index 6f5c247591a..393c86d1818 100644 --- a/src/gateway/sharding/shard_runner.rs +++ b/src/gateway/sharding/shard_runner.rs @@ -1,5 +1,7 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use dashmap::DashMap; +use dashmap::try_result::TryResult; use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender}; use tokio_tungstenite::tungstenite; use tokio_tungstenite::tungstenite::error::Error as TungsteniteError; @@ -28,7 +30,7 @@ use crate::model::event::Event; use crate::model::event::GatewayEvent; #[cfg(feature = "voice")] use crate::model::id::ChannelId; -use crate::model::id::GuildId; +use crate::model::id::{GuildId, ShardId}; use crate::model::user::OnlineStatus; /// A runner for managing a [`Shard`] and its respective WebSocket client. @@ -38,7 +40,7 @@ pub struct ShardRunner { raw_event_handler: Option>, #[cfg(feature = "framework")] framework: Option>, - runner_info: Arc>, + runners: Arc)>>, // channel to send messages back to the shard manager manager_tx: Sender, // channel to receive messages from the shard manager and dispatches @@ -66,7 +68,7 @@ impl ShardRunner { raw_event_handler: opt.raw_event_handler, #[cfg(feature = "framework")] framework: opt.framework, - runner_info: opt.runner_info, + runners: opt.runners, manager_tx: opt.manager_tx, runner_rx: rx, runner_tx: tx, @@ -458,7 +460,9 @@ impl ShardRunner { #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] fn update_runner_info(&self) { - if let Ok(mut runner_info) = self.runner_info.try_lock() { + if let TryResult::Present(mut entry) = self.runners.try_get_mut(&self.shard.info.id) { + let (runner_info, _) = entry.value_mut(); + runner_info.latency = self.shard.latency(); runner_info.stage = self.shard.stage(); } @@ -473,7 +477,7 @@ impl ShardRunner { http: Arc::clone(&self.http), #[cfg(feature = "cache")] cache: Arc::clone(&self.cache), - runner_info: Arc::clone(&self.runner_info), + runners: Arc::clone(&self.runners), #[cfg(feature = "collector")] collectors: Arc::clone(&self.collectors), } @@ -491,7 +495,7 @@ pub struct ShardRunnerOptions { pub raw_event_handler: Option>, #[cfg(feature = "framework")] pub framework: Option>, - pub runner_info: Arc>, + pub runners: Arc)>>, pub manager_tx: Sender, pub shard: Shard, #[cfg(feature = "voice")]