diff --git a/src/cmd/serve.rs b/src/cmd/serve.rs
index d13b945942..27c9b057bb 100644
--- a/src/cmd/serve.rs
+++ b/src/cmd/serve.rs
@@ -92,16 +92,34 @@ fn set_serve_error(msg: &'static str, e: errors::Error) {
}
}
-async fn handle_request(req: Request
, mut root: PathBuf) -> Result> {
+async fn handle_request(
+ req: Request,
+ mut root: PathBuf,
+ base_path: String,
+) -> Result> {
+ let path_str = req.uri().path();
+ if !path_str.starts_with(&base_path) {
+ return Ok(not_found());
+ }
+
+ let trimmed_path = &path_str[base_path.len() - 1..];
+
let original_root = root.clone();
let mut path = RelativePathBuf::new();
// https://zola.discourse.group/t/percent-encoding-for-slugs/736
- let decoded = match percent_encoding::percent_decode_str(req.uri().path()).decode_utf8() {
+ let decoded = match percent_encoding::percent_decode_str(trimmed_path).decode_utf8() {
Ok(d) => d,
Err(_) => return Ok(not_found()),
};
- for c in decoded.split('/') {
+ let decoded_path = if base_path != "/" && decoded.starts_with(&base_path) {
+ // Remove the base_path from the request path before processing
+ decoded[base_path.len()..].to_string()
+ } else {
+ decoded.to_string()
+ };
+
+ for c in decoded_path.split('/') {
path.push(c);
}
@@ -318,6 +336,39 @@ fn rebuild_done_handling(broadcaster: &Sender, res: Result<()>, reload_path: &st
}
}
+fn construct_url(base_url: &str, no_port_append: bool, interface_port: u16) -> String {
+ if base_url == "/" {
+ return String::from("/");
+ }
+
+ let (protocol, stripped_url) = match base_url {
+ url if url.starts_with("http://") => ("http://", &url[7..]),
+ url if url.starts_with("https://") => ("https://", &url[8..]),
+ url => ("http://", url),
+ };
+
+ let (domain, path) = {
+ let parts: Vec<&str> = stripped_url.splitn(2, '/').collect();
+ if parts.len() > 1 {
+ (parts[0], format!("/{}", parts[1]))
+ } else {
+ (parts[0], String::new())
+ }
+ };
+
+ let full_address = if no_port_append {
+ format!("{}{}{}", protocol, domain, path)
+ } else {
+ format!("{}{}:{}{}", protocol, domain, interface_port, path)
+ };
+
+ if full_address.ends_with('/') {
+ full_address
+ } else {
+ format!("{}/", full_address)
+ }
+}
+
#[allow(clippy::too_many_arguments)]
fn create_new_site(
root_dir: &Path,
@@ -330,7 +381,7 @@ fn create_new_site(
include_drafts: bool,
mut no_port_append: bool,
ws_port: Option,
-) -> Result<(Site, SocketAddr)> {
+) -> Result<(Site, SocketAddr, String)> {
SITE_CONTENT.write().unwrap().clear();
let mut site = Site::new(root_dir, config_file)?;
@@ -345,24 +396,10 @@ fn create_new_site(
|u| u.to_string(),
);
- let base_url = if base_url == "/" {
- String::from("/")
- } else {
- let base_address = if no_port_append {
- base_url.to_string()
- } else {
- format!("{}:{}", base_url, interface_port)
- };
-
- if site.config.base_url.ends_with('/') {
- format!("http://{}/", base_address)
- } else {
- format!("http://{}", base_address)
- }
- };
+ let constructed_base_url = construct_url(&base_url, no_port_append, interface_port);
site.enable_serve_mode();
- site.set_base_url(base_url);
+ site.set_base_url(constructed_base_url.clone());
if let Some(output_dir) = output_dir {
if !force && output_dir.exists() {
return Err(Error::msg(format!(
@@ -384,7 +421,7 @@ fn create_new_site(
messages::notify_site_size(&site);
messages::warn_about_ignored_pages(&site);
site.build()?;
- Ok((site, address))
+ Ok((site, address, constructed_base_url))
}
#[allow(clippy::too_many_arguments)]
@@ -403,7 +440,7 @@ pub fn serve(
utc_offset: UtcOffset,
) -> Result<()> {
let start = Instant::now();
- let (mut site, bind_address) = create_new_site(
+ let (mut site, bind_address, constructed_base_url) = create_new_site(
root_dir,
interface,
interface_port,
@@ -415,6 +452,11 @@ pub fn serve(
no_port_append,
None,
)?;
+ let base_path = match constructed_base_url.splitn(4, '/').nth(3) {
+ Some(path) => format!("/{}", path),
+ None => "/".to_string(),
+ };
+
messages::report_elapsed_time(start);
// Stop right there if we can't bind to the address
@@ -479,19 +521,27 @@ pub fn serve(
rt.block_on(async {
let make_service = make_service_fn(move |_| {
let static_root = static_root.clone();
+ let base_path = base_path.clone();
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
- response_error_injector(handle_request(req, static_root.clone()))
+ response_error_injector(handle_request(
+ req,
+ static_root.clone(),
+ base_path.clone(),
+ ))
}))
}
});
let server = Server::bind(&bind_address).serve(make_service);
- println!("Web server is available at http://{}\n", bind_address);
+ println!(
+ "Web server is available at {} (bound to {})\n",
+ &constructed_base_url, &bind_address
+ );
if open {
- if let Err(err) = open::that(format!("http://{}", bind_address)) {
+ if let Err(err) = open::that(format!("{}", &constructed_base_url)) {
eprintln!("Failed to open URL in your browser: {}", err);
}
}
@@ -618,7 +668,7 @@ pub fn serve(
no_port_append,
ws_port,
) {
- Ok((s, _)) => {
+ Ok((s, _, _)) => {
clear_serve_error();
rebuild_done_handling(&broadcaster, Ok(()), "/x.js");
@@ -801,7 +851,7 @@ fn is_folder_empty(dir: &Path) -> bool {
mod tests {
use std::path::{Path, PathBuf};
- use super::{detect_change_kind, is_temp_file, ChangeKind};
+ use super::{construct_url, detect_change_kind, is_temp_file, ChangeKind};
#[test]
fn can_recognize_temp_files() {
@@ -893,4 +943,40 @@ mod tests {
let config_filename = Path::new("config.toml");
assert_eq!(expected, detect_change_kind(pwd, path, config_filename));
}
+
+ #[test]
+ fn test_construct_url_base_url_is_slash() {
+ let result = construct_url("/", false, 8080);
+ assert_eq!(result, "/");
+ }
+
+ #[test]
+ fn test_construct_url_http_protocol() {
+ let result = construct_url("http://example.com", false, 8080);
+ assert_eq!(result, "http://example.com:8080/");
+ }
+
+ #[test]
+ fn test_construct_url_https_protocol() {
+ let result = construct_url("https://example.com", false, 8080);
+ assert_eq!(result, "https://example.com:8080/");
+ }
+
+ #[test]
+ fn test_construct_url_no_protocol() {
+ let result = construct_url("example.com", false, 8080);
+ assert_eq!(result, "http://example.com:8080/");
+ }
+
+ #[test]
+ fn test_construct_url_no_port_append() {
+ let result = construct_url("https://example.com", true, 8080);
+ assert_eq!(result, "https://example.com/");
+ }
+
+ #[test]
+ fn test_construct_url_trailing_slash() {
+ let result = construct_url("http://example.com/", false, 8080);
+ assert_eq!(result, "http://example.com:8080/");
+ }
}