diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml new file mode 100644 index 00000000000..c829bb160b1 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "closures_mut_ref" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml b/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml new file mode 100644 index 00000000000..11497a473bc --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/Prover.toml @@ -0,0 +1 @@ +x = "0" diff --git a/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr new file mode 100644 index 00000000000..ae990e004fd --- /dev/null +++ b/crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr @@ -0,0 +1,20 @@ +use dep::std; + +fn main(mut x: Field) { + let one = 1; + let add1 = |z| { + *z = *z + one; + }; + + let two = 2; + let add2 = |z| { + *z = *z + two; + }; + + add1(&mut x); + assert(x == 1); + + add2(&mut x); + assert(x == 3); + +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml new file mode 100644 index 00000000000..3c2277e35a5 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_fn_selector" +authors = [""] +compiler_version = "0.8.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr new file mode 100644 index 00000000000..767cff0c409 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_fn_selector/src/main.nr @@ -0,0 +1,39 @@ +use dep::std; + +fn g(x: &mut Field) -> () { + *x *= 2; +} + +fn h(x: &mut Field) -> () { + *x *= 3; +} + +fn selector(flag: &mut bool) -> fn(&mut Field) -> () { + let my_func = if *flag { + g + } else { + h + }; + + // Flip the flag for the next function call + *flag = !(*flag); + my_func +} + +fn main() { + + let mut flag: bool = true; + + let mut x: Field = 100; + let returned_func = selector(&mut flag); + returned_func(&mut x); + + assert(x == 200); + + let mut y: Field = 100; + let returned_func2 = selector(&mut flag); + returned_func2(&mut y); + + assert(y == 300); + +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml new file mode 100644 index 00000000000..cf7526abc7f --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "higher_order_functions" +authors = [""] +compiler_version = "0.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml b/crates/nargo_cli/tests/test_data/higher_order_functions/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr new file mode 100644 index 00000000000..fefd23b7dbc --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/src/main.nr @@ -0,0 +1,87 @@ +use dep::std; + +fn main() -> pub Field { + let f = if 3 * 7 > 200 as u32 { foo } else { bar }; + assert(f()[1] == 2); + // Lambdas: + assert(twice(|x| x * 2, 5) == 20); + assert((|x, y| x + y + 1)(2, 3) == 6); + + // nested lambdas + assert((|a, b| { + a + (|c| c + 2)(b) + })(0, 1) == 3); + + + // Closures: + let a = 42; + let g = || a; + assert(g() == 42); + + // When you copy mutable variables, + // the capture of the copies shouldn't change: + let mut x = 2; + x = x + 1; + let z = x; + + // Add extra mutations to ensure we can mutate x without the + // captured z changing. + x = x + 1; + assert((|y| y + z)(1) == 4); + + // When you capture mutable variables, + // again, the captured variable doesn't change: + let closure_capturing_mutable = (|y| y + x); + assert(closure_capturing_mutable(1) == 5); + x += 1; + assert(closure_capturing_mutable(1) == 5); + + let ret = twice(add1, 3); + + test_array_functions(); + ret +} + +/// Test the array functions in std::array +fn test_array_functions() { + let myarray: [i32; 3] = [1, 2, 3]; + assert(myarray.any(|n| n > 2)); + + let evens: [i32; 3] = [2, 4, 6]; + assert(evens.all(|n| n > 1)); + + assert(evens.fold(0, |a, b| a + b) == 12); + assert(evens.reduce(|a, b| a + b) == 12); + + // TODO: is this a sort_via issue with the new backend, + // or something more general? + // + // currently it fails only with `--experimental-ssa` with + // "not yet implemented: Cast into signed" + // but it worked with the original ssa backend + // (before dropping it) + // + // opened #2121 for it + // https://github.com/noir-lang/noir/issues/2121 + + // let descending = myarray.sort_via(|a, b| a > b); + // assert(descending == [3, 2, 1]); + + assert(evens.map(|n| n / 2) == myarray); +} + +fn foo() -> [u32; 2] { + [1, 3] +} + +fn bar() -> [u32; 2] { + [3, 2] +} + +fn add1(x: Field) -> Field { + x + 1 +} + +fn twice(f: fn(Field) -> Field, x: Field) -> Field { + f(f(x)) +} diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json new file mode 100644 index 00000000000..c1233b8160b --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/target/c.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[],"param_witnesses":{},"return_type":null,"return_witnesses":[]},"bytecode":[155,194,56,97,194,4,0],"proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json b/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json new file mode 100644 index 00000000000..8d7a1566313 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/higher_order_functions/target/main.json @@ -0,0 +1 @@ +{"backend":"acvm-backend-barretenberg","abi":{"parameters":[{"name":"x","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"y","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"z","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"}],"param_witnesses":{"x":[1],"y":[2],"z":[3]},"return_type":null,"return_witnesses":[]},"bytecode":"H4sIAAAAAAAA/9WUTW6DMBSEJ/yFhoY26bYLjoAxBLPrVYpK7n+EgmoHamWXeShYQsYSvJ+Z9/kDwCf+1m58ArsXi3PgnUN7dt/u7P9fdi8fW8rlATduCW89GFe5l2iMES90YBd+EyTyjIjtGYIm+HF1eanroa0GpdV3WXW9acq66S9GGdWY5qcyWg+mNm3Xd23ZqVoP6tp0+moDJ5AxNOTUWdk6VUTsOSb6wtRPCuDYziaZAzGA92OMFCsAPCUqMAOcQg5gZwIb4BdsA+A9seeU6AtTPymAUzubZA7EAD6MMTKsAPCUqMAMcAY5gJ0JbIBfsQ2AD8SeM6IvTP2kAM7sbJI5EAP4OMbIsQLAU6ICM8A55AB2JrABfsM2AD4Se86Jvjy5freeQ2LPObGud6J+Ce5ADz6LzJqX9Z4W75HdgzszkQj0BC+Pr6PohSpl0kkg7hm84Zfq+8z36N/l9OyaLtcv2EfpKJUUAAA=","proving_key":null,"verification_key":null} \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr b/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr new file mode 100644 index 00000000000..a539f87a554 Binary files /dev/null and b/crates/nargo_cli/tests/test_data/higher_order_functions/target/witness.tr differ diff --git a/crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml b/crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml new file mode 100644 index 00000000000..1470053df2f --- /dev/null +++ b/crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "inner_outer_cl" +authors = [""] +compiler_version = "0.7.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr b/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr new file mode 100644 index 00000000000..ce847b56b93 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr @@ -0,0 +1,12 @@ +fn main() { + let z1 = 0; + let z2 = 1; + let cl_outer = |x| { + let cl_inner = |y| { + x + y + z2 + }; + cl_inner(1) + z1 + }; + let result = cl_outer(1); + assert(result == 3); +} diff --git a/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml new file mode 100644 index 00000000000..3e411b2849b --- /dev/null +++ b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "ret_fn_ret_cl" +authors = [""] +compiler_version = "0.7.1" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml new file mode 100644 index 00000000000..3a627b9188b --- /dev/null +++ b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml @@ -0,0 +1 @@ +x = "10" diff --git a/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr new file mode 100644 index 00000000000..d3a3346b541 --- /dev/null +++ b/crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr @@ -0,0 +1,39 @@ +use dep::std; + +fn f(x: Field) -> Field { + x + 1 +} + +fn ret_fn() -> fn(Field) -> Field { + f +} + +// TODO: in the advanced implicitly generic function with closures branch +// which would support higher-order functions in a better way +// support returning closures: +// +// fn ret_closure() -> fn(Field) -> Field { +// let y = 1; +// let inner_closure = |z| -> Field{ +// z + y +// }; +// inner_closure +// } + +fn ret_lambda() -> fn(Field) -> Field { + let cl = |z: Field| -> Field { + z + 1 + }; + cl +} + +fn main(x : Field) { + let result_fn = ret_fn(); + assert(result_fn(x) == x + 1); + + // let result_closure = ret_closure(); + // assert(result_closure(x) == x + 1); + + let result_lambda = ret_lambda(); + assert(result_lambda(x) == x + 1); +} diff --git a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs index 3e0bbff2a83..c3578e5ee7e 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -218,7 +218,7 @@ impl<'a> FunctionContext<'a> { } ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"), ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"), - ast::Type::Function(_, _) => Type::Function, + ast::Type::Function(_, _, _) => Type::Function, ast::Type::Slice(element) => { let element_types = Self::convert_type(element).flatten(); Type::Slice(Rc::new(element_types)) diff --git a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs index 76fbea289be..2beebf6871c 100644 --- a/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/crates/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -12,8 +12,8 @@ use crate::hir::type_check::{type_check_func, TypeChecker}; use crate::hir::Context; use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TypeAliasId}; use crate::{ - ExpressionKind, Generics, Ident, LetStatement, NoirFunction, NoirStruct, NoirTypeAlias, - ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, Literal, + ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct, + NoirTypeAlias, ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, }; use fm::FileId; use iter_extended::vecmap; diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 8b4f97dbd8e..681c853899f 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -12,10 +12,10 @@ // // XXX: Resolver does not check for unused functions use crate::hir_def::expr::{ - HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, - HirConstructorExpression, HirExpression, HirForExpression, HirIdent, HirIfExpression, - HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, - HirMethodCallExpression, HirPrefixExpression, + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, + HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, + HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }; use crate::token::Attribute; use regex::Regex; @@ -58,6 +58,13 @@ type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; +pub struct LambdaContext { + captures: Vec, + /// the index in the scope tree + /// (sometimes being filled by ScopeTree's find method) + scope_index: usize, +} + /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 /// definition in scope, and to convert the AST into the HIR. /// @@ -81,12 +88,10 @@ pub struct Resolver<'a> { /// were declared in. generics: Vec<(Rc, TypeVariable, Span)>, - /// Lambdas share the function scope of the function they're defined in, - /// so to identify whether they use any variables from the parent function - /// we keep track of the scope index a variable is declared in. When a lambda - /// is declared we push a scope and set this lambda_index to the scope index. - /// Any variable from a scope less than that must be from the parent function. - lambda_index: usize, + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, } /// ResolverMetas are tagged onto each definition to track how many times they are used @@ -112,7 +117,7 @@ impl<'a> Resolver<'a> { self_type: None, generics: Vec::new(), errors: Vec::new(), - lambda_index: 0, + lambda_stack: Vec::new(), file, } } @@ -125,10 +130,6 @@ impl<'a> Resolver<'a> { self.errors.push(err); } - fn current_lambda_index(&self) -> usize { - self.scopes.current_scope_index() - } - /// Resolving a function involves interning the metadata /// interning any statements inside of the function /// and interning the function itself @@ -279,25 +280,25 @@ impl<'a> Resolver<'a> { // // If a variable is not found, then an error is logged and a dummy id // is returned, for better error reporting UX - fn find_variable_or_default(&mut self, name: &Ident) -> HirIdent { + fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { self.find_variable(name).unwrap_or_else(|error| { self.push_err(error); let id = DefinitionId::dummy_id(); let location = Location::new(name.span(), self.file); - HirIdent { location, id } + (HirIdent { location, id }, 0) }) } - fn find_variable(&mut self, name: &Ident) -> Result { + fn find_variable(&mut self, name: &Ident) -> Result<(HirIdent, usize), ResolverError> { // Find the definition for this Ident let scope_tree = self.scopes.current_scope_tree(); let variable = scope_tree.find(&name.0.contents); let location = Location::new(name.span(), self.file); - if let Some((variable_found, _)) = variable { + if let Some((variable_found, scope)) = variable { variable_found.num_times_used += 1; let id = variable_found.ident.id; - Ok(HirIdent { location, id }) + Ok((HirIdent { location, id }, scope)) } else { Err(ResolverError::VariableNotDeclared { name: name.0.contents.clone(), @@ -363,7 +364,8 @@ impl<'a> Resolver<'a> { UnresolvedType::Function(args, ret) => { let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); - Type::Function(args, ret) + let env = Box::new(Type::Unit); + Type::Function(args, ret, env) } UnresolvedType::MutableReference(element) => { Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) @@ -517,24 +519,24 @@ impl<'a> Resolver<'a> { } } - fn get_ident_from_path(&mut self, path: Path) -> HirIdent { + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { let location = Location::new(path.span(), self.file); let error = match path.as_ident().map(|ident| self.find_variable(ident)) { - Some(Ok(ident)) => return ident, + Some(Ok(found)) => return found, // Try to look it up as a global, but still issue the first error if we fail Some(Err(error)) => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(_) => error, }, None => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(error) => error, }, }; self.push_err(error); let id = DefinitionId::dummy_id(); - HirIdent { location, id } + (HirIdent { location, id }, 0) } /// Translates an UnresolvedType to a Type @@ -705,7 +707,7 @@ impl<'a> Resolver<'a> { }); } - let mut typ = Type::Function(parameter_types, return_type); + let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit)); if !generics.is_empty() { typ = Type::Forall(generics, Box::new(typ)); @@ -837,12 +839,14 @@ impl<'a> Resolver<'a> { Self::find_numeric_generics_in_type(field, found); } } - Type::Function(parameters, return_type) => { + + Type::Function(parameters, return_type, _env) => { for parameter in parameters { Self::find_numeric_generics_in_type(parameter, found); } Self::find_numeric_generics_in_type(return_type, found); } + Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { if let Type::NamedGeneric(type_variable, name) = generic { @@ -915,7 +919,7 @@ impl<'a> Resolver<'a> { fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { match lvalue { LValue::Ident(ident) => { - HirLValue::Ident(self.find_variable_or_default(&ident), Type::Error) + HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error) } LValue::MemberAccess { object, field_name } => { let object = Box::new(self.resolve_lvalue(*object)); @@ -933,6 +937,39 @@ impl<'a> Resolver<'a> { } } + fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let pos = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if pos.is_none() { + self.lambda_stack[lambda_index] + .captures + .push(HirCapturedVar { ident: hir_ident, transitive_capture_index }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(pos.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )); + } + } + } + } + pub fn resolve_expression(&mut self, expr: Expression) -> ExprId { let hir_expr = match expr.kind { ExpressionKind::Literal(literal) => HirExpression::Literal(match literal { @@ -965,7 +1002,20 @@ impl<'a> Resolver<'a> { // Otherwise, then it is referring to an Identifier // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; // If the expression is a singular indent, we search the resolver's current scope as normal. - let hir_ident = self.get_ident_from_path(path); + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(_) => {} + DefinitionKind::Global(_) => {} + DefinitionKind::GenericType(_) => {} + // We ignore the above definition kinds because only local variables can be captured by closures. + DefinitionKind::Local(_) => { + self.resolve_local_variable(hir_ident, var_scope_index); + } + } + } + HirExpression::Ident(hir_ident) } ExpressionKind::Prefix(prefix) => { @@ -1087,8 +1137,9 @@ impl<'a> Resolver<'a> { // We must stay in the same function scope as the parent function to allow for closures // to capture variables. This is currently limited to immutable variables. ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| { - let new_index = this.current_lambda_index(); - let old_index = std::mem::replace(&mut this.lambda_index, new_index); + let scope_index = this.scopes.current_scope_index(); + + this.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let parameters = vecmap(lambda.parameters, |(pattern, typ)| { let parameter = DefinitionKind::Local(None); @@ -1098,8 +1149,14 @@ impl<'a> Resolver<'a> { let return_type = this.resolve_inferred_type(lambda.return_type); let body = this.resolve_expression(lambda.body); - this.lambda_index = old_index; - HirExpression::Lambda(HirLambda { parameters, return_type, body }) + let lambda_context = this.lambda_stack.pop().unwrap(); + + HirExpression::Lambda(HirLambda { + parameters, + return_type, + body, + captures: lambda_context.captures, + }) }), }; @@ -1411,6 +1468,7 @@ pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result< #[cfg(test)] mod test { + use core::panic; use std::collections::HashMap; use fm::FileId; @@ -1419,10 +1477,14 @@ mod test { use crate::hir::def_map::{ModuleData, ModuleId, ModuleOrigin}; use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::import::PathResolutionError; + use crate::hir::resolution::resolver::StmtId; use crate::graph::CrateId; + use crate::hir_def::expr::HirExpression; use crate::hir_def::function::HirFunction; + use crate::hir_def::stmt::HirStatement; use crate::node_interner::{FuncId, NodeInterner}; + use crate::ParsedModule; use crate::{ hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId}, parse_program, Path, @@ -1432,29 +1494,24 @@ mod test { // func_namespace is used to emulate the fact that functions can be imported // and functions can be forward declared - fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + fn init_src_code_resolution( + src: &str, + ) -> (ParsedModule, NodeInterner, HashMap, FileId, TestPathResolver) { let (program, errors) = parse_program(src); - assert!(errors.is_empty()); - - let mut interner = NodeInterner::default(); - - let func_ids = vecmap(&func_namespace, |name| { - let id = interner.push_fn(HirFunction::empty()); - interner.push_function_definition(name.to_string(), id); - id - }); - - let mut path_resolver = TestPathResolver(HashMap::new()); - for (name, id) in func_namespace.into_iter().zip(func_ids) { - path_resolver.insert_func(name.to_owned(), id); + if !errors.is_empty() { + panic!("Unexpected parse errors in test code: {:?}", errors); } + let interner: NodeInterner = NodeInterner::default(); + let mut def_maps: HashMap = HashMap::new(); let file = FileId::default(); let mut modules = arena::Arena::new(); modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + let path_resolver = TestPathResolver(HashMap::new()); + def_maps.insert( CrateId::dummy_id(), CrateDefMap { @@ -1465,10 +1522,30 @@ mod test { }, ); + (program, interner, def_maps, file, path_resolver) + } + + // func_namespace is used to emulate the fact that functions can be imported + // and functions can be forward declared + fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.to_string(), id); + id + }); + + for (name, id) in func_namespace.into_iter().zip(func_ids) { + path_resolver.insert_func(name.to_owned(), id); + } + let mut errors = Vec::new(); for func in program.functions { let id = interner.push_fn(HirFunction::empty()); interner.push_function_definition(func.name().to_string(), id); + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); let (_, _, err) = resolver.resolve_function(func, id, ModuleId::dummy_id()); errors.extend(err); @@ -1477,6 +1554,81 @@ mod test { errors } + fn get_program_captures(src: &str) -> Vec> { + let (program, mut interner, def_maps, file, mut path_resolver) = + init_src_code_resolution(src); + + let mut all_captures: Vec> = Vec::new(); + for func in program.functions { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(func.name().clone().to_string(), id); + path_resolver.insert_func(func.name().to_owned(), id); + + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, _, _) = resolver.resolve_function(func, id, ModuleId::dummy_id()); + + // Iterate over function statements and apply filtering function + parse_statement_blocks( + hir_func.block(&interner).statements(), + &interner, + &mut all_captures, + ); + } + all_captures + } + + fn parse_statement_blocks( + stmts: &[StmtId], + interner: &NodeInterner, + result: &mut Vec>, + ) { + let mut expr: HirExpression; + + for stmt_id in stmts.iter() { + let hir_stmt = interner.statement(stmt_id); + match hir_stmt { + HirStatement::Expression(expr_id) => { + expr = interner.expression(&expr_id); + } + HirStatement::Let(let_stmt) => { + expr = interner.expression(&let_stmt.expression); + } + HirStatement::Assign(assign_stmt) => { + expr = interner.expression(&assign_stmt.expression); + } + HirStatement::Constrain(constr_stmt) => { + expr = interner.expression(&constr_stmt.0); + } + HirStatement::Semi(semi_expr) => { + expr = interner.expression(&semi_expr); + } + HirStatement::Error => panic!("Invalid HirStatement!"), + } + get_lambda_captures(expr, &interner, result); // TODO: dyn filter function as parameter + } + } + + fn get_lambda_captures( + expr: HirExpression, + interner: &NodeInterner, + result: &mut Vec>, + ) { + if let HirExpression::Lambda(lambda_expr) = expr { + let mut cur_capture = Vec::new(); + + for capture in lambda_expr.captures.iter() { + cur_capture.push(interner.definition(capture.ident.id).name.clone()); + } + result.push(cur_capture); + + // Check for other captures recursively within the lambda body + let hir_body_expr = interner.expression(&lambda_expr.body); + if let HirExpression::Block(block_expr) = hir_body_expr.clone() { + parse_statement_blocks(block_expr.statements(), interner, result); + } + } + } + #[test] fn resolve_empty_function() { let src = " @@ -1656,9 +1808,103 @@ mod test { x } "#; + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + } + + #[test] + fn resolve_basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + panic!("Unexpected errors: {:?}", errors); + } + } + + #[test] + fn resolve_simplified_closure() { + // based on bug https://github.com/noir-lang/noir/issues/1088 + + let src = r#"fn do_closure(x: Field) -> Field { + let y = x; + let ret_capture = || { + y + }; + ret_capture() + } + + fn main(x: Field) { + assert(do_closure(x) == 100); + } + + "#; + let parsed_captures = get_program_captures(src); + let mut expected_captures = vec![]; + expected_captures.push(vec!["y".to_string()]); + assert_eq!(expected_captures, parsed_captures); + } + + #[test] + fn resolve_complex_closures() { + let src = r#" + fn main(x: Field) -> pub Field { + let closure_without_captures = |x| x + x; + let a = closure_without_captures(1); + + let closure_capturing_a_param = |y| y + x; + let b = closure_capturing_a_param(2); + + let closure_capturing_a_local_var = |y| y + b; + let c = closure_capturing_a_local_var(3); + + let closure_with_transitive_captures = |y| { + let d = 5; + let nested_closure = |z| { + let doubly_nested_closure = |w| w + x + b; + a + z + y + d + x + doubly_nested_closure(4) + x + y + }; + let res = nested_closure(5); + res + }; + + a + b + c + closure_with_transitive_captures(6) + } + "#; let errors = resolve_src_code(src, vec!["main", "foo"]); assert!(errors.is_empty()); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + + let expected_captures = vec![ + vec![], + vec!["x".to_string()], + vec!["b".to_string()], + vec!["x".to_string(), "b".to_string(), "a".to_string()], + vec![ + "x".to_string(), + "b".to_string(), + "a".to_string(), + "y".to_string(), + "d".to_string(), + ], + vec!["x".to_string(), "b".to_string()], + ]; + + let parsed_captures = get_program_captures(src); + + assert_eq!(expected_captures, parsed_captures); } #[test] @@ -1694,6 +1940,9 @@ mod test { } } + // possible TODO: Create a more sophisticated set of search functions over the HIR, so we can check + // that the correct variables are captured in each closure + fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 24ac5f3443e..6c111a1d6a0 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -279,6 +279,12 @@ impl<'interner> TypeChecker<'interner> { Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) } HirExpression::Lambda(lambda) => { + let captured_vars = + vecmap(lambda.captures, |capture| self.interner.id_type(capture.ident.id)); + + let env_type: Type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; + let params = vecmap(lambda.parameters, |(pattern, typ)| { self.bind_pattern(&pattern, typ.clone()); typ @@ -294,7 +300,8 @@ impl<'interner> TypeChecker<'interner> { expr_span: span, } }); - Type::Function(params, Box::new(lambda.return_type)) + + Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } }; @@ -319,9 +326,9 @@ impl<'interner> TypeChecker<'interner> { argument_types: &mut [(Type, ExprId, noirc_errors::Span)], ) { let expected_object_type = match function_type { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(args, _) => args.get(0), + Type::Function(args, _, _) => args.get(0), typ => unreachable!("Unexpected type for function: {typ}"), }, typ => unreachable!("Unexpected type for function: {typ}"), @@ -870,6 +877,35 @@ impl<'interner> TypeChecker<'interner> { } } + fn bind_function_type_impl( + &mut self, + fn_params: &Vec, + fn_ret: &Type, + callsite_args: &Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + if fn_params.len() != callsite_args.len() { + self.errors.push(TypeCheckError::ParameterCountMismatch { + expected: fn_params.len(), + found: callsite_args.len(), + span, + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { + arg.make_subtype_of(param, *arg_span, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + } + }); + } + + fn_ret.clone() + } + fn bind_function_type( &mut self, function: Type, @@ -886,38 +922,17 @@ impl<'interner> TypeChecker<'interner> { let ret = self.interner.next_type_variable(); let args = vecmap(args, |(arg, _, _)| arg); - let expected = Type::Function(args, Box::new(ret.clone())); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); if let Err(error) = binding.borrow_mut().bind_to(expected, span) { self.errors.push(error); } ret } - Type::Function(parameters, ret) => { - if parameters.len() != args.len() { - self.errors.push(TypeCheckError::ParameterCountMismatch { - expected: parameters.len(), - found: args.len(), - span, - }); - return Type::Error; - } - - for (param, (arg, arg_id, arg_span)) in parameters.iter().zip(args) { - arg.make_subtype_with_coercions( - param, - arg_id, - self.interner, - &mut self.errors, - || TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: arg_span, - }, - ); - } - - *ret + Type::Function(parameters, ret, _env) => { + // ignoring env for subtype on purpose + self.bind_function_type_impl(parameters.as_ref(), ret.as_ref(), args.as_ref(), span) } Type::Error => Type::Error, found => { diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index 26d0e36abf9..1883c0abf62 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -152,6 +152,7 @@ impl<'interner> TypeChecker<'interner> { #[cfg(test)] mod test { use std::collections::HashMap; + use std::vec; use fm::FileId; use iter_extended::vecmap; @@ -245,7 +246,11 @@ mod test { contract_function_type: None, is_internal: None, is_unconstrained: false, - typ: Type::Function(vec![Type::field(None), Type::field(None)], Box::new(Type::Unit)), + typ: Type::Function( + vec![Type::field(None), Type::field(None)], + Box::new(Type::Unit), + Box::new(Type::Unit), + ), parameters: vec![ Param(Identifier(x), Type::field(None), noirc_abi::AbiVisibility::Private), Param(Identifier(y), Type::field(None), noirc_abi::AbiVisibility::Private), @@ -314,7 +319,29 @@ mod test { type_check_src_code(src, vec![String::from("main"), String::from("foo")]); } + #[test] + fn basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + } + #[test] + fn closure_with_no_args() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = || x; + closure() + } + "#; + + type_check_src_code(src, vec![String::from("main")]); + } // This is the same Stub that is in the resolver, maybe we can pull this out into a test module and re-use? struct TestPathResolver(HashMap); diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index db7db0a803d..fd980328f5f 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -197,9 +197,25 @@ impl HirBlockExpression { } } +/// A variable captured inside a closure +#[derive(Debug, Clone)] +pub struct HirCapturedVar { + pub ident: HirIdent, + + /// This will be None when the capture refers to a local variable declared + /// in the same scope as the closure. In a closure-inside-another-closure + /// scenarios, we might have a transitive captures of variables that must + /// be propagated during the construction of each closure. In this case, + /// we store the index of the captured variable in the environment of our + /// direct parent closure. We do this in order to simplify the HIR to AST + /// transformation in the monomorphization pass. + pub transitive_capture_index: Option, +} + #[derive(Debug, Clone)] pub struct HirLambda { pub parameters: Vec<(HirPattern, Type)>, pub return_type: Type, pub body: ExprId, + pub captures: Vec, } diff --git a/crates/noirc_frontend/src/hir_def/function.rs b/crates/noirc_frontend/src/hir_def/function.rs index a69e8bb08b5..225731626f0 100644 --- a/crates/noirc_frontend/src/hir_def/function.rs +++ b/crates/noirc_frontend/src/hir_def/function.rs @@ -180,9 +180,9 @@ impl FuncMeta { /// Gives the (uninstantiated) return type of this function. pub fn return_type(&self) -> &Type { match &self.typ { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, Type::Forall(_, typ) => match typ.as_ref() { - Type::Function(_, ret) => ret, + Type::Function(_, ret, _env) => ret, _ => unreachable!(), }, _ => unreachable!(), diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index ff0a4e53fae..d77b8033ba1 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -70,8 +70,11 @@ pub enum Type { /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. NamedGeneric(TypeVariable, Rc), - /// A functions with arguments, and a return type. - Function(Vec, Box), + /// A functions with arguments, a return type and environment. + /// the environment should be `Unit` by default, + /// for closures it should contain a `Tuple` type with the captured + /// variable types. + Function(Vec, Box, Box), /// &mut T MutableReference(Box), @@ -697,9 +700,10 @@ impl Type { Type::Tuple(fields) => { fields.iter().any(|field| field.contains_numeric_typevar(target_id)) } - Type::Function(parameters, return_type) => { + Type::Function(parameters, return_type, env) => { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) + || env.contains_numeric_typevar(target_id) } Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { @@ -797,9 +801,15 @@ impl std::fmt::Display for Type { let typevars = vecmap(typevars, |(var, _)| var.to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) } - Type::Function(args, ret) => { - let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) + Type::Function(args, ret, env) => { + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), + }; + + let args = vecmap(args.iter(), ToString::to_string); + + write!(f, "fn({}) -> {ret}{closure_env_text}", args.join(", ")) } Type::MutableReference(element) => { write!(f, "&mut {element}") @@ -1196,9 +1206,9 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { if params_a.len() == params_b.len() { - for (a, b) in params_a.iter().zip(params_b) { + for (a, b) in params_a.iter().zip(params_b.iter()) { a.try_unify(b, span)?; } @@ -1403,7 +1413,7 @@ impl Type { } } - (Function(params_a, ret_a), Function(params_b, ret_b)) => { + (Function(params_a, ret_a, _env_a), Function(params_b, ret_b, _env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b) { a.is_subtype_of(b, span)?; @@ -1505,7 +1515,7 @@ impl Type { Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), - Type::Function(_, _) => unreachable!(), + Type::Function(_, _, _) => unreachable!(), Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), Type::NotConstant => unreachable!(), } @@ -1620,10 +1630,11 @@ impl Type { let typ = Box::new(typ.substitute(type_bindings)); Type::Forall(typevars.clone(), typ) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, |arg| arg.substitute(type_bindings)); let ret = Box::new(ret.substitute(type_bindings)); - Type::Function(args, ret) + let env = Box::new(env.substitute(type_bindings)); + Type::Function(args, ret, env) } Type::MutableReference(element) => { Type::MutableReference(Box::new(element.substitute(type_bindings))) @@ -1660,8 +1671,10 @@ impl Type { Type::Forall(typevars, typ) => { !typevars.iter().any(|(id, _)| *id == target_id) && typ.occurs(target_id) } - Type::Function(args, ret) => { - args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) + Type::Function(args, ret, env) => { + args.iter().any(|arg| arg.occurs(target_id)) + || ret.occurs(target_id) + || env.occurs(target_id) } Type::MutableReference(element) => element.occurs(target_id), @@ -1706,11 +1719,13 @@ impl Type { self.clone() } - Function(args, ret) => { + Function(args, ret, env) => { let args = vecmap(args, |arg| arg.follow_bindings()); let ret = Box::new(ret.follow_bindings()); - Function(args, ret) + let env = Box::new(env.follow_bindings()); + Function(args, ret, env) } + MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), // Expect that this function should only be called on instantiated types @@ -1751,7 +1766,10 @@ fn convert_array_expression_to_slice( interner.push_expr_location(func, location.span, location.file); interner.push_expr_type(&call, target_type.clone()); - interner.push_expr_type(&func, Type::Function(vec![array_type], Box::new(target_type))); + interner.push_expr_type( + &func, + Type::Function(vec![array_type], Box::new(target_type), Box::new(Type::Unit)), + ); } impl BinaryTypeOperator { diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 7ad05f09231..33c3bbebff4 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -29,7 +29,6 @@ pub enum Expression { Tuple(Vec), ExtractTupleField(Box, usize), Call(Call), - Let(Let), Constrain(Box, Location), Assign(Assign), @@ -103,6 +102,12 @@ pub struct Binary { pub location: Location, } +#[derive(Debug, Clone)] +pub struct Lambda { + pub function: Ident, + pub env: Ident, +} + #[derive(Debug, Clone)] pub struct If { pub condition: Box, @@ -213,7 +218,7 @@ pub enum Type { Tuple(Vec), Slice(Box), MutableReference(Box), - Function(/*args:*/ Vec, /*ret:*/ Box), + Function(/*args:*/ Vec, /*ret:*/ Box, /*env:*/ Box), } impl Type { @@ -324,9 +329,13 @@ impl std::fmt::Display for Type { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) } - Type::Function(args, ret) => { + Type::Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {}", args.join(", "), ret) + let closure_env_text = match **env { + Type::Unit => "".to_string(), + _ => format!(" with closure environment {env}"), + }; + write!(f, "fn({}) -> {}{}", args.join(", "), ret, closure_env_text) } Type::Slice(element) => write!(f, "[{element}"), Type::MutableReference(element) => write!(f, "&mut {element}"), diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index dbe2ee080bf..c8167baf6bb 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -19,6 +19,7 @@ use crate::{ expr::*, function::{FuncMeta, Param, Parameters}, stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, + types, }, node_interner::{self, DefinitionKind, NodeInterner, StmtId}, token::Attribute, @@ -30,6 +31,11 @@ use self::ast::{Definition, FuncId, Function, LocalId, Program}; pub mod ast; pub mod printer; +struct LambdaContext { + env_ident: Box, + captures: Vec, +} + /// The context struct for the monomorphization pass. /// /// This struct holds the FIFO queue of functions to monomorphize, which is added to @@ -58,6 +64,8 @@ struct Monomorphizer<'interner> { /// Used to reference existing definitions in the HIR interner: &'interner NodeInterner, + lambda_envs_stack: Vec, + next_local_id: u32, next_function_id: u32, } @@ -103,6 +111,7 @@ impl<'interner> Monomorphizer<'interner> { next_local_id: 0, next_function_id: 0, interner, + lambda_envs_stack: Vec::new(), } } @@ -348,7 +357,7 @@ impl<'interner> Monomorphizer<'interner> { } HirExpression::Constructor(constructor) => self.constructor(constructor, expr), - HirExpression::Lambda(lambda) => self.lambda(lambda), + HirExpression::Lambda(lambda) => self.lambda(lambda, expr), HirExpression::MethodCall(_) => { unreachable!("Encountered HirExpression::MethodCall during monomorphization") @@ -541,6 +550,15 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Block(definitions) } + /// Find a captured variable in the innermost closure + fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option { + let ctx = self.lambda_envs_stack.last()?; + ctx.captures + .iter() + .position(|capture| capture.ident.id == id) + .map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index)) + } + /// A local (ie non-global) ident only fn local_ident(&mut self, ident: &HirIdent) -> Option { let definition = self.interner.definition(ident.id); @@ -564,14 +582,25 @@ impl<'interner> Monomorphizer<'interner> { let definition = self.lookup_function(*func_id, expr_id, &typ); let typ = Self::convert_type(&typ); - let ident = ast::Ident { location, mutable, definition, name, typ }; - ast::Expression::Ident(ident) + let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; + let ident_expression = ast::Expression::Ident(ident); + if self.is_function_closure_type(&typ) { + ast::Expression::Tuple(vec![ + ast::Expression::ExtractTupleField( + Box::new(ident_expression.clone()), + 0usize, + ), + ast::Expression::ExtractTupleField(Box::new(ident_expression), 1usize), + ]) + } else { + ident_expression + } } DefinitionKind::Global(expr_id) => self.expr(*expr_id), - DefinitionKind::Local(_) => { + DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { let ident = self.local_ident(&ident).unwrap(); ast::Expression::Ident(ident) - } + }), DefinitionKind::GenericType(type_variable) => { let value = match &*type_variable.borrow() { TypeBinding::Unbound(_) => { @@ -657,10 +686,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) } - HirType::Function(args, ret) => { + HirType::Function(args, ret, env) => { let args = vecmap(args, Self::convert_type); let ret = Box::new(Self::convert_type(ret)); - ast::Type::Function(args, ret) + let env = Box::new(Self::convert_type(env)); + ast::Type::Function(args, ret, env) } HirType::MutableReference(element) => { @@ -677,19 +707,44 @@ impl<'interner> Monomorphizer<'interner> { } } + fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool { + let t = Self::convert_type(&self.interner.id_type(raw_func_id)); + if self.is_function_closure_type(&t) { + true + } else if let ast::Type::Tuple(elements) = t { + if elements.len() == 2 { + matches!(elements[1], ast::Type::Function(_, _, _)) + } else { + false + } + } else { + false + } + } + + fn is_function_closure_type(&self, t: &ast::Type) -> bool { + if let ast::Type::Function(_, _, env) = t { + let e = (*env).clone(); + matches!(*e, ast::Type::Tuple(_captures)) + } else { + false + } + } + fn function_call( &mut self, call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let func = Box::new(self.expr(call.func)); + let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); + let func: Box; let return_type = self.interner.id_type(id); let return_type = Self::convert_type(&return_type); let location = call.location; - if let ast::Expression::Ident(ident) = func.as_ref() { + if let ast::Expression::Ident(ident) = original_func.as_ref() { if let Definition::Oracle(name) = &ident.definition { if name.as_str() == "println" { // Oracle calls are required to be wrapped in an unconstrained function @@ -699,12 +754,39 @@ impl<'interner> Monomorphizer<'interner> { } } - self.try_evaluate_call(&func, &return_type).unwrap_or(ast::Expression::Call(ast::Call { - func, - arguments, - return_type, - location, - })) + let mut block_expressions = vec![]; + + let is_closure = self.is_function_closure(call.func); + if is_closure { + let extracted_func: ast::Expression; + let hir_call_func = self.interner.expression(&call.func); + if let HirExpression::Lambda(l) = hir_call_func { + let (setup, closure_variable) = self.lambda_with_setup(l, call.func); + block_expressions.push(setup); + extracted_func = closure_variable; + } else { + extracted_func = *original_func; + } + func = Box::new(ast::Expression::ExtractTupleField( + Box::new(extracted_func.clone()), + 1usize, + )); + let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + arguments.insert(0, env_argument); + } else { + func = original_func.clone(); + }; + + let call = self + .try_evaluate_call(&func, &return_type) + .unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location })); + + if !block_expressions.is_empty() { + block_expressions.push(call); + ast::Expression::Block(block_expressions) + } else { + call + } } /// Adds a function argument that contains type metadata that is required to tell @@ -914,7 +996,16 @@ impl<'interner> Monomorphizer<'interner> { } } - fn lambda(&mut self, lambda: HirLambda) -> ast::Expression { + fn lambda(&mut self, lambda: HirLambda, expr: node_interner::ExprId) -> ast::Expression { + if lambda.captures.is_empty() { + self.lambda_no_capture(lambda) + } else { + let (setup, closure_variable) = self.lambda_with_setup(lambda, expr); + ast::Expression::Block(vec![setup, closure_variable]) + } + } + + fn lambda_no_capture(&mut self, lambda: HirLambda) -> ast::Expression { let ret_type = Self::convert_type(&lambda.return_type); let lambda_name = "lambda"; let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); @@ -935,7 +1026,8 @@ impl<'interner> Monomorphizer<'interner> { let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; self.push_function(id, function); - let typ = ast::Type::Function(parameter_types, Box::new(ret_type)); + let typ = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(ast::Type::Unit)); let name = lambda_name.to_owned(); ast::Expression::Ident(ast::Ident { @@ -947,6 +1039,133 @@ impl<'interner> Monomorphizer<'interner> { }) } + fn lambda_with_setup( + &mut self, + lambda: HirLambda, + expr: node_interner::ExprId, + ) -> (ast::Expression, ast::Expression) { + // returns (, ) + // which can be used directly in callsites or transformed + // directly to a single `Expression` + // for other cases by `lambda` which is called by `expr` + // + // it solves the problem of detecting special cases where + // we call something like + // `{let env$.. = ..;}.1({let env$.. = ..;}.0, ..)` + // which was leading to redefinition errors + // + // instead of detecting and extracting + // patterns in the resulting tree, + // which seems more fragile, we directly reuse the return parameters + // of this function in those cases + let ret_type = Self::convert_type(&lambda.return_type); + let lambda_name = "lambda"; + let parameter_types = vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ)); + + // Manually convert to Parameters type so we can reuse the self.parameters method + let parameters = Parameters(vecmap(lambda.parameters, |(pattern, typ)| { + Param(pattern, typ, noirc_abi::AbiVisibility::Private) + })); + + let mut converted_parameters = self.parameters(parameters); + + let id = self.next_function_id(); + let name = lambda_name.to_owned(); + let return_type = ret_type.clone(); + + let env_local_id = self.next_local_id(); + let env_name = "env"; + let env_tuple = ast::Expression::Tuple(vecmap(&lambda.captures, |capture| { + match capture.transitive_capture_index { + Some(field_index) => match self.lambda_envs_stack.last() { + Some(lambda_ctx) => ast::Expression::ExtractTupleField( + lambda_ctx.env_ident.clone(), + field_index, + ), + None => unreachable!( + "Expected to find a parent closure environment, but found none" + ), + }, + None => { + let ident = self.local_ident(&capture.ident).unwrap(); + ast::Expression::Ident(ident) + } + } + })); + let expr_type = self.interner.id_type(expr); + let env_typ = if let types::Type::Function(_, _, function_env_type) = expr_type { + Self::convert_type(&function_env_type) + } else { + unreachable!("expected a Function type for a Lambda node") + }; + + let env_let_stmt = ast::Expression::Let(ast::Let { + id: env_local_id, + mutable: false, + name: env_name.to_string(), + expression: Box::new(env_tuple), + }); + + let location = None; // TODO: This should match the location of the lambda expression + let mutable = false; + let definition = Definition::Local(env_local_id); + + let env_ident = ast::Expression::Ident(ast::Ident { + location, + mutable, + definition, + name: env_name.to_string(), + typ: env_typ.clone(), + }); + + self.lambda_envs_stack.push(LambdaContext { + env_ident: Box::new(env_ident.clone()), + captures: lambda.captures, + }); + let body = self.expr(lambda.body); + self.lambda_envs_stack.pop(); + + let lambda_fn_typ: ast::Type = + ast::Type::Function(parameter_types, Box::new(ret_type), Box::new(env_typ.clone())); + let lambda_fn = ast::Expression::Ident(ast::Ident { + definition: Definition::Function(id), + mutable: false, + location: None, // TODO: This should match the location of the lambda expression + name: name.clone(), + typ: lambda_fn_typ.clone(), + }); + + let mut parameters = vec![]; + parameters.push((env_local_id, true, env_name.to_string(), env_typ.clone())); + parameters.append(&mut converted_parameters); + + let unconstrained = false; + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); + let block_local_id = self.next_local_id(); + let block_ident_name = "closure_variable"; + let block_let_stmt = ast::Expression::Let(ast::Let { + id: block_local_id, + mutable: false, + name: block_ident_name.to_string(), + expression: Box::new(ast::Expression::Block(vec![env_let_stmt, lambda_value])), + }); + + let closure_definition = Definition::Local(block_local_id); + + let closure_ident = ast::Expression::Ident(ast::Ident { + location, + mutable: false, + definition: closure_definition, + name: block_ident_name.to_string(), + typ: ast::Type::Tuple(vec![env_typ, lambda_fn_typ]), + }); + + (block_let_stmt, closure_ident) + } + /// Implements std::unsafe::zeroed by returning an appropriate zeroed /// ast literal or collection node for the given type. Note that for functions /// there is no obvious zeroed value so this should be considered unsafe to use. @@ -984,8 +1203,8 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Tuple(fields) => { ast::Expression::Tuple(vecmap(fields, |field| self.zeroed_value_of_type(field))) } - ast::Type::Function(parameter_types, ret_type) => { - self.create_zeroed_function(parameter_types, ret_type) + ast::Type::Function(parameter_types, ret_type, env) => { + self.create_zeroed_function(parameter_types, ret_type, env) } ast::Type::Slice(element_type) => { ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { @@ -1012,6 +1231,7 @@ impl<'interner> Monomorphizer<'interner> { &mut self, parameter_types: &[ast::Type], ret_type: &ast::Type, + env_type: &ast::Type, ) -> ast::Expression { let lambda_name = "zeroed_lambda"; @@ -1034,7 +1254,11 @@ impl<'interner> Monomorphizer<'interner> { mutable: false, location: None, name: lambda_name.to_owned(), - typ: ast::Type::Function(parameter_types.to_owned(), Box::new(ret_type.clone())), + typ: ast::Type::Function( + parameter_types.to_owned(), + Box::new(ret_type.clone()), + Box::new(env_type.clone()), + ), }) } } @@ -1072,3 +1296,167 @@ fn undo_instantiation_bindings(bindings: TypeBindings) { *var.borrow_mut() = TypeBinding::Unbound(id); } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use fm::FileId; + use iter_extended::vecmap; + + use crate::{ + graph::CrateId, + hir::{ + def_map::{ + CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId, ModuleOrigin, + }, + resolution::{ + import::PathResolutionError, path_resolver::PathResolver, resolver::Resolver, + }, + }, + hir_def::function::HirFunction, + node_interner::{FuncId, NodeInterner}, + parse_program, + }; + + use super::monomorphize; + + // TODO: refactor into a more general test utility? + // mostly copied from hir / type_check / mod.rs and adapted a bit + fn type_check_src_code(src: &str, func_namespace: Vec) -> (FuncId, NodeInterner) { + let (program, errors) = parse_program(src); + let mut interner = NodeInterner::default(); + + // Using assert_eq here instead of assert(errors.is_empty()) displays + // the whole vec if the assert fails rather than just two booleans + assert_eq!(errors, vec![]); + + let main_id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition("main".into(), main_id); + + let func_ids = vecmap(&func_namespace, |name| { + let id = interner.push_fn(HirFunction::empty()); + interner.push_function_definition(name.into(), id); + id + }); + + let mut path_resolver = TestPathResolver(HashMap::new()); + for (name, id) in func_namespace.into_iter().zip(func_ids.clone()) { + path_resolver.insert_func(name.to_owned(), id); + } + + let mut def_maps: HashMap = HashMap::new(); + let file = FileId::default(); + + let mut modules = arena::Arena::new(); + modules.insert(ModuleData::new(None, ModuleOrigin::File(file), false)); + + def_maps.insert( + CrateId::dummy_id(), + CrateDefMap { + root: path_resolver.local_module_id(), + modules, + krate: CrateId::dummy_id(), + extern_prelude: HashMap::new(), + }, + ); + + let func_meta = vecmap(program.functions, |nf| { + let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); + let (hir_func, func_meta, _resolver_errors) = + resolver.resolve_function(nf, main_id, ModuleId::dummy_id()); + // TODO: not sure why, we do get an error here, + // but otherwise seem to get an ok monomorphization result + // assert_eq!(resolver_errors, vec![]); + (hir_func, func_meta) + }); + + println!("Before update_fn"); + + for ((hir_func, meta), func_id) in func_meta.into_iter().zip(func_ids.clone()) { + interner.update_fn(func_id, hir_func); + interner.push_fn_meta(meta, func_id); + } + + println!("Before type_check_func"); + + // Type check section + let errors = crate::hir::type_check::type_check_func( + &mut interner, + func_ids.first().cloned().unwrap(), + ); + assert_eq!(errors, vec![]); + (func_ids.first().cloned().unwrap(), interner) + } + + // TODO: refactor into a more general test utility? + // TestPathResolver struct and impls copied from hir / type_check / mod.rs + struct TestPathResolver(HashMap); + + impl PathResolver for TestPathResolver { + fn resolve( + &self, + _def_maps: &HashMap, + path: crate::Path, + ) -> Result { + // Not here that foo::bar and hello::foo::bar would fetch the same thing + let name = path.segments.last().unwrap(); + let mod_def = self.0.get(&name.0.contents).cloned(); + mod_def.ok_or_else(move || PathResolutionError::Unresolved(name.clone())) + } + + fn local_module_id(&self) -> LocalModuleId { + // This is not LocalModuleId::dummy since we need to use this to index into a Vec + // later and do not want to push u32::MAX number of elements before we do. + LocalModuleId(arena::Index::from_raw_parts(0, 0)) + } + + fn module_id(&self) -> ModuleId { + ModuleId { krate: CrateId::dummy_id(), local_id: self.local_module_id() } + } + } + + impl TestPathResolver { + fn insert_func(&mut self, name: String, func_id: FuncId) { + self.0.insert(name, func_id.into()); + } + } + + // a helper test method + // TODO: maybe just compare trimmed src/expected + // for easier formatting? + fn check_rewrite(src: &str, expected: &str) { + let (func, interner) = type_check_src_code(src, vec!["main".to_string()]); + let program = monomorphize(func, &interner); + // println!("[{}]", program); + assert!(format!("{}", program) == expected); + } + + #[test] + fn simple_closure_with_no_captured_variables() { + let src = r#" + fn main() -> Field { + let x = 1; + let closure = || x; + closure() + } + "#; + + let expected_rewrite = r#"fn main$f0() -> Field { + let x$0 = 1; + let closure$3 = { + let closure_variable$2 = { + let env$1 = (x$l0); + (env$l1, lambda$f1) + }; + closure_variable$l2 + }; + closure$l3.1(closure$l3.0) +} +fn lambda$f1(mut env$l1: (Field)) -> Field { + env$l1.0 +} +"#; + check_rewrite(src, expected_rewrite); + } +} diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index f5fea5c1ea7..6b3d2757c14 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -672,7 +672,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::String(_) => Some(String), Type::Unit => Some(Unit), Type::Tuple(_) => Some(Tuple), - Type::Function(_, _) => Some(Function), + Type::Function(_, _, _) => Some(Function), Type::MutableReference(element) => get_type_method_key(element), // We do not support adding methods to these types