diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 935918c6b7e..a6e66764418 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -211,6 +211,11 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Result Ssa { + // We first gather all calls to brillig functions that have some constants in them, + // together with how many calls were done to it (in total, and with a given set of constants) + + // Calls to a given function with arguments where some might be constants + // function_id -> (constants -> count) + let mut calls = HashMap::new(); + + // Count of all calls to a given function + // function_id -> count + let mut total_calls = HashMap::new(); + + for function in self.functions.values() { + if !function.runtime().is_acir() { + continue; + } + + function.gather_calls_to_brillig_functions_with_constants( + &self, + &mut calls, + &mut total_calls, + ); + } + + // Now we determine which constants we are going to inline. + // The rule we'll use is: if a given set of constants was used more than 30% + // of the time across all calls to a given function, we'll create a specific + // function with those constants inlined. + calls.retain(|func_id, entries| { + let total_count = total_calls[func_id] as f64; + entries.retain(|_, count| (*count as f64 / total_count) >= 0.3); + !entries.is_empty() + }); + + // Next, create specialized functions where those constants are inlined + // function_id -> (constants -> new_function_id) + let mut new_functions: HashMap>, FunctionId>> = + HashMap::new(); + + for (func_id, entries) in calls { + let function = self.functions[&func_id].clone(); + let function_num_instructions = function.num_instructions(); + for (constants, _) in entries { + let Some(new_function_id) = self.maybe_add_fn(|func_id| { + let new_function = + inline_constants_into_function(&function, &constants, func_id); + // No point in using the new function if it's not more optimal + if new_function.num_instructions() < function_num_instructions { + Some(new_function) + } else { + None + } + }) else { + continue; + }; + let entry = new_functions.entry(func_id).or_default(); + entry.entry(constants).insert_entry(new_function_id); + } + } + + // Finally, redirect calls to use the new functions + for function in self.functions.values_mut() { + if !function.runtime().is_acir() { + continue; + } + + function.replace_brillig_calls_with_constants(&new_functions); + } + + self + } +} + +#[derive(Debug, Hash, PartialEq, Eq)] +enum Constant { + Number(FieldElement, NumericType), + Array(Vec, Type), +} + +impl Function { + fn gather_calls_to_brillig_functions_with_constants( + &self, + ssa: &Ssa, + calls: &mut HashMap>, usize>>, + total_calls: &mut HashMap, + ) { + for block in self.reachable_blocks() { + for instruction_id in self.dfg[block].instructions() { + let instruction = &self.dfg[*instruction_id]; + let Instruction::Call { func, arguments } = instruction else { + continue; + }; + + let Value::Function(func_id) = self.dfg[*func] else { + continue; + }; + + let func = &ssa.functions[&func_id]; + if !func.runtime().is_brillig() { + continue; + } + + *total_calls.entry(func_id).or_default() += 1; + + if !arguments.iter().any(|argument| self.dfg.is_constant(*argument)) { + continue; + } + + let constants = vecmap(arguments, |argument| get_constant(*argument, &self.dfg)); + *calls.entry(func_id).or_default().entry(constants).or_default() += 1; + } + } + } + + fn replace_brillig_calls_with_constants( + &mut self, + functions: &HashMap>, FunctionId>>, + ) { + for block in self.reachable_blocks() { + let instruction_ids = self.dfg[block].take_instructions(); + for instruction_id in instruction_ids { + let instruction = &self.dfg[instruction_id]; + let Instruction::Call { func, arguments } = instruction else { + self.dfg[block].insert_instruction(instruction_id); + continue; + }; + + let Value::Function(func_id) = self.dfg[*func] else { + self.dfg[block].insert_instruction(instruction_id); + continue; + }; + + let Some(entries) = functions.get(&func_id) else { + self.dfg[block].insert_instruction(instruction_id); + continue; + }; + + if !arguments.iter().any(|argument| self.dfg.is_constant(*argument)) { + self.dfg[block].insert_instruction(instruction_id); + continue; + } + + let constants = vecmap(arguments, |argument| get_constant(*argument, &self.dfg)); + let Some(new_function_id) = entries.get(&constants) else { + self.dfg[block].insert_instruction(instruction_id); + continue; + }; + + let mut new_arguments = Vec::new(); + for (index, constant) in constants.iter().enumerate() { + if constant.is_none() { + new_arguments.push(arguments[index]); + } + } + + let new_function_id = self.dfg.import_function(*new_function_id); + let new_instruction = + Instruction::Call { func: new_function_id, arguments: new_arguments }; + let call_stack = self.dfg.get_instruction_call_stack_id(instruction_id); + let old_results = self.dfg.instruction_results(instruction_id); + let old_results = old_results.to_vec(); + let typevars = old_results + .iter() + .map(|value| self.dfg.type_of_value(*value)) + .collect::>(); + + let new_results = self.dfg.insert_instruction_and_results( + new_instruction, + block, + Some(typevars), + call_stack, + ); + let new_results = new_results.results().iter().copied().collect::>(); + for (old_result, new_result) in old_results.into_iter().zip(new_results) { + self.dfg.set_value_from_id(old_result, new_result); + } + } + } + } +} + +fn get_constant(value: ValueId, dfg: &DataFlowGraph) -> Option { + if let Some((value, typ)) = dfg.get_numeric_constant_with_type(value) { + return Some(Constant::Number(value, typ)); + } + + if let Some((values, typ)) = dfg.get_array_constant(value) { + let mut constants = Vec::with_capacity(values.len()); + for value in values { + constants.push(get_constant(value, dfg)?); + } + return Some(Constant::Array(constants, typ)); + } + + None +} + +fn inline_constants_into_function( + function: &Function, + constants: &[Option], + id: FunctionId, +) -> Function { + let mut function = Function::clone_with_id(id, function); + let entry_block_id = function.entry_block(); + + // Take the entry block instructions as we first might need to insert a few MakeArray instructions + // and they must appear before everything else. + let entry_block_instructions = function.dfg[entry_block_id].take_instructions(); + + let parameters = function.parameters().to_vec(); + + // First replace all constant parameters + for (parameter, constant) in parameters.iter().zip(constants) { + if let Some(constant) = constant { + let constant = make_constant(&mut function.dfg, constant, entry_block_id); + function.dfg.set_value_from_id(*parameter, constant); + } + } + + function.dfg[entry_block_id].instructions_mut().extend(entry_block_instructions); + + // Then keep only those parameters for which the argument is not a constant + let mut new_parameters = Vec::new(); + for (index, constant) in constants.iter().enumerate() { + if constant.is_none() { + new_parameters.push(parameters[index]); + } + } + let entry_block = &mut function.dfg[entry_block_id]; + entry_block.set_parameters(new_parameters); + + // Next, optimize the function a bit... + + // Help unrolling determine bounds. + function.as_slice_optimization(); + // Prepare for unrolling + function.loop_invariant_code_motion(); + // We might not be able to unroll all loops without fully inlining them, so ignore errors. + let _ = function.unroll_loops_iteratively(); + // Reduce the number of redundant stores/loads after unrolling + function.mem2reg(); + // Try to reduce the number of blocks. + function.simplify_function(); + // Remove leftover instructions. + function.dead_instruction_elimination(true, false, false); + + function +} + +fn make_constant(dfg: &mut DataFlowGraph, constant: &Constant, block: BasicBlockId) -> ValueId { + match constant { + Constant::Number(value, typ) => dfg.make_constant(*value, *typ), + Constant::Array(constants, typ) => { + let elements = + constants.iter().map(|constant| make_constant(dfg, constant, block)).collect(); + let instruction = Instruction::MakeArray { elements, typ: typ.clone() }; + // TODO: call stack + dfg.insert_instruction_and_results(instruction, block, None, CallStackId::root()) + .first() + } + } +} + +#[cfg(test)] +mod tests { + use crate::ssa::{opt::assert_normalized_ssa_equals, ssa_gen::Ssa}; + + #[test] + fn inlines_if_same_constant_is_always_used() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v3 = call f1(Field 1, v0) -> Field + v4 = call f1(Field 1, v0) -> Field + v5 = add v3, v4 + return v5 + } + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v3 = add v0, Field 1 + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = call f2(v0) -> Field + v3 = call f2(v0) -> Field + v4 = add v2, v3 + return v4 + } + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v3 = add v0, Field 1 + return v3 + } + brillig(inline) fn foo f2 { + b0(v0: Field): + return Field 2 + } + "; + let ssa = ssa.inline_constants_into_brillig_functions(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_if_same_array_is_always_used() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = make_array [Field 1, Field 2]: [Field; 2] + v3 = call f1(v2, v0) -> Field + v4 = make_array [Field 1, Field 2]: [Field; 2] + v5 = call f1(v4, v0) -> Field + v6 = add v3, v5 + return v6 + } + brillig(inline) fn foo f1 { + b0(v0: [Field; 2], v1: Field): + v2 = array_get v0, index u32 0 -> Field + v3 = add v2, v1 + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v3 = make_array [Field 1, Field 2] : [Field; 2] + v5 = call f2(v0) -> Field + v6 = make_array [Field 1, Field 2] : [Field; 2] + v7 = call f2(v0) -> Field + v8 = add v5, v7 + return v8 + } + brillig(inline) fn foo f1 { + b0(v0: [Field; 2], v1: Field): + v3 = array_get v0, index u32 0 -> Field + v4 = add v3, v1 + return v4 + } + brillig(inline) fn foo f2 { + b0(v0: Field): + v2 = add Field 1, v0 + return v2 + } + "; + let ssa = ssa.inline_constants_into_brillig_functions(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn does_not_inline_if_inlined_function_does_not_have_less_instructions() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v3 = call f1(Field 1, v0) -> Field + v4 = call f1(Field 1, v0) -> Field + v5 = add v3, v4 + return v5 + } + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.inline_constants_into_brillig_functions(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn does_not_inline_brillig_call_into_brillig_function() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field): + v3 = call f1(Field 1, v0) -> Field + v4 = call f1(Field 1, v0) -> Field + v5 = add v3, v4 + return v5 + } + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.inline_constants_into_brillig_functions(); + assert_normalized_ssa_equals(ssa, src); + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index a9784d4c7cf..4f8514da6f1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -16,6 +16,7 @@ mod defunctionalize; mod die; pub(crate) mod flatten_cfg; mod hint; +mod inline_constants_into_brillig_functions; pub(crate) mod inlining; mod loop_invariant; mod make_constrain_not_equal; diff --git a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index 3d812870c06..c616468e6e5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -221,7 +221,8 @@ fn check_for_negated_jmpif_condition( call_stack, }) = function.dfg[block].terminator() { - if let Value::Instruction { instruction, .. } = function.dfg[*condition] { + let condition = function.dfg.resolve(*condition); + if let Value::Instruction { instruction, .. } = function.dfg[condition] { if let Instruction::Not(negated_condition) = function.dfg[instruction] { let call_stack = *call_stack; let jmpif = TerminatorInstruction::JmpIf { diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index ad52473620d..94c2bb579bd 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -85,6 +85,17 @@ impl Ssa { new_id } + /// Adds a new function to the program, but only if the given lambda returns `Some`. + pub(crate) fn maybe_add_fn( + &mut self, + build_with_id: impl FnOnce(FunctionId) -> Option, + ) -> Option { + let new_id = self.next_id.next(); + let function = build_with_id(new_id)?; + self.functions.insert(new_id, function); + Some(new_id) + } + pub(crate) fn generate_entry_point_index(mut self) -> Self { let entry_points = self.functions.keys().filter(|function| self.is_entry_point(**function)).enumerate(); diff --git a/test_programs/execution_success/reference_counts/src/main.nr b/test_programs/execution_success/reference_counts/src/main.nr index 2ee8a13f7a4..276210ef58a 100644 --- a/test_programs/execution_success/reference_counts/src/main.nr +++ b/test_programs/execution_success/reference_counts/src/main.nr @@ -1,8 +1,8 @@ -use std::mem::array_refcount; +use std::{hint::black_box, mem::array_refcount}; fn main() { let mut array = [0, 1, 2]; - assert_refcount(array, 1); + assert_refcount(black_box(array), 1); borrow(array, array_refcount(array)); borrow_mut(&mut array, array_refcount(array)); @@ -22,14 +22,14 @@ fn main() { } fn borrow(array: [Field; 3], rc_before_call: u32) { - assert_refcount(array, rc_before_call); + assert_refcount(black_box(array), rc_before_call); println(array[0]); } fn borrow_mut(array: &mut [Field; 3], rc_before_call: u32) { // Optimization: inc_rc isn't needed since there is only one array (`array`) // of the same type that `array` can be modified through - assert_refcount(*array, rc_before_call + 0); + assert_refcount(black_box(*array), rc_before_call + 0); array[0] = 3; println(array[0]); } @@ -37,7 +37,7 @@ fn borrow_mut(array: &mut [Field; 3], rc_before_call: u32) { // Returning a copy of the array, otherwise the SSA can end up optimizing away // the `array_set`, with the whole body just becoming basically `println(4);`. fn copy_mut(mut array: [Field; 3], rc_before_call: u32) -> [Field; 3] { - assert_refcount(array, rc_before_call + 1); + assert_refcount(black_box(array), rc_before_call + 1); array[0] = 4; println(array[0]); array @@ -46,8 +46,8 @@ fn copy_mut(mut array: [Field; 3], rc_before_call: u32) -> [Field; 3] { /// Borrow the same array mutably through both parameters, inc_rc is necessary here, although /// only one is needed to bring the rc from 1 to 2. fn borrow_mut_two(array1: &mut [Field; 3], array2: &mut [Field; 3], rc_before_call: u32) { - assert_refcount(*array1, rc_before_call + 1); - assert_refcount(*array2, rc_before_call + 1); + assert_refcount(black_box(*array1), rc_before_call + 1); + assert_refcount(black_box(*array2), rc_before_call + 1); array1[0] = 5; array2[0] = 6; println(array1[0]); // array1 & 2 alias, so this should also print 6 @@ -62,8 +62,8 @@ fn borrow_mut_two_separate( rc_before_call1: u32, rc_before_call2: u32, ) { - assert_refcount(*array1, rc_before_call1 + 0); - assert_refcount(*array2, rc_before_call2 + 0); + assert_refcount(black_box(*array1), rc_before_call1 + 0); + assert_refcount(black_box(*array2), rc_before_call2 + 0); array1[0] = 7; array2[0] = 8; println(array1[0]);