From ce4a9ea60ab10c1b6ca5501ff59fd43a80e584f3 Mon Sep 17 00:00:00 2001 From: liushuyu Date: Tue, 11 Jan 2022 13:36:21 -0700 Subject: [PATCH] templates/load_data: add an optional parameter `headers` (#1710) * templates/load_data: add an optional parameter headers ... ... now `load_data` function supports setting extra headers * docs/templates/overview: cover some edge-cases in the explanation * templates/load_data: fix caching logic with headers * docs/templates: change wording for load_data headers explanations --- .../templates/src/global_fns/load_data.rs | 157 +++++++++++++++++- .../documentation/templates/overview.md | 37 +++++ 2 files changed, 191 insertions(+), 3 deletions(-) diff --git a/components/templates/src/global_fns/load_data.rs b/components/templates/src/global_fns/load_data.rs index 0d5526528a..366b3be9f7 100644 --- a/components/templates/src/global_fns/load_data.rs +++ b/components/templates/src/global_fns/load_data.rs @@ -6,7 +6,7 @@ use std::str::FromStr; use std::sync::{Arc, Mutex}; use csv::Reader; -use reqwest::header::{HeaderValue, CONTENT_TYPE}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; use reqwest::{blocking::Client, header}; use tera::{from_value, to_value, Error, Function as TeraFn, Map, Result, Value}; use url::Url; @@ -120,12 +120,14 @@ impl DataSource { method: Method, post_body: &Option, post_content_type: &Option, + headers: &Option>, ) -> u64 { let mut hasher = DefaultHasher::new(); format.hash(&mut hasher); method.hash(&mut hasher); post_body.hash(&mut hasher); post_content_type.hash(&mut hasher); + headers.hash(&mut hasher); self.hash(&mut hasher); hasher.finish() } @@ -162,6 +164,31 @@ fn get_output_format_from_args( } } +fn add_headers_from_args(header_args: Option>) -> Result { + let mut headers = HeaderMap::new(); + if let Some(header_args) = header_args { + for arg in header_args { + let mut splitter = arg.splitn(2, '='); + let key = splitter + .next() + .ok_or_else(|| { + format!("Invalid header argument. Expecting header key, got '{}'", arg) + })? + .to_string(); + let value = splitter.next().ok_or_else(|| { + format!("Invalid header argument. Expecting header value, got '{}'", arg) + })?; + headers.append( + HeaderName::from_str(&key) + .map_err(|e| format!("Invalid header name '{}': {}", key, e))?, + value.parse().map_err(|e| format!("Invalid header value '{}': {}", value, e))?, + ); + } + } + + Ok(headers) +} + /// A Tera function to load data from a file or from a URL /// Currently the supported formats are json, toml, csv, bibtex and plain text #[derive(Debug)] @@ -223,6 +250,11 @@ impl TeraFn for LoadData { }, _ => Method::Get, }; + let headers = optional_arg!( + Vec, + args.get("headers"), + "`load_data`: `headers` needs to be an argument with a list of strings of format =." + ); // If the file doesn't exist, source is None let data_source = match ( @@ -255,8 +287,13 @@ impl TeraFn for LoadData { }; let file_format = get_output_format_from_args(format_arg, &data_source)?; - let cache_key = - data_source.get_cache_key(&file_format, method, &post_body_arg, &post_content_type); + let cache_key = data_source.get_cache_key( + &file_format, + method, + &post_body_arg, + &post_content_type, + &headers, + ); let mut cache = self.result_cache.lock().expect("result cache lock"); if let Some(cached_result) = cache.get(&cache_key) { @@ -271,10 +308,12 @@ impl TeraFn for LoadData { let req = match method { Method::Get => response_client .get(url.as_str()) + .headers(add_headers_from_args(headers)?) .header(header::ACCEPT, file_format.as_accept_header()), Method::Post => { let mut resp = response_client .post(url.as_str()) + .headers(add_headers_from_args(headers)?) .header(header::ACCEPT, file_format.as_accept_header()); if let Some(content_type) = post_content_type { match HeaderValue::from_str(&content_type) { @@ -660,12 +699,14 @@ mod tests { Method::Get, &None, &None, + &Some(vec![]), ); let cache_key_2 = DataSource::Path(get_test_file("test.toml")).get_cache_key( &OutputFormat::Toml, Method::Get, &None, &None, + &Some(vec![]), ); assert_eq!(cache_key, cache_key_2); } @@ -677,12 +718,14 @@ mod tests { Method::Get, &None, &None, + &Some(vec![]), ); let json_cache_key = DataSource::Path(get_test_file("test.json")).get_cache_key( &OutputFormat::Toml, Method::Get, &None, &None, + &Some(vec![]), ); assert_ne!(toml_cache_key, json_cache_key); } @@ -694,16 +737,37 @@ mod tests { Method::Get, &None, &None, + &Some(vec![]), ); let json_cache_key = DataSource::Path(get_test_file("test.toml")).get_cache_key( &OutputFormat::Json, Method::Get, &None, &None, + &Some(vec![]), ); assert_ne!(toml_cache_key, json_cache_key); } + #[test] + fn different_cache_key_per_headers() { + let header1_cache_key = DataSource::Path(get_test_file("test.toml")).get_cache_key( + &OutputFormat::Json, + Method::Get, + &None, + &None, + &Some(vec!["a=b".to_string()]), + ); + let header2_cache_key = DataSource::Path(get_test_file("test.toml")).get_cache_key( + &OutputFormat::Json, + Method::Get, + &None, + &None, + &Some(vec![]), + ); + assert_ne!(header1_cache_key, header2_cache_key); + } + #[test] fn can_load_remote_data() { let _m = mock("GET", "/zpydpkjj67") @@ -1002,4 +1066,91 @@ mod tests { _mjson.assert(); } + + #[test] + fn is_custom_headers_working() { + let _mjson = mock("POST", "/kr1zdgbm4y4") + .with_header("content-type", "application/json") + .match_header("accept", "text/plain") + .match_header("x-custom-header", "some-values") + .with_body("{i_am:'json'}") + .expect(1) + .create(); + let url = format!("{}{}", mockito::server_url(), "/kr1zdgbm4y4"); + + let static_fn = LoadData::new(PathBuf::from("../utils"), None, PathBuf::new()); + let mut args = HashMap::new(); + args.insert("url".to_string(), to_value(&url).unwrap()); + args.insert("format".to_string(), to_value("plain").unwrap()); + args.insert("method".to_string(), to_value("post").unwrap()); + args.insert("content_type".to_string(), to_value("text/plain").unwrap()); + args.insert("body".to_string(), to_value("this is a match").unwrap()); + args.insert("headers".to_string(), to_value(["x-custom-header=some-values"]).unwrap()); + let result = static_fn.call(&args); + assert!(result.is_ok()); + + _mjson.assert(); + } + + #[test] + fn is_custom_headers_working_with_multiple_values() { + let _mjson = mock("POST", "/kr1zdgbm4y5") + .with_status(201) + .with_header("content-type", "application/json") + .match_header("authorization", "Bearer 123") + // Mockito currently does not have a way to validate multiple headers with the same name + // see https://github.com/lipanski/mockito/issues/117 + .match_header("accept", mockito::Matcher::Any) + .match_header("x-custom-header", "some-values") + .match_header("x-other-header", "some-other-values") + .with_body("I am a server that needs authentication and returns HTML with Accept set to JSON") + .expect(1) + .create(); + let url = format!("{}{}", mockito::server_url(), "/kr1zdgbm4y5"); + + let static_fn = LoadData::new(PathBuf::from("../utils"), None, PathBuf::new()); + let mut args = HashMap::new(); + args.insert("url".to_string(), to_value(&url).unwrap()); + args.insert("format".to_string(), to_value("plain").unwrap()); + args.insert("method".to_string(), to_value("post").unwrap()); + args.insert("content_type".to_string(), to_value("text/plain").unwrap()); + args.insert("body".to_string(), to_value("this is a match").unwrap()); + args.insert( + "headers".to_string(), + to_value([ + "x-custom-header=some-values", + "x-other-header=some-other-values", + "accept=application/json", + "authorization=Bearer 123", + ]) + .unwrap(), + ); + let result = static_fn.call(&args); + assert!(result.is_ok()); + + _mjson.assert(); + } + + #[test] + fn fails_when_specifying_invalid_headers() { + let _mjson = mock("GET", "/kr1zdgbm4y6").with_status(204).expect(0).create(); + let static_fn = LoadData::new(PathBuf::from("../utils"), None, PathBuf::new()); + let url = format!("{}{}", mockito::server_url(), "/kr1zdgbm4y6"); + let mut args = HashMap::new(); + args.insert("url".to_string(), to_value(&url).unwrap()); + args.insert("format".to_string(), to_value("plain").unwrap()); + args.insert("headers".to_string(), to_value(["bad-entry::bad-header"]).unwrap()); + let result = static_fn.call(&args); + assert!(result.is_err()); + + let static_fn = LoadData::new(PathBuf::from("../utils"), None, PathBuf::new()); + let mut args = HashMap::new(); + args.insert("url".to_string(), to_value(&url).unwrap()); + args.insert("format".to_string(), to_value("plain").unwrap()); + args.insert("headers".to_string(), to_value(["\n=\r"]).unwrap()); + let result = static_fn.call(&args); + assert!(result.is_err()); + + _mjson.assert(); + } } diff --git a/docs/content/documentation/templates/overview.md b/docs/content/documentation/templates/overview.md index 7ab7013ac7..468ee804c6 100644 --- a/docs/content/documentation/templates/overview.md +++ b/docs/content/documentation/templates/overview.md @@ -409,6 +409,43 @@ This example will make a POST request to the kroki service to generate a SVG. {{postdata|safe}} ``` +If you need additional handling for the HTTP headers, you can use the `headers` parameter. +You might need this parameter when the resource requires authentication or require passing additional +parameters via special headers. +Please note that the headers will be appended to the default headers set by Zola itself instead of replacing them. + +This example will make a POST request to the GitHub markdown rendering service. + +```jinja2 +{% set postdata = load_data(url="https://api.github.com/markdown", format="plain", method="POST", content_type="application/json", headers=["accept=application/vnd.github.v3+json"], body='{"text":"headers support added in #1710, commit before it: b3918f124d13ec1bedad4860c15a060dd3751368","context":"getzola/zola","mode":"gfm"}')%} +{{postdata|safe}} +``` + +The following example shows how to send a GraphQL query to GitHub (requires authentication). +If you want to try this example on your own machine, you need to provide a GitHub PAT (Personal Access Token), +you can acquire the access token at this link: https://github.com/settings/tokens and then set `GITHUB_TOKEN` +environment variable to the access token you have obtained. + +```jinja2 +{% set token = get_env(name="GITHUB_TOKEN") %} +{% set postdata = load_data(url="https://api.github.com/graphql", format="json", method="POST" ,content_type="application/json", headers=["accept=application/vnd.github.v4.idl", "authentication=Bearer " ~ token], body='{"query":"query { viewer { login }}"}')%} +{{postdata|safe}} +``` + +In case you need to specify multiple headers with the same name, you can specify them like this: + +``` +headers=["accept=application/json,text/html"] +``` + +Which is equivalent to two `Accept` headers with `application/json` and `text/html`. + +If it doesn't work, you can instead specify the headers multiple times to achieve a similar effect: + +``` +headers=["accept=application/json", "accept=text/html"] +``` + #### Data caching Data file loading and remote requests are cached in memory during the build, so multiple requests aren't made