Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions examples/gateway-parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
//! [`ShardMessageStream`]: twilight_gateway::stream::ShardMessageStream

use futures_util::{future::join_all, StreamExt};
use std::{env, iter, sync::Arc, thread};
use std::{env, iter, thread};
use tokio::{signal, sync::watch, task::JoinSet};
use twilight_gateway::{
queue::LocalQueue,
stream::{self, ShardEventStream},
CloseFrame, Config, Intents, Shard,
};
Expand All @@ -20,23 +19,15 @@ async fn main() -> anyhow::Result<()> {

let token = env::var("DISCORD_TOKEN")?;
let client = Client::new(token.clone());

let queue = Arc::new(LocalQueue::new());
// callback to create a config for each shard, useful for when not all
// shards have the same configuration, such as for per-shard presences
let config_callback = |_| {
Config::builder(token.clone(), Intents::GUILDS)
.queue(queue.clone())
.build()
};
let config = Config::new(token.clone(), Intents::GUILDS);

let tasks = thread::available_parallelism()?.get();

// Split shards into a vec of `tasks` vecs of shards.
let init = iter::repeat_with(Vec::new)
.take(tasks)
.collect::<Vec<Vec<_>>>();
let shards = stream::create_recommended(&client, config_callback)
let shards = stream::create_recommended(&client, config, |_, builder| builder.build())
.await?
.enumerate()
.fold(init, |mut fold, (idx, shard)| {
Expand Down
30 changes: 12 additions & 18 deletions examples/gateway-reshard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use futures_util::StreamExt;
use std::{env, sync::Arc, time::Duration};
use tokio::time;
use twilight_gateway::{
queue::{LocalQueue, Queue},
stream::{self, ShardEventStream, ShardMessageStream},
Config, Event, Intents, Shard, ShardId,
Config, ConfigBuilder, Event, Intents, Shard, ShardId,
};
use twilight_http::Client;

Expand All @@ -15,19 +14,13 @@ async fn main() -> anyhow::Result<()> {

let token = env::var("DISCORD_TOKEN")?;
let client = Arc::new(Client::new(token.clone()));
let queue: Arc<dyn Queue> = Arc::new(LocalQueue::new());

let config_callback = |_| {
// A queue must be specified in the builder for the shards to reuse the
// same one, which is necessary to not hit any gateway queue ratelimit.
Config::builder(
token.clone(),
Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT,
)
.queue(Arc::clone(&queue))
.build()
};
let mut shards = stream::create_recommended(&client, &config_callback)
let config = Config::new(
token.clone(),
Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT,
);
let config_callback = |_, builder: ConfigBuilder| builder.build();

let mut shards = stream::create_recommended(&client, config.clone(), &config_callback)
.await?
.collect::<Vec<_>>();

Expand All @@ -40,7 +33,7 @@ async fn main() -> anyhow::Result<()> {
_ = gateway_runner(Arc::clone(&client), shards) => break,
// Resharding complete! Time to run `gateway_runner` with the new
// list of shards.
Ok(Some(new_shards)) = reshard(&client, config_callback) => {
Ok(Some(new_shards)) = reshard(&client, config.clone(), config_callback) => {
// Assign the new list of shards to `shards`, dropping the
// old list.
shards = new_shards;
Expand Down Expand Up @@ -94,14 +87,15 @@ async fn event_handler(client: Arc<Client>, event: Event) -> anyhow::Result<()>
#[tracing::instrument(skip_all)]
async fn reshard(
client: &Client,
config_callback: impl Fn(ShardId) -> Config,
config: Config,
config_callback: impl Fn(ShardId, ConfigBuilder) -> Config,
) -> anyhow::Result<Option<Vec<Shard>>> {
const RESHARD_DURATION: Duration = Duration::from_secs(60 * 60 * 8);

// Reshard every eight hours.
time::sleep(RESHARD_DURATION).await;

let mut shards = stream::create_recommended(client, config_callback)
let mut shards = stream::create_recommended(client, config, config_callback)
.await?
.collect::<Vec<_>>();

Expand Down
15 changes: 3 additions & 12 deletions twilight-gateway/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ Create the recommended number of shards and stream over their events:

```rust,no_run
use futures::StreamExt;
use std::{collections::HashMap, env, sync::Arc};
use std::env;
use twilight_gateway::{
queue::LocalQueue,
stream::{self, ShardEventStream},
Config, Intents,
};
Expand All @@ -91,17 +90,9 @@ async fn main() -> anyhow::Result<()> {

let token = env::var("DISCORD_TOKEN")?;
let client = Client::new(token.clone());
let config = Config::new(token, Intents::GUILDS);

let queue = Arc::new(LocalQueue::new());
// Callback to create a config for each shard, useful for when not all shards
// have the same configuration, such as for per-shard presences.
let config_callback = |_| {
Config::builder(token.clone(), Intents::GUILDS)
.queue(queue.clone())
.build()
};

let mut shards = stream::create_recommended(&client, config_callback)
let mut shards = stream::create_recommended(&client, config, |_, builder| builder.build())
.await?
.collect::<Vec<_>>();

Expand Down
8 changes: 0 additions & 8 deletions twilight-gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ impl Config {
&self.token.inner
}

/// Set the TLS container for the configuration.
///
/// This is necessary for sharing a TLS container across configurations.
#[allow(clippy::missing_const_for_fn)]
pub(crate) fn set_tls(&mut self, tls: TlsContainer) {
self.tls = tls;
}

/// Session information to resume a shard on initialization.
pub(crate) fn take_session(&mut self) -> Option<Session> {
self.session.take()
Expand Down
107 changes: 38 additions & 69 deletions twilight-gateway/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
//! Utilities for managing collections of shards.
//!
//! Multiple shards may easily be created at once, with a per shard config
//! created from a `Fn(ShardId) -> Config` closure, with the help of the
//! `create_` set of functions. These functions will also reuse shards' TLS
//! context, something otherwise achieved by cloning an existing [`Config`], but
//! will not by default set a shared [session queue] (see
//! [`ConfigBuilder::queue`]).
//! created from a `Fn(ShardId, ConfigBuilder) -> Config` closure, with the help
//! of the `create_` set of functions. These functions will reuse shards'
//! TLS context and [session queue], something otherwise achieved by cloning an
//! existing [`Config`].
//!
//! # Concurrency
//!
Expand All @@ -32,9 +31,7 @@
//! [gateway-parallel]: https://github.com/twilight-rs/twilight/blob/main/examples/gateway-parallel.rs
//! [session queue]: crate::queue

use crate::{
error::ReceiveMessageError, message::Message, tls::TlsContainer, Config, Shard, ShardId,
};
use crate::{error::ReceiveMessageError, message::Message, Config, ConfigBuilder, Shard, ShardId};
use futures_util::{
future::BoxFuture,
stream::{FuturesUnordered, Stream, StreamExt},
Expand Down Expand Up @@ -112,9 +109,8 @@ pub enum StartRecommendedErrorType {
///
/// ```no_run
/// use futures::StreamExt;
/// use std::{env, sync::Arc};
/// use std::env;
/// use twilight_gateway::{
/// queue::LocalQueue,
/// stream::{self, ShardEventStream},
/// Config, Intents,
/// };
Expand All @@ -124,17 +120,9 @@ pub enum StartRecommendedErrorType {
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let token = env::var("DISCORD_TOKEN")?;
/// let client = Client::new(token.clone());
/// let config = Config::new(token.clone(), Intents::GUILDS);
///
/// let queue = Arc::new(LocalQueue::new());
/// // callback to create a config for each shard, useful for when not all shards
/// // have the same configuration, such as for per-shard presences
/// let config_callback = |_| {
/// Config::builder(token.clone(), Intents::GUILDS)
/// .queue(queue.clone())
/// .build()
/// };
///
/// let mut shards = stream::create_recommended(&client, config_callback)
/// let mut shards = stream::create_recommended(&client, config, |_, builder| builder.build())
/// .await?
/// .collect::<Vec<_>>();
///
Expand Down Expand Up @@ -224,9 +212,8 @@ impl<'a> Stream for ShardEventStream<'a> {
///
/// ```no_run
/// use futures::StreamExt;
/// use std::{env, sync::Arc};
/// use std::env;
/// use twilight_gateway::{
/// queue::LocalQueue,
/// stream::{self, ShardMessageStream},
/// Config, Intents,
/// };
Expand All @@ -236,17 +223,9 @@ impl<'a> Stream for ShardEventStream<'a> {
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let token = env::var("DISCORD_TOKEN")?;
/// let client = Client::new(token.clone());
/// let config = Config::new(token.clone(), Intents::GUILDS);
///
/// let queue = Arc::new(LocalQueue::new());
/// // callback to create a config for each shard, useful for when not all shards
/// // have the same configuration, such as for per-shard presences
/// let config_callback = |_| {
/// Config::builder(token.clone(), Intents::GUILDS)
/// .queue(queue.clone())
/// .build()
/// };
///
/// let mut shards = stream::create_recommended(&client, config_callback)
/// let mut shards = stream::create_recommended(&client, config, |_, builder| builder.build())
/// .await?
/// .collect::<Vec<_>>();
///
Expand Down Expand Up @@ -371,32 +350,26 @@ struct NextItemOutput<'a, Item> {
shard: &'a mut Shard,
}

/// Create a single bucket's worth of shards with provided configuration for
/// each shard.
/// Create a single bucket's worth of shards.
///
/// Passing a primary config is required. Further customization of this config
/// may be performed in the callback.
///
/// # Examples
///
/// Start bucket 2 out of 10 with 100 shards in total and collect them into a
/// list:
///
/// ```no_run
/// use std::{env, sync::Arc};
/// use twilight_gateway::{queue::LocalQueue, stream, Config, Intents};
/// use std::env;
/// use twilight_gateway::{stream, Config, Intents};
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let token = env::var("DISCORD_TOKEN")?;
///
/// let queue = Arc::new(LocalQueue::new());
/// // callback to create a config for each shard, useful for when not all shards
/// // have the same configuration, such as for per-shard presences
/// let config_callback = |_| {
/// Config::builder(token.clone(), Intents::GUILDS)
/// .queue(queue.clone())
/// .build()
/// };
///
/// let shards = stream::create_bucket(2, 10, 100, config_callback)
/// let config = Config::new(token.clone(), Intents::GUILDS);
/// let shards = stream::create_bucket(2, 10, 100, config, |_, builder| builder.build())
/// .map(|shard| (shard.id().number(), shard))
/// .collect::<Vec<_>>();
///
Expand All @@ -412,10 +385,11 @@ struct NextItemOutput<'a, Item> {
///
/// Panics if loading TLS certificates fails.
#[track_caller]
pub fn create_bucket<F: Fn(ShardId) -> Config>(
pub fn create_bucket<F: Fn(ShardId, ConfigBuilder) -> Config>(
bucket_id: u64,
concurrency: u64,
total: u64,
config: Config,
per_shard_config: F,
) -> impl Iterator<Item = Shard> {
assert!(bucket_id < total, "bucket id must be less than the total");
Expand All @@ -425,18 +399,19 @@ pub fn create_bucket<F: Fn(ShardId) -> Config>(
);

let concurrency = concurrency.try_into().unwrap();
let tls = TlsContainer::new().unwrap();

(bucket_id..total).step_by(concurrency).map(move |index| {
let id = ShardId::new(index, total);
let mut config = per_shard_config(id);
config.set_tls(tls.clone());
let config = per_shard_config(id, ConfigBuilder::with_config(config.clone()));

Shard::with_config(id, config)
})
}

/// Create a range of shards with provided configuration for each shard.
/// Create a range of shards.
///
/// Passing a primary config is required. Further customization of this config
/// may be performed in the callback.
///
/// # Examples
///
Expand All @@ -450,16 +425,8 @@ pub fn create_bucket<F: Fn(ShardId) -> Config>(
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let token = env::var("DISCORD_TOKEN")?;
///
/// let queue = Arc::new(LocalQueue::new());
/// // callback to create a config for each shard, useful for when not all shards
/// // have the same configuration, such as for per-shard presences
/// let config_callback = |_| {
/// Config::builder(token.clone(), Intents::GUILDS)
/// .queue(queue.clone())
/// .build()
/// };
///
/// let shards = stream::create_range(0..10, 10, config_callback)
/// let config = Config::new(token.clone(), Intents::GUILDS);
/// let shards = stream::create_range(0..10, 10, config, |_, builder| builder.build())
/// .map(|shard| (shard.id().number(), shard))
/// .collect::<HashMap<_, _>>();
///
Expand All @@ -474,25 +441,26 @@ pub fn create_bucket<F: Fn(ShardId) -> Config>(
///
/// Panics if loading TLS certificates fails.
#[track_caller]
pub fn create_range<F: Fn(ShardId) -> Config>(
pub fn create_range<F: Fn(ShardId, ConfigBuilder) -> Config>(
range: impl RangeBounds<u64>,
total: u64,
config: Config,
per_shard_config: F,
) -> impl Iterator<Item = Shard> {
let range = calculate_range(range, total);
let tls = TlsContainer::new().unwrap();

range.map(move |index| {
let id = ShardId::new(index, total);
let mut config = per_shard_config(id);
config.set_tls(tls.clone());
let config = per_shard_config(id, ConfigBuilder::with_config(config.clone()));

Shard::with_config(id, config)
})
}

/// Create a range of shards from Discord's recommendation with configuration
/// for each shard.
/// Create a range of shards from Discord's recommendation.
///
/// Passing a primary config is required. Further customization of this config
/// may be performed in the callback.
///
/// Internally calls [`create_range`] with the values from [`GetGatewayAuthed`].
///
Expand All @@ -510,8 +478,9 @@ pub fn create_range<F: Fn(ShardId) -> Config>(
///
/// [`GetGatewayAuthed`]: twilight_http::request::GetGatewayAuthed
#[cfg(feature = "twilight-http")]
pub async fn create_recommended<F: Fn(ShardId) -> Config>(
pub async fn create_recommended<F: Fn(ShardId, ConfigBuilder) -> Config>(
client: &Client,
config: Config,
per_shard_config: F,
) -> Result<impl Iterator<Item = Shard>, StartRecommendedError> {
let request = client.gateway().authed();
Expand All @@ -527,7 +496,7 @@ pub async fn create_recommended<F: Fn(ShardId) -> Config>(
source: Some(Box::new(source)),
})?;

Ok(create_range(.., info.shards, per_shard_config))
Ok(create_range(.., info.shards, config, per_shard_config))
}

/// Transform any range into a sized range based on the total.
Expand Down