Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 125 additions & 7 deletions crates/goose-cli/src/recipes/template_recipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,92 @@ use crate::recipes::recipe::BUILT_IN_RECIPE_DIR_PARAM;

const CURRENT_TEMPLATE_NAME: &str = "current_template";

fn preprocess_template_variables(content: &str) -> Result<String> {
let all_template_variables = extract_template_variables(content);
let complex_template_variables = filter_complex_variables(&all_template_variables);
let unparsable_template_variables = filter_unparseable_variables(&complex_template_variables)?;
replace_unparseable_vars_with_raw(content, &unparsable_template_variables)
}

fn extract_template_variables(content: &str) -> Vec<String> {
let template_var_re = Regex::new(r"\{\{(.*?)\}\}").unwrap();
template_var_re
.captures_iter(content)
.map(|cap| cap[1].to_string())
.collect()
}

fn filter_complex_variables(template_variables: &[String]) -> Vec<String> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a little comment about what complex variables are? Is space before and after the name but not inbetween?

let valid_var_re = Regex::new(r"^\s*[a-zA-Z_][a-zA-Z0-9_]*\s*$").unwrap();
template_variables
.iter()
.filter(|var| !valid_var_re.is_match(var))
.cloned()
.collect()
}

fn filter_unparseable_variables(template_variables: &[String]) -> Result<Vec<String>> {
let mut vars_to_convert = Vec::new();

for var in template_variables {
// Create individual environment for each validation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need this comment (code self-explains)

let mut env = Environment::new();
env.set_undefined_behavior(UndefinedBehavior::Lenient);

let test_template = format!(
"{open}{content}{close}",
open = "{{",
content = var,
close = "}}"
);
if env.template_from_str(&test_template).is_err() {
vars_to_convert.push(var.clone());
}
}

Ok(vars_to_convert)
}

fn replace_unparseable_vars_with_raw(
content: &str,
unparsable_template_variables: &[String],
) -> Result<String> {
let mut result = content.to_string();

for var in unparsable_template_variables {
let pattern = format!(
"{open}{content}{close}",
open = "{{",
content = var,
close = "}}"
);
let replacement = format!(
"{{% raw %}}{open}{content}{close}{{% endraw %}}",
open = "{{",
close = "}}",
content = var
);
result = result.replace(&pattern, &replacement);
}

Ok(result)
}

pub fn render_recipe_content_with_params(
content: &str,
params: &HashMap<String, String>,
) -> Result<String> {
// Pre-process content to replace empty double quotes with single quotes
// This prevents MiniJinja from escaping "" to "\"\"" which would break YAML parsing
let re = Regex::new(r#":\s*"""#).unwrap();
let processed_content = re.replace_all(content, ": ''");
let content_with_empty_quotes_replaced = re.replace_all(content, ": ''");

// Pre-process template variables to convert invalid variable names to raw content
let content_with_safe_variables =
preprocess_template_variables(&content_with_empty_quotes_replaced)?;

let env = add_template_in_env(
&processed_content,
&content_with_safe_variables,
params.get(BUILT_IN_RECIPE_DIR_PARAM).unwrap().clone(),
UndefinedBehavior::Strict,
)?;
Expand All @@ -37,9 +112,11 @@ pub fn render_recipe_silent_when_variables_are_provided(
content: &str,
params: &HashMap<String, String>,
) -> Result<String> {
let preprocessed_content = preprocess_template_variables(content)?;

let mut env = minijinja::Environment::new();
env.set_undefined_behavior(UndefinedBehavior::Lenient);
let template = env.template_from_str(content)?;
let template = env.template_from_str(&preprocessed_content)?;
let rendered_content = template.render(params)?;
Ok(rendered_content)
}
Expand Down Expand Up @@ -87,8 +164,14 @@ pub fn parse_recipe_content(
content: &str,
recipe_dir: String,
) -> Result<(Recipe, HashSet<String>)> {
let (env, template_variables) =
get_env_with_template_variables(content, recipe_dir, UndefinedBehavior::Lenient)?;
// Pre-process template variables to handle invalid variable names
let preprocessed_content = preprocess_template_variables(content)?;

let (env, template_variables) = get_env_with_template_variables(
&preprocessed_content,
recipe_dir,
UndefinedBehavior::Lenient,
)?;
let template = env.get_template(CURRENT_TEMPLATE_NAME).unwrap();
let rendered_content = template
.render(())
Expand All @@ -104,8 +187,14 @@ pub fn render_recipe_for_preview(
recipe_dir: String,
params: &HashMap<String, String>,
) -> Result<Recipe> {
let (env, template_variables) =
get_env_with_template_variables(content, recipe_dir, UndefinedBehavior::Lenient)?;
// Pre-process template variables to handle invalid variable names
let preprocessed_content = preprocess_template_variables(content)?;

let (env, template_variables) = get_env_with_template_variables(
&preprocessed_content,
recipe_dir,
UndefinedBehavior::Lenient,
)?;
let template = env.get_template(CURRENT_TEMPLATE_NAME).unwrap();
// if the variables are not provided, the template will be rendered with the variables, otherwise it will keep the variables as is
let mut ctx = preserve_vars(&template_variables).clone();
Expand Down Expand Up @@ -175,6 +264,35 @@ mod tests {
assert!(err.to_string().contains("unexpected end of input"));
}

#[test]
fn test_render_content_with_spaced_variables() {
let content = "Hello {{hf model org}}_{{hf model name}}!";
let params = HashMap::from([("recipe_dir".to_string(), "some_dir".to_string())]);
let result = render_recipe_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello {{hf model org}}_{{hf model name}}!");

let content = "Hello {{hf model org}_{hf model name}}!";
let params = HashMap::from([("recipe_dir".to_string(), "some_dir".to_string())]);
let result = render_recipe_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello {{hf model org}_{hf model name}}!");

let content = "Hello {{valid_var}}!";
let params = HashMap::from([
("recipe_dir".to_string(), "some_dir".to_string()),
("valid_var".to_string(), "World".to_string()),
]);
let result = render_recipe_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello World!");

let content = "{{valid_var}} and {{invalid var}}";
let params = HashMap::from([
("recipe_dir".to_string(), "some_dir".to_string()),
("valid_var".to_string(), "Hello".to_string()),
]);
let result = render_recipe_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello and {{invalid var}}");
}

#[test]
fn test_empty_prompt() {
let content = r#"
Expand Down
Loading