diff --git a/Cargo.lock b/Cargo.lock index de6eaff0d4..53bdcf6024 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1902,6 +1902,26 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap 5.5.3", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.5", + "smallvec", + "spinning_top", +] + [[package]] name = "group" version = "0.13.0" @@ -2803,14 +2823,18 @@ dependencies = [ "cached", "check-if-email-exists", "cookie_store", + "dashmap 6.1.0", "doc-comment", "email_address", "futures", "glob", + "governor", "headers", "html5ever", "html5gum", "http 1.4.0", + "httpdate", + "humantime-serde", "hyper 1.8.1", "ignore", "ip_network", @@ -2969,6 +2993,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -2979,6 +3009,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -3651,6 +3687,21 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.9" @@ -3780,6 +3831,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.10.0", +] + [[package]] name = "rayon" version = "1.11.0" @@ -4566,6 +4626,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" diff --git a/README.md b/README.md index bda7ed5a4f..2ca3f215cf 100644 --- a/README.md +++ b/README.md @@ -423,7 +423,12 @@ Options: --default-extension This is the default file extension that is applied to files without an extension. - This is useful for files without extensions or with unknown extensions. The extension will be used to determine the file type for processing. Examples: --default-extension md, --default-extension html + This is useful for files without extensions or with unknown extensions. + The extension will be used to determine the file type for processing. + + Examples: + --default-extension md + --default-extension html --dump Don't perform any link checking. Instead, dump all the links extracted from inputs that would be checked @@ -519,10 +524,39 @@ Options: You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'. This is the same format that other tools like curl or wget use. Multiple headers can be specified by using the flag multiple times. + The specified headers are used for ALL requests. + Use the `hosts` option to configure headers on a per-host basis. --hidden Do not skip hidden directories and files + --host-concurrency + Default maximum concurrent requests per host (default: 10) + + This limits the maximum amount of requests that are sent simultaneously + to the same host. This helps to prevent overwhelming servers and + running into rate-limits. Use the `hosts` option to configure this + on a per-host basis. + + Examples: + --host-concurrency 2 # Conservative for slow APIs + --host-concurrency 20 # Aggressive for fast APIs + + --host-request-interval + Minimum interval between requests to the same host (default: 50ms) + + Sets a baseline delay between consecutive requests to prevent + overloading servers. The adaptive algorithm may increase this based + on server responses (rate limits, errors). Use the `hosts` option + to configure this on a per-host basis. + + Examples: + --host-request-interval 50ms # Fast for robust APIs + --host-request-interval 1s # Conservative for rate-limited APIs + + --host-stats + Show per-host statistics at the end of the run + -i, --insecure Proceed for server connections considered insecure (invalid TLS) diff --git a/fixtures/configs/headers.toml b/fixtures/configs/headers.toml index 2873301f65..d4e6b7107a 100644 --- a/fixtures/configs/headers.toml +++ b/fixtures/configs/headers.toml @@ -4,3 +4,7 @@ X-Bar = "Baz" # Alternative TOML syntax: # header = { X-Foo = "Bar", X-Bar = "Baz" } + + +[hosts."127.0.0.1"] +headers = { "X-Host-Specific" = "Foo" } diff --git a/lychee-bin/src/client.rs b/lychee-bin/src/client.rs index 4c99f6fe7c..3efefffe6f 100644 --- a/lychee-bin/src/client.rs +++ b/lychee-bin/src/client.rs @@ -2,7 +2,7 @@ use crate::options::{Config, HeaderMapExt}; use crate::parse::{parse_duration_secs, parse_remaps}; use anyhow::{Context, Result}; use http::{HeaderMap, StatusCode}; -use lychee_lib::{Client, ClientBuilder}; +use lychee_lib::{Client, ClientBuilder, ratelimit::RateLimitConfig}; use regex::RegexSet; use reqwest_cookie_store::CookieStoreMutex; use std::sync::Arc; @@ -55,6 +55,11 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc>) - .include_fragments(cfg.include_fragments) .fallback_extensions(cfg.fallback_extensions.clone()) .index_files(cfg.index_files.clone()) + .rate_limit_config(RateLimitConfig::from_options( + cfg.host_concurrency, + cfg.host_request_interval, + )) + .hosts(cfg.hosts.clone()) .build() .client() .context("Failed to create request client") diff --git a/lychee-bin/src/commands/check.rs b/lychee-bin/src/commands/check.rs index 62a908998d..d0209a7c31 100644 --- a/lychee-bin/src/commands/check.rs +++ b/lychee-bin/src/commands/check.rs @@ -1,8 +1,10 @@ use std::collections::HashSet; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use std::sync::Mutex; use std::time::Duration; use futures::StreamExt; +use lychee_lib::ratelimit::HostPool; use reqwest::Url; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -24,7 +26,7 @@ use super::CommandParams; pub(crate) async fn check( params: CommandParams, -) -> Result<(ResponseStats, Arc, ExitCode), ErrorKind> +) -> Result<(ResponseStats, Cache, ExitCode, Arc), ErrorKind> where S: futures::Stream>, { @@ -41,7 +43,6 @@ where } else { ResponseStats::default() }; - let cache_ref = params.cache.clone(); let client = params.client; let cache = params.cache; @@ -53,7 +54,7 @@ where let accept = params.cfg.accept.into(); // Start receiving requests - tokio::spawn(request_channel_task( + let handle = tokio::spawn(request_channel_task( recv_req, send_resp, max_concurrency, @@ -74,8 +75,9 @@ where stats, )); - // Wait until all messages are sent - send_inputs_loop(params.requests, send_req, &progress).await?; + // Wait until all requests are sent + send_requests(params.requests, send_req, &progress).await?; + let (cache, client) = handle.await?; // Wait until all responses are received let result = show_results_task.await?; @@ -103,7 +105,8 @@ where } else { ExitCode::LinkCheckFailure }; - Ok((stats, cache_ref, code)) + + Ok((stats, cache, code, client.host_pool())) } async fn suggest_archived_links( @@ -143,7 +146,7 @@ async fn suggest_archived_links( // drops the `send_req` channel on exit // required for the receiver task to end, which closes send_resp, which allows // the show_results_task to finish -async fn send_inputs_loop( +async fn send_requests( requests: S, send_req: mpsc::Sender>, progress: &Progress, @@ -180,17 +183,17 @@ async fn request_channel_task( send_resp: mpsc::Sender>, max_concurrency: usize, client: Client, - cache: Arc, + cache: Cache, cache_exclude_status: HashSet, accept: HashSet, -) { +) -> (Cache, Client) { StreamExt::for_each_concurrent( ReceiverStream::new(recv_req), max_concurrency, |request: Result| async { let response = handle( &client, - cache.clone(), + &cache, cache_exclude_status.clone(), request, accept.clone(), @@ -204,6 +207,8 @@ async fn request_channel_task( }, ) .await; + + (cache, client) } /// Check a URL and return a response. @@ -235,7 +240,7 @@ async fn check_url(client: &Client, request: Request) -> Response { /// a failed response. async fn handle( client: &Client, - cache: Arc, + cache: &Cache, cache_exclude_status: HashSet, request: Result, accept: HashSet, @@ -247,6 +252,8 @@ async fn handle( }; let uri = request.uri.clone(); + + // First check the persistent disk-based cache if let Some(v) = cache.get(&uri) { // Found a cached request // Overwrite cache status in case the URI is excluded in the @@ -260,16 +267,28 @@ async fn handle( // code. Status::from_cache_status(v.value().status, &accept) }; + + // Track cache hit in the per-host stats (only for network URIs) + if !uri.is_file() + && let Err(e) = client.host_pool().record_cache_hit(&uri) + { + log::debug!("Failed to record cache hit for {uri}: {e}"); + } + return Ok(Response::new(uri.clone(), status, request.source.into())); } - // Request was not cached; run a normal check + // Cache miss - track it and run a normal check (only for network URIs) + if !uri.is_file() + && let Err(e) = client.host_pool().record_cache_miss(&uri) + { + log::debug!("Failed to record cache miss for {uri}: {e}"); + } + let response = check_url(client, request).await; - // - Never cache filesystem access as it is fast already so caching has no - // benefit. - // - Skip caching unsupported URLs as they might be supported in a - // future run. + // - Never cache filesystem access as it is fast already so caching has no benefit. + // - Skip caching unsupported URLs as they might be supported in a future run. // - Skip caching excluded links; they might not be excluded in the next run. // - Skip caching links for which the status code has been explicitly excluded from the cache. let status = response.status(); diff --git a/lychee-bin/src/commands/mod.rs b/lychee-bin/src/commands/mod.rs index 5b2c6f62db..38892e98e3 100644 --- a/lychee-bin/src/commands/mod.rs +++ b/lychee-bin/src/commands/mod.rs @@ -10,7 +10,6 @@ pub(crate) use dump_inputs::dump_inputs; use std::fs; use std::io::{self, Write}; use std::path::PathBuf; -use std::sync::Arc; use crate::cache::Cache; use crate::options::Config; @@ -20,7 +19,7 @@ use lychee_lib::{Client, Request}; /// Parameters passed to every command pub(crate) struct CommandParams>> { pub(crate) client: Client, - pub(crate) cache: Arc, + pub(crate) cache: Cache, pub(crate) requests: S, pub(crate) cfg: Config, } diff --git a/lychee-bin/src/formatters/host_stats/compact.rs b/lychee-bin/src/formatters/host_stats/compact.rs new file mode 100644 index 0000000000..121230e259 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/compact.rs @@ -0,0 +1,81 @@ +use anyhow::Result; +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use crate::formatters::color::{DIM, NORMAL, color}; +use lychee_lib::ratelimit::HostStats; + +use super::HostStatsFormatter; + +struct CompactHostStats { + host_stats: HashMap, +} + +impl Display for CompactHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.host_stats.is_empty() { + return Ok(()); + } + + writeln!(f)?; + writeln!(f, "šŸ“Š Per-host Statistics")?; + + let separator = "─".repeat(60); + color!(f, DIM, "{}", separator)?; + writeln!(f)?; + + let sorted_hosts = super::sort_host_stats(&self.host_stats); + + // Calculate optimal hostname width based on longest hostname + let max_hostname_len = sorted_hosts + .iter() + .map(|(hostname, _)| hostname.len()) + .max() + .unwrap_or(0); + let hostname_width = (max_hostname_len + 2).max(10); // At least 10 chars with padding + + for (hostname, stats) in sorted_hosts { + let median_time = stats + .median_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + let cache_hit_rate = stats.cache_hit_rate() * 100.0; + + color!( + f, + NORMAL, + "{:6} reqs │ {:>6.1}% success │ {:>8} median │ {:>6.1}% cached", + hostname, + stats.total_requests, + stats.success_rate() * 100.0, + median_time, + cache_hit_rate, + width = hostname_width + )?; + writeln!(f)?; + } + + Ok(()) + } +} + +pub(crate) struct Compact; + +impl Compact { + pub(crate) const fn new() -> Self { + Self + } +} + +impl HostStatsFormatter for Compact { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let compact = CompactHostStats { host_stats }; + Ok(Some(compact.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/detailed.rs b/lychee-bin/src/formatters/host_stats/detailed.rs new file mode 100644 index 0000000000..01bfd42bc8 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/detailed.rs @@ -0,0 +1,90 @@ +use anyhow::Result; +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use lychee_lib::ratelimit::HostStats; + +use super::HostStatsFormatter; + +struct DetailedHostStats { + host_stats: HashMap, +} + +impl Display for DetailedHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.host_stats.is_empty() { + return Ok(()); + } + + writeln!(f, "\nšŸ“Š Per-host Statistics")?; + writeln!(f, "---------------------")?; + + let sorted_hosts = super::sort_host_stats(&self.host_stats); + + for (hostname, stats) in sorted_hosts { + writeln!(f, "\nHost: {hostname}")?; + writeln!(f, " Total requests: {}", stats.total_requests)?; + writeln!( + f, + " Successful: {} ({:.1}%)", + stats.successful_requests, + stats.success_rate() * 100.0 + )?; + + if stats.rate_limited > 0 { + writeln!( + f, + " Rate limited: {} (429 Too Many Requests)", + stats.rate_limited + )?; + } + if stats.client_errors > 0 { + writeln!(f, " Client errors (4xx): {}", stats.client_errors)?; + } + if stats.server_errors > 0 { + writeln!(f, " Server errors (5xx): {}", stats.server_errors)?; + } + + if let Some(median_time) = stats.median_request_time() { + writeln!( + f, + " Median response time: {:.0}ms", + median_time.as_millis() + )?; + } + + let cache_hit_rate = stats.cache_hit_rate(); + if cache_hit_rate > 0.0 { + writeln!(f, " Cache hit rate: {:.1}%", cache_hit_rate * 100.0)?; + writeln!( + f, + " Cache hits: {}, misses: {}", + stats.cache_hits, stats.cache_misses + )?; + } + } + + Ok(()) + } +} + +pub(crate) struct Detailed; + +impl Detailed { + pub(crate) const fn new() -> Self { + Self + } +} + +impl HostStatsFormatter for Detailed { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let detailed = DetailedHostStats { host_stats }; + Ok(Some(detailed.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/json.rs b/lychee-bin/src/formatters/host_stats/json.rs new file mode 100644 index 0000000000..24f7fe0d2e --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/json.rs @@ -0,0 +1,57 @@ +use anyhow::{Context, Result}; +use serde_json::json; +use std::collections::HashMap; + +use super::HostStatsFormatter; +use lychee_lib::ratelimit::HostStats; + +pub(crate) struct Json; + +impl Json { + pub(crate) const fn new() -> Self { + Self {} + } +} + +impl HostStatsFormatter for Json { + /// Format host stats as JSON object + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + // Convert HostStats to a more JSON-friendly format + let json_stats: HashMap = host_stats + .into_iter() + .map(|(hostname, stats)| { + let json_value = json!({ + "total_requests": stats.total_requests, + "successful_requests": stats.successful_requests, + "success_rate": stats.success_rate(), + "rate_limited": stats.rate_limited, + "client_errors": stats.client_errors, + "server_errors": stats.server_errors, + "median_request_time_ms": stats.median_request_time() + .map(|d| { + #[allow(clippy::cast_possible_truncation)] + let millis = d.as_millis() as u64; + millis + }), + "cache_hits": stats.cache_hits, + "cache_misses": stats.cache_misses, + "cache_hit_rate": stats.cache_hit_rate(), + "status_codes": stats.status_codes + }); + (hostname, json_value) + }) + .collect(); + + let output = json!({ + "host_statistics": json_stats + }); + + serde_json::to_string_pretty(&output) + .map(Some) + .context("Cannot format host stats as JSON") + } +} diff --git a/lychee-bin/src/formatters/host_stats/markdown.rs b/lychee-bin/src/formatters/host_stats/markdown.rs new file mode 100644 index 0000000000..8980066107 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/markdown.rs @@ -0,0 +1,92 @@ +use std::{ + collections::HashMap, + fmt::{self, Display}, +}; + +use super::HostStatsFormatter; +use anyhow::Result; +use lychee_lib::ratelimit::HostStats; +use tabled::{ + Table, Tabled, + settings::{Alignment, Modify, Style, object::Segment}, +}; + +#[derive(Tabled)] +struct HostStatsTableEntry { + #[tabled(rename = "Host")] + host: String, + #[tabled(rename = "Requests")] + requests: u64, + #[tabled(rename = "Success Rate")] + success_rate: String, + #[tabled(rename = "Median Time")] + median_time: String, + #[tabled(rename = "Cache Hit Rate")] + cache_hit_rate: String, +} + +fn host_stats_table(host_stats: &HashMap) -> String { + let sorted_hosts = super::sort_host_stats(host_stats); + + let entries: Vec = sorted_hosts + .into_iter() + .map(|(hostname, stats)| { + let median_time = stats + .median_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + HostStatsTableEntry { + host: hostname.clone(), + requests: stats.total_requests, + success_rate: format!("{:.1}%", stats.success_rate() * 100.0), + median_time, + cache_hit_rate: format!("{:.1}%", stats.cache_hit_rate() * 100.0), + } + }) + .collect(); + + if entries.is_empty() { + return String::new(); + } + + let style = Style::markdown(); + Table::new(entries) + .with(Modify::new(Segment::all()).with(Alignment::left())) + .with(style) + .to_string() +} + +struct MarkdownHostStats(HashMap); + +impl Display for MarkdownHostStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.0.is_empty() { + return Ok(()); + } + + writeln!(f, "\n## Per-host Statistics")?; + writeln!(f)?; + writeln!(f, "{}", host_stats_table(&self.0))?; + + Ok(()) + } +} + +pub(crate) struct Markdown; + +impl Markdown { + pub(crate) const fn new() -> Self { + Self {} + } +} + +impl HostStatsFormatter for Markdown { + fn format(&self, host_stats: HashMap) -> Result> { + if host_stats.is_empty() { + return Ok(None); + } + + let markdown = MarkdownHostStats(host_stats); + Ok(Some(markdown.to_string())) + } +} diff --git a/lychee-bin/src/formatters/host_stats/mod.rs b/lychee-bin/src/formatters/host_stats/mod.rs new file mode 100644 index 0000000000..8c312bfdd5 --- /dev/null +++ b/lychee-bin/src/formatters/host_stats/mod.rs @@ -0,0 +1,28 @@ +mod compact; +mod detailed; +mod json; +mod markdown; + +pub(crate) use compact::Compact; +pub(crate) use detailed::Detailed; +pub(crate) use json::Json; +pub(crate) use markdown::Markdown; + +use anyhow::Result; +use lychee_lib::ratelimit::HostStats; +use std::collections::HashMap; + +/// Trait for formatting per-host statistics in different output formats +pub(crate) trait HostStatsFormatter { + /// Format the host statistics and return them as a string + fn format(&self, host_stats: HashMap) -> Result>; +} + +/// Sort host statistics by request count (descending order) +/// This matches the display order we want in the output +fn sort_host_stats(host_stats: &HashMap) -> Vec<(&String, &HostStats)> { + let mut sorted_hosts: Vec<_> = host_stats.iter().collect(); + // Sort by total requests (descending) + sorted_hosts.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_requests)); + sorted_hosts +} diff --git a/lychee-bin/src/formatters/mod.rs b/lychee-bin/src/formatters/mod.rs index a7b1a90673..de36c32bb6 100644 --- a/lychee-bin/src/formatters/mod.rs +++ b/lychee-bin/src/formatters/mod.rs @@ -1,11 +1,12 @@ pub(crate) mod color; pub(crate) mod duration; +pub(crate) mod host_stats; pub(crate) mod log; pub(crate) mod response; pub(crate) mod stats; pub(crate) mod suggestion; -use self::{response::ResponseFormatter, stats::StatsFormatter}; +use self::{host_stats::HostStatsFormatter, response::ResponseFormatter, stats::StatsFormatter}; use crate::options::{OutputMode, StatsFormat}; use supports_color::Stream; @@ -41,6 +42,19 @@ pub(crate) fn get_progress_formatter(mode: &OutputMode) -> Box Box { + match format { + StatsFormat::Compact | StatsFormat::Raw => Box::new(host_stats::Compact::new()), // Use compact for raw + StatsFormat::Detailed => Box::new(host_stats::Detailed::new()), + StatsFormat::Json => Box::new(host_stats::Json::new()), + StatsFormat::Markdown => Box::new(host_stats::Markdown::new()), + } +} + /// Create a response formatter based on the given format option pub(crate) fn get_response_formatter(mode: &OutputMode) -> Box { // Checks if color is supported in current environment or NO_COLOR is set (https://no-color.org) diff --git a/lychee-bin/src/formatters/stats/mod.rs b/lychee-bin/src/formatters/stats/mod.rs index dc2d2233d6..8d6cb559e9 100644 --- a/lychee-bin/src/formatters/stats/mod.rs +++ b/lychee-bin/src/formatters/stats/mod.rs @@ -13,10 +13,12 @@ pub(crate) use raw::Raw; use std::{ collections::{HashMap, HashSet}, fmt::Display, + fs, + io::{Write, stdout}, }; -use crate::stats::ResponseStats; -use anyhow::Result; +use crate::{formatters::get_stats_formatter, options::Config, stats::ResponseStats}; +use anyhow::{Context, Result}; use lychee_lib::InputSource; pub(crate) trait StatsFormatter { @@ -24,6 +26,25 @@ pub(crate) trait StatsFormatter { fn format(&self, stats: ResponseStats) -> Result>; } +/// If configured to do so, output response statistics to stdout or the specified output file. +pub(crate) fn output_response_statistics(stats: ResponseStats, config: &Config) -> Result<()> { + let is_empty = stats.is_empty(); + let formatter = get_stats_formatter(&config.format, &config.mode); + if let Some(formatted_stats) = formatter.format(stats)? { + if let Some(output) = &config.output { + fs::write(output, formatted_stats).context("Cannot write status output to file")?; + } else { + if config.verbose.log_level() >= log::Level::Info && !is_empty { + // separate summary from the verbose list of links above with a newline + writeln!(stdout())?; + } + // we assume that the formatted stats don't have a final newline + writeln!(stdout(), "{formatted_stats}")?; + } + } + Ok(()) +} + /// Convert a `ResponseStats` `HashMap` to a sorted Vec of key-value pairs /// The returned keys and values are both sorted in natural, case-insensitive order fn sort_stat_map(stat_map: &HashMap>) -> Vec<(&InputSource, Vec<&T>)> diff --git a/lychee-bin/src/host_stats.rs b/lychee-bin/src/host_stats.rs new file mode 100644 index 0000000000..5ba26d9995 --- /dev/null +++ b/lychee-bin/src/host_stats.rs @@ -0,0 +1,27 @@ +use anyhow::{Context, Result}; +use lychee_lib::ratelimit::HostPool; + +use crate::{formatters::get_host_stats_formatter, options::Config}; + +/// If configured to do so, output per-host statistics to stdout or the specified output file. +pub(crate) fn output_per_host_statistics(host_pool: &HostPool, config: &Config) -> Result<()> { + if !config.host_stats { + return Ok(()); + } + + let host_stats = host_pool.all_host_stats(); + let host_stats_formatter = get_host_stats_formatter(&config.format, &config.mode); + + if let Some(formatted_host_stats) = host_stats_formatter.format(host_stats)? { + if let Some(output) = &config.output { + // For file output, append to the existing output + let mut file_content = std::fs::read_to_string(output).unwrap_or_default(); + file_content.push_str(&formatted_host_stats); + std::fs::write(output, file_content) + .context("Cannot write host stats to output file")?; + } else { + print!("{formatted_host_stats}"); + } + } + Ok(()) +} diff --git a/lychee-bin/src/main.rs b/lychee-bin/src/main.rs index 01a4411c80..1e3d910e27 100644 --- a/lychee-bin/src/main.rs +++ b/lychee-bin/src/main.rs @@ -59,14 +59,13 @@ #![deny(missing_docs)] use std::fs::{self, File}; -use std::io::{self, BufRead, BufReader, ErrorKind, Write}; +use std::io::{self, BufRead, BufReader, ErrorKind}; use std::path::PathBuf; -use std::sync::Arc; use anyhow::{Context, Error, Result, bail}; use clap::{Parser, crate_version}; use commands::{CommandParams, generate}; -use formatters::{get_stats_formatter, log::init_logging}; +use formatters::log::init_logging; use http::HeaderMap; use log::{error, info, warn}; @@ -86,6 +85,7 @@ mod client; mod commands; mod files_from; mod formatters; +mod host_stats; mod options; mod parse; mod progress; @@ -93,10 +93,13 @@ mod stats; mod time; mod verbosity; +use crate::formatters::stats::output_response_statistics; +use crate::stats::ResponseStats; use crate::{ cache::{Cache, StoreExt}, - formatters::{duration::Duration, stats::StatsFormatter}, + formatters::duration::Duration, generate::generate, + host_stats::output_per_host_statistics, options::{Config, LYCHEE_CACHE_FILE, LYCHEE_IGNORE_FILE, LycheeOptions}, }; @@ -368,7 +371,6 @@ async fn run(opts: &LycheeOptions) -> Result { let requests = collector.collect_links_from_file_types(inputs, opts.config.extensions.clone()); let cache = load_cache(&opts.config).unwrap_or_default(); - let cache = Arc::new(cache); let cookie_jar = load_cookie_jar(&opts.config).with_context(|| { format!( @@ -381,7 +383,6 @@ async fn run(opts: &LycheeOptions) -> Result { })?; let client = client::create(&opts.config, cookie_jar.as_deref())?; - let params = CommandParams { client, cache, @@ -392,39 +393,10 @@ async fn run(opts: &LycheeOptions) -> Result { let exit_code = if opts.config.dump { commands::dump(params).await? } else { - let (stats, cache, exit_code) = commands::check(params).await?; - - let github_issues = stats - .error_map - .values() - .flatten() - .any(|body| body.uri.domain() == Some("github.com")); - - let stats_formatter: Box = - get_stats_formatter(&opts.config.format, &opts.config.mode); - - let is_empty = stats.is_empty(); - let formatted_stats = stats_formatter.format(stats)?; - - if let Some(formatted_stats) = formatted_stats { - if let Some(output) = &opts.config.output { - fs::write(output, formatted_stats).context("Cannot write status output to file")?; - } else { - if opts.config.verbose.log_level() >= log::Level::Info && !is_empty { - // separate summary from the verbose list of links above - // with a newline - writeln!(io::stdout())?; - } - // we assume that the formatted stats don't have a final newline - writeln!(io::stdout(), "{formatted_stats}")?; - } - } - - if github_issues && opts.config.github_token.is_none() { - warn!( - "There were issues with GitHub URLs. You could try setting a GitHub token and running lychee again.", - ); - } + let (stats, cache, exit_code, host_pool) = commands::check(params).await?; + github_warning(&stats, &opts.config); + output_response_statistics(stats, &opts.config)?; + output_per_host_statistics(&host_pool, &opts.config)?; if opts.config.cache { cache.store(LYCHEE_CACHE_FILE)?; @@ -440,3 +412,17 @@ async fn run(opts: &LycheeOptions) -> Result { Ok(exit_code as i32) } + +/// Display user-friendly message if there were any issues with GitHub URLs +fn github_warning(stats: &ResponseStats, config: &Config) { + let github_errors = stats + .error_map + .values() + .flatten() + .any(|body| body.uri.domain() == Some("github.com")); + if github_errors && config.github_token.is_none() { + warn!( + "There were issues with GitHub URLs. You could try setting a GitHub token and running lychee again.", + ); + } +} diff --git a/lychee-bin/src/options.rs b/lychee-bin/src/options.rs index b5bb70c142..b9be8fe521 100644 --- a/lychee-bin/src/options.rs +++ b/lychee-bin/src/options.rs @@ -11,6 +11,7 @@ use http::{ header::{HeaderName, HeaderValue}, }; use lychee_lib::Preprocessor; +use lychee_lib::ratelimit::HostConfigs; use lychee_lib::{ Base, BasicAuthSelector, DEFAULT_MAX_REDIRECTS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_WAIT_TIME_SECS, DEFAULT_TIMEOUT_SECS, DEFAULT_USER_AGENT, FileExtensions, @@ -390,7 +391,7 @@ where pub(crate) struct Config { /// Read input filenames from the given file or stdin (if path is '-'). #[arg( - long = "files-from", + long, value_name = "PATH", long_help = "Read input filenames from the given file or stdin (if path is '-'). @@ -421,6 +422,11 @@ File Format: #[serde(default)] pub(crate) no_progress: bool, + /// Show per-host statistics at the end of the run + #[arg(long)] + #[serde(default)] + pub(crate) host_stats: bool, + /// A list of file extensions. Files not matching the specified extensions are skipped. /// /// E.g. a user can specify `--extensions html,htm,php,asp,aspx,jsp,cgi` @@ -444,8 +450,11 @@ specify both extensions explicitly." /// /// This is useful for files without extensions or with unknown extensions. /// The extension will be used to determine the file type for processing. - /// Examples: --default-extension md, --default-extension html - #[arg(long, value_name = "EXTENSION")] + /// + /// Examples: + /// --default-extension md + /// --default-extension html + #[arg(long, value_name = "EXTENSION", verbatim_doc_comment)] #[serde(default)] pub(crate) default_extension: Option, @@ -528,6 +537,34 @@ with a status code of 429, 500 and 501." #[serde(default = "max_concurrency")] pub(crate) max_concurrency: usize, + /// Default maximum concurrent requests per host (default: 10) + /// + /// This limits the maximum amount of requests that are sent simultaneously + /// to the same host. This helps to prevent overwhelming servers and + /// running into rate-limits. Use the `hosts` option to configure this + /// on a per-host basis. + /// + /// Examples: + /// --host-concurrency 2 # Conservative for slow APIs + /// --host-concurrency 20 # Aggressive for fast APIs + #[arg(long, verbatim_doc_comment)] + #[serde(default)] + pub(crate) host_concurrency: Option, + + /// Minimum interval between requests to the same host (default: 50ms) + /// + /// Sets a baseline delay between consecutive requests to prevent + /// overloading servers. The adaptive algorithm may increase this based + /// on server responses (rate limits, errors). Use the `hosts` option + /// to configure this on a per-host basis. + /// + /// Examples: + /// --host-request-interval 50ms # Fast for robust APIs + /// --host-request-interval 1s # Conservative for rate-limited APIs + #[arg(long, value_parser = humantime::parse_duration, verbatim_doc_comment)] + #[serde(default, with = "humantime_serde")] + pub(crate) host_request_interval: Option, + /// Number of threads to utilize. /// Defaults to number of cores available to the system #[arg(short = 'T', long)] @@ -664,7 +701,7 @@ Note: This option only takes effect on `file://` URIs which exist and point to a /// Set custom header for requests #[arg( short = 'H', - long = "header", + long, // Note: We use a `Vec<(String, String)>` for headers, which is // unfortunate. The reason is that `clap::ArgAction::Append` collects // multiple values, and `clap` cannot automatically convert these tuples @@ -677,7 +714,9 @@ Note: This option only takes effect on `file://` URIs which exist and point to a Some websites require custom headers to be passed in order to return valid responses. You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'. This is the same format that other tools like curl or wget use. -Multiple headers can be specified by using the flag multiple times." +Multiple headers can be specified by using the flag multiple times. +The specified headers are used for ALL requests. +Use the `hosts` option to configure headers on a per-host basis." )] #[serde(default)] #[serde(deserialize_with = "deserialize_headers")] @@ -887,6 +926,11 @@ esac"# )] #[serde(default)] pub(crate) preprocess: Option, + + /// Host-specific configurations from config file + #[arg(skip)] + #[serde(default)] + pub(crate) hosts: HostConfigs, } impl Config { @@ -923,6 +967,11 @@ impl Config { self.github_token = toml.github_token; } + // Hosts configuration is only available in TOML for now (not in the CLI) + // That's because it's a bit complex to specify on the command line and + // we didn't come up with a good syntax for it yet. + self.hosts = toml.hosts; + // NOTE: if you see an error within this macro call, check to make sure that // that the fields provided to fold_in! match all the fields of the Config struct. fold_in! { @@ -933,6 +982,7 @@ impl Config { // Keys which are handled outside of fold_in ..header, ..github_token, + ..hosts, // Keys with defaults to assign accept: StatusCodeSelector::default(), @@ -944,6 +994,8 @@ impl Config { cache_exclude_status: None, cookie_jar: None, default_extension: None, + host_concurrency: None, + host_request_interval: None, dump: false, dump_inputs: false, exclude: Vec::::new(), @@ -960,6 +1012,7 @@ impl Config { generate: None, glob_ignore_case: false, hidden: false, + host_stats: false, include: Vec::::new(), include_fragments: false, include_mail: false, diff --git a/lychee-bin/tests/cli.rs b/lychee-bin/tests/cli.rs index 154c202636..c0cae571d0 100644 --- a/lychee-bin/tests/cli.rs +++ b/lychee-bin/tests/cli.rs @@ -19,14 +19,14 @@ mod cli { fs::{self, File}, io::{BufRead, Write}, path::Path, - time::Duration, + time::{Duration, Instant}, }; use tempfile::{NamedTempFile, tempdir}; use test_utils::{fixtures_path, mock_server, redirecting_mock_server, root_path}; use uuid::Uuid; use wiremock::{ - Mock, ResponseTemplate, + Mock, Request, ResponseTemplate, matchers::{basic_auth, method}, }; @@ -2358,6 +2358,48 @@ The config file should contain every possible key for documentation purposes." .success(); } + #[tokio::test] + async fn test_retry_rate_limit_headers() { + const RETRY_DELAY: Duration = Duration::from_secs(1); + const TOLERANCE: Duration = Duration::from_millis(500); + let server = wiremock::MockServer::start().await; + + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with( + ResponseTemplate::new(429) + .append_header("Retry-After", RETRY_DELAY.as_secs().to_string()), + ) + .expect(1) + .up_to_n_times(1) + .mount(&server) + .await; + + let start = Instant::now(); + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with(move |_: &Request| { + let delta = Instant::now().duration_since(start); + assert!(delta > RETRY_DELAY); + assert!(delta < RETRY_DELAY + TOLERANCE); + ResponseTemplate::new(200) + }) + .expect(1) + .mount(&server) + .await; + + cargo_bin_cmd!() + // Direct args are not using the host pool, they are resolved earlier via Collector + .arg("-") + // Retry wait times are added on top of host-specific backoff timeout + .arg("--retry-wait-time") + .arg("0") + .write_stdin(server.uri()) + .assert() + .success(); + + // Check that the server received the request with the header + server.verify().await; + } + #[tokio::test] async fn test_no_header_set_on_input() { let server = wiremock::MockServer::start().await; @@ -2394,7 +2436,6 @@ The config file should contain every possible key for documentation purposes." wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::header("X-Foo", "Bar")) .respond_with(wiremock::ResponseTemplate::new(200)) - // We expect the mock to be called exactly least once. .expect(1) .named("GET expecting custom header"), ) @@ -2421,7 +2462,6 @@ The config file should contain every possible key for documentation purposes." .and(wiremock::matchers::header("X-Foo", "Bar")) .and(wiremock::matchers::header("X-Bar", "Baz")) .respond_with(wiremock::ResponseTemplate::new(200)) - // We expect the mock to be called exactly least once. .expect(1) .named("GET expecting custom header"), ) @@ -2449,8 +2489,8 @@ The config file should contain every possible key for documentation purposes." wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::header("X-Foo", "Bar")) .and(wiremock::matchers::header("X-Bar", "Baz")) + .and(wiremock::matchers::header("X-Host-Specific", "Foo")) .respond_with(wiremock::ResponseTemplate::new(200)) - // We expect the mock to be called exactly least once. .expect(1) .named("GET expecting custom header"), ) @@ -2461,7 +2501,8 @@ The config file should contain every possible key for documentation purposes." .arg("--verbose") .arg("--config") .arg(config) - .arg(server.uri()) + .arg("-") + .write_stdin(server.uri()) .assert() .success(); diff --git a/lychee-lib/Cargo.toml b/lychee-lib/Cargo.toml index 5f5a6bf65a..0c94b62ca5 100644 --- a/lychee-lib/Cargo.toml +++ b/lychee-lib/Cargo.toml @@ -18,13 +18,17 @@ async-trait = "0.1.88" cached = "0.56.0" check-if-email-exists = { version = "0.9.1", optional = true } cookie_store = "0.22.0" +dashmap = { version = "6.1.0" } email_address = "0.2.9" futures = "0.3.31" glob = "0.3.3" +governor = "0.6.3" headers = "0.4.1" html5ever = "0.36.1" html5gum = "0.8.3" http = "1.4.0" +httpdate = "1.0.3" +humantime-serde = "1.1.1" hyper = "1.8.1" ignore = "0.4.25" ip_network = "0.4.1" diff --git a/lychee-lib/src/checker/website.rs b/lychee-lib/src/checker/website.rs index 82dfd84083..6a00915e46 100644 --- a/lychee-lib/src/checker/website.rs +++ b/lychee-lib/src/checker/website.rs @@ -2,6 +2,7 @@ use crate::{ BasicAuthCredentials, ErrorKind, FileType, Status, Uri, chain::{Chain, ChainResult, ClientRequestChains, Handler, RequestChain}, quirks::Quirks, + ratelimit::HostPool, retry::RetryExt, types::{redirect_history::RedirectHistory, uri::github::GithubUri}, utils::fragment_checker::{FragmentChecker, FragmentInput}, @@ -10,7 +11,7 @@ use async_trait::async_trait; use http::{Method, StatusCode}; use octocrab::Octocrab; use reqwest::{Request, Response, header::CONTENT_TYPE}; -use std::{collections::HashSet, path::Path, time::Duration}; +use std::{collections::HashSet, path::Path, sync::Arc, time::Duration}; use url::Url; #[derive(Debug, Clone)] @@ -18,9 +19,6 @@ pub(crate) struct WebsiteChecker { /// Request method used for making requests. method: reqwest::Method, - /// The HTTP client used for requests. - reqwest_client: reqwest::Client, - /// GitHub client used for requests. github_client: Option, @@ -54,25 +52,36 @@ pub(crate) struct WebsiteChecker { /// Keep track of HTTP redirections for reporting redirect_history: RedirectHistory, + + /// Optional host pool for per-host rate limiting. + /// + /// When present, HTTP requests will be routed through this pool for + /// rate limiting. When None, requests go directly through `reqwest_client`. + host_pool: Arc, } impl WebsiteChecker { + /// Get a reference to `HostPool` + #[must_use] + pub(crate) fn host_pool(&self) -> Arc { + self.host_pool.clone() + } + #[allow(clippy::too_many_arguments)] pub(crate) fn new( method: reqwest::Method, retry_wait_time: Duration, redirect_history: RedirectHistory, max_retries: u64, - reqwest_client: reqwest::Client, accepted: HashSet, github_client: Option, require_https: bool, plugin_request_chain: RequestChain, include_fragments: bool, + host_pool: Arc, ) -> Self { Self { method, - reqwest_client, github_client, plugin_request_chain, redirect_history, @@ -82,11 +91,14 @@ impl WebsiteChecker { require_https, include_fragments, fragment_checker: FragmentChecker::new(), + host_pool, } } /// Retry requests up to `max_retries` times /// with an exponential backoff. + /// Note that, in addition, there also is a host-specific backoff + /// when host-specific rate limiting or errors are detected. pub(crate) async fn retry_request(&self, request: Request) -> Status { let mut retries: u64 = 0; let mut wait_time = self.retry_wait_time; @@ -109,7 +121,7 @@ impl WebsiteChecker { let method = request.method().clone(); let request_url = request.url().clone(); - match self.reqwest_client.execute(request).await { + match self.host_pool.execute_request(request).await { Ok(response) => { let status = Status::new(&response, &self.accepted); // when `accept=200,429`, `status_code=429` will be treated as success @@ -146,7 +158,10 @@ impl WebsiteChecker { status } } - Err(e) => e.into(), + Err(e) => match e { + ErrorKind::NetworkRequest(error) => Status::from(error), + _ => e.into(), + }, } } @@ -239,10 +254,7 @@ impl WebsiteChecker { /// - The request failed. /// - The response status code is not accepted. async fn check_website_inner(&self, uri: &Uri, default_chain: &RequestChain) -> Status { - let request = self - .reqwest_client - .request(self.method.clone(), uri.as_str()) - .build(); + let request = self.host_pool.build_request(self.method.clone(), uri); let request = match request { Ok(r) => r, diff --git a/lychee-lib/src/client.rs b/lychee-lib/src/client.rs index a6efa70dc9..e0ec4e6c5c 100644 --- a/lychee-lib/src/client.rs +++ b/lychee-lib/src/client.rs @@ -32,6 +32,7 @@ use crate::{ chain::RequestChain, checker::{file::FileChecker, mail::MailChecker, website::WebsiteChecker}, filter::Filter, + ratelimit::{ClientMap, HostConfigs, HostKey, HostPool, RateLimitConfig}, remap::Remaps, types::{DEFAULT_ACCEPTED_STATUS_CODES, redirect_history::RedirectHistory}, }; @@ -304,6 +305,12 @@ pub struct ClientBuilder { /// early and return a status, so that subsequent chain items are /// skipped and the lychee-internal request chain is not activated. plugin_request_chain: RequestChain, + + /// Global rate limiting configuration that applies as defaults to all hosts + rate_limit_config: RateLimitConfig, + + /// Per-host configuration overrides + hosts: HostConfigs, } impl Default for ClientBuilder { @@ -329,53 +336,20 @@ impl ClientBuilder { /// /// [here]: https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#errors pub fn client(self) -> Result { - let Self { - user_agent, - custom_headers: mut headers, - .. - } = self; - - if let Some(prev_user_agent) = - headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?) - { - debug!( - "Found user-agent in headers: {}. Overriding it with {user_agent}.", - prev_user_agent.to_str().unwrap_or("ļæ½"), - ); - } - - headers.insert( - header::TRANSFER_ENCODING, - HeaderValue::from_static("chunked"), - ); - let redirect_history = RedirectHistory::new(); + let reqwest_client = self + .build_client(&redirect_history)? + .build() + .map_err(ErrorKind::BuildRequestClient)?; - let mut builder = reqwest::ClientBuilder::new() - .gzip(true) - .default_headers(headers) - .danger_accept_invalid_certs(self.allow_insecure) - .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT)) - .tcp_keepalive(Duration::from_secs(TCP_KEEPALIVE)) - .redirect(redirect_policy( - redirect_history.clone(), - self.max_redirects, - )); - - if let Some(cookie_jar) = self.cookie_jar { - builder = builder.cookie_provider(cookie_jar); - } + let client_map = self.build_host_clients(&redirect_history)?; - if let Some(min_tls) = self.min_tls_version { - builder = builder.min_tls_version(min_tls); - } - - let reqwest_client = match self.timeout { - Some(t) => builder.timeout(t), - None => builder, - } - .build() - .map_err(ErrorKind::BuildRequestClient)?; + let host_pool = HostPool::new( + self.rate_limit_config, + self.hosts, + reqwest_client, + client_map, + ); let github_client = match self.github_token.as_ref().map(ExposeSecret::expose_secret) { Some(token) if !token.is_empty() => Some( @@ -406,12 +380,12 @@ impl ClientBuilder { self.retry_wait_time, redirect_history.clone(), self.max_retries, - reqwest_client, self.accepted, github_client, self.require_https, self.plugin_request_chain, self.include_fragments, + Arc::new(host_pool), ); Ok(Client { @@ -427,6 +401,72 @@ impl ClientBuilder { ), }) } + + /// Build the host-specific clients with their host-specific headers + fn build_host_clients(&self, redirect_history: &RedirectHistory) -> Result { + self.hosts + .iter() + .map(|(host, config)| { + let mut headers = self.default_headers()?; + headers.extend(config.headers.clone()); + let client = self + .build_client(redirect_history)? + .default_headers(headers) + .build() + .map_err(ErrorKind::BuildRequestClient)?; + Ok((HostKey::from(host.as_str()), client)) + }) + .collect() + } + + /// Create a [`reqwest::ClientBuilder`] based on various fields + fn build_client(&self, redirect_history: &RedirectHistory) -> Result { + let mut builder = reqwest::ClientBuilder::new() + .gzip(true) + .default_headers(self.default_headers()?) + .danger_accept_invalid_certs(self.allow_insecure) + .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT)) + .tcp_keepalive(Duration::from_secs(TCP_KEEPALIVE)) + .redirect(redirect_policy( + redirect_history.clone(), + self.max_redirects, + )); + + if let Some(cookie_jar) = self.cookie_jar.clone() { + builder = builder.cookie_provider(cookie_jar); + } + + if let Some(min_tls) = self.min_tls_version { + builder = builder.min_tls_version(min_tls); + } + + if let Some(timeout) = self.timeout { + builder = builder.timeout(timeout); + } + + Ok(builder) + } + + fn default_headers(&self) -> Result { + let user_agent = self.user_agent.clone(); + let mut headers = self.custom_headers.clone(); + + if let Some(prev_user_agent) = + headers.insert(header::USER_AGENT, HeaderValue::try_from(&user_agent)?) + { + debug!( + "Found user-agent in headers: {}. Overriding it with {user_agent}.", + prev_user_agent.to_str().unwrap_or("ļæ½"), + ); + } + + headers.insert( + header::TRANSFER_ENCODING, + HeaderValue::from_static("chunked"), + ); + + Ok(headers) + } } /// Create our custom [`redirect::Policy`] in order to stop following redirects @@ -467,6 +507,12 @@ pub struct Client { } impl Client { + /// Get `HostPool` + #[must_use] + pub fn host_pool(&self) -> Arc { + self.website_checker.host_pool() + } + /// Check a single request. /// /// `request` can be either a [`Request`] or a type that can be converted @@ -498,8 +544,7 @@ impl Client { } let status = match uri.scheme() { - // We don't check tel: URIs - _ if uri.is_tel() => Status::Excluded, + _ if uri.is_tel() => Status::Excluded, // We don't check tel: URIs _ if uri.is_file() => self.check_file(uri).await, _ if uri.is_mail() => self.check_mail(uri).await, _ => self.check_website(uri, credentials).await?, diff --git a/lychee-lib/src/extract/html/html5gum.rs b/lychee-lib/src/extract/html/html5gum.rs index f33741b62e..95be6a1f89 100644 --- a/lychee-lib/src/extract/html/html5gum.rs +++ b/lychee-lib/src/extract/html/html5gum.rs @@ -238,7 +238,6 @@ impl LinkExtractor { if let Some(name) = self.current_attributes.get("name") { self.fragments.insert(name.to_string()); } - self.current_attributes.clear(); } } diff --git a/lychee-lib/src/lib.rs b/lychee-lib/src/lib.rs index 6c917fda92..2f4fb5381b 100644 --- a/lychee-lib/src/lib.rs +++ b/lychee-lib/src/lib.rs @@ -68,6 +68,9 @@ pub mod extract; pub mod remap; +/// Per-host rate limiting and concurrency control +pub mod ratelimit; + /// Filters are a way to define behavior when encountering /// URIs that need to be treated differently, such as /// local IPs or e-mail addresses diff --git a/lychee-lib/src/ratelimit/config.rs b/lychee-lib/src/ratelimit/config.rs new file mode 100644 index 0000000000..0d48f52a1b --- /dev/null +++ b/lychee-lib/src/ratelimit/config.rs @@ -0,0 +1,208 @@ +use http::{HeaderMap, HeaderName, HeaderValue}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +use crate::ratelimit::HostKey; + +/// Default number of concurrent requests per host +const DEFAULT_CONCURRENCY: usize = 10; + +/// Default interval between requests to the same host +const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50); + +/// Global rate limiting configuration that applies as defaults to all hosts +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct RateLimitConfig { + /// Default maximum concurrent requests per host + #[serde(default = "default_concurrency")] + pub concurrency: usize, + + /// Default minimum interval between requests to the same host + #[serde(default = "default_request_interval", with = "humantime_serde")] + pub request_interval: Duration, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + concurrency: default_concurrency(), + request_interval: default_request_interval(), + } + } +} + +/// Default number of concurrent requests per host +const fn default_concurrency() -> usize { + DEFAULT_CONCURRENCY +} + +/// Default interval between requests to the same host +const fn default_request_interval() -> Duration { + DEFAULT_REQUEST_INTERVAL +} + +impl RateLimitConfig { + /// Create a `RateLimitConfig` from CLI options, using defaults for missing values + #[must_use] + pub fn from_options(concurrency: Option, request_interval: Option) -> Self { + Self { + concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY), + request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL), + } + } +} + +/// Per-host configuration overrides +pub type HostConfigs = HashMap; + +/// Configuration for a specific host's rate limiting behavior +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct HostConfig { + /// Maximum concurrent requests allowed to this host + pub concurrency: Option, + + /// Minimum interval between requests to this host + #[serde(default, with = "humantime_serde")] + pub request_interval: Option, + + /// Custom headers to send with requests to this host + #[serde(default)] + #[serde(deserialize_with = "deserialize_headers")] + #[serde(serialize_with = "serialize_headers")] + pub headers: HeaderMap, +} + +impl Default for HostConfig { + fn default() -> Self { + Self { + concurrency: None, + request_interval: None, + headers: HeaderMap::new(), + } + } +} + +impl HostConfig { + /// Get the effective maximum concurrency, falling back to the global default + #[must_use] + pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize { + self.concurrency.unwrap_or(global_config.concurrency) + } + + /// Get the effective request interval, falling back to the global default + #[must_use] + pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration { + self.request_interval + .unwrap_or(global_config.request_interval) + } +} + +/// Custom deserializer for headers from TOML config format +fn deserialize_headers<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let map = HashMap::::deserialize(deserializer)?; + let mut header_map = HeaderMap::new(); + + for (name, value) in map { + let header_name = HeaderName::from_bytes(name.as_bytes()) + .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?; + let header_value = HeaderValue::from_str(&value).map_err(|e| { + serde::de::Error::custom(format!("Invalid header value '{value}': {e}")) + })?; + header_map.insert(header_name, header_value); + } + + Ok(header_map) +} + +/// Custom serializer for headers to TOML config format +fn serialize_headers(headers: &HeaderMap, serializer: S) -> Result +where + S: serde::Serializer, +{ + let map: HashMap = headers + .iter() + .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) + .collect(); + map.serialize(serializer) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_rate_limit_config() { + let config = RateLimitConfig::default(); + assert_eq!(config.concurrency, 10); + assert_eq!(config.request_interval, Duration::from_millis(50)); + } + + #[test] + fn test_host_config_effective_values() { + let global_config = RateLimitConfig::default(); + + // Test with no overrides + let host_config = HostConfig::default(); + assert_eq!(host_config.effective_concurrency(&global_config), 10); + assert_eq!( + host_config.effective_request_interval(&global_config), + Duration::from_millis(50) + ); + + // Test with overrides + let host_config = HostConfig { + concurrency: Some(5), + request_interval: Some(Duration::from_millis(500)), + headers: HeaderMap::new(), + }; + assert_eq!(host_config.effective_concurrency(&global_config), 5); + assert_eq!( + host_config.effective_request_interval(&global_config), + Duration::from_millis(500) + ); + } + + #[test] + fn test_config_serialization() { + let config = RateLimitConfig { + concurrency: 15, + request_interval: Duration::from_millis(200), + }; + + let toml = toml::to_string(&config).unwrap(); + let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap(); + + assert_eq!(config.concurrency, deserialized.concurrency); + assert_eq!(config.request_interval, deserialized.request_interval); + } + + #[test] + fn test_headers_serialization() { + let mut headers = HeaderMap::new(); + headers.insert("Authorization", "Bearer token123".parse().unwrap()); + headers.insert("User-Agent", "test-agent".parse().unwrap()); + + let host_config = HostConfig { + concurrency: Some(5), + request_interval: Some(Duration::from_millis(500)), + headers, + }; + + let toml = toml::to_string(&host_config).unwrap(); + let deserialized: HostConfig = toml::from_str(&toml).unwrap(); + + assert_eq!(deserialized.concurrency, Some(5)); + assert_eq!( + deserialized.request_interval, + Some(Duration::from_millis(500)) + ); + assert_eq!(deserialized.headers.len(), 2); + assert!(deserialized.headers.contains_key("authorization")); + assert!(deserialized.headers.contains_key("user-agent")); + } +} diff --git a/lychee-lib/src/ratelimit/headers.rs b/lychee-lib/src/ratelimit/headers.rs new file mode 100644 index 0000000000..bdc616c9aa --- /dev/null +++ b/lychee-lib/src/ratelimit/headers.rs @@ -0,0 +1,102 @@ +//! Handle rate limiting headers. +//! Note that we might want to replace this module with +//! at some point in the future. + +use http::HeaderValue; +use std::time::{Duration, SystemTime}; +use thiserror::Error; + +#[derive(Debug, Error, PartialEq, Eq)] +pub(crate) enum RetryAfterParseError { + #[error("Unable to parse value '{0}'")] + ValueError(String), + + #[error("Header value contains invalid chars")] + HeaderValueError, +} + +/// Parse the "Retry-After" header as specified per +/// [RFC 7231 section 7.1.3](https://www.rfc-editor.org/rfc/rfc7231#section-7.1.3) +pub(crate) fn parse_retry_after(value: &HeaderValue) -> Result { + let value = value + .to_str() + .map_err(|_| RetryAfterParseError::HeaderValueError)?; + + // RFC 7231: Retry-After = HTTP-date / delay-seconds + value.parse::().map(Duration::from_secs).or_else(|_| { + httpdate::parse_http_date(value) + .map(|s| { + s.duration_since(SystemTime::now()) + // if date is in the past, we can use ZERO + .unwrap_or(Duration::ZERO) + }) + .map_err(|_| RetryAfterParseError::ValueError(value.into())) + }) +} + +/// Parse the common "X-RateLimit" header fields. +/// Unfortunately, this is not standardised yet, but there is an +/// [IETF draft](https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/). +pub(crate) fn parse_common_rate_limit_header_fields( + headers: &http::HeaderMap, +) -> (Option, Option) { + let remaining = self::parse_header_value( + headers, + &[ + "x-ratelimit-remaining", + "x-rate-limit-remaining", + "ratelimit-remaining", + ], + ); + + let limit = self::parse_header_value( + headers, + &["x-ratelimit-limit", "x-rate-limit-limit", "ratelimit-limit"], + ); + + (remaining, limit) +} + +/// Helper method to parse numeric header values from common rate limit headers +fn parse_header_value(headers: &http::HeaderMap, header_names: &[&str]) -> Option { + for header_name in header_names { + if let Some(value) = headers.get(*header_name) + && let Ok(value_str) = value.to_str() + && let Ok(number) = value_str.parse::() + { + return Some(number); + } + } + None +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use http::HeaderValue; + + use crate::ratelimit::headers::{RetryAfterParseError, parse_retry_after}; + + #[test] + fn test_retry_after() { + assert_eq!(parse_retry_after(&value("1")), Ok(Duration::from_secs(1))); + assert_eq!( + parse_retry_after(&value("-1")), + Err(RetryAfterParseError::ValueError("-1".into())) + ); + + assert_eq!( + parse_retry_after(&value("Fri, 15 May 2015 15:34:21 GMT")), + Ok(Duration::ZERO) + ); + + let result = parse_retry_after(&value("Fri, 15 May 4099 15:34:21 GMT")); + let is_in_future = matches!(result, Ok(d) if d.as_secs() > 0); + assert!(is_in_future); + } + + fn value(v: &str) -> HeaderValue { + HeaderValue::from_str(v).unwrap() + } +} diff --git a/lychee-lib/src/ratelimit/host/host.rs b/lychee-lib/src/ratelimit/host/host.rs new file mode 100644 index 0000000000..842f9a1720 --- /dev/null +++ b/lychee-lib/src/ratelimit/host/host.rs @@ -0,0 +1,371 @@ +use crate::ratelimit::headers; +use dashmap::DashMap; +use governor::{ + Quota, RateLimiter, + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, +}; +use humantime_serde::re::humantime::format_duration; +use log::warn; +use reqwest::{Client as ReqwestClient, Request, Response}; +use std::time::{Duration, Instant}; +use std::{num::NonZeroU32, sync::Mutex}; +use tokio::sync::Semaphore; + +use super::key::HostKey; +use super::stats::HostStats; +use crate::types::Result; +use crate::{CacheStatus, Status, Uri}; +use crate::{ + ErrorKind, + ratelimit::{HostConfig, RateLimitConfig}, +}; + +/// Cap maximum backoff duration to reasonable limits +const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60); + +/// Cache value for per-host caching +#[derive(Debug, Clone)] +struct HostCacheValue { + status: CacheStatus, +} + +impl From<&Status> for HostCacheValue { + fn from(status: &Status) -> Self { + Self { + status: status.into(), + } + } +} + +/// Per-host cache for storing request results +type HostCache = DashMap; + +/// Represents a single host with its own rate limiting, concurrency control, +/// HTTP client configuration, and request cache. +/// +/// Each host maintains: +/// - A token bucket rate limiter using governor +/// - A semaphore for concurrency control +/// - A dedicated HTTP client with host-specific headers and cookies +/// - Statistics tracking for adaptive behavior +/// - A per-host cache to prevent duplicate requests +#[derive(Debug)] +pub struct Host { + /// The hostname this instance manages + pub key: HostKey, + + /// Rate limiter using token bucket algorithm + rate_limiter: Option>, + + /// Controls maximum concurrent requests to this host + semaphore: Semaphore, + + /// HTTP client configured for this specific host + client: ReqwestClient, + + /// Request statistics and adaptive behavior tracking + stats: Mutex, + + /// Current backoff duration for adaptive rate limiting + backoff_duration: Mutex, + + /// Per-host cache to prevent duplicate requests + cache: HostCache, +} + +impl Host { + /// Create a new Host instance for the given hostname + #[must_use] + pub fn new( + key: HostKey, + host_config: &HostConfig, + global_config: &RateLimitConfig, + client: ReqwestClient, + ) -> Self { + const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap(); + let interval = host_config.effective_request_interval(global_config); + let rate_limiter = + Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST))); + + // Create semaphore for concurrency control + let max_concurrent = host_config.effective_concurrency(global_config); + let semaphore = Semaphore::new(max_concurrent); + + Host { + key, + rate_limiter, + semaphore, + client, + stats: Mutex::new(HostStats::default()), + backoff_duration: Mutex::new(Duration::from_millis(0)), + cache: DashMap::new(), + } + } + + /// Check if a URI is cached and return the cached status if valid + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn get_cached_status(&self, uri: &Uri) -> Option { + if let Some(entry) = self.cache.get(uri) { + // Cache hit + self.stats.lock().unwrap().record_cache_hit(); + return Some(entry.status); + } + + // Cache miss + self.stats.lock().unwrap().record_cache_miss(); + None + } + + /// Cache a request result + pub fn cache_result(&self, uri: &Uri, status: &Status) { + let cache_value = HostCacheValue::from(status); + self.cache.insert(uri.clone(), cache_value); + } + + /// Execute a request with rate limiting, concurrency control, and caching + /// + /// This method: + /// 1. Checks the per-host cache for existing results + /// 2. If not cached, acquires a semaphore permit for concurrency control + /// 3. Waits for rate limiter permission + /// 4. Applies adaptive backoff if needed + /// 5. Executes the request + /// 6. Updates statistics based on response + /// 7. Parses rate limit headers to adjust future behavior + /// 8. Caches the result for future use + /// + /// # Arguments + /// + /// * `request` - The HTTP request to execute + /// + /// # Errors + /// + /// Returns an error if the request fails or rate limiting is exceeded + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub async fn execute_request(&self, request: Request) -> Result { + let uri = Uri::from(request.url().clone()); + + // Note: Cache checking is handled at the HostPool level + // This method focuses on executing the actual HTTP request + + // Acquire semaphore permit for concurrency control + let _permit = self + .semaphore + .acquire() + .await + // SAFETY: this should not panic as we never close the semaphore + .expect("Semaphore was closed unexpectedly"); + + // Apply adaptive backoff if needed + let backoff_duration = { + let backoff = self.backoff_duration.lock().unwrap(); + *backoff + }; + if !backoff_duration.is_zero() { + log::debug!( + "Host {} applying backoff delay of {}ms due to previous rate limiting or errors", + self.key, + backoff_duration.as_millis() + ); + tokio::time::sleep(backoff_duration).await; + } + + if let Some(rate_limiter) = &self.rate_limiter { + rate_limiter.until_ready().await; + } + + // Execute the request and track timing + let start_time = Instant::now(); + let response = match self.client.execute(request).await { + Ok(response) => response, + Err(e) => { + // Wrap network/HTTP errors to preserve the original error + return Err(ErrorKind::NetworkRequest(e)); + } + }; + let request_time = start_time.elapsed(); + + // Update statistics based on response + let status_code = response.status().as_u16(); + self.update_stats_and_backoff(status_code, request_time); + + // Parse rate limit headers to adjust behavior + self.handle_rate_limit_headers(&response); + + // Cache the result + let status = Status::Ok(response.status()); + self.cache_result(&uri, &status); + + Ok(response) + } + + pub(crate) const fn get_client(&self) -> &ReqwestClient { + &self.client + } + + /// Update internal statistics and backoff based on the response + fn update_stats_and_backoff(&self, status_code: u16, request_time: Duration) { + // Update statistics + { + let mut stats = self.stats.lock().unwrap(); + stats.record_response(status_code, request_time); + } + + // Update backoff duration based on response + { + let mut backoff = self.backoff_duration.lock().unwrap(); + match status_code { + 200..=299 => { + // Reset backoff on success + *backoff = Duration::from_millis(0); + } + 429 => { + // Exponential backoff on rate limit, capped at 30 seconds + let new_backoff = std::cmp::min( + if backoff.is_zero() { + Duration::from_millis(500) + } else { + *backoff * 2 + }, + Duration::from_secs(30), + ); + log::debug!( + "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms", + self.key, + backoff.as_millis(), + new_backoff.as_millis() + ); + *backoff = new_backoff; + } + 500..=599 => { + // Moderate backoff increase on server errors, capped at 10 seconds + *backoff = std::cmp::min( + *backoff + Duration::from_millis(200), + Duration::from_secs(10), + ); + } + _ => {} // No backoff change for other status codes + } + } + } + + /// Parse rate limit headers from response and adjust behavior + fn handle_rate_limit_headers(&self, response: &Response) { + // Implement basic parsing here rather than using the rate-limits crate to keep dependencies minimal + let headers = response.headers(); + self.handle_retry_after_header(headers); + self.handle_common_rate_limit_header_fields(headers); + } + + /// Handle the common "X-RateLimit" header fields. + fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) { + if let (Some(remaining), Some(limit)) = + headers::parse_common_rate_limit_header_fields(headers) + && limit > 0 + { + #[allow(clippy::cast_precision_loss)] + let usage_ratio = (limit - remaining) as f64 / limit as f64; + + // If we've used more than 80% of our quota, apply preventive backoff + if usage_ratio > 0.8 { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64); + self.increase_backoff(duration); + } + } + } + + /// Handle the "Retry-After" header + fn handle_retry_after_header(&self, headers: &http::HeaderMap) { + if let Some(retry_after_value) = headers.get("retry-after") { + let duration = match headers::parse_retry_after(retry_after_value) { + Ok(e) => e, + Err(e) => { + warn!("Unable to parse Retry-After header as per RFC 7231: {e}"); + return; + } + }; + + self.increase_backoff(duration); + } + } + + fn increase_backoff(&self, mut increased_backoff: Duration) { + if increased_backoff > MAXIMUM_BACKOFF { + warn!( + "Encountered an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.", + format_duration(increased_backoff), + format_duration(MAXIMUM_BACKOFF) + ); + increased_backoff = MAXIMUM_BACKOFF; + } + + let mut backoff = self.backoff_duration.lock().unwrap(); + *backoff = std::cmp::max(*backoff, increased_backoff); + } + + /// Get host statistics + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn stats(&self) -> HostStats { + self.stats.lock().unwrap().clone() + } + + /// Record a cache hit from the persistent disk cache + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn record_persistent_cache_hit(&self) { + self.stats.lock().unwrap().record_cache_hit(); + } + + /// Record a cache miss from the persistent disk cache + /// + /// # Panics + /// + /// Panics if the statistics mutex is poisoned + pub fn record_persistent_cache_miss(&self) { + self.stats.lock().unwrap().record_cache_miss(); + } + + /// Get the current number of available permits (concurrent request slots) + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Get the current cache size (number of cached entries) + pub fn cache_size(&self) -> usize { + self.cache.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ratelimit::{HostConfig, RateLimitConfig}; + use reqwest::Client; + + #[tokio::test] + async fn test_host_creation() { + let key = HostKey::from("example.com"); + let host_config = HostConfig::default(); + let global_config = RateLimitConfig::default(); + + let host = Host::new(key.clone(), &host_config, &global_config, Client::default()); + + assert_eq!(host.key, key); + assert_eq!(host.available_permits(), 10); // Default concurrency + assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON); + assert_eq!(host.cache_size(), 0); + } +} diff --git a/lychee-lib/src/ratelimit/host/key.rs b/lychee-lib/src/ratelimit/host/key.rs new file mode 100644 index 0000000000..4361b9b8da --- /dev/null +++ b/lychee-lib/src/ratelimit/host/key.rs @@ -0,0 +1,151 @@ +use serde::Deserialize; +use std::fmt; +use url::Url; + +use crate::ErrorKind; +use crate::types::Result; + +/// A type-safe representation of a hostname for rate limiting purposes. +/// +/// This extracts and normalizes hostnames from URLs to ensure consistent +/// rate limiting across requests to the same host (domain or IP address). +/// +/// # Examples +/// +/// ``` +/// use lychee_lib::ratelimit::HostKey; +/// use url::Url; +/// +/// let url = Url::parse("https://api.github.com/repos/user/repo").unwrap(); +/// let host_key = HostKey::try_from(&url).unwrap(); +/// assert_eq!(host_key.as_str(), "api.github.com"); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)] +pub struct HostKey(String); + +impl HostKey { + /// Get the hostname as a string slice + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Get the hostname as an owned String + #[must_use] + pub fn into_string(self) -> String { + self.0 + } +} + +impl TryFrom<&Url> for HostKey { + type Error = ErrorKind; + + fn try_from(url: &Url) -> Result { + let host = url.host_str().ok_or_else(|| ErrorKind::InvalidUrlHost)?; + + // Normalize to lowercase for consistent lookup + Ok(HostKey(host.to_lowercase())) + } +} + +impl TryFrom<&crate::Uri> for HostKey { + type Error = ErrorKind; + + fn try_from(uri: &crate::Uri) -> Result { + Self::try_from(&uri.url) + } +} + +impl TryFrom for HostKey { + type Error = ErrorKind; + + fn try_from(url: Url) -> Result { + HostKey::try_from(&url) + } +} + +impl fmt::Display for HostKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for HostKey { + fn from(host: String) -> Self { + HostKey(host.to_lowercase()) + } +} + +impl From<&str> for HostKey { + fn from(host: &str) -> Self { + HostKey(host.to_lowercase()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_host_key_from_url() { + let url = Url::parse("https://api.github.com/repos/user/repo").unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + assert_eq!(host_key.as_str(), "api.github.com"); + } + + #[test] + fn test_host_key_normalization() { + let url = Url::parse("https://API.GITHUB.COM/repos/user/repo").unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + assert_eq!(host_key.as_str(), "api.github.com"); + } + + #[test] + fn test_host_key_subdomain_separation() { + let api_url = Url::parse("https://api.github.com/").unwrap(); + let www_url = Url::parse("https://www.github.com/").unwrap(); + + let api_key = HostKey::try_from(&api_url).unwrap(); + let www_key = HostKey::try_from(&www_url).unwrap(); + + assert_ne!(api_key, www_key); + assert_eq!(api_key.as_str(), "api.github.com"); + assert_eq!(www_key.as_str(), "www.github.com"); + } + + #[test] + fn test_host_key_from_string() { + let host_key = HostKey::from("example.com"); + assert_eq!(host_key.as_str(), "example.com"); + + let host_key = HostKey::from("EXAMPLE.COM"); + assert_eq!(host_key.as_str(), "example.com"); + } + + #[test] + fn test_host_key_no_host() { + let url = Url::parse("file:///path/to/file").unwrap(); + let result = HostKey::try_from(&url); + assert!(result.is_err()); + } + + #[test] + fn test_host_key_display() { + let host_key = HostKey::from("example.com"); + assert_eq!(format!("{host_key}"), "example.com"); + } + + #[test] + fn test_host_key_hash_equality() { + use std::collections::HashMap; + + let key1 = HostKey::from("example.com"); + let key2 = HostKey::from("EXAMPLE.COM"); + + let mut map = HashMap::new(); + map.insert(key1, "value"); + + // Should find the value with normalized key + assert_eq!(map.get(&key2), Some(&"value")); + } +} diff --git a/lychee-lib/src/ratelimit/host/mod.rs b/lychee-lib/src/ratelimit/host/mod.rs new file mode 100644 index 0000000000..50b8b1ad3e --- /dev/null +++ b/lychee-lib/src/ratelimit/host/mod.rs @@ -0,0 +1,9 @@ +#![allow(clippy::module_inception)] + +mod host; +mod key; +mod stats; + +pub use host::Host; +pub use key::HostKey; +pub use stats::HostStats; diff --git a/lychee-lib/src/ratelimit/host/stats.rs b/lychee-lib/src/ratelimit/host/stats.rs new file mode 100644 index 0000000000..c78ec43623 --- /dev/null +++ b/lychee-lib/src/ratelimit/host/stats.rs @@ -0,0 +1,253 @@ +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +/// Record and report statistics for a [`crate::ratelimit::Host`] +#[derive(Debug, Clone, Default)] +pub struct HostStats { + /// Total number of requests made to this host + pub total_requests: u64, + /// Number of successful requests (2xx status) + pub successful_requests: u64, + /// Number of requests that received rate limit responses (429) + pub rate_limited: u64, + /// Number of server error responses (5xx) + pub server_errors: u64, + /// Number of client error responses (4xx, excluding 429) + pub client_errors: u64, + /// Timestamp of the last successful request + pub last_success: Option, + /// Timestamp of the last rate limit response + pub last_rate_limit: Option, + /// Request times for median calculation + pub request_times: Vec, + /// Status code counts + pub status_codes: HashMap, + /// Number of cache hits + pub cache_hits: u64, + /// Number of cache misses + pub cache_misses: u64, +} + +impl HostStats { + /// Record a response with status code and request duration + pub fn record_response(&mut self, status_code: u16, request_time: Duration) { + self.total_requests += 1; + + // Track status code + *self.status_codes.entry(status_code).or_insert(0) += 1; + + // Categorize response + match status_code { + 200..=299 => { + self.successful_requests += 1; + self.last_success = Some(Instant::now()); + } + 429 => { + self.rate_limited += 1; + self.last_rate_limit = Some(Instant::now()); + } + 400..=499 => { + self.client_errors += 1; + } + 500..=599 => { + self.server_errors += 1; + } + _ => {} // Other status codes + } + + self.request_times.push(request_time); + } + + /// Get median request time + #[must_use] + pub fn median_request_time(&self) -> Option { + if self.request_times.is_empty() { + return None; + } + + let mut times = self.request_times.clone(); + times.sort(); + let mid = times.len() / 2; + + if times.len().is_multiple_of(2) { + // Average of two middle values + Some((times[mid - 1] + times[mid]) / 2) + } else { + Some(times[mid]) + } + } + + /// Get error rate (percentage) + #[must_use] + pub fn error_rate(&self) -> f64 { + if self.total_requests == 0 { + return 0.0; + } + let errors = self.rate_limited + self.client_errors + self.server_errors; + #[allow(clippy::cast_precision_loss)] + let error_rate = errors as f64 / self.total_requests as f64; + error_rate * 100.0 + } + + /// Get the current success rate (0.0 to 1.0) + #[must_use] + pub fn success_rate(&self) -> f64 { + if self.total_requests == 0 { + 1.0 // Assume success until proven otherwise + } else { + #[allow(clippy::cast_precision_loss)] + let success_rate = self.successful_requests as f64 / self.total_requests as f64; + success_rate + } + } + + /// Get average request time + #[must_use] + pub fn average_request_time(&self) -> Option { + if self.request_times.is_empty() { + return None; + } + + let total: Duration = self.request_times.iter().sum(); + #[allow(clippy::cast_possible_truncation)] + Some(total / (self.request_times.len() as u32)) + } + + /// Get the most recent request time + #[must_use] + pub fn latest_request_time(&self) -> Option { + self.request_times.iter().last().copied() + } + + /// Check if this host has been experiencing rate limiting recently + #[must_use] + pub fn is_currently_rate_limited(&self) -> bool { + if let Some(last_rate_limit) = self.last_rate_limit { + // Consider rate limited if we got a 429 in the last 60 seconds + last_rate_limit.elapsed() < Duration::from_secs(60) + } else { + false + } + } + + /// Record a cache hit + pub const fn record_cache_hit(&mut self) { + self.cache_hits += 1; + // Cache hits should also count as total requests from user perspective + self.total_requests += 1; + // Cache hits are typically for successful previous requests, so count as successful + self.successful_requests += 1; + } + + /// Record a cache miss + pub const fn record_cache_miss(&mut self) { + self.cache_misses += 1; + // Cache misses will be followed by actual requests that increment total_requests + // so we don't increment here to avoid double-counting + } + + /// Get cache hit rate (0.0 to 1.0) + #[must_use] + pub fn cache_hit_rate(&self) -> f64 { + let total_cache_requests = self.cache_hits + self.cache_misses; + if total_cache_requests == 0 { + 0.0 + } else { + #[allow(clippy::cast_precision_loss)] + let hit_rate = self.cache_hits as f64 / total_cache_requests as f64; + hit_rate + } + } + + /// Get human-readable summary of the stats + #[must_use] + pub fn summary(&self) -> String { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let success_pct = (self.success_rate() * 100.0) as u64; + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let error_pct = self.error_rate() as u64; + + let avg_time = self + .average_request_time() + .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis())); + + format!( + "{} requests ({}% success, {}% errors), avg: {}", + self.total_requests, success_pct, error_pct, avg_time + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_host_stats_success_rate() { + let mut stats = HostStats::default(); + + // No requests yet - should assume success + assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON); + + // Record some successful requests + stats.record_response(200, Duration::from_millis(100)); + stats.record_response(200, Duration::from_millis(120)); + assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON); + + // Record a rate limited request + stats.record_response(429, Duration::from_millis(150)); + assert!((stats.success_rate() - (2.0 / 3.0)).abs() < 0.001); + + // Record a server error + stats.record_response(500, Duration::from_millis(200)); + assert!((stats.success_rate() - 0.5).abs() < f64::EPSILON); + } + + #[test] + fn test_host_stats_tracking() { + let mut stats = HostStats::default(); + + // Initially empty + assert_eq!(stats.total_requests, 0); + assert_eq!(stats.successful_requests, 0); + assert!(stats.error_rate().abs() < f64::EPSILON); + + // Record a successful response + stats.record_response(200, Duration::from_millis(100)); + assert_eq!(stats.total_requests, 1); + assert_eq!(stats.successful_requests, 1); + assert!(stats.error_rate().abs() < f64::EPSILON); + assert_eq!(stats.status_codes.get(&200), Some(&1)); + + // Record rate limited response + stats.record_response(429, Duration::from_millis(200)); + assert_eq!(stats.total_requests, 2); + assert_eq!(stats.rate_limited, 1); + assert!((stats.error_rate() - 50.0).abs() < f64::EPSILON); + + // Record server error + stats.record_response(500, Duration::from_millis(150)); + assert_eq!(stats.total_requests, 3); + assert_eq!(stats.server_errors, 1); + + // Check median request time + assert_eq!( + stats.median_request_time(), + Some(Duration::from_millis(150)) + ); + } + + #[test] + fn test_summary_formatting() { + let mut stats = HostStats::default(); + stats.record_response(200, Duration::from_millis(150)); + stats.record_response(500, Duration::from_millis(200)); + + let summary = stats.summary(); + assert!(summary.contains("2 requests")); + assert!(summary.contains("50% success")); + assert!(summary.contains("50% errors")); + assert!(summary.contains("175ms")); // average of 150 and 200 + } +} diff --git a/lychee-lib/src/ratelimit/mod.rs b/lychee-lib/src/ratelimit/mod.rs new file mode 100644 index 0000000000..ad4cb48551 --- /dev/null +++ b/lychee-lib/src/ratelimit/mod.rs @@ -0,0 +1,22 @@ +//! Per-host rate limiting and concurrency control. +//! +//! This module provides adaptive rate limiting for HTTP requests on a per-host basis. +//! It prevents overwhelming servers with too many concurrent requests and respects +//! server-provided rate limit headers. +//! +//! # Architecture +//! +//! - [`HostKey`]: Represents a hostname/domain for rate limiting +//! - [`Host`]: Manages rate limiting, concurrency, and caching for a specific host +//! - [`HostPool`]: Coordinates multiple hosts and routes requests appropriately +//! - [`HostConfig`]: Configuration for per-host behavior +//! - [`HostStats`]: Statistics tracking for each host + +mod config; +mod headers; +mod host; +mod pool; + +pub use config::{HostConfig, HostConfigs, RateLimitConfig}; +pub use host::{Host, HostKey, HostStats}; +pub use pool::{ClientMap, HostPool}; diff --git a/lychee-lib/src/ratelimit/pool.rs b/lychee-lib/src/ratelimit/pool.rs new file mode 100644 index 0000000000..cbd1d2bf88 --- /dev/null +++ b/lychee-lib/src/ratelimit/pool.rs @@ -0,0 +1,374 @@ +use dashmap::DashMap; +use http::Method; +use reqwest::{Client, Request, Response}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::ratelimit::{Host, HostConfigs, HostKey, HostStats, RateLimitConfig}; +use crate::types::Result; +use crate::{CacheStatus, ErrorKind, Status, Uri}; + +/// Keep track of host-specific [`reqwest::Client`]s +pub type ClientMap = HashMap; + +/// Manages a pool of Host instances and routes requests to appropriate hosts. +/// +/// The `HostPool` serves as the central coordinator for per-host rate limiting. +/// It creates host instances on-demand and provides a unified interface for +/// executing HTTP requests with appropriate rate limiting applied. +/// +/// # Architecture +/// +/// - Each unique hostname gets its own Host instance with dedicated rate limiting +/// - Hosts are created lazily when first requested +/// - Thread-safe using `DashMap` for concurrent access to host instances +#[derive(Debug)] +pub struct HostPool { + /// Map of hostname to Host instances, created on-demand + hosts: DashMap>, + + /// Global configuration for rate limiting defaults + global_config: RateLimitConfig, + + /// Per-host configuration overrides + host_configs: HostConfigs, + + /// Fallback client for hosts without host-specific client + default_client: Client, + + /// Host-specific clients + client_map: ClientMap, +} + +impl HostPool { + /// Create a new `HostPool` with the given configuration + #[must_use] + pub fn new( + global_config: RateLimitConfig, + host_configs: HostConfigs, + default_client: Client, + client_map: ClientMap, + ) -> Self { + Self { + hosts: DashMap::new(), + global_config, + host_configs, + default_client, + client_map, + } + } + + /// Try to execute a [`Request`] with appropriate per-host rate limiting. + /// + /// # Errors + /// + /// Fails if: + /// - The request URL has no valid hostname + /// - The underlying HTTP request fails + /// + /// # Examples + /// + /// ```no_run + /// # use lychee_lib::ratelimit::{HostPool, RateLimitConfig}; + /// # use std::collections::HashMap; + /// # use reqwest::{Request, header::HeaderMap}; + /// # use std::time::Duration; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// let pool = HostPool::default(); + /// let request = reqwest::Request::new(reqwest::Method::GET, "https://example.com".parse()?); + /// let response = pool.execute_request(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn execute_request(&self, request: Request) -> Result { + let url = request.url(); + let host_key = HostKey::try_from(url)?; + let host = self.get_or_create_host(host_key); + host.execute_request(request).await + } + + /// Try to build a [`Request`] + /// + /// # Errors + /// + /// Fails if: + /// - The request URI has no valid hostname + /// - The request fails to build + pub fn build_request(&self, method: Method, uri: &Uri) -> Result { + let host_key = HostKey::try_from(uri)?; + let host = self.get_or_create_host(host_key); + host.get_client() + .request(method, uri.url.clone()) + .build() + .map_err(ErrorKind::BuildRequestClient) + } + + /// Get an existing host or create a new one for the given hostname + fn get_or_create_host(&self, host_key: HostKey) -> Arc { + self.hosts + .entry(host_key.clone()) + .or_insert_with(|| { + let host_config = self + .host_configs + .get(&host_key) + .cloned() + .unwrap_or_default(); + + let client = self + .client_map + .get(&host_key) + .unwrap_or(&self.default_client) + .clone(); + + Arc::new(Host::new( + host_key, + &host_config, + &self.global_config, + client, + )) + }) + .value() + .clone() + } + + /// Returns statistics for the host if it exists, otherwise returns empty stats. + /// This provides consistent behavior whether or not requests have been made to that host yet. + #[must_use] + pub fn host_stats(&self, hostname: &str) -> HostStats { + let host_key = HostKey::from(hostname); + self.hosts + .get(&host_key) + .map(|host| host.stats()) + .unwrap_or_default() + } + + /// Returns a `HashMap` mapping hostnames to their statistics. + /// Only hosts that have had requests will be included. + #[must_use] + pub fn all_host_stats(&self) -> HashMap { + self.hosts + .iter() + .map(|entry| { + let hostname = entry.key().to_string(); + let stats = entry.value().stats(); + (hostname, stats) + }) + .collect() + } + + /// Get the number of host instances that have been created, + /// which corresponds to the number of unique hostnames that have + /// been accessed. + #[must_use] + pub fn active_host_count(&self) -> usize { + self.hosts.len() + } + + /// Get a copy of the current host-specific configurations. + /// This is useful for debugging or runtime monitoring of configuration. + #[must_use] + pub fn host_configurations(&self) -> HostConfigs { + self.host_configs.clone() + } + + /// Remove a host from the pool. + /// + /// This forces the host to be recreated with updated configuration + /// the next time a request is made to it. Any ongoing requests to + /// that host will continue with the old instance. + /// + /// # Returns + /// + /// Returns true if a host was removed, false if no host existed for that hostname. + #[must_use] + pub fn remove_host(&self, hostname: &str) -> bool { + let host_key = HostKey::from(hostname); + self.hosts.remove(&host_key).is_some() + } + + /// Check if a URI is cached in the appropriate host's cache + /// + /// # Returns + /// + /// Returns the cached status if found and valid, `None` otherwise + #[must_use] + pub fn get_cached_status(&self, uri: &Uri) -> Option { + let host_key = HostKey::try_from(uri).ok()?; + + if let Some(host) = self.hosts.get(&host_key) { + host.get_cached_status(uri) + } else { + None + } + } + + /// Cache a result for a URI in the appropriate host's cache + pub fn cache_result(&self, uri: &Uri, status: &Status) { + if let Ok(host_key) = HostKey::try_from(uri) + && let Some(host) = self.hosts.get(&host_key) + { + host.cache_result(uri, status); + } + // If host doesn't exist yet, we don't cache + // The result will be cached when the host is created and the request is made + } + + /// Get cache statistics across all hosts + #[must_use] + pub fn cache_stats(&self) -> HashMap { + self.hosts + .iter() + .map(|entry| { + let hostname = entry.key().to_string(); + let cache_size = entry.value().cache_size(); + let hit_rate = entry.value().stats().cache_hit_rate(); + (hostname, (cache_size, hit_rate)) + }) + .collect() + } + + /// Record a cache hit for the given URI in host statistics + /// + /// This tracks that a request was served from the persistent disk cache + /// rather than going through the rate-limited HTTP request flow. + /// This method will create a host instance if one doesn't exist yet. + /// + /// # Errors + /// + /// Returns an error if the host key cannot be parsed from the URI. + pub fn record_cache_hit(&self, uri: &crate::Uri) -> Result<()> { + let host_key = crate::ratelimit::HostKey::try_from(uri)?; + + // Get or create the host (this ensures statistics tracking even for cache-only requests) + let host = self.get_or_create_host(host_key); + host.record_persistent_cache_hit(); + Ok(()) + } + + /// Record a cache miss for the given URI in host statistics + /// + /// This tracks that a request could not be served from the persistent disk cache + /// and will need to go through the rate-limited HTTP request flow. + /// This method will create a Host instance if one doesn't exist yet. + /// + /// # Errors + /// + /// Returns an error if the host key cannot be parsed from the URI. + pub fn record_cache_miss(&self, uri: &crate::Uri) -> Result<()> { + let host_key = crate::ratelimit::HostKey::try_from(uri)?; + + // Get or create the host (this ensures statistics tracking even for cache-only requests) + let host = self.get_or_create_host(host_key); + host.record_persistent_cache_miss(); + Ok(()) + } +} + +impl Default for HostPool { + fn default() -> Self { + Self::new( + RateLimitConfig::default(), + HostConfigs::default(), + Client::default(), + HashMap::new(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ratelimit::RateLimitConfig; + + use url::Url; + + #[test] + fn test_host_pool_creation() { + let pool = HostPool::new( + RateLimitConfig::default(), + HostConfigs::default(), + Client::default(), + HashMap::new(), + ); + + assert_eq!(pool.active_host_count(), 0); + } + + #[test] + fn test_host_pool_default() { + let pool = HostPool::default(); + assert_eq!(pool.active_host_count(), 0); + } + + #[tokio::test] + async fn test_host_creation_on_demand() { + let pool = HostPool::default(); + let url: Url = "https://example.com/path".parse().unwrap(); + let host_key = HostKey::try_from(&url).unwrap(); + + // No hosts initially + assert_eq!(pool.active_host_count(), 0); + assert_eq!(pool.host_stats("example.com").total_requests, 0); + + // Create host on demand + let host = pool.get_or_create_host(host_key); + + // Now we have one host + assert_eq!(pool.active_host_count(), 1); + assert_eq!(pool.host_stats("example.com").total_requests, 0); + assert_eq!(host.key.as_str(), "example.com"); + } + + #[tokio::test] + async fn test_host_reuse() { + let pool = HostPool::default(); + let url: Url = "https://example.com/path1".parse().unwrap(); + let host_key1 = HostKey::try_from(&url).unwrap(); + + let url: Url = "https://example.com/path2".parse().unwrap(); + let host_key2 = HostKey::try_from(&url).unwrap(); + + // Create host for first request + let host1 = pool.get_or_create_host(host_key1); + assert_eq!(pool.active_host_count(), 1); + + // Second request to same host should reuse + let host2 = pool.get_or_create_host(host_key2); + assert_eq!(pool.active_host_count(), 1); + + // Should be the same instance + assert!(Arc::ptr_eq(&host1, &host2)); + } + + #[test] + fn test_host_config_management() { + let pool = HostPool::default(); + + // Initially no host configurations + let configs = pool.host_configurations(); + assert_eq!(configs.len(), 0); + } + + #[test] + fn test_host_removal() { + let pool = HostPool::default(); + + // Remove non-existent host + assert!(!pool.remove_host("nonexistent.com")); + + // We can't easily test removal of existing hosts without making actual requests + // due to the async nature of host creation, but the basic functionality works + } + + #[test] + fn test_all_host_stats() { + let pool = HostPool::default(); + + // No hosts initially + let stats = pool.all_host_stats(); + assert!(stats.is_empty()); + + // Stats would be populated after actual requests are made to create hosts + } +} diff --git a/lychee-lib/src/types/error.rs b/lychee-lib/src/types/error.rs index a41cad8f3a..a351c4dbbf 100644 --- a/lychee-lib/src/types/error.rs +++ b/lychee-lib/src/types/error.rs @@ -138,7 +138,7 @@ pub enum ErrorKind { #[error("Cannot send/receive message from channel")] Channel(#[from] tokio::sync::mpsc::error::SendError), - /// An URL with an invalid host was found + /// A URL without a host was found #[error("URL is missing a host")] InvalidUrlHost, @@ -335,7 +335,7 @@ impl ErrorKind { [name] => format!("An index file ({name}) is required"), [init @ .., tail] => format!("An index file ({}, or {}) is required", init.join(", "), tail), }.into(), - ErrorKind::PreprocessorError{command, reason} => Some(format!("Command '{command}' failed {reason}. Check value of the preprocessor option")) + ErrorKind::PreprocessorError{command, reason} => Some(format!("Command '{command}' failed {reason}. Check value of the preprocessor option")), } } diff --git a/lychee.example.toml b/lychee.example.toml index e8989b67b4..967031ae92 100644 --- a/lychee.example.toml +++ b/lychee.example.toml @@ -16,6 +16,9 @@ no_progress = false # Path to summary output file. output = ".config.dummy.report.md" +# Show host statistics +host_stats = true + # Extract links instead of checking them dump = true @@ -200,3 +203,20 @@ archive = "wayback" # Search and suggest link replacements for all broken links suggest = true + +############################# Hosts ############################# + +# Maximum simultaneous requests to the same host +host_concurrency = 5 + +# Minimum interval between requests to the same host +host_request_interval = "50ms" + +# Customize hosts +[hosts."blog.example.com"] +# Overwrite `host_concurrency` for this host +concurrency = 5 +# Overwrite `host_request_interval` for this host +request_interval = "0" # zero disables rate limiting +# Merge global `header` values with the following `headers` for this host +headers = { "A" = "B" }