Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ framework = ["client", "model", "utils"]
# Enables gateway support, which allows bots to listen for Discord events.
gateway = ["flate2"]
# Enables HTTP, which enables bots to execute actions on Discord.
http = ["mime_guess", "percent-encoding"]
http = ["dashmap", "mime_guess", "percent-encoding"]
# Enables wrapper methods around HTTP requests on model types.
# Requires "builder" to configure the requests and "http" to execute them.
# Note: the model type definitions themselves are always active, regardless of this feature.
Expand Down
97 changes: 54 additions & 43 deletions src/http/ratelimiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@
//!
//! [Taken from]: https://discord.com/developers/docs/topics/rate-limits#rate-limits

use std::collections::HashMap;
use std::borrow::Cow;
use std::fmt;
use std::str::{self, FromStr};
use std::sync::Arc;
use std::time::SystemTime;

use dashmap::DashMap;
use reqwest::header::HeaderMap;
use reqwest::{Client, Response, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::Mutex;
use tokio::time::{sleep, Duration};
use tracing::debug;

Expand All @@ -60,7 +60,7 @@ pub struct RatelimitInfo {
pub timeout: std::time::Duration,
pub limit: i64,
pub method: LightMethod,
pub path: String,
pub path: Cow<'static, str>,
pub global: bool,
}

Expand All @@ -83,10 +83,8 @@ pub struct RatelimitInfo {
/// [`reset`]: Ratelimit::reset
pub struct Ratelimiter {
client: Client,
global: Arc<Mutex<()>>,
// When futures is implemented, make tasks clear out their respective entry when the 'reset'
// passes.
routes: Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>>,
global: Mutex<()>,
routes: DashMap<RatelimitingBucket, Ratelimit>,
token: SecretString,
absolute_ratelimits: bool,
ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
Expand Down Expand Up @@ -117,8 +115,8 @@ impl Ratelimiter {
fn _new(client: Client, token: String) -> Self {
Self {
client,
global: Arc::default(),
routes: Arc::default(),
global: Mutex::default(),
routes: DashMap::new(),
token: SecretString::new(token),
ratelimit_callback: Box::new(|_| {}),
absolute_ratelimits: false,
Expand Down Expand Up @@ -156,23 +154,22 @@ impl Ratelimiter {
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// # let http: Http = unimplemented!();
/// let routes = http.ratelimiter.unwrap().routes();
/// let reader = routes.read().await;
///
/// let channel_id = ChannelId::new(7);
/// let route = Route::Channel {
/// channel_id,
/// };
/// if let Some(route) = reader.get(&route.ratelimiting_bucket()) {
/// if let Some(reset) = route.lock().await.reset() {
/// if let Some(route) = routes.get(&route.ratelimiting_bucket()) {
/// if let Some(reset) = route.reset() {
/// println!("Reset time at: {:?}", reset);
/// }
/// }
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn routes(&self) -> Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>> {
Arc::clone(&self.routes)
pub fn routes(&self) -> &DashMap<RatelimitingBucket, Ratelimit> {
&self.routes
}

/// # Errors
Expand All @@ -191,10 +188,14 @@ impl Ratelimiter {
// - sleep if there is 0 remaining
// - then, perform the request
let ratelimiting_bucket = req.route.ratelimiting_bucket();
let bucket =
Arc::clone(self.routes.write().await.entry(ratelimiting_bucket).or_default());
let delay_time = {
let mut bucket = self.routes.entry(ratelimiting_bucket).or_default();
bucket.pre_hook(&req, &self.ratelimit_callback)
};

bucket.lock().await.pre_hook(&req, &self.ratelimit_callback).await;
if let Some(delay_time) = delay_time {
sleep(delay_time).await;
}

let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
let response = self.client.execute(request.build()?).await?;
Expand Down Expand Up @@ -230,7 +231,7 @@ impl Ratelimiter {
timeout: Duration::from_secs_f64(retry_after),
limit: 50,
method: req.method,
path: req.route.path().to_string(),
path: req.route.path(),
global: true,
});
sleep(Duration::from_secs_f64(retry_after)).await;
Expand All @@ -241,11 +242,23 @@ impl Ratelimiter {
},
)
} else {
bucket
.lock()
.await
.post_hook(&response, &req, &self.ratelimit_callback, self.absolute_ratelimits)
.await
let delay_time = if let Some(mut bucket) = self.routes.get_mut(&ratelimiting_bucket)
{
bucket.post_hook(
&response,
&req,
&self.ratelimit_callback,
self.absolute_ratelimits,
)
} else {
Ok(None)
};

if let Ok(Some(delay_time)) = delay_time {
sleep(delay_time).await;
};

delay_time.map(|d| d.is_some())
};

if !redo.unwrap_or(true) {
Expand Down Expand Up @@ -277,28 +290,29 @@ pub struct Ratelimit {
}

impl Ratelimit {
#[must_use]
#[cfg_attr(feature = "tracing_instrument", instrument(skip(ratelimit_callback)))]
pub async fn pre_hook(
pub fn pre_hook(
&mut self,
req: &Request<'_>,
ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
) {
) -> Option<std::time::Duration> {
if self.limit() == 0 {
return;
return None;
}

let Some(reset) = self.reset else {
// We're probably in the past.
self.remaining = self.limit;
return;
return None;
};

let Ok(delay) = reset.duration_since(SystemTime::now()) else {
// if duration is negative (i.e. adequate time has passed since last call to this api)
if self.remaining() != 0 {
self.remaining -= 1;
}
return;
return None;
};

if self.remaining() == 0 {
Expand All @@ -311,29 +325,28 @@ impl Ratelimit {
timeout: delay,
limit: self.limit,
method: req.method,
path: req.route.path().to_string(),
path: req.route.path(),
global: false,
});

sleep(delay).await;

return;
Some(delay)
} else {
self.remaining -= 1;
None
}

self.remaining -= 1;
}

/// # Errors
///
/// Errors if unable to parse response headers.
#[cfg_attr(feature = "tracing_instrument", instrument(skip(ratelimit_callback)))]
pub async fn post_hook(
pub fn post_hook(
&mut self,
response: &Response,
req: &Request<'_>,
ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
absolute_ratelimits: bool,
) -> Result<bool> {
) -> Result<Option<Duration>> {
if let Some(limit) = parse_header(response.headers(), "x-ratelimit-limit")? {
self.limit = limit;
}
Expand All @@ -359,7 +372,7 @@ impl Ratelimit {
}

Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS {
false
None
} else if let Some(retry_after) = parse_header::<f64>(response.headers(), "retry-after")? {
debug!(
"Ratelimited on route {:?} for {:?}s",
Expand All @@ -370,15 +383,13 @@ impl Ratelimit {
timeout: Duration::from_secs_f64(retry_after),
limit: self.limit,
method: req.method,
path: req.route.path().to_string(),
path: req.route.path(),
global: false,
});

sleep(Duration::from_secs_f64(retry_after)).await;

true
Some(Duration::from_secs_f64(retry_after))
} else {
false
None
})
}

Expand Down