diff --git a/tooling/ast_fuzzer/src/lib.rs b/tooling/ast_fuzzer/src/lib.rs index 2329846a5b7..5ef74e871e8 100644 --- a/tooling/ast_fuzzer/src/lib.rs +++ b/tooling/ast_fuzzer/src/lib.rs @@ -30,8 +30,10 @@ pub struct Config { pub max_array_size: usize, /// Maximum size of for loop ranges, which affects unrolling in ACIR. pub max_loop_size: usize, - /// Maximum call depth for recursive calls. - pub max_call_depth: usize, + /// Whether to choose the backstop for `loop` and `while` randomly. + pub vary_loop_size: bool, + /// Maximum number of recursive calls to make at runtime. + pub max_recursive_calls: usize, /// Frequency of expressions, which produce a value. pub expr_freqs: Freqs, /// Frequency of statements in ACIR functions. @@ -81,7 +83,8 @@ impl Default for Config { max_tuple_size: 5, max_array_size: 4, max_loop_size: 10, - max_call_depth: 5, + vary_loop_size: true, + max_recursive_calls: 25, expr_freqs, stmt_freqs_acir, stmt_freqs_brillig, diff --git a/tooling/ast_fuzzer/src/program/expr.rs b/tooling/ast_fuzzer/src/program/expr.rs index 60a5f0a2a33..9bb5b06aa79 100644 --- a/tooling/ast_fuzzer/src/program/expr.rs +++ b/tooling/ast_fuzzer/src/program/expr.rs @@ -277,10 +277,18 @@ pub(crate) fn if_else( } /// Assign a value to an identifier. -pub(crate) fn assign(ident: Ident, expr: Expression) -> Expression { +pub(crate) fn assign_ident(ident: Ident, expr: Expression) -> Expression { Expression::Assign(Assign { lvalue: LValue::Ident(ident), expression: Box::new(expr) }) } +/// Assign a value to a mutable reference. +pub(crate) fn assign_ref(ident: Ident, expr: Expression) -> Expression { + let typ = ident.typ.clone(); + let lvalue = LValue::Ident(ident); + let lvalue = LValue::Dereference { reference: Box::new(lvalue), element_type: typ }; + Expression::Assign(Assign { lvalue, expression: Box::new(expr) }) +} + /// Cast an expression to a target type. pub(crate) fn cast(lhs: Expression, tgt_type: Type) -> Expression { Expression::Cast(Cast { lhs: Box::new(lhs), r#type: tgt_type, location: Location::dummy() }) @@ -312,6 +320,11 @@ pub(crate) fn deref(rhs: Expression, tgt_type: Type) -> Expression { unary(UnaryOp::Dereference { implicitly_added: false }, rhs, tgt_type) } +/// Reference an expression as a target type +pub(crate) fn ref_mut(rhs: Expression, tgt_type: Type) -> Expression { + unary(UnaryOp::Reference { mutable: true }, rhs, tgt_type) +} + /// Make a unary expression. pub(crate) fn unary(op: UnaryOp, rhs: Expression, tgt_type: Type) -> Expression { Expression::Unary(Unary { diff --git a/tooling/ast_fuzzer/src/program/func.rs b/tooling/ast_fuzzer/src/program/func.rs index 4d7b15e2dde..8f94c671c73 100644 --- a/tooling/ast_fuzzer/src/program/func.rs +++ b/tooling/ast_fuzzer/src/program/func.rs @@ -931,19 +931,16 @@ impl<'a> FunctionContext<'a> { // Increment the index in the beginning of the body. expr::prepend( &mut loop_body, - expr::assign( + expr::assign_ident( idx_ident, expr::binary(idx_expr.clone(), BinaryOp::Add, expr::u32_literal(1)), ), ); // Put everything into if/else + let max_loop_size = self.gen_loop_size(u)?; let loop_body = expr::if_else( - expr::binary( - idx_expr, - BinaryOp::Equal, - expr::u32_literal(self.ctx.config.max_loop_size as u32), - ), + expr::binary(idx_expr, BinaryOp::Equal, expr::u32_literal(max_loop_size as u32)), Expression::Break, loop_body, Type::Unit, @@ -985,19 +982,16 @@ impl<'a> FunctionContext<'a> { // Increment the index in the beginning of the body. expr::prepend( &mut loop_body, - expr::assign( + expr::assign_ident( idx_ident, expr::binary(idx_expr.clone(), BinaryOp::Add, expr::u32_literal(1)), ), ); // Put everything into if/else + let max_loop_size = self.gen_loop_size(u)?; let inner_block = Expression::Block(vec![expr::if_else( - expr::binary( - idx_expr, - BinaryOp::Equal, - expr::u32_literal(self.ctx.config.max_loop_size as u32), - ), + expr::binary(idx_expr, BinaryOp::Equal, expr::u32_literal(max_loop_size as u32)), Expression::Break, loop_body, Type::Unit, @@ -1013,6 +1007,14 @@ impl<'a> FunctionContext<'a> { Ok(Expression::Block(stmts)) } + + fn gen_loop_size(&self, u: &mut Unstructured) -> arbitrary::Result { + if self.ctx.config.vary_loop_size { + u.choose_index(self.ctx.config.max_loop_size) + } else { + Ok(self.ctx.config.max_loop_size) + } + } } #[test] @@ -1020,7 +1022,8 @@ fn test_loop() { let mut u = Unstructured::new(&[0u8; 1]); let mut ctx = Context::default(); ctx.config.max_loop_size = 10; - ctx.add_main_decl(&mut u); + ctx.config.vary_loop_size = false; + ctx.gen_main_decl(&mut u); let mut fctx = FunctionContext::new(&mut ctx, FuncId(0)); fctx.budget = 2; let loop_code = format!("{}", fctx.gen_loop(&mut u).unwrap()).replace(" ", ""); @@ -1045,7 +1048,8 @@ fn test_while() { let mut u = Unstructured::new(&[0u8; 1]); let mut ctx = Context::default(); ctx.config.max_loop_size = 10; - ctx.add_main_decl(&mut u); + ctx.config.vary_loop_size = false; + ctx.gen_main_decl(&mut u); let mut fctx = FunctionContext::new(&mut ctx, FuncId(0)); fctx.budget = 2; let while_code = format!("{}", fctx.gen_while(&mut u).unwrap()).replace(" ", ""); diff --git a/tooling/ast_fuzzer/src/program/mod.rs b/tooling/ast_fuzzer/src/program/mod.rs index 476911776df..4ef6852d5ca 100644 --- a/tooling/ast_fuzzer/src/program/mod.rs +++ b/tooling/ast_fuzzer/src/program/mod.rs @@ -180,7 +180,7 @@ impl Context { /// Generate and add main (for testing) #[cfg(test)] - fn add_main_decl(&mut self, u: &mut Unstructured) { + fn gen_main_decl(&mut self, u: &mut Unstructured) { let d = self.gen_function_decl(u, 0).unwrap(); self.function_declarations.insert(FuncId(0u32), d); } @@ -208,7 +208,7 @@ impl Context { /// As a post-processing step, identify recursive functions and add a call depth parameter to them. fn rewrite_functions(&mut self, u: &mut Unstructured) -> arbitrary::Result<()> { - rewrite::add_recursion_depth(self, u) + rewrite::add_recursion_limit(self, u) } /// Return the generated [Program]. diff --git a/tooling/ast_fuzzer/src/program/rewrite.rs b/tooling/ast_fuzzer/src/program/rewrite.rs index f7daefbc875..a5adb844416 100644 --- a/tooling/ast_fuzzer/src/program/rewrite.rs +++ b/tooling/ast_fuzzer/src/program/rewrite.rs @@ -1,84 +1,185 @@ -use std::collections::HashSet; +use std::collections::BTreeMap; use arbitrary::Unstructured; +use im::HashMap; +use nargo::errors::Location; use noirc_frontend::{ ast::BinaryOpKind, - monomorphization::ast::{Definition, Expression, Function, IdentId, LocalId, Program, Type}, - shared::Visibility, + monomorphization::ast::{ + Call, Definition, Expression, FuncId, Function, Ident, IdentId, LocalId, Program, Type, + }, }; use super::{ - Context, VariableId, expr, func, types, + Context, VariableId, expr, types, visitor::{visit_expr, visit_expr_mut}, }; -/// Find recursive functions and add a `ctx_depth` parameter to them. -pub(crate) fn add_recursion_depth( +/// Find recursive functions and add a `ctx_limit: &mut u32` parameter to them, +/// which we use to limit the number of recursive calls. This is complicated by +/// the fact that we cannot pass mutable references from ACIR to Brillig. To +/// overcome that, we create a proxy function for unconstrained functions that +/// take `mut ctx_limit: u32` instead, and pass it on as a mutable ref. +pub(crate) fn add_recursion_limit( ctx: &mut Context, u: &mut Unstructured, ) -> arbitrary::Result<()> { // Collect recursive functions, ie. the ones which call other functions. + // Remember if they are unconstrained; those need proxies as well. let recursive_functions = ctx .functions .iter() - .filter_map(|(id, func)| expr::has_call(&func.body).then_some(*id)) - .collect::>(); + .filter_map(|(id, func)| expr::has_call(&func.body).then_some((*id, func.unconstrained))) + .collect::>(); - for (func_id, func) in - ctx.functions.iter_mut().filter(|(id, _)| recursive_functions.contains(id)) - { + // Create proxies for unconstrained recursive functions. + // We could check whether they are called from ACIR, but that would require further traversals. + let mut proxy_functions = HashMap::new(); + let mut next_func_id = FuncId(ctx.functions.len() as u32); + + for (func_id, unconstrained) in &recursive_functions { + if !*unconstrained || *func_id == Program::main_id() { + continue; + } + let mut proxy = ctx.functions[func_id].clone(); + proxy.id = next_func_id; + proxy.name = format!("{}_proxy", proxy.name); + // We will replace the body and update the params later. + proxy_functions.insert(*func_id, proxy); + next_func_id = FuncId(next_func_id.0 + 1); + } + + // Rewrite recursive functions. + for (func_id, unconstrained) in recursive_functions.iter() { + let func = ctx.functions.get_mut(func_id).unwrap(); let is_main = *func_id == Program::main_id(); + // We'll need a new ID for variables or parameters. We could speed this up by // 1) caching this value in a "function meta" construct, or - // 2) using `u32::MAX`, but we wouldn't be able to add caching to `Program`, - // so eventually we'll need to look at the values to do random mutations. - let (next_local_id, next_ident_id) = next_local_and_ident_id(func); - let depth_id = LocalId(next_local_id); - let depth_name = "ctx_depth".to_string(); - let depth_ident_id = IdentId(next_ident_id); - let depth_ident = expr::ident_inner( - VariableId::Local(depth_id), - depth_ident_id, - !is_main, - depth_name.clone(), - types::U32, - ); - let depth_expr = Expression::Ident(depth_ident.clone()); - let depth_decreased = - expr::binary(depth_expr.clone(), BinaryOpKind::Subtract, expr::u32_literal(1)); + // 2) using `u32::MAX`, but then we would be in a worse situation next time + // 3) draw values from `Context` instead of `FunctionContext`, which breaks continuity, but saves an extra traversal. + // We wouldn't be able to add caching to `Program` without changing it, so eventually we'll need to look at the values + // to do random mutations, or we have to pass back some meta along with `Program` and look it up there. For now we + // traverse the AST to figure out what the next ID to use is. + let (mut next_local_id, mut next_ident_id) = next_local_and_ident_id(func); + + let mut next_local_id = || { + let id = next_local_id; + next_local_id += 1; + LocalId(id) + }; + + let mut next_ident_id = || { + let id = next_ident_id; + next_ident_id += 1; + IdentId(id) + }; + + let limit_name = "ctx_limit".to_string(); + let limit_id = next_local_id(); + let limit_var = VariableId::Local(limit_id); if is_main { - // In main we initialize the depth to its maximum value. - let init_depth = expr::let_var( - depth_id, - false, - depth_name, - expr::u32_literal(ctx.config.max_call_depth as u32), + // In main we initialize the limit to its maximum value. + let init_limit = expr::let_var( + limit_id, + true, + limit_name.clone(), + expr::u32_literal(ctx.config.max_recursive_calls as u32), ); - expr::prepend(&mut func.body, init_depth); + expr::prepend(&mut func.body, init_limit); } else { - // In non-main we look at the depth and return a random value if it's zero, + // In non-main we look at the limit and return a random value if it's zero, // otherwise decrease it by one and continue with the original body. - func.parameters.push((depth_id, true, depth_name.clone(), types::U32)); - func.func_sig.0.push(func::hir_param(true, &types::U32, Visibility::Private)); + let limit_type = types::ref_mut(types::U32); + func.parameters.push((limit_id, false, limit_name.clone(), limit_type.clone())); + // Generate a random value to return. let default_return = expr::gen_literal(u, &func.return_type)?; - expr::replace(&mut func.body, |body| { + let limit_ident = expr::ident_inner( + limit_var, + next_ident_id(), + false, + limit_name.clone(), + limit_type, + ); + let limit_expr = Expression::Ident(limit_ident.clone()); + + expr::replace(&mut func.body, |mut body| { + expr::prepend( + &mut body, + expr::assign_ref( + limit_ident, + expr::binary( + expr::deref(limit_expr.clone(), types::U32), + BinaryOpKind::Subtract, + expr::u32_literal(1), + ), + ), + ); expr::if_else( - expr::equal(depth_expr.clone(), expr::u32_literal(0)), + expr::equal(expr::deref(limit_expr.clone(), types::U32), expr::u32_literal(0)), default_return, - Expression::Block(vec![ - expr::assign(depth_ident, depth_decreased.clone()), - body, - ]), + body, func.return_type.clone(), ) }); } - // Update calls to pass along the depth. - visit_expr_mut(&mut func.body, &mut |expr| { + // Add the non-reference version of the parameter to the proxy function. + if let Some(proxy) = proxy_functions.get_mut(func_id) { + proxy.parameters.push((limit_id, true, limit_name.clone(), types::U32)); + // The body is just a call the the non-proxy function. + proxy.body = Expression::Call(Call { + func: Box::new(Expression::Ident(Ident { + location: None, + definition: Definition::Function(*func_id), + mutable: false, + name: func.name.clone(), + typ: Type::Function( + func.parameters.iter().map(|p| p.3.clone()).collect(), + Box::new(func.return_type.clone()), + Box::new(Type::Unit), + func.unconstrained, + ), + id: next_ident_id(), + })), + arguments: proxy + .parameters + .iter() + .map(|(id, mutable, name, typ)| { + if *id == limit_id { + // Pass mutable reference to the limit. + expr::ref_mut( + expr::ident( + VariableId::Local(*id), + next_ident_id(), + *mutable, + name.clone(), + typ.clone(), + ), + typ.clone(), + ) + } else { + // Pass every other parameter as-is. + expr::ident( + VariableId::Local(*id), + next_ident_id(), + *mutable, + name.clone(), + typ.clone(), + ) + } + }) + .collect(), + return_type: proxy.return_type.clone(), + location: Location::dummy(), + }); + } + + // Update calls to pass along the limit and call the proxy if necessary. + visit_expr_mut(&mut func.body, &mut |expr: &mut Expression| { if let Expression::Call(call) = expr { let Expression::Ident(func) = call.func.as_mut() else { unreachable!("functions are called by ident"); @@ -87,19 +188,78 @@ pub(crate) fn add_recursion_depth( unreachable!("function definition expected"); }; // If the callee isn't recursive, it won't have the extra parameter. - if !recursive_functions.contains(&func_id) { + let Some(callee_unconstrained) = recursive_functions.get(&func_id) else { return true; - } + }; let Type::Function(param_types, _, _, _) = &mut func.typ else { unreachable!("function type expected"); }; - param_types.push(types::U32); - call.arguments.push(depth_expr.clone()); + if *callee_unconstrained && !unconstrained { + // Calling Brillig from ACIR: call the proxy. + let Some(proxy) = proxy_functions.get(&func_id) else { + unreachable!("expected to have a proxy"); + }; + func.name = proxy.name.clone(); + func.definition = Definition::Function(proxy.id); + // Pass the limit by value. + let limit_expr = if is_main { + expr::ident( + limit_var, + next_ident_id(), + true, + limit_name.clone(), + types::U32, + ) + } else { + expr::deref( + expr::ident( + limit_var, + next_ident_id(), + false, + limit_name.clone(), + types::ref_mut(types::U32), + ), + types::U32, + ) + }; + param_types.push(types::U32); + call.arguments.push(limit_expr); + } else { + // Pass the limit by reference. + let limit_type = types::ref_mut(types::U32); + let limit_expr = if is_main { + expr::ref_mut( + expr::ident( + limit_var, + next_ident_id(), + true, + limit_name.clone(), + types::U32, + ), + limit_type, + ) + } else { + expr::ident( + limit_var, + next_ident_id(), + false, + limit_name.clone(), + limit_type, + ) + }; + param_types.push(types::U32); + call.arguments.push(limit_expr); + } } true }); } + // Append proxy functions. + for (_, proxy) in proxy_functions { + ctx.functions.insert(proxy.id, proxy); + } + Ok(()) } diff --git a/tooling/ast_fuzzer/src/program/tests.rs b/tooling/ast_fuzzer/src/program/tests.rs index d7f1c4dff50..9c6d4544334 100644 --- a/tooling/ast_fuzzer/src/program/tests.rs +++ b/tooling/ast_fuzzer/src/program/tests.rs @@ -1,13 +1,19 @@ +use arbitrary::Unstructured; use nargo::errors::Location; use noirc_evaluator::{assert_ssa_snapshot, ssa::ssa_gen}; use noirc_frontend::{ ast::IntegerBitSize, monomorphization::ast::{ - Expression, For, FuncId, Function, InlineType, LocalId, Program, Type, + Call, Definition, Expression, For, FuncId, Function, Ident, IdentId, InlineType, LocalId, + Program, Type, }, shared::Visibility, }; +use crate::{Config, program::FunctionDeclaration}; + +use super::{Context, DisplayAstAsNoir}; + #[test] fn test_make_name() { use crate::program::make_name; @@ -92,3 +98,146 @@ fn test_modulo_of_negative_literals_in_range() { } "); } + +/// Check that the AST we generate for recursive functions is as expected. +#[test] +fn test_recursion_limit_rewrite() { + let mut ctx = Context::new(Config::default()); + let mut next_ident_id = 0; + + let mut add_func = |id: FuncId, name: &str, unconstrained: bool, calling: &[FuncId]| { + let calls = calling + .iter() + .map(|callee_id| { + let (callee_name, callee_unconstrained) = if *callee_id == id { + (name.to_string(), unconstrained) + } else { + let callee = &ctx.functions[callee_id]; + (callee.name.clone(), callee.unconstrained) + }; + + let ident_id = IdentId(next_ident_id); + next_ident_id += 1; + + Expression::Call(Call { + func: Box::new(Expression::Ident(Ident { + location: None, + definition: Definition::Function(*callee_id), + mutable: false, + name: callee_name, + typ: Type::Function( + vec![], + Box::new(Type::Unit), + Box::new(Type::Unit), + callee_unconstrained, + ), + id: ident_id, + })), + arguments: vec![], + return_type: Type::Unit, + location: Location::dummy(), + }) + }) + .collect(); + + let func = Function { + id, + name: name.to_string(), + parameters: vec![], + body: Expression::Block(calls), + return_type: Type::Unit, + unconstrained, + inline_type: InlineType::InlineAlways, + func_sig: (vec![], None), + }; + + ctx.function_declarations.insert( + id, + FunctionDeclaration { + name: name.to_string(), + params: vec![], + param_visibilities: vec![], + return_type: Type::Unit, + return_visibility: Visibility::Private, + inline_type: func.inline_type, + unconstrained: func.unconstrained, + }, + ); + + ctx.functions.insert(id, func); + }; + + // Create functions: + // - ACIR main, calling foo + // - ACIR foo, calling bar + // - Brillig bar, calling baz and qux + // - Brillig baz, calling itself + // - Brillig qux, not calling anything + + let main_id = FuncId(0); + let foo_id = FuncId(1); + let bar_id = FuncId(2); + let baz_id = FuncId(3); + let qux_id = FuncId(4); + + add_func(qux_id, "qux", true, &[]); + add_func(baz_id, "baz", true, &[baz_id]); + add_func(bar_id, "bar", true, &[baz_id, qux_id]); + add_func(foo_id, "foo", false, &[bar_id]); + add_func(main_id, "main", false, &[foo_id]); + + // We only generate `Unit` returns, so no randomness is expected, + // but it would be deterministic anyway. + let mut u = Unstructured::new(&[0u8; 1]); + ctx.rewrite_functions(&mut u).unwrap(); + let program = ctx.finalize(); + + // Check that: + // - main passes the limit to foo by ref + // - foo passes the limit to bar_proxy by value + // - bar_proxy passes the limit to baz by ref + // - bar does not passes the limit to qux + // - baz passes the limit to itself by ref + + let code = format!("{}", DisplayAstAsNoir(&program)); + + insta::assert_snapshot!(code, @r" + fn main() -> () { + let mut ctx_limit = 25; + foo((&mut ctx_limit)) + } + fn foo(ctx_limit: &mut u32) -> () { + if ((*ctx_limit) == 0) { + () + } else { + *ctx_limit = ((*ctx_limit) - 1); + unsafe { bar_proxy((*ctx_limit)) } + } + } + unconstrained fn bar(ctx_limit: &mut u32) -> () { + if ((*ctx_limit) == 0) { + () + } else { + *ctx_limit = ((*ctx_limit) - 1); + baz(ctx_limit); + qux() + } + } + unconstrained fn baz(ctx_limit: &mut u32) -> () { + if ((*ctx_limit) == 0) { + () + } else { + *ctx_limit = ((*ctx_limit) - 1); + baz(ctx_limit) + } + } + unconstrained fn qux() -> () { + } + unconstrained fn bar_proxy(mut ctx_limit: u32) -> () { + bar((&mut ctx_limit)) + } + unconstrained fn baz_proxy(mut ctx_limit: u32) -> () { + baz((&mut ctx_limit)) + } + "); +} diff --git a/tooling/ast_fuzzer/src/program/types.rs b/tooling/ast_fuzzer/src/program/types.rs index d4fcfcba71c..8f391d07b85 100644 --- a/tooling/ast_fuzzer/src/program/types.rs +++ b/tooling/ast_fuzzer/src/program/types.rs @@ -222,3 +222,8 @@ pub(crate) fn can_binary_op_return_from_input(op: &BinaryOp, input: &Type, outpu _ => false, } } + +/// Reference an expression into a target type +pub(crate) fn ref_mut(typ: Type) -> Type { + Type::Reference(Box::new(typ), true) +}