From 0ea221af9fd5b3155645f58b98cf09a68bef1eae Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 16 Mar 2024 15:54:58 -0700 Subject: [PATCH] refactor reqwest client suite w/ conf items. Signed-off-by: Jason Volk --- src/api/appservice_server.rs | 2 +- src/api/client_server/media.rs | 6 +- src/api/server_server.rs | 5 +- src/config/mod.rs | 70 +++++++++++++ src/database/mod.rs | 2 +- src/service/globals/mod.rs | 177 +++++++++++++++++---------------- src/service/pusher/mod.rs | 2 +- 7 files changed, 168 insertions(+), 96 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 3d503a57a..217acf328 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -47,7 +47,7 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); let url = reqwest_request.url().clone(); - let mut response = match services().globals.default_client().execute(reqwest_request).await { + let mut response = match services().globals.client.appservice.execute(reqwest_request).await { Ok(r) => r, Err(e) => { warn!( diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index bb98814bc..a36be69a0 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -707,7 +707,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool { } async fn request_url_preview(url: &str) -> Result { - let client = services().globals.url_preview_client(); + let client = &services().globals.client.url_preview; let response = client.head(url).send().await?; if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) { @@ -722,8 +722,8 @@ async fn request_url_preview(url: &str) -> Result { None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")), }; let data = match content_type { - html if html.starts_with("text/html") => download_html(&client, url).await?, - img if img.starts_with("image/") => download_image(&client, url).await?, + html if html.starts_with("text/html") => download_html(client, url).await?, + img if img.starts_with("image/") => download_image(client, url).await?, _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), }; diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 471d4e8ca..928713031 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -239,7 +239,7 @@ where let url = reqwest_request.url().clone(); debug!("Sending request to {destination} at {url}"); - let response = services().globals.federation_client().execute(reqwest_request).await; + let response = services().globals.client.federation.execute(reqwest_request).await; debug!("Received response from {destination} at {url}"); match response { @@ -517,7 +517,8 @@ async fn request_well_known(destination: &str) -> Option { let response = services() .globals - .default_client() + .client + .well_known .get(&format!("https://{destination}/.well-known/matrix/server")) .send() .await; diff --git a/src/config/mod.rs b/src/config/mod.rs index e56187661..b922b4832 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -63,6 +63,34 @@ pub struct Config { pub max_concurrent_requests: u16, #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, + #[serde(default = "default_request_conn_timeout")] + pub request_conn_timeout: u64, + #[serde(default = "default_request_timeout")] + pub request_timeout: u64, + #[serde(default = "default_request_idle_per_host")] + pub request_idle_per_host: u16, + #[serde(default = "default_request_idle_timeout")] + pub request_idle_timeout: u64, + #[serde(default = "default_well_known_conn_timeout")] + pub well_known_conn_timeout: u64, + #[serde(default = "default_well_known_timeout")] + pub well_known_timeout: u64, + #[serde(default = "default_federation_timeout")] + pub federation_timeout: u64, + #[serde(default = "default_federation_idle_per_host")] + pub federation_idle_per_host: u16, + #[serde(default = "default_federation_idle_timeout")] + pub federation_idle_timeout: u64, + #[serde(default = "default_sender_timeout")] + pub sender_timeout: u64, + #[serde(default = "default_sender_idle_timeout")] + pub sender_idle_timeout: u64, + #[serde(default = "default_appservice_timeout")] + pub appservice_timeout: u64, + #[serde(default = "default_appservice_idle_timeout")] + pub appservice_idle_timeout: u64, + #[serde(default = "default_pusher_idle_timeout")] + pub pusher_idle_timeout: u64, #[serde(default)] pub allow_registration: bool, #[serde(default)] @@ -272,6 +300,20 @@ impl fmt::Display for Config { ("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()), ("Maximum request size (bytes)", &self.max_request_size.to_string()), ("Maximum concurrent requests", &self.max_concurrent_requests.to_string()), + ("Request connect timeout", &self.request_conn_timeout.to_string()), + ("Request timeout", &self.request_timeout.to_string()), + ("Idle connections per host", &self.request_idle_per_host.to_string()), + ("Request pool idle timeout", &self.request_idle_timeout.to_string()), + ("Well_known connect timeout", &self.well_known_conn_timeout.to_string()), + ("Well_known timeout", &self.well_known_timeout.to_string()), + ("Federation timeout", &self.federation_timeout.to_string()), + ("Federation pool idle per host", &self.federation_idle_per_host.to_string()), + ("Federation pool idle timeout", &self.federation_idle_timeout.to_string()), + ("Sender timeout", &self.sender_timeout.to_string()), + ("Sender pool idle timeout", &self.sender_idle_timeout.to_string()), + ("Appservice timeout", &self.appservice_timeout.to_string()), + ("Appservice pool idle timeout", &self.appservice_idle_timeout.to_string()), + ("Pusher pool idle timeout", &self.pusher_idle_timeout.to_string()), ("Allow registration", &self.allow_registration.to_string()), ( "Registration token", @@ -487,6 +529,34 @@ fn default_max_request_size() -> u32 { fn default_max_concurrent_requests() -> u16 { 500 } +fn default_request_conn_timeout() -> u64 { 10 } + +fn default_request_timeout() -> u64 { 35 } + +fn default_request_idle_per_host() -> u16 { 1 } + +fn default_request_idle_timeout() -> u64 { 5 } + +fn default_well_known_conn_timeout() -> u64 { 6 } + +fn default_well_known_timeout() -> u64 { 10 } + +fn default_federation_timeout() -> u64 { 300 } + +fn default_federation_idle_per_host() -> u16 { 1 } + +fn default_federation_idle_timeout() -> u64 { 25 } + +fn default_sender_timeout() -> u64 { 75 } + +fn default_sender_idle_timeout() -> u64 { 50 } + +fn default_appservice_timeout() -> u64 { 120 } + +fn default_appservice_idle_timeout() -> u64 { 300 } + +fn default_pusher_idle_timeout() -> u64 { 15 } + fn default_max_fetch_prev_events() -> u16 { 100_u16 } fn default_trusted_servers() -> Vec { vec![OwnedServerName::try_from("matrix.org").unwrap()] } diff --git a/src/database/mod.rs b/src/database/mod.rs index 55d4ce9c6..0822fb80d 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1058,7 +1058,7 @@ impl KeyValueDatabase { async fn try_handle_updates() -> Result<()> { let response = - services().globals.default_client().get("https://pupbrain.dev/check-for-updates/stable").send().await?; + services().globals.client.default.get("https://pupbrain.dev/check-for-updates/stable").send().await?; let response = serde_json::from_str::(&response.text().await?).map_err(|e| { error!("Bad check for updates response: {e}"); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 5fe1fd20f..092179a0a 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -23,7 +23,10 @@ use hyper::{ service::Service as HyperService, }; use regex::RegexSet; -use reqwest::dns::{Addrs, Resolve, Resolving}; +use reqwest::{ + dns::{Addrs, Resolve, Resolving}, + redirect, +}; use ruma::{ api::{ client::sync::sync_events, @@ -57,9 +60,7 @@ pub struct Service<'a> { keypair: Arc, dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option, - url_preview_client: reqwest::Client, - federation_client: reqwest::Client, - default_client: reqwest::Client, + pub client: Client, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, pub bad_event_ratelimiter: Arc>>, @@ -78,6 +79,16 @@ pub struct Service<'a> { pub argon: Argon2<'a>, } +pub struct Client { + pub default: reqwest::Client, + pub url_preview: reqwest::Client, + pub well_known: reqwest::Client, + pub federation: reqwest::Client, + pub sender: reqwest::Client, + pub appservice: reqwest::Client, + pub pusher: reqwest::Client, +} + /// Handles "rotation" of long-polling requests. "Rotation" in this context is /// similar to "rotation" of log files and the like. /// @@ -145,6 +156,77 @@ impl Resolve for Resolver { } } +impl Client { + pub fn new(config: &Config, tls_name_override: &Arc>) -> Client { + let resolver = Arc::new(Resolver::new(tls_name_override.clone())); + Client { + default: Self::base(config).unwrap().build().unwrap(), + + url_preview: Self::base(config).unwrap().build().unwrap(), + + well_known: Self::base(config).unwrap() + .dns_resolver(resolver.clone()) + .connect_timeout(Duration::from_secs(config.well_known_conn_timeout)) + .timeout(Duration::from_secs(config.well_known_timeout)) + .pool_max_idle_per_host(0) + .redirect(redirect::Policy::limited(4)) + .build() + .unwrap(), + + federation: Self::base(config).unwrap() + .dns_resolver(resolver.clone()) + .timeout(Duration::from_secs(config.federation_timeout)) + .pool_max_idle_per_host(config.federation_idle_per_host.into()) + .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + sender: Self::base(config).unwrap() + .dns_resolver(resolver) + .timeout(Duration::from_secs(config.sender_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.sender_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + appservice: Self::base(config).unwrap() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(config.appservice_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.appservice_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + pusher: Self::base(config).unwrap() + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + } + } + + fn base(config: &Config) -> Result { + let builder = reqwest::Client::builder() + .hickory_dns(true) + .timeout(Duration::from_secs(config.request_timeout)) + .connect_timeout(Duration::from_secs(config.request_conn_timeout)) + .pool_max_idle_per_host(config.request_idle_per_host.into()) + .pool_idle_timeout(Duration::from_secs(config.request_idle_timeout)) + .user_agent("Conduwuit".to_owned() + "/" + env!("CARGO_PKG_VERSION")) + .redirect(redirect::Policy::limited(6)); + + if let Some(proxy) = config.proxy.to_proxy()? { + Ok(builder.proxy(proxy)) + } else { + Ok(builder) + } + } +} + impl Service<'_> { pub fn load(db: &'static dyn Data, config: Config) -> Result { let keypair = db.load_keypair(); @@ -163,12 +245,6 @@ impl Service<'_> { let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); - let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?; - let default_client = reqwest_client_builder(&config)?.build()?; - let federation_client = reqwest_client_builder(&config)? - .dns_resolver(Arc::new(Resolver::new(tls_name_override.clone()))) - .build()?; - // Supported and stable room versions let stable_room_versions = vec![ RoomVersionId::V6, @@ -193,17 +269,15 @@ impl Service<'_> { ); let mut s = Self { db, - config, + config: config.clone(), keypair: Arc::new(keypair), dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { error!("Failed to set up trust dns resolver with system config: {}", e); Error::bad_config("Failed to set up trust dns resolver with system config.") })?, actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), - tls_name_override, - url_preview_client, - federation_client, - default_client, + tls_name_override: tls_name_override.clone(), + client: Client::new(&config, &tls_name_override), jwt_decoding_key, stable_room_versions, unstable_room_versions, @@ -235,26 +309,6 @@ impl Service<'_> { /// Returns this server's keypair. pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } - /// Returns a reqwest client which can be used to send requests for URL - /// previews This is the same as `default_client()` except a redirect policy - /// of max 2 is set - pub fn url_preview_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.url_preview_client.clone() - } - - /// Returns a reqwest client which can be used to send requests - pub fn default_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.default_client.clone() - } - - /// Returns a client used for resolving .well-knowns - pub fn federation_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.federation_client.clone() - } - #[tracing::instrument(skip(self))] pub fn next_count(&self) -> Result { self.db.next_count() } @@ -488,56 +542,3 @@ impl Service<'_> { services().globals.rotate.fire(); } } - -fn reqwest_client_builder(config: &Config) -> Result { - let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { - if attempt.previous().len() > 6 { - attempt.error("Too many redirects (max is 6)") - } else { - attempt.follow() - } - }); - - let mut reqwest_client_builder = reqwest::Client::builder() - .hickory_dns(true) - .pool_max_idle_per_host(1) - .pool_idle_timeout(Duration::from_secs(50)) - .connect_timeout(Duration::from_secs(60)) - .timeout(Duration::from_secs(60 * 5)) - .redirect(redirect_policy) - .user_agent("Conduwuit".to_owned() + "/" + env!("CARGO_PKG_VERSION")); - - if let Some(proxy) = config.proxy.to_proxy()? { - reqwest_client_builder = reqwest_client_builder.proxy(proxy); - } - - Ok(reqwest_client_builder) -} - -fn url_preview_reqwest_client_builder(config: &Config) -> Result { - // for security reasons (e.g. malicious open redirect), we do not want to follow - // too many redirects when generating URL previews. let's keep it at least 2 to - // account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider - // raising it to 3. - let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { - if attempt.previous().len() > 2 { - attempt.error("Too many redirects (max is 2)") - } else { - attempt.follow() - } - }); - - let mut reqwest_client_builder = reqwest::Client::builder() - .hickory_dns(true) - .pool_max_idle_per_host(0) - .connect_timeout(Duration::from_secs(20)) - .timeout(Duration::from_secs(30)) - .redirect(redirect_policy) - .user_agent("Conduwuit".to_owned() + "/" + env!("CARGO_PKG_VERSION")); - - if let Some(proxy) = config.proxy.to_proxy()? { - reqwest_client_builder = reqwest_client_builder.proxy(proxy); - } - - Ok(reqwest_client_builder) -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 13abe52fd..47d5572fb 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -63,7 +63,7 @@ impl Service { //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); let url = reqwest_request.url().clone(); - let response = services().globals.default_client().execute(reqwest_request).await; + let response = services().globals.client.pusher.execute(reqwest_request).await; match response { Ok(mut response) => {