Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ collector = ["gateway"]
# Enables the Framework trait which is an abstraction for old-style text commands.
framework = ["gateway"]
# Enables gateway support, which allows bots to listen for Discord events.
gateway = ["model", "flate2"]
gateway = ["model", "flate2", "dashmap"]
# Enables HTTP, which enables bots to execute actions on Discord.
http = ["dashmap", "mime_guess", "percent-encoding"]
# Enables wrapper methods around HTTP requests on model types.
Expand Down
22 changes: 9 additions & 13 deletions examples/e07_shard_manager/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
//!
//! This isn't particularly useful for small bots, but is useful for large bots that may need to
//! split load on separate VPSs or dedicated servers. Additionally, Discord requires that there be
//! at least one shard for every
//! 2500 guilds that a bot is on.
//! at least one shard for every 2500 guilds that a bot is on.
//!
//! For the purposes of this example, we'll print the current statuses of the two shards to the
//! terminal every 30 seconds. This includes the ID of the shard, the current connection stage,
Expand Down Expand Up @@ -60,22 +59,19 @@ async fn main() {
let mut client =
Client::builder(token, intents).event_handler(Handler).await.expect("Err creating client");

// Here we get a HashMap of of the shards' status that we move into a new thread. A separate
// tokio task holds the ownership to each entry, so each one will require acquiring a lock
// before reading.
let runners = client.shard_manager.runner_info();
// Here we get a DashMap of of the shards' status that we move into a new thread.
let runners = client.shard_manager.runners.clone();

tokio::spawn(async move {
loop {
sleep(Duration::from_secs(30)).await;

for (id, runner) in &runners {
if let Ok(runner) = runner.lock() {
println!(
"Shard ID {} is {} with a latency of {:?}",
id, runner.stage, runner.latency,
);
}
for entry in runners.iter() {
let (id, (runner, _)) = entry.pair();
println!(
"Shard ID {} is {} with a latency of {:?}",
id, runner.stage, runner.latency,
);
}
}
});
Expand Down
6 changes: 4 additions & 2 deletions src/gateway/client/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use dashmap::DashMap;
use futures::channel::mpsc::UnboundedSender as Sender;

#[cfg(feature = "cache")]
Expand Down Expand Up @@ -46,7 +47,8 @@ pub struct Context {
pub http: Arc<Http>,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
pub runner_info: Arc<Mutex<ShardRunnerInfo>>,
/// Metadata about the initialised shards, and their control channels.
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
#[cfg(feature = "collector")]
pub(crate) collectors: Arc<parking_lot::RwLock<Vec<CollectorCallback>>>,
}
Expand Down
39 changes: 15 additions & 24 deletions src/gateway/sharding/shard_manager.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::num::NonZeroU16;
use std::sync::Arc;
#[cfg(feature = "framework")]
use std::sync::OnceLock;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

use dashmap::DashMap;
use futures::StreamExt;
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
use tokio::time::{sleep, timeout};
Expand Down Expand Up @@ -67,7 +67,7 @@ pub struct ShardManager {
///
/// **Note**: It is highly recommended to not mutate this yourself unless you need to. Instead
/// prefer to use methods on this struct that are provided where possible.
pub runners: HashMap<ShardId, (Arc<Mutex<ShardRunnerInfo>>, Sender<ShardRunnerMessage>)>,
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
/// A copy of the client's voice manager.
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
Expand Down Expand Up @@ -103,7 +103,7 @@ impl ShardManager {
framework: opt.framework,
last_start: None,
queue: ShardQueue::new(opt.max_concurrency),
runners: HashMap::new(),
runners: Arc::new(DashMap::new()),
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
Expand Down Expand Up @@ -187,7 +187,7 @@ impl ShardManager {
pub fn restart(&mut self, shard_id: ShardId) {
info!("Restarting shard {shard_id}");

if let Some((_, tx)) = self.runners.remove(&shard_id) {
if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Restart) {
warn!("Failed to send restart signal to shard {shard_id}: {why:?}");
}
Expand All @@ -203,7 +203,7 @@ impl ShardManager {
pub fn shutdown(&mut self, shard_id: ShardId, code: u16) {
info!("Shutting down shard {}", shard_id);

if let Some((_, tx)) = self.runners.remove(&shard_id) {
if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(code)) {
warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}");
}
Expand Down Expand Up @@ -263,18 +263,13 @@ impl ShardManager {
let cloned_http = Arc::clone(&self.http);
shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));

let runner_info = Arc::new(Mutex::new(ShardRunnerInfo {
latency: None,
stage: ConnectionStage::Disconnected,
}));

let mut runner = ShardRunner::new(ShardRunnerOptions {
data: Arc::clone(&self.data),
event_handler: self.event_handler.clone(),
raw_event_handler: self.raw_event_handler.clone(),
#[cfg(feature = "framework")]
framework: self.framework.get().cloned(),
runner_info: Arc::clone(&runner_info),
runners: Arc::clone(&self.runners),
manager_tx: self.manager_tx.clone(),
#[cfg(feature = "voice")]
voice_manager: self.voice_manager.clone(),
Expand All @@ -284,6 +279,11 @@ impl ShardManager {
http: Arc::clone(&self.http),
});

let runner_info = ShardRunnerInfo {
latency: None,
stage: ConnectionStage::Disconnected,
};

self.runners.insert(shard_id, (runner_info, runner.runner_tx()));

spawn_named("shard_runner::run", async move { runner.run().await });
Expand All @@ -305,17 +305,7 @@ impl ShardManager {
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
#[must_use]
pub fn shards_instantiated(&self) -> Vec<ShardId> {
self.runners.keys().copied().collect()
}

/// Returns the [`ShardRunnerInfo`] corresponding to each running shard.
///
/// Note that the shard runner also holds a copy of its info, which is why each entry is
/// wrapped in `Arc<Mutex<T>>`.
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
#[must_use]
pub fn runner_info(&self) -> HashMap<ShardId, Arc<Mutex<ShardRunnerInfo>>> {
self.runners.iter().map(|(&id, (runner, _))| (id, Arc::clone(runner))).collect()
self.runners.iter().map(|entries| *entries.key()).collect()
}

/// Returns the gateway intents used for this gateway connection.
Expand All @@ -334,7 +324,8 @@ impl Drop for ShardManager {
fn drop(&mut self) {
info!("Shutting down all shards");

for (shard_id, (_, tx)) in self.runners.drain() {
for entry in self.runners.iter() {
let (shard_id, (_, tx)) = entry.pair();
info!("Shutting down shard {}", shard_id);
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(1000)) {
warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}");
Expand Down
18 changes: 11 additions & 7 deletions src/gateway/sharding/shard_runner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use dashmap::DashMap;
use dashmap::try_result::TryResult;
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
Expand Down Expand Up @@ -28,7 +30,7 @@ use crate::model::event::Event;
use crate::model::event::GatewayEvent;
#[cfg(feature = "voice")]
use crate::model::id::ChannelId;
use crate::model::id::GuildId;
use crate::model::id::{GuildId, ShardId};
use crate::model::user::OnlineStatus;

/// A runner for managing a [`Shard`] and its respective WebSocket client.
Expand All @@ -38,7 +40,7 @@ pub struct ShardRunner {
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
framework: Option<Arc<dyn Framework>>,
runner_info: Arc<Mutex<ShardRunnerInfo>>,
runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
// channel to send messages back to the shard manager
manager_tx: Sender<ShardManagerMessage>,
// channel to receive messages from the shard manager and dispatches
Expand Down Expand Up @@ -66,7 +68,7 @@ impl ShardRunner {
raw_event_handler: opt.raw_event_handler,
#[cfg(feature = "framework")]
framework: opt.framework,
runner_info: opt.runner_info,
runners: opt.runners,
manager_tx: opt.manager_tx,
runner_rx: rx,
runner_tx: tx,
Expand Down Expand Up @@ -458,7 +460,9 @@ impl ShardRunner {

#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
fn update_runner_info(&self) {
if let Ok(mut runner_info) = self.runner_info.try_lock() {
if let TryResult::Present(mut entry) = self.runners.try_get_mut(&self.shard.info.id) {
let (runner_info, _) = entry.value_mut();

runner_info.latency = self.shard.latency();
runner_info.stage = self.shard.stage();
}
Expand All @@ -473,7 +477,7 @@ impl ShardRunner {
http: Arc::clone(&self.http),
#[cfg(feature = "cache")]
cache: Arc::clone(&self.cache),
runner_info: Arc::clone(&self.runner_info),
runners: Arc::clone(&self.runners),
#[cfg(feature = "collector")]
collectors: Arc::clone(&self.collectors),
}
Expand All @@ -491,7 +495,7 @@ pub struct ShardRunnerOptions {
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Option<Arc<dyn Framework>>,
pub runner_info: Arc<Mutex<ShardRunnerInfo>>,
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
pub manager_tx: Sender<ShardManagerMessage>,
pub shard: Shard,
#[cfg(feature = "voice")]
Expand Down
Loading