diff --git a/compiler/noirc_evaluator/src/brillig/mod.rs b/compiler/noirc_evaluator/src/brillig/mod.rs index 3783af8ca56..7226c621c9b 100644 --- a/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/compiler/noirc_evaluator/src/brillig/mod.rs @@ -38,7 +38,7 @@ pub struct BrilligOptions { /// Context structure for the brillig pass. /// It stores brillig-related data required for brillig generation. -#[derive(Default)] +#[derive(Default, Clone)] pub struct Brillig { /// Maps SSA function labels to their brillig artifact ssa_function_to_brillig: HashMap>, diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index bbbc43494a6..243c9dca917 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -7,7 +7,7 @@ //! This module heavily borrows from Cranelift use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, fs::File, io::Write, path::{Path, PathBuf}, @@ -15,7 +15,7 @@ use std::{ use crate::{ acir::ssa::Artifacts, - brillig::BrilligOptions, + brillig::{Brillig, BrilligOptions}, errors::{RuntimeError, SsaReport}, }; use acvm::{ @@ -91,17 +91,142 @@ pub struct SsaEvaluatorOptions { pub max_bytecode_increase_percent: Option, } +/// An SSA pass reified as a construct we can put into a list, +/// which facilitates equivalence testing between different +/// stages of the processing pipeline. +pub struct SsaPass<'a> { + msg: &'static str, + run: Box Result + 'a>, +} + +impl<'a> SsaPass<'a> { + pub(crate) fn new(f: F, msg: &'static str) -> Self + where + F: Fn(Ssa) -> Ssa + 'a, + { + Self::new_try(move |ssa| Ok(f(ssa)), msg) + } + + pub(crate) fn new_try(f: F, msg: &'static str) -> Self + where + F: Fn(Ssa) -> Result + 'a, + { + Self { msg, run: Box::new(f) } + } +} + pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec); +/// The default SSA optimization pipeline. +/// +/// After these passes everything is ready for execution, which is +/// something we take can advantage of in the [secondary_passes]. +pub fn primary_passes(options: &SsaEvaluatorOptions) -> Vec { + vec![ + SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"), + SsaPass::new(Ssa::defunctionalize, "Defunctionalization"), + SsaPass::new(Ssa::inline_simple_functions, "Inlining simple functions"), + // BUG: Enabling this mem2reg causes an integration test failure in aztec-package; see: + // https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2622809518 + //SsaPass::new(Ssa::mem2reg, "Mem2Reg (1st)"), + SsaPass::new(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs"), + SsaPass::new( + move |ssa| ssa.preprocess_functions(options.inliner_aggressiveness), + "Preprocessing Functions", + ), + SsaPass::new(move |ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining"), + // Run mem2reg with the CFG separated into blocks + SsaPass::new(Ssa::mem2reg, "Mem2Reg"), + SsaPass::new(Ssa::simplify_cfg, "Simplifying"), + SsaPass::new(Ssa::as_slice_optimization, "`as_slice` optimization"), + SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"), + SsaPass::new_try( + Ssa::evaluate_static_assert_and_assert_constant, + "`static_assert` and `assert_constant`", + ), + SsaPass::new(Ssa::purity_analysis, "Purity Analysis"), + SsaPass::new(Ssa::loop_invariant_code_motion, "Loop Invariant Code Motion"), + SsaPass::new_try( + move |ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent), + "Unrolling", + ), + SsaPass::new(Ssa::simplify_cfg, "Simplifying"), + SsaPass::new(Ssa::mem2reg, "Mem2Reg"), + SsaPass::new(Ssa::flatten_cfg, "Flattening"), + SsaPass::new(Ssa::remove_bit_shifts, "Removing Bit Shifts"), + // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores + SsaPass::new(Ssa::mem2reg, "Mem2Reg"), + // Run the inlining pass again to handle functions with `InlineType::NoPredicates`. + // Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point. + // This pass must come immediately following `mem2reg` as the succeeding passes + // may create an SSA which inlining fails to handle. + SsaPass::new( + move |ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness), + "Inlining", + ), + SsaPass::new(Ssa::remove_if_else, "Remove IfElse"), + SsaPass::new(Ssa::purity_analysis, "Purity Analysis"), + SsaPass::new(Ssa::fold_constants, "Constant Folding"), + SsaPass::new(Ssa::flatten_basic_conditionals, "Simplify conditionals for unconstrained"), + SsaPass::new(Ssa::remove_enable_side_effects, "EnableSideEffectsIf removal"), + SsaPass::new(Ssa::fold_constants_using_constraints, "Constraint Folding"), + SsaPass::new(Ssa::make_constrain_not_equal_instructions, "Adding constrain not equal"), + SsaPass::new(Ssa::check_u128_mul_overflow, "Check u128 mul overflow"), + SsaPass::new(Ssa::dead_instruction_elimination, "Dead Instruction Elimination"), + SsaPass::new(Ssa::simplify_cfg, "Simplifying"), + SsaPass::new(Ssa::array_set_optimization, "Array Set Optimizations"), + // The Brillig globals pass expected that we have the used globals map set for each function. + // The used globals map is determined during DIE, so we should duplicate entry points before a DIE pass run. + SsaPass::new(Ssa::brillig_entry_point_analysis, "Brillig Entry Point Analysis"), + // Remove any potentially unnecessary duplication from the Brillig entry point analysis. + SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"), + SsaPass::new(Ssa::remove_truncate_after_range_check, "Removing Truncate after RangeCheck"), + // This pass makes transformations specific to Brillig generation. + // It must be the last pass to either alter or add new instructions before Brillig generation, + // as other semantics in the compiler can potentially break (e.g. inserting instructions). + // We can safely place the pass before DIE as that pass only removes instructions. + // We also need DIE's tracking of used globals in case the array get transformations + // end up using an existing constant from the globals space. + SsaPass::new(Ssa::brillig_array_gets, "Brillig Array Get Optimizations"), + SsaPass::new(Ssa::dead_instruction_elimination, "Dead Instruction Elimination"), + ] +} + +/// The second SSA pipeline, in which we take the Brillig functions compiled after +/// the primary pipeline, and execute the ones with all-constant arguments, +/// to replace the calls with the return value. +pub fn secondary_passes(brillig: &Brillig) -> Vec { + vec![ + SsaPass::new(move |ssa| ssa.fold_constants_with_brillig(brillig), "Inlining Brillig Calls"), + // It could happen that we inlined all calls to a given brillig function. + // In that case it's unused so we can remove it. This is what we check next. + SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"), + SsaPass::new(Ssa::dead_instruction_elimination_acir, "Dead Instruction Elimination"), + ] +} + /// Optimize the given program by converting it into SSA /// form and performing optimizations there. When finished, /// convert the final SSA into an ACIR program and return it. /// An ACIR program is made up of both ACIR functions /// and Brillig functions for unconstrained execution. -pub(crate) fn optimize_into_acir( +/// +/// The `primary` SSA passes are applied on the initial SSA. +/// Then we compile the Brillig functions, and use the output +/// to run a `secondary` pass, which can use the Brillig +/// artifacts to do constant folding. +/// +/// See the [primary_passes] and [secondary_passes] for +/// the default implementations. +fn optimize_into_acir( program: Program, options: &SsaEvaluatorOptions, -) -> Result { + primary: &[SsaPass], + secondary: S, +) -> Result +where + S: for<'b> Fn(&'b Brillig) -> Vec>, +{ let ssa_gen_span = span!(Level::TRACE, "ssa_generation"); let ssa_gen_span_guard = ssa_gen_span.enter(); let builder = SsaBuilder::new( @@ -111,7 +236,9 @@ pub(crate) fn optimize_into_acir( &options.emit_ssa, )?; - let mut ssa = optimize_all(builder, options)?; + let mut builder = builder.run_passes(primary)?; + let passed = std::mem::take(&mut builder.passed); + let mut ssa = builder.finish(); let mut ssa_level_warnings = vec![]; @@ -129,12 +256,9 @@ pub(crate) fn optimize_into_acir( ssa, ssa_logging: options.ssa_logging.clone(), print_codegen_timings: options.print_codegen_timings, + passed, } - .run_pass(|ssa| ssa.fold_constants_with_brillig(&brillig), "Inlining Brillig Calls Inlining") - // It could happen that we inlined all calls to a given brillig function. - // In that case it's unused so we can remove it. This is what we check next. - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (4th)") - .run_pass(Ssa::dead_instruction_elimination_acir, "Dead Instruction Elimination (3rd)") + .run_passes(&secondary(&brillig))? .finish(); if !options.skip_underconstrained_check { @@ -166,78 +290,6 @@ pub(crate) fn optimize_into_acir( Ok(ArtifactsAndWarnings(artifacts, ssa_level_warnings)) } -/// Run all SSA passes. -fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result { - Ok(builder - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)") - .run_pass(Ssa::defunctionalize, "Defunctionalization") - .run_pass(Ssa::inline_simple_functions, "Inlining simple functions") - // BUG: Enabling this mem2reg causes an integration test failure in aztec-package; see: - // https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2622809518 - //.run_pass(Ssa::mem2reg, "Mem2Reg (1st)") - .run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs") - .run_pass( - |ssa| ssa.preprocess_functions(options.inliner_aggressiveness), - "Preprocessing Functions", - ) - .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)") - // Run mem2reg with the CFG separated into blocks - .run_pass(Ssa::mem2reg, "Mem2Reg (2nd)") - .run_pass(Ssa::simplify_cfg, "Simplifying (1st)") - .run_pass(Ssa::as_slice_optimization, "`as_slice` optimization") - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)") - .try_run_pass( - Ssa::evaluate_static_assert_and_assert_constant, - "`static_assert` and `assert_constant`", - )? - .run_pass(Ssa::purity_analysis, "Purity Analysis") - .run_pass(Ssa::loop_invariant_code_motion, "Loop Invariant Code Motion") - .try_run_pass( - |ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent), - "Unrolling", - )? - .run_pass(Ssa::simplify_cfg, "Simplifying (2nd)") - .run_pass(Ssa::mem2reg, "Mem2Reg (3rd)") - .run_pass(Ssa::flatten_cfg, "Flattening") - .run_pass(Ssa::remove_bit_shifts, "Removing Bit Shifts") - // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores - .run_pass(Ssa::mem2reg, "Mem2Reg (4th)") - // Run the inlining pass again to handle functions with `InlineType::NoPredicates`. - // Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point. - // This pass must come immediately following `mem2reg` as the succeeding passes - // may create an SSA which inlining fails to handle. - .run_pass( - |ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness), - "Inlining (2nd)", - ) - .run_pass(Ssa::remove_if_else, "Remove IfElse") - .run_pass(Ssa::purity_analysis, "Purity Analysis (2nd)") - .run_pass(Ssa::fold_constants, "Constant Folding") - .run_pass(Ssa::flatten_basic_conditionals, "Simplify conditionals for unconstrained") - .run_pass(Ssa::remove_enable_side_effects, "EnableSideEffectsIf removal") - .run_pass(Ssa::fold_constants_using_constraints, "Constraint Folding") - .run_pass(Ssa::make_constrain_not_equal_instructions, "Adding constrain not equal") - .run_pass(Ssa::check_u128_mul_overflow, "Check u128 mul overflow") - .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (1st)") - .run_pass(Ssa::simplify_cfg, "Simplifying (3rd):") - .run_pass(Ssa::array_set_optimization, "Array Set Optimizations") - // The Brillig globals pass expected that we have the used globals map set for each function. - // The used globals map is determined during DIE, so we should duplicate entry points before a DIE pass run. - .run_pass(Ssa::brillig_entry_point_analysis, "Brillig Entry Point Analysis") - // Remove any potentially unnecessary duplication from the Brillig entry point analysis. - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (3rd)") - .run_pass(Ssa::remove_truncate_after_range_check, "Removing Truncate after RangeCheck") - // This pass makes transformations specific to Brillig generation. - // It must be the last pass to either alter or add new instructions before Brillig generation, - // as other semantics in the compiler can potentially break (e.g. inserting instructions). - // We can safely place the pass before DIE as that pass only removes instructions. - // We also need DIE's tracking of used globals in case the array get transformations - // end up using an existing constant from the globals space. - .run_pass(Ssa::brillig_array_gets, "Brillig Array Get Optimizations") - .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)") - .finish()) -} - // Helper to time SSA passes fn time(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T { let start_time = chrono::Utc::now().time(); @@ -308,6 +360,21 @@ pub fn create_program( program: Program, options: &SsaEvaluatorOptions, ) -> Result { + create_program_with_passes(program, options, &primary_passes(options), secondary_passes) +} + +/// Compiles the [`Program`] into [`ACIR`][acvm::acir::circuit::Program] by running it through +/// `primary` and `secondary` SSA passes. +#[tracing::instrument(level = "trace", skip_all)] +pub fn create_program_with_passes( + program: Program, + options: &SsaEvaluatorOptions, + primary: &[SsaPass], + secondary: S, +) -> Result +where + S: for<'b> Fn(&'b Brillig) -> Vec>, +{ let debug_variables = program.debug_variables.clone(); let debug_types = program.debug_types.clone(); let debug_functions = program.debug_functions.clone(); @@ -317,7 +384,7 @@ pub fn create_program( let ArtifactsAndWarnings( (generated_acirs, generated_brillig, brillig_function_names, error_types), ssa_level_warnings, - ) = optimize_into_acir(program, options)?; + ) = optimize_into_acir(program, options, primary, secondary)?; assert_eq!( generated_acirs.len(), @@ -480,6 +547,7 @@ struct SsaBuilder { ssa: Ssa, ssa_logging: SsaLogging, print_codegen_timings: bool, + passed: HashMap, } impl SsaBuilder { @@ -499,14 +567,26 @@ impl SsaBuilder { let ssa_path = emit_ssa.with_extension("ssa.json"); write_to_file(&serde_json::to_vec(&ssa).unwrap(), &ssa_path); } - Ok(SsaBuilder { ssa_logging, print_codegen_timings, ssa }.print("Initial SSA")) + let builder = + SsaBuilder { ssa_logging, print_codegen_timings, ssa, passed: Default::default() }; + let builder = builder.print("Initial SSA"); + Ok(builder) } fn finish(self) -> Ssa { self.ssa.generate_entry_point_index() } + /// Run a list of SSA passes. + fn run_passes(mut self, passes: &[SsaPass]) -> Result { + for pass in passes { + self = self.try_run_pass(|ssa| (pass.run)(ssa), pass.msg)?; + } + Ok(self) + } + /// Runs the given SSA pass and prints the SSA afterward if `print_ssa_passes` is true. + #[allow(dead_code)] fn run_pass(mut self, pass: F, msg: &str) -> Self where F: FnOnce(Ssa) -> Ssa, @@ -525,6 +605,9 @@ impl SsaBuilder { } fn print(mut self, msg: &str) -> Self { + // Count the number of times we have seen this message. + let cnt = self.passed.entry(msg.to_string()).and_modify(|cnt| *cnt += 1).or_insert(1); + // Always normalize if we are going to print at least one of the passes if !matches!(self.ssa_logging, SsaLogging::None) { self.ssa.normalize_ids(); @@ -541,7 +624,7 @@ impl SsaBuilder { } }; if print_ssa_pass { - println!("After {msg}:\n{}", self.ssa); + println!("After {msg} ({cnt}):\n{}", self.ssa); } self } diff --git a/compiler/noirc_evaluator/src/ssa/opt/hint.rs b/compiler/noirc_evaluator/src/ssa/opt/hint.rs index 3a70a9939a9..679bc792eca 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/hint.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/hint.rs @@ -6,7 +6,7 @@ mod tests { assert_ssa_snapshot, brillig::BrilligOptions, errors::RuntimeError, - ssa::{Ssa, SsaBuilder, SsaEvaluatorOptions, SsaLogging, optimize_all}, + ssa::{Ssa, SsaBuilder, SsaEvaluatorOptions, SsaLogging, primary_passes}, }; fn run_all_passes(ssa: Ssa) -> Result { @@ -27,9 +27,10 @@ mod tests { ssa, ssa_logging: options.ssa_logging.clone(), print_codegen_timings: false, + passed: Default::default(), }; - optimize_all(builder, options) + Ok(builder.run_passes(&primary_passes(options))?.finish()) } /// Test that the `std::hint::black_box` function prevents some of the optimizations. diff --git a/tooling/ast_fuzzer/fuzz/src/lib.rs b/tooling/ast_fuzzer/fuzz/src/lib.rs index 2e192ff7fd5..dfddf0b118a 100644 --- a/tooling/ast_fuzzer/fuzz/src/lib.rs +++ b/tooling/ast_fuzzer/fuzz/src/lib.rs @@ -3,6 +3,7 @@ use color_eyre::eyre; use noir_ast_fuzzer::DisplayAstAsNoir; use noir_ast_fuzzer::compare::{CompareResult, CompareSsa}; use noirc_abi::input_parser::Format; +use noirc_evaluator::ssa::{primary_passes, secondary_passes}; use noirc_evaluator::{ brillig::BrilligOptions, ssa::{self, SsaEvaluatorOptions, SsaProgramArtifact}, @@ -50,12 +51,13 @@ pub fn create_ssa_or_die( eprintln!("---\n{}\n---", program); } - ssa::create_program(program, options).unwrap_or_else(|e| { - panic!( - "failed to compile program: {}{e}", - msg.map(|s| format!("{s}: ")).unwrap_or_default() - ) - }) + ssa::create_program_with_passes(program, options, &primary_passes(options), secondary_passes) + .unwrap_or_else(|e| { + panic!( + "failed to compile program: {}{e}", + msg.map(|s| format!("{s}: ")).unwrap_or_default() + ) + }) } /// Compare the execution result and print the inputs if the result is a failure.