Skip to content

Commit

Permalink
Validate fieldnames and types when using pydantic codegen (#1189)
Browse files Browse the repository at this point in the history
Python and pydantic do not allow arbitrary identifiers to be used as
fields in classes. This PR adds checks to the BAML grammar, which run
conditionally when the user includes a python/pydantic code generator
block:
 - field names must not be Python keywords.
- field names must not be lexographically equal to the field type, or
the base of an optional type.

E.g. rule 1:
```python
# Not ok
class Foo(BaseModel):
  if string
```

E.g. rule 2:
```python
class ETA(BaseModel):
  time: string

# Not ok
class Foo(BaseModel):
  ETA: ETA
```

These rules are now checked during validation of the syntax tree prior
to construction of the IR, and if they are violated we push an error to
`Diagnostics`.

Bonus: There are a few changes in the PR not related to the issue - they
are little cleanups to reduce the number of unnecessary `rustc`
warnings.
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Add validation for field names in BAML classes to prevent Python
keyword and type name conflicts when using Pydantic code generation.
> 
>   - **Validation**:
> - Add `assert_no_field_name_collisions()` in `classes.rs` to check
field names against Python keywords and type names when using Pydantic.
>     - Use `reserved_names()` to map keywords to target languages.
>   - **Diagnostics**:
> - Update `new_field_validation_error()` in `error.rs` to accept
`String` for error messages.
>   - **Miscellaneous**:
> - Remove unused code and features in `lib.rs` and `build.rs` to reduce
rustc warnings.
> - Add tests `generator_keywords1.baml` and `generator_keywords2.baml`
to validate new rules.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 49d31fb. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Nov 22, 2024
1 parent 8c3d536 commit 93b393d
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ mod functions;
mod template_strings;
mod types;

use baml_types::GeneratorOutputType;

use crate::{configuration::Generator, validate::generator_loader::load_generators_from_ast};

use super::context::Context;

use std::collections::HashSet;

pub(super) fn validate(ctx: &mut Context<'_>) {
enums::validate(ctx);
classes::validate(ctx);
Expand All @@ -17,6 +23,13 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
template_strings::validate(ctx);
configurations::validate(ctx);

let generators = load_generators_from_ast(ctx.db.ast(), ctx.diagnostics);
let codegen_targets: HashSet<GeneratorOutputType> = generators.into_iter().filter_map(|generator| match generator {
Generator::Codegen(gen) => Some(gen.output_type),
Generator::BoundaryCloud(_) => None
}).collect::<HashSet<_>>();
classes::assert_no_field_name_collisions(ctx, &codegen_targets);

if !ctx.diagnostics.has_errors() {
cycle::validate(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use internal_baml_schema_ast::ast::{WithName, WithSpan};
use baml_types::GeneratorOutputType;
use internal_baml_schema_ast::ast::{Field, FieldType, WithName, WithSpan};

use super::types::validate_type;
use crate::validate::validation_pipeline::context::Context;
use internal_baml_diagnostics::DatamodelError;

use itertools::join;
use std::collections::{HashMap, HashSet};

pub(super) fn validate(ctx: &mut Context<'_>) {
let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default(
internal_baml_jinja_types::JinjaContext::Prompt,
Expand Down Expand Up @@ -45,3 +49,114 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
defined_types.errors_mut().clear();
}
}

/// Enforce that keywords in the user's requested target languages
/// do not appear as field names in BAML classes, and that field
/// names are not equal to type names when using Pydantic.
pub(super) fn assert_no_field_name_collisions(
ctx: &mut Context<'_>,
generator_output_types: &HashSet<GeneratorOutputType>,
) {
// The list of reserved words for all user-requested codegen targets.
let reserved = reserved_names(generator_output_types);

for cls in ctx.db.walk_classes() {
for c in cls.static_fields() {
let field: &Field<FieldType> = c.ast_field();

// Check for keyword in field name.
if let Some(langs) = reserved.get(field.name()) {
let msg = match langs.as_slice() {
[lang] => format!("Field name is a reserved word in generated {lang} clients."),
_ => format!(
"Field name is a reserved word in language clients: {}.",
join(langs, ", ")
),
};
ctx.push_error(DatamodelError::new_field_validation_error(
msg,
"class",
c.name(),
field.name(),
field.span.clone(),
))
}

// Check for collision between field name and type name when using Pydantic.
if generator_output_types.contains(&GeneratorOutputType::PythonPydantic) {
let type_name = field
.expr
.as_ref()
.map_or("".to_string(), |r#type| r#type.name());
if field.name() == type_name {
ctx.push_error(DatamodelError::new_field_validation_error(
"When using the python/pydantic generator, a field name must not be exactly equal to the type name. Consider changing the field name and using an alias.".to_string(),
"class",
c.name(),
field.name(),
field.span.clone()
))
}
}
}
}
}

/// For a given set of target languages, construct a map from keyword to the
/// list of target languages in which that identifier is a keyword.
///
/// This will be used later to make error messages like, "Could not use name
/// `continue` becase that is a keyword in Python", "Could not use the name
/// `return` because that is a keyword in Python and Typescript".
fn reserved_names(
generator_output_types: &HashSet<GeneratorOutputType>,
) -> HashMap<&'static str, Vec<GeneratorOutputType>> {
let mut keywords: HashMap<&str, Vec<GeneratorOutputType>> = HashMap::new();

let language_keywords: Vec<(&str, GeneratorOutputType)> = [
if generator_output_types.contains(&GeneratorOutputType::PythonPydantic) {
RESERVED_NAMES_PYTHON
.into_iter()
.map(|name| (*name, GeneratorOutputType::PythonPydantic))
.collect()
} else {
Vec::new()
},
if generator_output_types.contains(&GeneratorOutputType::Typescript) {
RESERVED_NAMES_TYPESCRIPT
.into_iter()
.map(|name| (*name, GeneratorOutputType::Typescript))
.collect()
} else {
Vec::new()
},
]
.iter()
.flatten()
.cloned()
.collect();

language_keywords
.into_iter()
.for_each(|(keyword, generator_output_type)| {
keywords
.entry(keyword)
.and_modify(|types| (*types).push(generator_output_type))
.or_insert(vec![generator_output_type]);
});

keywords
}

// This list of keywords was copied from
// https://www.w3schools.com/python/python_ref_keywords.asp
// .
const RESERVED_NAMES_PYTHON: &[&str] = &[
"False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue",
"def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import",
"in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while",
"with", "yield",
];

// Typescript is much more flexible in the key names it allows.
const RESERVED_NAMES_TYPESCRIPT: &[&str] = &[];
9 changes: 9 additions & 0 deletions engine/baml-lib/baml-types/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
strum::VariantArray,
strum::VariantNames,
)]

#[derive(PartialEq, Eq)]
pub enum GeneratorOutputType {
#[strum(serialize = "rest/openapi")]
OpenApi,
Expand All @@ -22,6 +24,13 @@ pub enum GeneratorOutputType {
RubySorbet,
}

impl std::hash::Hash for GeneratorOutputType {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
}
}


impl GeneratorOutputType {
pub fn default_client_mode(&self) -> GeneratorDefaultClientMode {
match self {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
generator lang_python {
output_type python/pydantic
output_dir "../python"
version "0.68.0"
}

class ETA {
thing string
}

class Foo {
if string
ETA ETA?
}

// error: Error validating field `if` in class `if`: Field name is a reserved word in generated python/pydantic clients.
// --> class/generator_keywords1.baml:12
// |
// 11 | class Foo {
// 12 | if string
// 13 | ETA ETA?
// |
// error: Error validating field `ETA` in class `ETA`: When using the python/pydantic generator, a field name must not be exactly equal to the type name. Consider changing the field name and using an alias.
// --> class/generator_keywords1.baml:13
// |
// 12 | if string
// 13 | ETA ETA?
// 14 | }
// |
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// This file is just like generator_keywords1.baml, except that the fieldname
// has been changed in order to not collide with the field type `ETA`, and an
// alias is used to render that field as `ETA` in prompts.

generator lang_python {
output_type python/pydantic
output_dir "../python"
version "0.68.0"
}

class ETA {
thing string
}

class Foo {
eta ETA? @alias("ETA")
}
2 changes: 1 addition & 1 deletion engine/baml-lib/diagnostics/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ impl DatamodelError {
}

pub fn new_field_validation_error(
message: &str,
message: String,
container_type: &str,
container_name: &str,
field: &str,
Expand Down
3 changes: 0 additions & 3 deletions engine/baml-lib/parser-database/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
fn main() {
// If you have an existing build.rs file, just add this line to it.
#[cfg(feature = "use-pyo3")]
pyo3_build_config::use_pyo3_cfgs();
}
9 changes: 1 addition & 8 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
#[cfg(all(test, feature = "no_wasm"))]
mod tests;

// #[cfg(all(feature = "wasm", feature = "no_wasm"))]
// compile_error!(
// "The features 'wasm' and 'no_wasm' are mutually exclusive. You can only use one at a time."
// );
// mod tests;

#[cfg(feature = "internal")]
pub mod internal;
Expand All @@ -15,7 +9,6 @@ pub(crate) mod internal;
pub mod cli;
pub mod client_registry;
pub mod errors;
mod macros;
pub mod request;
mod runtime;
pub mod runtime_interface;
Expand Down
9 changes: 0 additions & 9 deletions engine/baml-runtime/src/macros.rs

This file was deleted.

6 changes: 3 additions & 3 deletions engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ impl WasmRuntime {
if span.file_path.as_str().ends_with(file_name)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
if let Some(parent_function) =
if let Some(_parent_function) =
tc.parent_functions.iter().find(|f| f.name == selected_func)
{
return functions.into_iter().find(|f| f.name == selected_func);
Expand All @@ -1251,7 +1251,7 @@ impl WasmRuntime {
if span.file_path.as_str().ends_with(file_name)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
if let Some(parent_function) =
if let Some(_parent_function) =
tc.parent_functions.iter().find(|f| f.name == selected_func)
{
return functions.into_iter().find(|f| f.name == selected_func);
Expand Down Expand Up @@ -1441,7 +1441,7 @@ fn js_fn_to_baml_src_reader(get_baml_src_cb: js_sys::Function) -> BamlSrcReader
}

#[wasm_bindgen]
struct WasmCallContext {
pub struct WasmCallContext {
/// Index of the orchestration graph node to use for the call
/// Defaults to 0 when unset
node_index: Option<usize>,
Expand Down
5 changes: 1 addition & 4 deletions engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ use baml_runtime::{
},
ChatMessagePart, RenderedPrompt,
};
use serde::Serialize;
use serde_json::json;

use crate::runtime_wasm::ToJsValue;
use baml_types::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64};
use baml_types::{BamlMediaContent, BamlMediaType, MediaBase64};
use serde_wasm_bindgen::to_value;
use wasm_bindgen::prelude::*;

use super::WasmFunction;

#[wasm_bindgen(getter_with_clone)]
pub struct WasmScope {
scope: OrchestrationScope,
Expand Down

0 comments on commit 93b393d

Please sign in to comment.