From 6fc3caa941d5f1c34dc9b8d7b26a70a218362ea8 Mon Sep 17 00:00:00 2001 From: James Williams Date: Fri, 1 Mar 2024 02:17:45 -0800 Subject: [PATCH] Fix --base-url improper path and protocol handling using `zola serve` (#2311) * Fix --base-url improper path and protocol handling. * Fix formatting. --- src/cmd/serve.rs | 140 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 113 insertions(+), 27 deletions(-) diff --git a/src/cmd/serve.rs b/src/cmd/serve.rs index d13b945942..27c9b057bb 100644 --- a/src/cmd/serve.rs +++ b/src/cmd/serve.rs @@ -92,16 +92,34 @@ fn set_serve_error(msg: &'static str, e: errors::Error) { } } -async fn handle_request(req: Request, mut root: PathBuf) -> Result> { +async fn handle_request( + req: Request, + mut root: PathBuf, + base_path: String, +) -> Result> { + let path_str = req.uri().path(); + if !path_str.starts_with(&base_path) { + return Ok(not_found()); + } + + let trimmed_path = &path_str[base_path.len() - 1..]; + let original_root = root.clone(); let mut path = RelativePathBuf::new(); // https://zola.discourse.group/t/percent-encoding-for-slugs/736 - let decoded = match percent_encoding::percent_decode_str(req.uri().path()).decode_utf8() { + let decoded = match percent_encoding::percent_decode_str(trimmed_path).decode_utf8() { Ok(d) => d, Err(_) => return Ok(not_found()), }; - for c in decoded.split('/') { + let decoded_path = if base_path != "/" && decoded.starts_with(&base_path) { + // Remove the base_path from the request path before processing + decoded[base_path.len()..].to_string() + } else { + decoded.to_string() + }; + + for c in decoded_path.split('/') { path.push(c); } @@ -318,6 +336,39 @@ fn rebuild_done_handling(broadcaster: &Sender, res: Result<()>, reload_path: &st } } +fn construct_url(base_url: &str, no_port_append: bool, interface_port: u16) -> String { + if base_url == "/" { + return String::from("/"); + } + + let (protocol, stripped_url) = match base_url { + url if url.starts_with("http://") => ("http://", &url[7..]), + url if url.starts_with("https://") => ("https://", &url[8..]), + url => ("http://", url), + }; + + let (domain, path) = { + let parts: Vec<&str> = stripped_url.splitn(2, '/').collect(); + if parts.len() > 1 { + (parts[0], format!("/{}", parts[1])) + } else { + (parts[0], String::new()) + } + }; + + let full_address = if no_port_append { + format!("{}{}{}", protocol, domain, path) + } else { + format!("{}{}:{}{}", protocol, domain, interface_port, path) + }; + + if full_address.ends_with('/') { + full_address + } else { + format!("{}/", full_address) + } +} + #[allow(clippy::too_many_arguments)] fn create_new_site( root_dir: &Path, @@ -330,7 +381,7 @@ fn create_new_site( include_drafts: bool, mut no_port_append: bool, ws_port: Option, -) -> Result<(Site, SocketAddr)> { +) -> Result<(Site, SocketAddr, String)> { SITE_CONTENT.write().unwrap().clear(); let mut site = Site::new(root_dir, config_file)?; @@ -345,24 +396,10 @@ fn create_new_site( |u| u.to_string(), ); - let base_url = if base_url == "/" { - String::from("/") - } else { - let base_address = if no_port_append { - base_url.to_string() - } else { - format!("{}:{}", base_url, interface_port) - }; - - if site.config.base_url.ends_with('/') { - format!("http://{}/", base_address) - } else { - format!("http://{}", base_address) - } - }; + let constructed_base_url = construct_url(&base_url, no_port_append, interface_port); site.enable_serve_mode(); - site.set_base_url(base_url); + site.set_base_url(constructed_base_url.clone()); if let Some(output_dir) = output_dir { if !force && output_dir.exists() { return Err(Error::msg(format!( @@ -384,7 +421,7 @@ fn create_new_site( messages::notify_site_size(&site); messages::warn_about_ignored_pages(&site); site.build()?; - Ok((site, address)) + Ok((site, address, constructed_base_url)) } #[allow(clippy::too_many_arguments)] @@ -403,7 +440,7 @@ pub fn serve( utc_offset: UtcOffset, ) -> Result<()> { let start = Instant::now(); - let (mut site, bind_address) = create_new_site( + let (mut site, bind_address, constructed_base_url) = create_new_site( root_dir, interface, interface_port, @@ -415,6 +452,11 @@ pub fn serve( no_port_append, None, )?; + let base_path = match constructed_base_url.splitn(4, '/').nth(3) { + Some(path) => format!("/{}", path), + None => "/".to_string(), + }; + messages::report_elapsed_time(start); // Stop right there if we can't bind to the address @@ -479,19 +521,27 @@ pub fn serve( rt.block_on(async { let make_service = make_service_fn(move |_| { let static_root = static_root.clone(); + let base_path = base_path.clone(); async { Ok::<_, hyper::Error>(service_fn(move |req| { - response_error_injector(handle_request(req, static_root.clone())) + response_error_injector(handle_request( + req, + static_root.clone(), + base_path.clone(), + )) })) } }); let server = Server::bind(&bind_address).serve(make_service); - println!("Web server is available at http://{}\n", bind_address); + println!( + "Web server is available at {} (bound to {})\n", + &constructed_base_url, &bind_address + ); if open { - if let Err(err) = open::that(format!("http://{}", bind_address)) { + if let Err(err) = open::that(format!("{}", &constructed_base_url)) { eprintln!("Failed to open URL in your browser: {}", err); } } @@ -618,7 +668,7 @@ pub fn serve( no_port_append, ws_port, ) { - Ok((s, _)) => { + Ok((s, _, _)) => { clear_serve_error(); rebuild_done_handling(&broadcaster, Ok(()), "/x.js"); @@ -801,7 +851,7 @@ fn is_folder_empty(dir: &Path) -> bool { mod tests { use std::path::{Path, PathBuf}; - use super::{detect_change_kind, is_temp_file, ChangeKind}; + use super::{construct_url, detect_change_kind, is_temp_file, ChangeKind}; #[test] fn can_recognize_temp_files() { @@ -893,4 +943,40 @@ mod tests { let config_filename = Path::new("config.toml"); assert_eq!(expected, detect_change_kind(pwd, path, config_filename)); } + + #[test] + fn test_construct_url_base_url_is_slash() { + let result = construct_url("/", false, 8080); + assert_eq!(result, "/"); + } + + #[test] + fn test_construct_url_http_protocol() { + let result = construct_url("http://example.com", false, 8080); + assert_eq!(result, "http://example.com:8080/"); + } + + #[test] + fn test_construct_url_https_protocol() { + let result = construct_url("https://example.com", false, 8080); + assert_eq!(result, "https://example.com:8080/"); + } + + #[test] + fn test_construct_url_no_protocol() { + let result = construct_url("example.com", false, 8080); + assert_eq!(result, "http://example.com:8080/"); + } + + #[test] + fn test_construct_url_no_port_append() { + let result = construct_url("https://example.com", true, 8080); + assert_eq!(result, "https://example.com/"); + } + + #[test] + fn test_construct_url_trailing_slash() { + let result = construct_url("http://example.com/", false, 8080); + assert_eq!(result, "http://example.com:8080/"); + } }