Skip to content

Commit

Permalink
feat: Add constraints via Jinja expressions (#1006)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> This PR adds support for constraints using Jinja expressions, updates
models and validation logic, and enhances testing for constraint
functionality.
> 
>   - **Behavior**:
> - Add support for constraints using Jinja expressions in `baml-core`.
> - Implement `evaluate_predicate` in `internal_baml_jinja` for
constraint evaluation.
> - Update `FieldType` to support `Constrained` type with constraints.
>   - **Models**:
>     - Add `Constraint` and `ConstraintLevel` in `baml-types`.
>     - Update `NodeAttributes` to include constraints.
>   - **Validation**:
>     - Add validation for constraints in `validation_pipeline`.
>     - Ensure constraints are not allowed as function parameters.
>   - **Testing**:
>     - Add tests for constraints in `test_constraints.rs`.
>     - Update integration tests to check constraint functionality.
>   - **Misc**:
>     - Add `itertools` dependency in `Cargo.lock`.
>     - Update `shell.nix` to include necessary tools for development.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 316ac5f. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Oct 23, 2024
1 parent 45b91d3 commit d794f28
Show file tree
Hide file tree
Showing 131 changed files with 5,645 additions and 2,206 deletions.
12 changes: 12 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ http-body = "1.0.0"
indexmap = { version = "2.1.0", features = ["serde"] }
indoc = "2.0.5"
log = "0.4.20"
# TODO: disable imports, etc
minijinja = { version = "1.0.16", default-features = false, features = [
"macros",
"builtins",
"debug",
"preserve_order",
"adjacent_loop_items",
"unicode",
"json",
"unstable_machinery",
"unstable_machinery_serde",
"custom_syntax",
"internal_debug",
"deserialization",
# We don't want to use these features:
# multi_template
# loader
#
] }
regex = "1.10.4"
scopeguard = "1.2.0"
serde_json = { version = "1", features = ["float_roundtrip", "preserve_order"] }
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal-baml-jinja-types = { path = "../jinja" }
internal-baml-parser-database = { path = "../parser-database" }
internal-baml-prompt-parser = { path = "../prompt-parser" }
internal-baml-schema-ast = { path = "../schema-ast" }
minijinja.workspace = true
rayon = "1.8.0"
regex = "1.10.3"
semver = "1.0.20"
Expand Down
109 changes: 107 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::{
Class, Client, Enum, EnumValue, Field, FunctionNode, RetryPolicy, TemplateString, TestCase,
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue};
use anyhow::{Context, Result};
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, TypeValue};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand Down Expand Up @@ -44,6 +44,7 @@ pub trait IRHelper {
params: &BamlMap<String, BamlValue>,
coerce_settings: ArgCoercer,
) -> Result<BamlValue>;
fn distribute_type<'a>(&'a self, value: BamlValue, field_type: FieldType) -> Result<BamlValueWithMeta<FieldType>>;
}

impl IRHelper for IntermediateRepr {
Expand Down Expand Up @@ -184,4 +185,108 @@ impl IRHelper for IntermediateRepr {
Ok(BamlValue::Map(baml_arg_map))
}
}

/// For some `BamlValue` with type `FieldType`, walk the structure of both the value
/// and the type simultaneously, associating each node in the `BamlValue` with its
/// `FieldType`.
fn distribute_type<'a>(
&'a self,
value: BamlValue,
field_type: FieldType,
) -> anyhow::Result<BamlValueWithMeta<FieldType>> {
let (unconstrained_type, _) = field_type.distribute_constraints();
match (value, unconstrained_type) {

(BamlValue::String(s), FieldType::Primitive(TypeValue::String)) => Ok(BamlValueWithMeta::String(s, field_type)),
(BamlValue::String(_), _) => anyhow::bail!("Could not unify Strinig with {:?}", field_type),

(BamlValue::Int(i), FieldType::Primitive(TypeValue::Int)) => Ok(BamlValueWithMeta::Int(i, field_type)),
(BamlValue::Int(_), _) => anyhow::bail!("Could not unify Int with {:?}", field_type),

(BamlValue::Float(f), FieldType::Primitive(TypeValue::Float)) => Ok(BamlValueWithMeta::Float(f, field_type)),
(BamlValue::Float(_), _) => anyhow::bail!("Could not unify Float with {:?}", field_type),

(BamlValue::Bool(b), FieldType::Primitive(TypeValue::Bool)) => Ok(BamlValueWithMeta::Bool(b, field_type)),
(BamlValue::Bool(_), _) => anyhow::bail!("Could not unify Bool with {:?}", field_type),

(BamlValue::Null, FieldType::Primitive(TypeValue::Null)) => Ok(BamlValueWithMeta::Null(field_type)),
(BamlValue::Null, _) => anyhow::bail!("Could not unify Null with {:?}", field_type),

(BamlValue::Map(pairs), FieldType::Map(k,val_type)) => {
let mapped_fields: BamlMap<String, BamlValueWithMeta<FieldType>> =
pairs
.into_iter()
.map(|(key, val)| {
let sub_value = self.distribute_type(val, *val_type.clone())?;
Ok((key, sub_value))
})
.collect::<anyhow::Result<BamlMap<String,BamlValueWithMeta<FieldType>>>>()?;
Ok(BamlValueWithMeta::Map( mapped_fields, field_type ))
},
(BamlValue::Map(_), _) => anyhow::bail!("Could not unify Map with {:?}", field_type),

(BamlValue::List(items), FieldType::List(item_type)) => {
let mapped_items: Vec<BamlValueWithMeta<FieldType>> =
items
.into_iter()
.map(|i| self.distribute_type(i, *item_type.clone()))
.collect::<anyhow::Result<Vec<_>>>()?;
Ok(BamlValueWithMeta::List(mapped_items, field_type))
}
(BamlValue::List(_), _) => anyhow::bail!("Could not unify List with {:?}", field_type),

(BamlValue::Media(m), FieldType::Primitive(TypeValue::Media(_))) => Ok(BamlValueWithMeta::Media(m, field_type)),
(BamlValue::Media(_), _) => anyhow::bail!("Could not unify Media with {:?}", field_type),

(BamlValue::Enum(name, val), FieldType::Enum(type_name)) => if name == *type_name {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
Err(anyhow::anyhow!("Could not unify Enum {name} with Enum type {type_name}"))
}
(BamlValue::Enum(enum_name,_), _) => anyhow::bail!("Could not unify Enum {enum_name} with {:?}", field_type),

(BamlValue::Class(name, fields), FieldType::Class(type_name)) => if name == *type_name {
let class_type = &self.find_class(type_name)?.item.elem;
let class_fields: BamlMap<String, FieldType> = class_type.static_fields.iter().map(|field_node| (field_node.elem.name.clone(), field_node.elem.r#type.elem.clone())).collect();
let mapped_fields = fields.into_iter().map(|(k,v)| {
let field_type = match class_fields.get(k.as_str()) {
Some(ft) => ft.clone(),
None => infer_type(&v),
};
let mapped_field = self.distribute_type(v, field_type)?;
Ok((k, mapped_field))
}).collect::<anyhow::Result<BamlMap<String, BamlValueWithMeta<FieldType>>>>()?;
Ok(BamlValueWithMeta::Class(name, mapped_fields, field_type))
} else {
Err(anyhow::anyhow!("Could not unify Class {name} with Class type {type_name}"))
}
(BamlValue::Class(class_name,_), _) => anyhow::bail!("Could not unify Class {class_name} with {:?}", field_type),

}
}
}


/// Derive the simplest type that can categorize a given value. This is meant to be used
/// by `distribute_type`, for dynamic fields of classes, whose types are not known statically.
pub fn infer_type<'a>(value: &'a BamlValue) -> FieldType {
match value {
BamlValue::Int(_) => FieldType::Primitive(TypeValue::Int),
BamlValue::Bool(_) => FieldType::Primitive(TypeValue::Bool),
BamlValue::Float(_) => FieldType::Primitive(TypeValue::Float),
BamlValue::String(_) => FieldType::Primitive(TypeValue::String),
BamlValue::Null => FieldType::Primitive(TypeValue::Null),
BamlValue::Map(pairs) => match pairs.iter().next() {
Some((k,v)) => FieldType::Map(Box::new(FieldType::Primitive(TypeValue::String)), Box::new( infer_type(v) )),
None => FieldType::Map(Box::new(FieldType::Primitive(TypeValue::String)), Box::new(FieldType::Primitive(TypeValue::Null)))
}
BamlValue::List(items) => match items.iter().next() {
Some(i) => infer_type(i),
None => FieldType::Primitive(TypeValue::Null),

},
BamlValue::Media(m) => FieldType::Primitive(TypeValue::Media(m.media_type)),
BamlValue::Enum(enum_name, _) => FieldType::Enum(enum_name.clone()),
BamlValue::Class(class_name, _) => FieldType::Class(class_name.clone()),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,8 @@ impl ScopeStack {
pub fn push_error(&mut self, error: String) {
self.scopes.last_mut().unwrap().errors.push(error);
}

pub fn push_warning(&mut self, warning: String) {
self.scopes.last_mut().unwrap().warnings.push(warning);
}
}
97 changes: 80 additions & 17 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use baml_types::{BamlMap, BamlMediaType, BamlValue, FieldType, LiteralValue, TypeValue};
use baml_types::{
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue
};
use core::result::Result;
use std::path::PathBuf;

use crate::ir::IntermediateRepr;

use super::{scope_diagnostics::ScopeStack, IRHelper};
use crate::ir::jinja_helpers::evaluate_predicate;

#[derive(Default)]
pub struct ParameterError {
Expand Down Expand Up @@ -39,8 +42,8 @@ impl ArgCoercer {
value: &BamlValue, // original value passed in by user
scope: &mut ScopeStack,
) -> Result<BamlValue, ()> {
match field_type {
FieldType::Primitive(t) => match t {
let value = match field_type.distribute_constraints() {
(FieldType::Primitive(t), _) => match t {
TypeValue::String if matches!(value, BamlValue::String(_)) => Ok(value.clone()),
TypeValue::String if self.allow_implicit_cast_to_string => match value {
BamlValue::Int(i) => Ok(BamlValue::String(i.to_string())),
Expand Down Expand Up @@ -170,7 +173,7 @@ impl ArgCoercer {
Err(())
}
},
FieldType::Enum(name) => match value {
(FieldType::Enum(name), _) => match value {
BamlValue::String(s) => {
if let Ok(e) = ir.find_enum(name) {
if e.walk_values().find(|v| v.item.elem.0 == *s).is_some() {
Expand Down Expand Up @@ -198,7 +201,7 @@ impl ArgCoercer {
Err(())
}
},
FieldType::Literal(literal) => Ok(match (literal, value) {
(FieldType::Literal(literal), _) => Ok(match (literal, value) {
(LiteralValue::Int(lit), BamlValue::Int(baml)) if lit == baml => value.clone(),
(LiteralValue::String(lit), BamlValue::String(baml)) if lit == baml => {
value.clone()
Expand All @@ -209,8 +212,8 @@ impl ArgCoercer {
return Err(());
}
}),
FieldType::Class(name) => match value {
BamlValue::Class(n, _) if n == name => return Ok(value.clone()),
(FieldType::Class(name), _) => match value {
BamlValue::Class(n, _) if n == name => Ok(value.clone()),
BamlValue::Class(_, obj) | BamlValue::Map(obj) => match ir.find_class(name) {
Ok(c) => {
let mut fields = BamlMap::new();
Expand Down Expand Up @@ -259,7 +262,7 @@ impl ArgCoercer {
Err(())
}
},
FieldType::List(item) => match value {
(FieldType::List(item), _) => match value {
BamlValue::List(arr) => {
let mut items = Vec::new();
for v in arr {
Expand All @@ -274,11 +277,11 @@ impl ArgCoercer {
Err(())
}
},
FieldType::Tuple(_) => {
(FieldType::Tuple(_), _) => {
scope.push_error(format!("Tuples are not yet supported"));
Err(())
}
FieldType::Map(k, v) => {
(FieldType::Map(k, v), _) => {
if let BamlValue::Map(kv) = value {
let mut map = BamlMap::new();
for (key, value) in kv {
Expand All @@ -300,18 +303,27 @@ impl ArgCoercer {
Err(())
}
}
FieldType::Union(options) => {
(FieldType::Union(options), _) => {
let mut first_good_result = Err(());
for option in options {
let mut scope = ScopeStack::new();
let result = self.coerce_arg(ir, option, value, &mut scope);
if !scope.has_errors() {
return result;
if first_good_result.is_err() {
let result = self.coerce_arg(ir, option, value, &mut scope);
if !scope.has_errors() {
if first_good_result.is_err() {
first_good_result = result
}
}
}
}
scope.push_error(format!("Expected one of {:?}, got `{}`", options, value));
Err(())
if first_good_result.is_err(){
scope.push_error(format!("Expected one of {:?}, got `{}`", options, value));
Err(())
} else {
first_good_result
}
}
FieldType::Optional(inner) => {
(FieldType::Optional(inner), _) => {
if matches!(value, BamlValue::Null) {
Ok(value.clone())
} else {
Expand All @@ -325,6 +337,57 @@ impl ArgCoercer {
}
}
}
(FieldType::Constrained { .. }, _) => {
unreachable!("The return value of distribute_constraints can never be FieldType::Constrainted");
}
}?;


let search_for_failures_result = first_failing_assert_nested(ir, &value, field_type).map_err(|e| {
scope.push_error(format!("Failed to evaluate assert: {:?}", e));
()
})?;
match search_for_failures_result {
Some(Constraint {label, expression, ..}) => {
let msg = label.as_ref().unwrap_or(&expression.0);
scope.push_error(format!("Failed assert: {msg}"));
Ok(value)
}
None => Ok(value)
}
}
}

/// Search a potentially deeply-nested `BamlValue` for any failing asserts,
/// returning the first one encountered.
fn first_failing_assert_nested<'a>(
ir: &'a IntermediateRepr,
baml_value: &BamlValue,
field_type: &'a FieldType
) -> anyhow::Result<Option<Constraint>> {
let value_with_types = ir.distribute_type(baml_value.clone(), field_type.clone())?;
let first_failure = value_with_types
.iter()
.map(|value_node| {
let (_, constraints) = value_node.meta().distribute_constraints();
constraints.into_iter().filter_map(|c| {
let constraint = c.clone();
let baml_value: BamlValue = value_node.into();
let result = evaluate_predicate(&&baml_value, &c.expression).map_err(|e| {
anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e))
});
match result {
Ok(false) => if c.level == ConstraintLevel::Assert {Some(Ok(constraint))} else { None },
Ok(true) => None,
Err(e) => Some(Err(e))

}
})
.collect::<Vec<_>>()
})
.map(|x| x.into_iter())
.flatten()
.next();
first_failure.transpose()

}
Loading

0 comments on commit d794f28

Please sign in to comment.