diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs index 3dfd7ddf3c3..759baab49ac 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs @@ -6,6 +6,8 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use super::brillig_block::BrilligBlock; use super::{BrilligVariable, Function, FunctionContext, ValueId}; +use crate::ssa::ir::call_graph::CallGraph; +use crate::ssa::ssa_gen::Ssa; use crate::{ brillig::{ Brillig, BrilligOptions, ConstantAllocation, DataFlowGraph, FunctionId, Label, @@ -55,11 +57,12 @@ pub(crate) type ConstantCounterMap = HashMap<(FieldElement, NumericType), usize> impl BrilligGlobals { pub(crate) fn new( - functions: &BTreeMap, + ssa: &Ssa, mut used_globals: HashMap>, main_id: FunctionId, ) -> Self { - let brillig_entry_points = get_brillig_entry_points(functions, main_id); + let call_graph = CallGraph::from_ssa(ssa); + let brillig_entry_points = get_brillig_entry_points(&ssa.functions, main_id, &call_graph); let mut hoisted_global_constants: HashMap = HashMap::default(); @@ -70,14 +73,14 @@ impl BrilligGlobals { Self::mark_globals_for_hoisting( &mut hoisted_global_constants, *entry_point, - &functions[entry_point], + &ssa.functions[entry_point], ); for inner_call in entry_point_inner_calls.iter() { Self::mark_globals_for_hoisting( &mut hoisted_global_constants, *entry_point, - &functions[inner_call], + &ssa.functions[inner_call], ); let inner_globals = used_globals @@ -285,6 +288,7 @@ impl Brillig { (artifact, function_context.ssa_value_allocations, globals_size, hoisted_global_constants) } } + #[cfg(test)] mod tests { use acvm::{ diff --git a/compiler/noirc_evaluator/src/brillig/mod.rs b/compiler/noirc_evaluator/src/brillig/mod.rs index b99217b6d50..a5e9c64a1b4 100644 --- a/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/compiler/noirc_evaluator/src/brillig/mod.rs @@ -153,8 +153,7 @@ impl Ssa { return brillig; } - let mut brillig_globals = - BrilligGlobals::new(&self.functions, used_globals_map, self.main_id); + let mut brillig_globals = BrilligGlobals::new(self, used_globals_map, self.main_id); // SSA Globals are computed once at compile time and shared across all functions, // thus we can just fetch globals from the main function. diff --git a/compiler/noirc_evaluator/src/ssa/opt/brillig_entry_points.rs b/compiler/noirc_evaluator/src/ssa/opt/brillig_entry_points.rs index bccc1dcd8a2..d8c8c316034 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/brillig_entry_points.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/brillig_entry_points.rs @@ -1,16 +1,17 @@ //! The purpose of this pass is to perform function specialization of Brillig functions based upon //! a function's entry points. Function specialization is performed through duplication of functions. +//! Brillig entry points are defined as functions called directly by ACIR functions or are `main`. //! //! This pass is done due to how globals are initialized for Brillig generation. -//! We allow multiple Brillig entry points (every call to Brillig from ACIR is an entry point), -//! and in order to avoid re-initializing globals used in one entry point but not another, -//! we set the globals initialization code based upon the globals used in a given entry point. -//! The ultimate goal is to optimize for runtime execution. +//! We allow multiple Brillig entry points, and in order to avoid re-initializing globals +//! used in one entry point but not another, we set the globals initialization code based +//! upon the globals used in a given entry point. The ultimate goal is to optimize for runtime execution. //! //! However, doing the above on its own is insufficient as we allow entry points to be called from //! other entry points and functions can be called across multiple entry points. -//! As all functions can potentially share entry points and use globals, the global allocations maps -//! generated for different entry points can conflict. +//! Without specialization, the following issues arise: +//! 1. Entry points calling the same function may conflict on global allocations. +//! 2. Entry points calling other entry points may cause overlapping global usage. //! //! To provide a more concrete example, let's take this program: //! ```noir @@ -51,12 +52,14 @@ //! CONST M32836 = 3 //! RETURN //! ``` -//! It is then not clear when generating the bytecode for `inner_func` which global allocations map should be used, -//! and any choice will lead to an incorrect program. +//! Here, `inner_func` is called by two different entry points. It is then not clear when generating the bytecode +//! for `inner_func` which global allocations map should be used, and any choice will lead to an incorrect program. //! If `inner_func` used the map for `entry_point_one` the bytecode generated would use `M32837` to represent `THREE`. //! However, when `inner_func` is called from `entry_point_two`, the address for `THREE` is `M32836`. //! -//! This pass will duplicate `inner_func` so that different functions are called by the different entry points. +//! This pass duplicates functions like `inner_func` so that each entry point gets its own specialized +//! version. The result is that bytecode can safely reference the correct globals without conflicts. +//! //! The test module for this pass can be referenced to see how this function duplication looks in SSA. use std::collections::{BTreeMap, BTreeSet}; @@ -66,7 +69,7 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use crate::ssa::{ Ssa, ir::{ - call_graph::called_functions_vec, + call_graph::CallGraph, function::{Function, FunctionId}, instruction::{Instruction, InstructionId}, value::{Value, ValueId}, @@ -75,13 +78,18 @@ use crate::ssa::{ impl Ssa { pub(crate) fn brillig_entry_point_analysis(mut self) -> Ssa { - if self.main().runtime().is_brillig() { + let main = self.main(); + if main.runtime().is_brillig() { return self; } - // Build a call graph based upon the Brillig entry points and set up + // Build a call graph from the SSA + let call_graph = CallGraph::from_ssa(&self); + + // From the call graph find the Brillig entry points and set up // the functions needing specialization before performing the actual call site rewrites. - let brillig_entry_points = get_brillig_entry_points(&self.functions, self.main_id); + let brillig_entry_points = + get_brillig_entry_points(&self.functions, self.main_id, &call_graph); let functions_to_clone_map = build_functions_to_clone(&brillig_entry_points); let (calls_to_update, mut new_functions_map) = build_calls_to_update(&mut self, functions_to_clone_map, &brillig_entry_points); @@ -304,89 +312,59 @@ fn collect_callsites_to_rewrite( new_calls_to_update } -/// Returns a map of Brillig entry points to all functions called in that entry point. -/// This includes any nested calls as well, as we want to be able to associate -/// any Brillig function with the appropriate global allocations. +/// Returns a map of Brillig entry points to all reachable functions from that entry point. +/// +/// A Brillig entry point is defined as a Brillig function that is directly called +/// from at least one ACIR function, or is the `main` function itself if it is Brillig. +/// +/// The value set for each entry point includes all functions reachable +/// from the entry point (excluding the entry itself if it is non-recursive). pub(crate) fn get_brillig_entry_points( functions: &BTreeMap, main_id: FunctionId, + call_graph: &CallGraph, ) -> BTreeMap> { + let recursive_functions = call_graph.get_recursive_functions(); let mut brillig_entry_points = BTreeMap::default(); - let acir_functions = functions.iter().filter(|(_, func)| func.runtime().is_acir()); - for (_, function) in acir_functions { - for block_id in function.reachable_blocks() { - for instruction_id in function.dfg[block_id].instructions() { - let instruction = &function.dfg[*instruction_id]; - let Instruction::Call { func: func_id, arguments: _ } = instruction else { - continue; - }; - - let func_value = &function.dfg[*func_id]; - let Value::Function(func_id) = func_value else { continue }; - - let called_function = &functions[func_id]; - if called_function.runtime().is_acir() { - continue; - } - - // We have now found a Brillig entry point. - brillig_entry_points.insert(*func_id, BTreeSet::default()); - build_entry_points_map_recursive( - functions, - *func_id, - *func_id, - &mut brillig_entry_points, - im::HashSet::new(), - ); - } + + // Only ACIR callers can introduce Brillig entry points + let acir_callers = call_graph + .callees() + .into_iter() + .filter(|(caller, _)| functions[caller].runtime().is_acir()); + for (_, callees) in acir_callers { + // Filter only the Brillig callees. These are the Brillig entry points. + let entry_points = callees.keys().filter(|callee| functions[callee].runtime().is_brillig()); + for &entry_point in entry_points { + brillig_entry_points.insert( + entry_point, + brillig_reachable(call_graph, &recursive_functions, entry_point), + ); } } // If main has been marked as Brillig, it is itself an entry point. // Run the same analysis from above on main. - let main_func = &functions[&main_id]; - if main_func.runtime().is_brillig() { - brillig_entry_points.insert(main_id, BTreeSet::default()); - build_entry_points_map_recursive( - functions, - main_id, - main_id, - &mut brillig_entry_points, - im::HashSet::new(), - ); + if functions[&main_id].runtime().is_brillig() { + brillig_entry_points + .insert(main_id, brillig_reachable(call_graph, &recursive_functions, main_id)); } brillig_entry_points } -/// Recursively mark any functions called in an entry point -fn build_entry_points_map_recursive( - functions: &BTreeMap, - entry_point: FunctionId, - called_function: FunctionId, - brillig_entry_points: &mut BTreeMap>, - mut explored_functions: im::HashSet, -) { - if explored_functions.insert(called_function).is_some() { - return; - } - - let inner_calls: HashSet = - called_functions_vec(&functions[&called_function]).into_iter().collect(); - - for inner_call in inner_calls { - if let Some(inner_calls) = brillig_entry_points.get_mut(&entry_point) { - inner_calls.insert(inner_call); - } - - build_entry_points_map_recursive( - functions, - entry_point, - inner_call, - brillig_entry_points, - explored_functions.clone(), - ); +/// Returns all functions reachable from the given Brillig entry point. +/// Includes the entry point itself if it is recursive, otherwise excludes it. +fn brillig_reachable( + call_graph: &CallGraph, + recursive_functions: &HashSet, + func: FunctionId, +) -> BTreeSet { + let mut reachable = call_graph.reachable_from([func]); + if !recursive_functions.contains(&func) { + reachable.remove(&func); } + reachable.into_iter().collect() } /// Builds a mapping from a [`FunctionId`] to the set of [`FunctionId`s][`FunctionId`] of all the brillig entrypoints @@ -937,4 +915,194 @@ mod tests { } "#); } + + #[test] + fn functions_reachable_from_single_entry_point_are_not_duplicated() { + let src = " + g0 = Field 1 + + acir(inline) fn main f0 { + b0(v1: Field): + call f1(v1) + return + } + brillig(inline) fn entry_point f1 { + b0(v1: Field): + call f2(v1) + return + } + brillig(inline) fn helper_func f2 { + b0(v1: Field): + call f3(v1) + return + } + brillig(inline) fn leaf_func f3 { + b0(v1: Field): + v2 = add g0, v1 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.brillig_entry_point_analysis(); + + // f2 and f3 are reachable from only one entry point, so they are not duplicated + assert_ssa_snapshot!(ssa, @r#" + g0 = Field 1 + + acir(inline) fn main f0 { + b0(v1: Field): + call f1(v1) + return + } + brillig(inline) fn entry_point f1 { + b0(v1: Field): + call f2(v1) + return + } + brillig(inline) fn helper_func f2 { + b0(v1: Field): + call f3(v1) + return + } + brillig(inline) fn leaf_func f3 { + b0(v1: Field): + v2 = add Field 1, v1 + return + } + "#); + } + + #[test] + fn idempotency() { + let src = " + g0 = Field 1 + g1 = Field 2 + g2 = Field 3 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v3: Field, v4: Field): + v5 = add g0, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v3: Field, v4: Field): + v5 = add g1, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + call f3(v3, v4) + return + } + brillig(inline) fn inner_func f3 { + b0(v3: Field, v4: Field): + v5 = add g2, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let mut first_ssa = ssa.brillig_entry_point_analysis().remove_unreachable_functions(); + + // We expect `inner_func` to be duplicated + assert_ssa_snapshot!(&mut first_ssa, @r" + g0 = Field 1 + g1 = Field 2 + g2 = Field 3 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v3: Field, v4: Field): + v5 = add Field 1, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v3: Field, v4: Field): + v5 = add Field 2, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + call f4(v3, v4) + return + } + brillig(inline) fn inner_func f3 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + brillig(inline) fn inner_func f4 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "); + + let mut second_ssa = + first_ssa.brillig_entry_point_analysis().remove_unreachable_functions(); + + // We expect `inner_func` to be duplicated + assert_ssa_snapshot!(&mut second_ssa, @r" + g0 = Field 1 + g1 = Field 2 + g2 = Field 3 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v3: Field, v4: Field): + v5 = add Field 1, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v3: Field, v4: Field): + v5 = add Field 2, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + call f4(v3, v4) + return + } + brillig(inline) fn inner_func f3 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + brillig(inline) fn inner_func f4 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/pure.rs b/compiler/noirc_evaluator/src/ssa/opt/pure.rs index 7ac8e6c4d1a..0f5953e9959 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/pure.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/pure.rs @@ -29,7 +29,10 @@ impl Ssa { /// identified as calling known pure functions. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn purity_analysis(mut self) -> Ssa { - let brillig_entry_points = get_brillig_entry_points(&self.functions, self.main_id); + let call_graph = CallGraph::from_ssa(&self); + + let brillig_entry_points = + get_brillig_entry_points(&self.functions, self.main_id, &call_graph); // First look through each function to get a baseline on its purity and collect // the functions it calls to build a call graph. @@ -44,7 +47,6 @@ impl Ssa { // Then transitively 'infect' any functions which call impure functions as also // impure. - let call_graph = CallGraph::from_ssa(&self); let purities = analyze_call_graph(call_graph, purities, &self.functions); let purities = Arc::new(purities);