Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix --base-url improper path and protocol handling using zola serve #2311

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 113 additions & 27 deletions src/cmd/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,34 @@ fn set_serve_error(msg: &'static str, e: errors::Error) {
}
}

async fn handle_request(req: Request<Body>, mut root: PathBuf) -> Result<Response<Body>> {
async fn handle_request(
jamwil marked this conversation as resolved.
Show resolved Hide resolved
req: Request<Body>,
mut root: PathBuf,
base_path: String,
) -> Result<Response<Body>> {
let path_str = req.uri().path();
if !path_str.starts_with(&base_path) {
jamwil marked this conversation as resolved.
Show resolved Hide resolved
return Ok(not_found());
}

let trimmed_path = &path_str[base_path.len() - 1..];

let original_root = root.clone();
let mut path = RelativePathBuf::new();
// https://zola.discourse.group/t/percent-encoding-for-slugs/736
let decoded = match percent_encoding::percent_decode_str(req.uri().path()).decode_utf8() {
let decoded = match percent_encoding::percent_decode_str(trimmed_path).decode_utf8() {
Ok(d) => d,
Err(_) => return Ok(not_found()),
};

for c in decoded.split('/') {
let decoded_path = if base_path != "/" && decoded.starts_with(&base_path) {
// Remove the base_path from the request path before processing
decoded[base_path.len()..].to_string()
} else {
decoded.to_string()
};

for c in decoded_path.split('/') {
path.push(c);
}

Expand Down Expand Up @@ -318,6 +336,39 @@ fn rebuild_done_handling(broadcaster: &Sender, res: Result<()>, reload_path: &st
}
}

fn construct_url(base_url: &str, no_port_append: bool, interface_port: u16) -> String {
jamwil marked this conversation as resolved.
Show resolved Hide resolved
if base_url == "/" {
return String::from("/");
}

let (protocol, stripped_url) = match base_url {
url if url.starts_with("http://") => ("http://", &url[7..]),
url if url.starts_with("https://") => ("https://", &url[8..]),
url => ("http://", url),
};

let (domain, path) = {
let parts: Vec<&str> = stripped_url.splitn(2, '/').collect();
if parts.len() > 1 {
(parts[0], format!("/{}", parts[1]))
} else {
(parts[0], String::new())
}
};

let full_address = if no_port_append {
format!("{}{}{}", protocol, domain, path)
} else {
format!("{}{}:{}{}", protocol, domain, interface_port, path)
};

if full_address.ends_with('/') {
full_address
} else {
format!("{}/", full_address)
}
}

#[allow(clippy::too_many_arguments)]
fn create_new_site(
root_dir: &Path,
Expand All @@ -330,7 +381,7 @@ fn create_new_site(
include_drafts: bool,
mut no_port_append: bool,
ws_port: Option<u16>,
) -> Result<(Site, SocketAddr)> {
) -> Result<(Site, SocketAddr, String)> {
SITE_CONTENT.write().unwrap().clear();

let mut site = Site::new(root_dir, config_file)?;
Expand All @@ -345,24 +396,10 @@ fn create_new_site(
|u| u.to_string(),
);

let base_url = if base_url == "/" {
String::from("/")
} else {
let base_address = if no_port_append {
base_url.to_string()
} else {
format!("{}:{}", base_url, interface_port)
};

if site.config.base_url.ends_with('/') {
format!("http://{}/", base_address)
} else {
format!("http://{}", base_address)
}
};
let constructed_base_url = construct_url(&base_url, no_port_append, interface_port);

site.enable_serve_mode();
site.set_base_url(base_url);
site.set_base_url(constructed_base_url.clone());
if let Some(output_dir) = output_dir {
if !force && output_dir.exists() {
return Err(Error::msg(format!(
Expand All @@ -384,7 +421,7 @@ fn create_new_site(
messages::notify_site_size(&site);
messages::warn_about_ignored_pages(&site);
site.build()?;
Ok((site, address))
Ok((site, address, constructed_base_url))
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -403,7 +440,7 @@ pub fn serve(
utc_offset: UtcOffset,
) -> Result<()> {
let start = Instant::now();
let (mut site, bind_address) = create_new_site(
let (mut site, bind_address, constructed_base_url) = create_new_site(
root_dir,
interface,
interface_port,
Expand All @@ -415,6 +452,11 @@ pub fn serve(
no_port_append,
None,
)?;
let base_path = match constructed_base_url.splitn(4, '/').nth(3) {
Some(path) => format!("/{}", path),
None => "/".to_string(),
};

messages::report_elapsed_time(start);

// Stop right there if we can't bind to the address
Expand Down Expand Up @@ -479,19 +521,27 @@ pub fn serve(
rt.block_on(async {
let make_service = make_service_fn(move |_| {
let static_root = static_root.clone();
let base_path = base_path.clone();

async {
Ok::<_, hyper::Error>(service_fn(move |req| {
response_error_injector(handle_request(req, static_root.clone()))
response_error_injector(handle_request(
req,
static_root.clone(),
base_path.clone(),
))
}))
}
});

let server = Server::bind(&bind_address).serve(make_service);

println!("Web server is available at http://{}\n", bind_address);
println!(
"Web server is available at {} (bound to {})\n",
&constructed_base_url, &bind_address
);
if open {
if let Err(err) = open::that(format!("http://{}", bind_address)) {
if let Err(err) = open::that(format!("{}", &constructed_base_url)) {
eprintln!("Failed to open URL in your browser: {}", err);
}
}
Expand Down Expand Up @@ -618,7 +668,7 @@ pub fn serve(
no_port_append,
ws_port,
) {
Ok((s, _)) => {
Ok((s, _, _)) => {
clear_serve_error();
rebuild_done_handling(&broadcaster, Ok(()), "/x.js");

Expand Down Expand Up @@ -801,7 +851,7 @@ fn is_folder_empty(dir: &Path) -> bool {
mod tests {
use std::path::{Path, PathBuf};

use super::{detect_change_kind, is_temp_file, ChangeKind};
use super::{construct_url, detect_change_kind, is_temp_file, ChangeKind};

#[test]
fn can_recognize_temp_files() {
Expand Down Expand Up @@ -893,4 +943,40 @@ mod tests {
let config_filename = Path::new("config.toml");
assert_eq!(expected, detect_change_kind(pwd, path, config_filename));
}

#[test]
fn test_construct_url_base_url_is_slash() {
let result = construct_url("/", false, 8080);
assert_eq!(result, "/");
}

#[test]
fn test_construct_url_http_protocol() {
let result = construct_url("http://example.com", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}

#[test]
fn test_construct_url_https_protocol() {
let result = construct_url("https://example.com", false, 8080);
assert_eq!(result, "https://example.com:8080/");
}

#[test]
fn test_construct_url_no_protocol() {
let result = construct_url("example.com", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}

#[test]
fn test_construct_url_no_port_append() {
let result = construct_url("https://example.com", true, 8080);
assert_eq!(result, "https://example.com/");
}

#[test]
fn test_construct_url_trailing_slash() {
let result = construct_url("http://example.com/", false, 8080);
assert_eq!(result, "http://example.com:8080/");
}
}