Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include type aliases in Jinja #1286

Merged
merged 5 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pub struct IntermediateRepr {
finite_recursive_cycles: Vec<IndexSet<String>>,

/// Type alias cycles introduced by lists and maps.
///
/// These are the only allowed cycles, because lists and maps introduce a
/// level of indirection that makes the cycle finite.
structural_recursive_alias_cycles: Vec<IndexMap<String, FieldType>>,

configuration: Configuration,
Expand Down Expand Up @@ -186,7 +189,7 @@ impl IntermediateRepr {
.collect(),
structural_recursive_alias_cycles: {
let mut recursive_aliases = vec![];
for cycle in db.structural_recursive_alias_cycles() {
for cycle in db.recursive_alias_cycles() {
let mut component = IndexMap::new();
for id in cycle {
let alias = &db.ast()[*id];
Expand Down Expand Up @@ -448,11 +451,7 @@ impl WithRepr<FieldType> for ast::FieldType {
}
}
Some(TypeWalker::TypeAlias(alias_walker)) => {
if db
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| cycle.contains(&alias_walker.id))
{
if db.is_recursive_type_alias(&alias_walker.id) {
FieldType::RecursiveTypeAlias(alias_walker.name().to_string())
} else {
alias_walker.resolved().to_owned().repr(db)?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
// We'll check type alias cycles first. Just like Typescript, cycles are
// allowed only for maps and lists. We'll call such cycles "structural
// recursion". Anything else like nulls or unions won't terminate a cycle.
let structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| {
let non_structural_type_aliases = HashMap::from_iter(ctx.db.walk_type_aliases().map(|alias| {
let mut dependencies = HashSet::new();
insert_required_alias_deps(alias.target(), ctx, &mut dependencies);

Expand All @@ -27,15 +27,17 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
// Based on the graph we've built with does not include the edges created
// by maps and lists, check the cycles and report them.
report_infinite_cycles(
&structural_type_aliases,
&non_structural_type_aliases,
ctx,
"These aliases form a dependency cycle",
);

// In order to avoid infinite recursion when resolving types for class
// dependencies below, we'll compute the cycles of aliases including maps
// and lists so that the recursion can be stopped before entering a cycle.
let complete_alias_cycles = Tarjan::components(ctx.db.type_alias_dependencies())
let complete_alias_cycles = ctx
.db
.recursive_alias_cycles()
.iter()
.flatten()
.copied()
Expand Down Expand Up @@ -139,29 +141,9 @@ fn insert_required_class_deps(
deps.insert(class.id);
}
Some(TypeWalker::TypeAlias(alias)) => {
// TODO: By the time this code runs we would ideally want
// type aliases to be resolved but we can't do that because
// type alias cycles are not validated yet, we have to
// do that in this file. Take a look at the `validate`
// function at `baml-lib/baml-core/src/lib.rs`.
//
// First we run the `ParserDatabase::validate` function
// which creates the alias graph by visiting all aliases.
// Then we run the `validate::validate` which ends up
// running this code here. Finally we run the
// `ParserDatabase::finalize` which is the place where we
// can resolve type aliases since we've already validated
// that there are no cycles so we won't run into infinite
// recursion. Ideally we want this:
//
// insert_required_deps(id, alias.resolved(), ctx, deps);

// But we'll run this instead which will follow all the
// alias pointers again until it finds the resolved type.
// We also have to stop recursion if we know the alias is
// part of a cycle.
// This code runs after aliases are already resolved.
if !alias_cycles.contains(&alias.id) {
insert_required_class_deps(id, alias.target(), ctx, deps, alias_cycles)
insert_required_class_deps(id, alias.resolved(), ctx, deps, alias_cycles)
}
}
_ => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub(super) fn validate(ctx: &mut Context<'_>) {
ctx.db.walk_templates().for_each(|t| {
t.add_to_types(&mut defined_types);
});
ctx.db.walk_type_aliases().for_each(|t| {
t.add_to_types(&mut defined_types);
});

for template in ctx.db.walk_templates() {
for args in template.walk_input_args() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
type ProjectId = int

function NormalAlias(pid: ProjectId) -> string {
client "openai/gpt-4o"
prompt #"Pid: {{ pid.id }}. Generate a fake name for it."#
}

type A = float
type B = A
type C = B

function LongerAlias(c: C) -> string {
client "openai/gpt-4o"
prompt #"{{ c.value }}"#
}

type JsonValue = int | string | bool | float | JsonObject | JsonArray
type JsonObject = map<string, JsonValue>
type JsonArray = JsonValue[]

function RecursiveAliases(j: JsonValue) -> string {
client "openai/gpt-4o"
prompt #"{{ j.value }}"#
}

type I = J
type J = I

function InvalidAlias(i: I) -> string {
client "openai/gpt-4o"
prompt #"{{ i.value }}"#
}

// warning: 'pid' is a type alias ProjectId (resolves to int), expected class
// --> class/type_aliases_jinja.baml:5
// |
// 4 | client "openai/gpt-4o"
// 5 | prompt #"Pid: {{ pid.id }}. Generate a fake name for it."#
// |
// warning: 'c' is a type alias C (resolves to float), expected class
// --> class/type_aliases_jinja.baml:14
// |
// 13 | client "openai/gpt-4o"
// 14 | prompt #"{{ c.value }}"#
// |
// warning: 'j' is a recursive type alias JsonValue, expected class
// --> class/type_aliases_jinja.baml:23
// |
// 22 | client "openai/gpt-4o"
// 23 | prompt #"{{ j.value }}"#
// |
// warning: 'i' is a recursive type alias I, expected class
// --> class/type_aliases_jinja.baml:31
// |
// 30 | client "openai/gpt-4o"
// 31 | prompt #"{{ i.value }}"#
// |
// error: Error validating: These aliases form a dependency cycle: I -> J
// --> class/type_aliases_jinja.baml:26
// |
// 25 |
// 26 | type I = J
// |
12 changes: 3 additions & 9 deletions engine/baml-lib/jinja/src/evaluate_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ impl TypeError {
} else {
// If there are multiple close names, suggest them all
let suggestions = close_names.join("`, `");
format!(
"Variable `{name}` does not exist. Did you mean one of these: `{suggestions}`?"
)
format!("Variable `{name}` does not exist. Did you mean one of these: `{suggestions}`?")
};

Self { message, span }
Expand Down Expand Up @@ -137,9 +135,7 @@ impl TypeError {

fn new_wrong_arg_count(func: &str, span: Span, expected: usize, got: usize) -> Self {
Self {
message: format!(
"Function '{func}' expects {expected} arguments, but got {got}"
),
message: format!("Function '{func}' expects {expected} arguments, but got {got}"),
span,
}
}
Expand Down Expand Up @@ -187,9 +183,7 @@ impl TypeError {
} else {
// If there are multiple close names, suggest them all
let suggestions = close_names.join("', '");
format!(
"Filter '{name}' does not exist. Did you mean one of these: '{suggestions}'?"
)
format!("Filter '{name}' does not exist. Did you mean one of these: '{suggestions}'?")
};

Self { message: format!("{message}\n\nSee: https://docs.rs/minijinja/latest/minijinja/filters/index.html#functions for the compelete list"), span }
Expand Down
30 changes: 30 additions & 0 deletions engine/baml-lib/jinja/src/evaluate_type/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ pub enum Type {
Both(Box<Type>, Box<Type>),
ClassRef(String),
FunctionRef(String),
/// TODO: This should be `AliasRef(String)` but functions like
/// [`Self::is_subtype_of`] or [`Self::bitor`] don't have access to the
/// [`PredefinedTypes`] instance, so we can't grab type resolutions from
/// there.
///
/// We'll just store all the necessary information in the type itself for
/// now.
Alias {
name: String,
target: Box<Type>,
resolved: Box<Type>,
},
/// TODO: This one could store the target so that we can report what it
/// points to instead of just the name.
RecursiveTypeAlias(String),
Image,
Audio,
}
Expand Down Expand Up @@ -92,6 +107,8 @@ impl Type {

(Type::ClassRef(_), _) => false,
(Type::FunctionRef(_), _) => false,
(Type::Alias { resolved, .. }, _) => resolved.is_subtype_of(other),
(Type::RecursiveTypeAlias(_), _) => false,
(Type::Image, _) => false,
(Type::Audio, _) => false,
(Type::String, _) => false,
Expand Down Expand Up @@ -147,6 +164,10 @@ impl Type {
Type::Both(l, r) => format!("{} & {}", l.name(), r.name()),
Type::ClassRef(name) => format!("class {name}"),
Type::FunctionRef(name) => format!("function {name}"),
Type::Alias { name, resolved, .. } => {
format!("type alias {name} (resolves to {})", resolved.name())
}
Type::RecursiveTypeAlias(name) => format!("recursive type alias {name}"),
Type::Image => "image".into(),
Type::Audio => "audio".into(),
}
Expand Down Expand Up @@ -219,6 +240,10 @@ enum Scope {
pub struct PredefinedTypes {
functions: HashMap<String, (Type, Vec<(String, Type)>)>,
classes: HashMap<String, HashMap<String, Type>>,
/// TODO: See the comment for [`Type::AliasRef`].
///
/// We should use this but we can't without a significant refactor.
aliases: HashMap<String, Type>,
// Variable name <--> Definition
variables: HashMap<String, Type>,
scopes: Vec<Scope>,
Expand Down Expand Up @@ -336,6 +361,7 @@ impl PredefinedTypes {
]),
JinjaContext::Parsing => Default::default(),
},
aliases: HashMap::new(),
scopes: Vec::new(),
errors: Vec::new(),
}
Expand Down Expand Up @@ -449,6 +475,10 @@ impl PredefinedTypes {
self.classes.insert(name.to_string(), fields);
}

pub fn add_alias(&mut self, name: &str, target: Type) {
self.aliases.insert(name.to_string(), target);
}

pub fn add_variable(&mut self, name: &str, t: Type) {
match self.scopes.last_mut() {
Some(Scope::Branch(true_vars, false_vars, branch_cond)) => {
Expand Down
26 changes: 6 additions & 20 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub use coerce_expression::{coerce, coerce_array, coerce_opt};
pub use internal_baml_schema_ast::ast;
use internal_baml_schema_ast::ast::{FieldType, SchemaAst, WithName};
pub use tarjan::Tarjan;
use types::resolve_type_alias;
pub use types::{
Attributes, ClientProperties, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType,
PromptAst, PromptVariable, RetryPolicy, RetryPolicyStrategy, StaticType,
Expand Down Expand Up @@ -117,6 +116,10 @@ impl ParserDatabase {
// Second pass: resolve top-level items and field types.
types::resolve_types(&mut ctx);

// Resolve type aliases now because Jinja template validation needs this
// information.
types::resolve_type_aliases(&mut ctx);

// Return early on type resolution errors.
ctx.diagnostics.to_result()?;

Expand All @@ -130,19 +133,6 @@ impl ParserDatabase {
}

fn finalize_dependencies(&mut self, diag: &mut Diagnostics) {
// Cycles left here after cycle validation are allowed. Basically lists
// and maps can introduce cycles.
self.types.structural_recursive_alias_cycles =
Tarjan::components(&self.types.type_alias_dependencies);

// Resolve type aliases.
// Cycles are already validated so this should not stack overflow and
// it should find the final type.
for alias_id in self.types.type_alias_dependencies.keys() {
let resolved = resolve_type_alias(&self.ast[*alias_id].value, &self);
self.types.resolved_type_aliases.insert(*alias_id, resolved);
}

// NOTE: Class dependency cycles are already checked at
// baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs
//
Expand Down Expand Up @@ -254,11 +244,7 @@ impl ParserDatabase {
// Add the resolved name itself to the deps.
collected_deps.insert(ident.name().to_owned());
// If the type is an alias then don't recurse.
if self
.structural_recursive_alias_cycles()
.iter()
.any(|cycle| cycle.contains(&walker.id))
{
if self.is_recursive_type_alias(&walker.id) {
None
} else {
Some(ident.name())
Expand Down Expand Up @@ -358,7 +344,7 @@ mod test {
let db = parse(baml)?;

assert_eq!(
db.structural_recursive_alias_cycles()
db.recursive_alias_cycles()
.iter()
.map(|ids| Vec::from_iter(ids.iter().map(|id| db.ast()[*id].name().to_string())))
.collect::<Vec<_>>(),
Expand Down
Loading
Loading