diff --git a/examples/collect_links/collect_links.rs b/examples/collect_links/collect_links.rs index a7266b026e..57edd5fcdb 100644 --- a/examples/collect_links/collect_links.rs +++ b/examples/collect_links/collect_links.rs @@ -1,4 +1,3 @@ -use http::HeaderMap; use lychee_lib::{Collector, Input, InputSource, Result}; use reqwest::Url; use std::path::PathBuf; @@ -14,13 +13,11 @@ async fn main() -> Result<()> { )), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, Input { source: InputSource::FsPath(PathBuf::from("fixtures/TEST.md")), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, ]; diff --git a/lychee-bin/src/main.rs b/lychee-bin/src/main.rs index b4af97925d..fb9ef5a49c 100644 --- a/lychee-bin/src/main.rs +++ b/lychee-bin/src/main.rs @@ -67,12 +67,13 @@ use anyhow::{Context, Error, Result, bail}; use clap::Parser; use commands::CommandParams; use formatters::{get_stats_formatter, log::init_logging}; +use http::HeaderMap; use log::{error, info, warn}; #[cfg(feature = "native-tls")] use openssl_sys as _; // required for vendored-openssl feature -use options::LYCHEE_CONFIG_FILE; +use options::{HeaderMapExt, LYCHEE_CONFIG_FILE}; use ring as _; // required for apple silicon use lychee_lib::BasicAuthExtractor; @@ -319,6 +320,7 @@ async fn run(opts: &LycheeOptions) -> Result { .skip_hidden(!opts.config.hidden) .skip_ignored(!opts.config.no_ignore) .include_verbatim(opts.config.include_verbatim) + .headers(HeaderMap::from_header_pairs(&opts.config.header)?) // File a bug if you rely on this envvar! It's going to go away eventually. .use_html5ever(std::env::var("LYCHEE_USE_HTML5EVER").is_ok_and(|x| x == "1")); diff --git a/lychee-bin/src/options.rs b/lychee-bin/src/options.rs index 93b61a0092..924d0c77e7 100644 --- a/lychee-bin/src/options.rs +++ b/lychee-bin/src/options.rs @@ -336,19 +336,10 @@ impl LycheeOptions { } else { Some(self.config.exclude_path.clone()) }; - let headers = HeaderMap::from_header_pairs(&self.config.header)?; self.raw_inputs .iter() - .map(|s| { - Input::new( - s, - None, - self.config.glob_ignore_case, - excluded.clone(), - headers.clone(), - ) - }) + .map(|s| Input::new(s, None, self.config.glob_ignore_case, excluded.clone())) .collect::>() .context("Cannot parse inputs from arguments") } diff --git a/lychee-bin/tests/cli.rs b/lychee-bin/tests/cli.rs index 569edf7831..c39749ec55 100644 --- a/lychee-bin/tests/cli.rs +++ b/lychee-bin/tests/cli.rs @@ -24,7 +24,10 @@ mod cli { use serde_json::Value; use tempfile::NamedTempFile; use uuid::Uuid; - use wiremock::{Mock, ResponseTemplate, matchers::basic_auth}; + use wiremock::{ + Mock, ResponseTemplate, + matchers::{basic_auth, method}, + }; type Result = std::result::Result>; @@ -1673,8 +1676,14 @@ mod cli { let password = "password123"; let mock_server = wiremock::MockServer::start().await; - Mock::given(basic_auth(username, password)) - .respond_with(ResponseTemplate::new(200)) + + Mock::given(method("GET")) + .and(basic_auth(username, password)) + .respond_with(ResponseTemplate::new(200)) // Authenticated requests are accepted + .mount(&mock_server) + .await; + Mock::given(method("GET")) + .respond_with(|_: &_| panic!("Received unauthenticated request")) .mount(&mock_server) .await; @@ -1690,6 +1699,16 @@ mod cli { .stdout(contains("1 Total")) .stdout(contains("1 OK")); + // Websites as direct arguments must also use authentication + main_command() + .arg(mock_server.uri()) + .arg("--verbose") + .arg("--basic-auth") + .arg(format!("{} {username}:{password}", mock_server.uri())) + .assert() + .success() + .stdout(contains("0 Total")); // Mock server returns no body, so there are no URLs to check + Ok(()) } diff --git a/lychee-lib/src/collector.rs b/lychee-lib/src/collector.rs index 935220c606..8ce6861f0f 100644 --- a/lychee-lib/src/collector.rs +++ b/lychee-lib/src/collector.rs @@ -1,5 +1,6 @@ use crate::ErrorKind; use crate::InputSource; +use crate::types::resolver::UrlContentResolver; use crate::{ Base, Input, Request, Result, basic_auth::BasicAuthExtractor, extract::Extractor, types::FileExtensions, types::uri::raw::RawUri, utils::request, @@ -9,7 +10,9 @@ use futures::{ StreamExt, stream::{self, Stream}, }; +use http::HeaderMap; use par_stream::ParStreamExt; +use reqwest::Client; use std::path::PathBuf; /// Collector keeps the state of link collection @@ -25,9 +28,15 @@ pub struct Collector { use_html5ever: bool, root_dir: Option, base: Option, + headers: HeaderMap, + client: Client, } impl Default for Collector { + /// # Panics + /// + /// We call `Client::new()` which can panic in certain scenarios. + /// Use `Collector::new()` to handle `ClientBuilder` errors gracefully. fn default() -> Self { Collector { basic_auth_extractor: None, @@ -38,6 +47,8 @@ impl Default for Collector { skip_ignored: true, root_dir: None, base: None, + headers: HeaderMap::new(), + client: Client::new(), } } } @@ -61,6 +72,10 @@ impl Collector { use_html5ever: false, skip_hidden: true, skip_ignored: true, + headers: HeaderMap::new(), + client: Client::builder() + .build() + .map_err(ErrorKind::BuildRequestClient)?, root_dir, base, }) @@ -87,6 +102,20 @@ impl Collector { self } + /// Set headers to use when resolving input URLs + #[must_use] + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.headers = headers; + self + } + + /// Set client to use for checking input URLs + #[must_use] + pub fn client(mut self, client: Client) -> Self { + self.client = client; + self + } + /// Use `html5ever` to parse HTML instead of `html5gum`. #[must_use] pub const fn use_html5ever(mut self, yes: bool) -> Self { @@ -141,17 +170,33 @@ impl Collector { let skip_hidden = self.skip_hidden; let skip_ignored = self.skip_ignored; let global_base = self.base; + + let resolver = UrlContentResolver { + basic_auth_extractor: self.basic_auth_extractor.clone(), + headers: self.headers.clone(), + client: self.client, + }; + stream::iter(inputs) .par_then_unordered(None, move |input| { let default_base = global_base.clone(); let extensions = extensions.clone(); + let resolver = resolver.clone(); + async move { let base = match &input.source { InputSource::RemoteUrl(url) => Base::try_from(url.as_str()).ok(), _ => default_base, }; + input - .get_contents(skip_missing_inputs, skip_hidden, skip_ignored, extensions) + .get_contents( + skip_missing_inputs, + skip_hidden, + skip_ignored, + extensions, + resolver, + ) .map(move |content| (content, base.clone())) } }) @@ -181,7 +226,7 @@ impl Collector { mod tests { use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write}; - use http::{HeaderMap, StatusCode}; + use http::StatusCode; use reqwest::Url; use super::*; @@ -229,15 +274,15 @@ mod tests { // Treat as plaintext file (no extension) let file_path = temp_dir.path().join("README"); let _file = File::create(&file_path).unwrap(); - let input = Input::new( - &file_path.as_path().display().to_string(), - None, - true, - None, - HeaderMap::new(), - )?; + let input = Input::new(&file_path.as_path().display().to_string(), None, true, None)?; let contents: Vec<_> = input - .get_contents(true, true, true, FileType::default_extensions()) + .get_contents( + true, + true, + true, + FileType::default_extensions(), + UrlContentResolver::default(), + ) .collect::>() .await; @@ -248,9 +293,15 @@ mod tests { #[tokio::test] async fn test_url_without_extension_is_html() -> Result<()> { - let input = Input::new("https://example.com/", None, true, None, HeaderMap::new())?; + let input = Input::new("https://example.com/", None, true, None)?; let contents: Vec<_> = input - .get_contents(true, true, true, FileType::default_extensions()) + .get_contents( + true, + true, + true, + FileType::default_extensions(), + UrlContentResolver::default(), + ) .collect::>() .await; @@ -283,7 +334,6 @@ mod tests { source: InputSource::String(TEST_STRING.to_owned()), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, Input { source: InputSource::RemoteUrl(Box::new( @@ -293,13 +343,11 @@ mod tests { )), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, Input { source: InputSource::FsPath(file_path), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, Input { source: InputSource::FsGlob { @@ -308,7 +356,6 @@ mod tests { }, file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }, ]; @@ -337,7 +384,6 @@ mod tests { source: InputSource::String("This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)".to_string()), file_type_hint: Some(FileType::Markdown), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -364,7 +410,6 @@ mod tests { ), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -394,7 +439,6 @@ mod tests { ), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -421,7 +465,6 @@ mod tests { ), file_type_hint: Some(FileType::Markdown), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -445,7 +488,6 @@ mod tests { source: InputSource::String(input), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -478,7 +520,6 @@ mod tests { source: InputSource::RemoteUrl(Box::new(server_uri.clone())), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, None).await.ok().unwrap(); @@ -499,7 +540,6 @@ mod tests { ), file_type_hint: None, excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, None).await.ok().unwrap(); @@ -530,7 +570,6 @@ mod tests { )), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }, Input { source: InputSource::RemoteUrl(Box::new( @@ -542,7 +581,6 @@ mod tests { )), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }, ]; @@ -571,14 +609,13 @@ mod tests { source: InputSource::String( r#" Index - About - Another + About + Another "# .into(), ), file_type_hint: Some(FileType::Html), excluded_paths: None, - headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); diff --git a/lychee-lib/src/types/basic_auth/credentials.rs b/lychee-lib/src/types/basic_auth/credentials.rs index 5435dc6641..82a630870f 100644 --- a/lychee-lib/src/types/basic_auth/credentials.rs +++ b/lychee-lib/src/types/basic_auth/credentials.rs @@ -8,7 +8,6 @@ use reqwest::Request; use serde::Deserialize; use thiserror::Error; -use crate::Status; use crate::chain::{ChainResult, Handler}; #[derive(Copy, Clone, Debug, Error, PartialEq)] @@ -74,15 +73,20 @@ impl BasicAuthCredentials { pub fn to_authorization(&self) -> Authorization { Authorization::basic(&self.username, &self.password) } + + /// Append the credentials as headers to a `Request` + pub fn append_to_request(&self, request: &mut Request) { + request + .headers_mut() + .append(AUTHORIZATION, self.to_authorization().0.encode()); + } } #[async_trait] -impl Handler for Option { - async fn handle(&mut self, mut request: Request) -> ChainResult { +impl Handler for Option { + async fn handle(&mut self, mut request: Request) -> ChainResult { if let Some(credentials) = self { - request - .headers_mut() - .append(AUTHORIZATION, credentials.to_authorization().0.encode()); + credentials.append_to_request(&mut request); } ChainResult::Next(request) diff --git a/lychee-lib/src/types/input.rs b/lychee-lib/src/types/input.rs index 7e02b84910..66650855e3 100644 --- a/lychee-lib/src/types/input.rs +++ b/lychee-lib/src/types/input.rs @@ -1,9 +1,10 @@ +use super::file::FileExtensions; +use super::resolver::UrlContentResolver; use crate::types::FileType; use crate::{ErrorKind, Result, utils}; use async_stream::try_stream; use futures::stream::Stream; use glob::glob_with; -use http::HeaderMap; use ignore::WalkBuilder; use reqwest::Url; use serde::{Deserialize, Serialize}; @@ -13,8 +14,6 @@ use std::fs; use std::path::{Path, PathBuf}; use tokio::io::{AsyncReadExt, stdin}; -use super::file::FileExtensions; - const STDIN: &str = "-"; #[derive(Debug)] @@ -110,8 +109,6 @@ pub struct Input { pub file_type_hint: Option, /// Excluded paths that will be skipped when reading content pub excluded_paths: Option>, - /// Custom headers to be used when fetching remote URLs - pub headers: reqwest::header::HeaderMap, } impl Input { @@ -128,7 +125,6 @@ impl Input { file_type_hint: Option, glob_ignore_case: bool, excluded_paths: Option>, - headers: reqwest::header::HeaderMap, ) -> Result { let source = if value == STDIN { InputSource::Stdin @@ -194,7 +190,6 @@ impl Input { source, file_type_hint, excluded_paths, - headers, }) } @@ -205,7 +200,7 @@ impl Input { /// Returns an error if the input does not exist (i.e. invalid path) /// and the input cannot be parsed as a URL. pub fn from_value(value: &str) -> Result { - Self::new(value, None, false, None, HeaderMap::new()) + Self::new(value, None, false, None) } /// Retrieve the contents from the input @@ -226,11 +221,12 @@ impl Input { // If `Input` is a file path, try the given file extensions in order. // Stop on the first match. file_extensions: FileExtensions, + resolver: UrlContentResolver, ) -> impl Stream> { try_stream! { match self.source { - InputSource::RemoteUrl(ref url) => { - let content = Self::url_contents(url, &self.headers).await; + InputSource::RemoteUrl(url) => { + let content = resolver.url_contents(*url).await; match content { Err(_) if skip_missing => (), Err(e) => Err(e)?, @@ -328,31 +324,6 @@ impl Input { } } - async fn url_contents(url: &Url, headers: &HeaderMap) -> Result { - // Assume HTML for default paths - let file_type = if url.path().is_empty() || url.path() == "/" { - FileType::Html - } else { - FileType::from(url.as_str()) - }; - - let client = reqwest::Client::new(); - - let res = client - .get(url.clone()) - .headers(headers.clone()) - .send() - .await - .map_err(ErrorKind::NetworkRequest)?; - let input_content = InputContent { - source: InputSource::RemoteUrl(Box::new(url.clone())), - file_type, - content: res.text().await.map_err(ErrorKind::ReadResponseBody)?, - }; - - Ok(input_content) - } - fn glob_contents( &self, pattern: &str, @@ -456,8 +427,6 @@ fn is_excluded_path(excluded_paths: &[PathBuf], path: &PathBuf) -> bool { #[cfg(test)] mod tests { - use http::HeaderMap; - use super::*; #[test] @@ -468,7 +437,7 @@ mod tests { assert!(path.exists()); assert!(path.is_relative()); - let input = Input::new(test_file, None, false, None, HeaderMap::new()); + let input = Input::new(test_file, None, false, None); assert!(input.is_ok()); assert!(matches!( input, @@ -476,7 +445,6 @@ mod tests { source: InputSource::FsPath(PathBuf { .. }), file_type_hint: None, excluded_paths: None, - headers: _, }) )); } diff --git a/lychee-lib/src/types/mod.rs b/lychee-lib/src/types/mod.rs index ab1c9e7c65..53093fae6d 100644 --- a/lychee-lib/src/types/mod.rs +++ b/lychee-lib/src/types/mod.rs @@ -10,6 +10,7 @@ mod file; mod input; pub(crate) mod mail; mod request; +pub(crate) mod resolver; mod response; mod status; mod status_code; diff --git a/lychee-lib/src/types/resolver.rs b/lychee-lib/src/types/resolver.rs new file mode 100644 index 0000000000..550583fcfb --- /dev/null +++ b/lychee-lib/src/types/resolver.rs @@ -0,0 +1,72 @@ +use super::{FileType, InputContent, InputSource}; +use crate::utils::request; +use crate::{BasicAuthExtractor, ErrorKind, Result, Uri}; +use http::HeaderMap; +use reqwest::{Client, Request, Url}; + +/// Structure to fetch remote content. +#[derive(Debug, Default, Clone)] +pub struct UrlContentResolver { + pub basic_auth_extractor: Option, + pub headers: HeaderMap, + pub client: reqwest::Client, +} + +impl UrlContentResolver { + /// Fetch remote content by URL. + /// + /// This method is not intended to check if a URL is functional but + /// to get a URL's content and process the content. + pub async fn url_contents(&self, url: Url) -> Result { + // Assume HTML for default paths + let file_type = match url.path() { + path if path.is_empty() || path == "/" => FileType::Html, + _ => FileType::from(url.as_str()), + }; + + let credentials = request::extract_credentials( + self.basic_auth_extractor.as_ref(), + &Uri { url: url.clone() }, + ); + + let request = self.build_request(&url, credentials)?; + let content = get_request_body_text(&self.client, request).await?; + + let input_content = InputContent { + source: InputSource::RemoteUrl(Box::new(url.clone())), + file_type, + content, + }; + + Ok(input_content) + } + + fn build_request( + &self, + url: &Url, + credentials: Option, + ) -> Result { + let mut request = self + .client + .request(reqwest::Method::GET, url.clone()) + .build() + .map_err(ErrorKind::BuildRequestClient)?; + + request.headers_mut().extend(self.headers.clone()); + if let Some(credentials) = credentials { + credentials.append_to_request(&mut request); + } + + Ok(request) + } +} + +async fn get_request_body_text(client: &Client, request: Request) -> Result { + client + .execute(request) + .await + .map_err(ErrorKind::NetworkRequest)? + .text() + .await + .map_err(ErrorKind::ReadResponseBody) +} diff --git a/lychee-lib/src/utils/request.rs b/lychee-lib/src/utils/request.rs index 779ced469d..6b57b9ee8c 100644 --- a/lychee-lib/src/utils/request.rs +++ b/lychee-lib/src/utils/request.rs @@ -14,7 +14,7 @@ use crate::{ }; /// Extract basic auth credentials for a given URL. -fn extract_credentials( +pub(crate) fn extract_credentials( extractor: Option<&BasicAuthExtractor>, uri: &Uri, ) -> Option {