Skip to content
3 changes: 0 additions & 3 deletions examples/collect_links/collect_links.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use http::HeaderMap;
use lychee_lib::{Collector, Input, InputSource, Result};
use reqwest::Url;
use std::path::PathBuf;
Expand All @@ -14,13 +13,11 @@ async fn main() -> Result<()> {
)),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsPath(PathBuf::from("fixtures/TEST.md")),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
];

Expand Down
4 changes: 3 additions & 1 deletion lychee-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ use anyhow::{Context, Error, Result, bail};
use clap::Parser;
use commands::CommandParams;
use formatters::{get_stats_formatter, log::init_logging};
use http::HeaderMap;
use log::{error, info, warn};

#[cfg(feature = "native-tls")]
use openssl_sys as _; // required for vendored-openssl feature

use options::LYCHEE_CONFIG_FILE;
use options::{HeaderMapExt, LYCHEE_CONFIG_FILE};
use ring as _; // required for apple silicon

use lychee_lib::BasicAuthExtractor;
Expand Down Expand Up @@ -319,6 +320,7 @@ async fn run(opts: &LycheeOptions) -> Result<i32> {
.skip_hidden(!opts.config.hidden)
.skip_ignored(!opts.config.no_ignore)
.include_verbatim(opts.config.include_verbatim)
.headers(HeaderMap::from_header_pairs(&opts.config.header)?)
// File a bug if you rely on this envvar! It's going to go away eventually.
.use_html5ever(std::env::var("LYCHEE_USE_HTML5EVER").is_ok_and(|x| x == "1"));

Expand Down
11 changes: 1 addition & 10 deletions lychee-bin/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,19 +336,10 @@ impl LycheeOptions {
} else {
Some(self.config.exclude_path.clone())
};
let headers = HeaderMap::from_header_pairs(&self.config.header)?;

self.raw_inputs
.iter()
.map(|s| {
Input::new(
s,
None,
self.config.glob_ignore_case,
excluded.clone(),
headers.clone(),
)
})
.map(|s| Input::new(s, None, self.config.glob_ignore_case, excluded.clone()))
.collect::<Result<_, _>>()
.context("Cannot parse inputs from arguments")
}
Expand Down
25 changes: 22 additions & 3 deletions lychee-bin/tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ mod cli {
use serde_json::Value;
use tempfile::NamedTempFile;
use uuid::Uuid;
use wiremock::{Mock, ResponseTemplate, matchers::basic_auth};
use wiremock::{
Mock, ResponseTemplate,
matchers::{basic_auth, method},
};

type Result<T> = std::result::Result<T, Box<dyn Error>>;

Expand Down Expand Up @@ -1673,8 +1676,14 @@ mod cli {
let password = "password123";

let mock_server = wiremock::MockServer::start().await;
Mock::given(basic_auth(username, password))
.respond_with(ResponseTemplate::new(200))

Mock::given(method("GET"))
.and(basic_auth(username, password))
.respond_with(ResponseTemplate::new(200)) // Authenticated requests are accepted
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.respond_with(|_: &_| panic!("Received unauthenticated request"))
.mount(&mock_server)
.await;

Expand All @@ -1690,6 +1699,16 @@ mod cli {
.stdout(contains("1 Total"))
.stdout(contains("1 OK"));

// Websites as direct arguments must also use authentication
main_command()
.arg(mock_server.uri())
.arg("--verbose")
.arg("--basic-auth")
.arg(format!("{} {username}:{password}", mock_server.uri()))
.assert()
.success()
.stdout(contains("0 Total")); // Mock server returns no body, so there are no URLs to check

Ok(())
}

Expand Down
93 changes: 65 additions & 28 deletions lychee-lib/src/collector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::ErrorKind;
use crate::InputSource;
use crate::types::resolver::UrlContentResolver;
use crate::{
Base, Input, Request, Result, basic_auth::BasicAuthExtractor, extract::Extractor,
types::FileExtensions, types::uri::raw::RawUri, utils::request,
Expand All @@ -9,7 +10,9 @@ use futures::{
StreamExt,
stream::{self, Stream},
};
use http::HeaderMap;
use par_stream::ParStreamExt;
use reqwest::Client;
use std::path::PathBuf;

/// Collector keeps the state of link collection
Expand All @@ -25,9 +28,15 @@ pub struct Collector {
use_html5ever: bool,
root_dir: Option<PathBuf>,
base: Option<Base>,
headers: HeaderMap,
client: Client,
}

impl Default for Collector {
/// # Panics
///
/// We call `Client::new()` which can panic in certain scenarios.
/// Use `Collector::new()` to handle `ClientBuilder` errors gracefully.
fn default() -> Self {
Collector {
basic_auth_extractor: None,
Expand All @@ -38,6 +47,8 @@ impl Default for Collector {
skip_ignored: true,
root_dir: None,
base: None,
headers: HeaderMap::new(),
client: Client::new(),
}
}
}
Expand All @@ -61,6 +72,10 @@ impl Collector {
use_html5ever: false,
skip_hidden: true,
skip_ignored: true,
headers: HeaderMap::new(),
client: Client::builder()
.build()
.map_err(ErrorKind::BuildRequestClient)?,
root_dir,
base,
})
Expand All @@ -87,6 +102,20 @@ impl Collector {
self
}

/// Set headers to use when resolving input URLs
#[must_use]
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}

/// Set client to use for checking input URLs
#[must_use]
pub fn client(mut self, client: Client) -> Self {
self.client = client;
self
}

/// Use `html5ever` to parse HTML instead of `html5gum`.
#[must_use]
pub const fn use_html5ever(mut self, yes: bool) -> Self {
Expand Down Expand Up @@ -141,17 +170,33 @@ impl Collector {
let skip_hidden = self.skip_hidden;
let skip_ignored = self.skip_ignored;
let global_base = self.base;

let resolver = UrlContentResolver {
basic_auth_extractor: self.basic_auth_extractor.clone(),
headers: self.headers.clone(),
client: self.client,
};

stream::iter(inputs)
.par_then_unordered(None, move |input| {
let default_base = global_base.clone();
let extensions = extensions.clone();
let resolver = resolver.clone();

async move {
let base = match &input.source {
InputSource::RemoteUrl(url) => Base::try_from(url.as_str()).ok(),
_ => default_base,
};

input
.get_contents(skip_missing_inputs, skip_hidden, skip_ignored, extensions)
.get_contents(
skip_missing_inputs,
skip_hidden,
skip_ignored,
extensions,
resolver,
)
.map(move |content| (content, base.clone()))
}
})
Expand Down Expand Up @@ -181,7 +226,7 @@ impl Collector {
mod tests {
use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write};

use http::{HeaderMap, StatusCode};
use http::StatusCode;
use reqwest::Url;

use super::*;
Expand Down Expand Up @@ -229,15 +274,15 @@ mod tests {
// Treat as plaintext file (no extension)
let file_path = temp_dir.path().join("README");
let _file = File::create(&file_path).unwrap();
let input = Input::new(
&file_path.as_path().display().to_string(),
None,
true,
None,
HeaderMap::new(),
)?;
let input = Input::new(&file_path.as_path().display().to_string(), None, true, None)?;
let contents: Vec<_> = input
.get_contents(true, true, true, FileType::default_extensions())
.get_contents(
true,
true,
true,
FileType::default_extensions(),
UrlContentResolver::default(),
)
.collect::<Vec<_>>()
.await;

Expand All @@ -248,9 +293,15 @@ mod tests {

#[tokio::test]
async fn test_url_without_extension_is_html() -> Result<()> {
let input = Input::new("https://example.com/", None, true, None, HeaderMap::new())?;
let input = Input::new("https://example.com/", None, true, None)?;
let contents: Vec<_> = input
.get_contents(true, true, true, FileType::default_extensions())
.get_contents(
true,
true,
true,
FileType::default_extensions(),
UrlContentResolver::default(),
)
.collect::<Vec<_>>()
.await;

Expand Down Expand Up @@ -283,7 +334,6 @@ mod tests {
source: InputSource::String(TEST_STRING.to_owned()),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::RemoteUrl(Box::new(
Expand All @@ -293,13 +343,11 @@ mod tests {
)),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsPath(file_path),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsGlob {
Expand All @@ -308,7 +356,6 @@ mod tests {
},
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
];

Expand Down Expand Up @@ -337,7 +384,6 @@ mod tests {
source: InputSource::String("This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)".to_string()),
file_type_hint: Some(FileType::Markdown),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();

Expand All @@ -364,7 +410,6 @@ mod tests {
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();

Expand Down Expand Up @@ -394,7 +439,6 @@ mod tests {
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();

Expand All @@ -421,7 +465,6 @@ mod tests {
),
file_type_hint: Some(FileType::Markdown),
excluded_paths: None,
headers: HeaderMap::new(),
};

let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
Expand All @@ -445,7 +488,6 @@ mod tests {
source: InputSource::String(input),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();

Expand Down Expand Up @@ -478,7 +520,6 @@ mod tests {
source: InputSource::RemoteUrl(Box::new(server_uri.clone())),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
};

let links = collect(vec![input], None, None).await.ok().unwrap();
Expand All @@ -499,7 +540,6 @@ mod tests {
),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, None).await.ok().unwrap();

Expand Down Expand Up @@ -530,7 +570,6 @@ mod tests {
)),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::RemoteUrl(Box::new(
Expand All @@ -542,7 +581,6 @@ mod tests {
)),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
},
];

Expand Down Expand Up @@ -571,14 +609,13 @@ mod tests {
source: InputSource::String(
r#"
<a href="index.html">Index</a>
<a href="about.html">About</a>
<a href="/another.html">Another</a>
<a href="about.html">About</a>
<a href="/another.html">Another</a>
"#
.into(),
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};

let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
Expand Down
Loading