-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to validate types for template strings (#1161)
<!-- ELLIPSIS_HIDDEN --> > [!IMPORTANT] > Add type validation for template strings in BAML engine with new validation logic and test cases. > > - **Behavior**: > - Add `template_strings` module to `validations.rs` and integrate it into the validation pipeline. > - Implement `validate()` in `template_strings.rs` to check template string types and handle errors. > - Update `functions.rs` to make `has_checks_nested()` public for use in template validation. > - **Walkers**: > - Add `walk_input_args()` to `TemplateStringWalker` in `template_string.rs` to iterate over template arguments. > - Add `ArgWalker` type to handle template string arguments. > - **Tests**: > - Add test cases in `bad_calls.baml`, `good_calls.baml`, and `invalid.baml` to verify template string validation. > - **Misc**: > - Minor documentation updates in `class.rs` and `mod.rs`. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for 0881047. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
- Loading branch information
Showing
10 changed files
with
308 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
use std::collections::HashSet; | ||
|
||
use crate::validate::validation_pipeline::context::Context; | ||
|
||
use either::Either; | ||
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; | ||
|
||
use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; | ||
|
||
use super::types::validate_type; | ||
|
||
pub(super) fn validate(ctx: &mut Context<'_>) { | ||
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default( | ||
internal_baml_jinja_types::JinjaContext::Prompt, | ||
); | ||
ctx.db.walk_classes().for_each(|t| { | ||
t.add_to_types(&mut defined_types); | ||
}); | ||
ctx.db.walk_templates().for_each(|t| { | ||
t.add_to_types(&mut defined_types); | ||
}); | ||
|
||
for template in ctx.db.walk_templates() { | ||
for args in template.walk_input_args() { | ||
let arg = args.ast_arg(); | ||
validate_type(ctx, &arg.1.field_type); | ||
} | ||
|
||
for args in template.walk_input_args() { | ||
let arg = args.ast_arg(); | ||
let field_type = &arg.1.field_type; | ||
|
||
let span = field_type.span().clone(); | ||
if super::functions::has_checks_nested(ctx, field_type) { | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
"Types with checks are not allowed as function parameters.", | ||
span, | ||
)); | ||
} | ||
} | ||
|
||
let prompt = match template.template_raw() { | ||
Some(p) => p, | ||
None => { | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
"Template string must be a raw string literal like `template_string MyTemplate(myArg: string) #\"\n\n\"#`", | ||
template.identifier().span().clone(), | ||
)); | ||
continue; | ||
} | ||
}; | ||
|
||
defined_types.start_scope(); | ||
|
||
template.walk_input_args().for_each(|arg| { | ||
let name = match arg.ast_arg().0 { | ||
Some(arg) => arg.name().to_string(), | ||
None => { | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
"Argument name is missing.", | ||
arg.ast_arg().1.span().clone(), | ||
)); | ||
return; | ||
} | ||
}; | ||
|
||
let field_type = ctx.db.to_jinja_type(&arg.ast_arg().1.field_type); | ||
|
||
defined_types.add_variable(&name, field_type); | ||
}); | ||
match internal_baml_jinja_types::validate_template( | ||
template.name(), | ||
prompt.raw_value(), | ||
&mut defined_types, | ||
) { | ||
Ok(_) => {} | ||
Err(e) => { | ||
let pspan = prompt.span(); | ||
if let Some(e) = e.parsing_errors { | ||
let range = match e.range() { | ||
Some(range) => range, | ||
None => { | ||
ctx.push_error(DatamodelError::new_validation_error( | ||
&format!("Error parsing jinja template: {}", e), | ||
pspan.clone(), | ||
)); | ||
continue; | ||
} | ||
}; | ||
|
||
let start_offset = pspan.start + range.start; | ||
let end_offset = pspan.start + range.end; | ||
|
||
let span = Span::new( | ||
pspan.file.clone(), | ||
start_offset as usize, | ||
end_offset as usize, | ||
); | ||
|
||
ctx.push_error(DatamodelError::new_validation_error( | ||
&format!("Error parsing jinja template: {}", e), | ||
span, | ||
)) | ||
} else { | ||
e.errors.iter().for_each(|t| { | ||
let span = t.span(); | ||
let span = Span::new( | ||
pspan.file.clone(), | ||
pspan.start + span.start_offset as usize, | ||
pspan.start + span.end_offset as usize, | ||
); | ||
ctx.push_warning(DatamodelWarning::new(t.message().to_string(), span)) | ||
}) | ||
} | ||
} | ||
} | ||
defined_types.end_scope(); | ||
defined_types.errors_mut().clear(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
engine/baml-lib/baml/tests/validation_files/template_string/bad_calls.baml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
template_string WithParams(a: int) #" | ||
... | ||
"# | ||
|
||
template_string BadCall1 #" | ||
{{ WithParams(a=2, b=2) }} | ||
"# | ||
|
||
template_string BadCall2 #" | ||
{{ WithParams("a") }} | ||
"# | ||
|
||
template_string BadCall3 #" | ||
{{ WithParams() }} | ||
"# | ||
|
||
template_string BadCall4 #" | ||
{{ Random(2) }} | ||
"# | ||
|
||
// warning: Function 'WithParams' expects 1 arguments, but got 2 | ||
// --> template_string/bad_calls.baml:6 | ||
// | | ||
// 5 | template_string BadCall1 #" | ||
// 6 | {{ WithParams(a=2, b=2) }} | ||
// | | ||
// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"] | ||
// --> template_string/bad_calls.baml:10 | ||
// | | ||
// 9 | template_string BadCall2 #" | ||
// 10 | {{ WithParams("a") }} | ||
// | | ||
// warning: Function 'WithParams' expects 1 arguments, but got 0 | ||
// --> template_string/bad_calls.baml:14 | ||
// | | ||
// 13 | template_string BadCall3 #" | ||
// 14 | {{ WithParams() }} | ||
// | | ||
// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`? | ||
// --> template_string/bad_calls.baml:18 | ||
// | | ||
// 17 | template_string BadCall4 #" | ||
// 18 | {{ Random(2) }} | ||
// | | ||
// warning: 'Random' is undefined, expected function | ||
// --> template_string/bad_calls.baml:18 | ||
// | | ||
// 17 | template_string BadCall4 #" | ||
// 18 | {{ Random(2) }} | ||
// | | ||
// warning: Function 'WithParams' expects 1 arguments, but got 2 | ||
// --> template_string/bad_calls.baml:6 | ||
// | | ||
// 5 | template_string BadCall1 #" | ||
// 6 | {{ WithParams(a=2, b=2) }} | ||
// | | ||
// warning: Function 'WithParams' expects argument 'a' to be of type int, but got literal["a"] | ||
// --> template_string/bad_calls.baml:10 | ||
// | | ||
// 9 | template_string BadCall2 #" | ||
// 10 | {{ WithParams("a") }} | ||
// | | ||
// warning: Function 'WithParams' expects 1 arguments, but got 0 | ||
// --> template_string/bad_calls.baml:14 | ||
// | | ||
// 13 | template_string BadCall3 #" | ||
// 14 | {{ WithParams() }} | ||
// | | ||
// warning: Variable `Random` does not exist. Did you mean one of these: `_`, `ctx`? | ||
// --> template_string/bad_calls.baml:18 | ||
// | | ||
// 17 | template_string BadCall4 #" | ||
// 18 | {{ Random(2) }} | ||
// | | ||
// warning: 'Random' is undefined, expected function | ||
// --> template_string/bad_calls.baml:18 | ||
// | | ||
// 17 | template_string BadCall4 #" | ||
// 18 | {{ Random(2) }} | ||
// | |
11 changes: 11 additions & 0 deletions
11
engine/baml-lib/baml/tests/validation_files/template_string/good_calls.baml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
template_string WithParams(a: int) #" | ||
... | ||
"# | ||
|
||
template_string GoodCall1 #" | ||
{{ WithParams(a=2) }} | ||
"# | ||
|
||
template_string GoodCall2 #" | ||
{{ WithParams(2) }} | ||
"# |
34 changes: 34 additions & 0 deletions
34
engine/baml-lib/baml/tests/validation_files/template_string/invalid.baml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
template_string FunctionWithBadParams( | ||
param: Unknown, | ||
param2: Unknown2[], | ||
param3: string | ||
) #" | ||
{{ param.foo }} | ||
{{ param2[0].doc }} | ||
{{ param3 }} | ||
"# | ||
|
||
// warning: 'param' is undefined, expected class | ||
// --> template_string/invalid.baml:6 | ||
// | | ||
// 5 | ) #" | ||
// 6 | {{ param.foo }} | ||
// | | ||
// warning: 'param' is undefined, expected class | ||
// --> template_string/invalid.baml:6 | ||
// | | ||
// 5 | ) #" | ||
// 6 | {{ param.foo }} | ||
// | | ||
// error: Type `Unknown` does not exist. Did you mean one of these: `int`, `float`, `bool`, `string`, `true`, `false`? | ||
// --> template_string/invalid.baml:2 | ||
// | | ||
// 1 | template_string FunctionWithBadParams( | ||
// 2 | param: Unknown, | ||
// | | ||
// error: Type `Unknown2` does not exist. Did you mean one of these: `string`, `int`, `float`, `bool`, `true`, `false`? | ||
// --> template_string/invalid.baml:3 | ||
// | | ||
// 2 | param: Unknown, | ||
// 3 | param2: Unknown2[], | ||
// | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters