Skip to content
Merged
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub struct BrilligOptions {

/// Context structure for the brillig pass.
/// It stores brillig-related data required for brillig generation.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct Brillig {
/// Maps SSA function labels to their brillig artifact
ssa_function_to_brillig: HashMap<FunctionId, BrilligArtifact<FieldElement>>,
Expand Down
253 changes: 168 additions & 85 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
//! This module heavily borrows from Cranelift

use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashMap},
fs::File,
io::Write,
path::{Path, PathBuf},
};

use crate::{
acir::ssa::Artifacts,
brillig::BrilligOptions,
brillig::{Brillig, BrilligOptions},
errors::{RuntimeError, SsaReport},
};
use acvm::{
Expand Down Expand Up @@ -91,17 +91,142 @@ pub struct SsaEvaluatorOptions {
pub max_bytecode_increase_percent: Option<i32>,
}

/// An SSA pass reified as a construct we can put into a list,
/// which facilitates equivalence testing between different
/// stages of the processing pipeline.
pub struct SsaPass<'a> {
msg: &'static str,
run: Box<dyn Fn(Ssa) -> Result<Ssa, RuntimeError> + 'a>,
}

impl<'a> SsaPass<'a> {
pub(crate) fn new<F>(f: F, msg: &'static str) -> Self
where
F: Fn(Ssa) -> Ssa + 'a,
{
Self::new_try(move |ssa| Ok(f(ssa)), msg)
}

pub(crate) fn new_try<F>(f: F, msg: &'static str) -> Self
where
F: Fn(Ssa) -> Result<Ssa, RuntimeError> + 'a,
{
Self { msg, run: Box::new(f) }
}
}

pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec<SsaReport>);

/// The default SSA optimization pipeline.
///
/// After these passes everything is ready for execution, which is
/// something we take can advantage of in the [secondary_passes].
pub fn primary_passes(options: &SsaEvaluatorOptions) -> Vec<SsaPass> {
vec![
SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"),
SsaPass::new(Ssa::defunctionalize, "Defunctionalization"),
SsaPass::new(Ssa::inline_simple_functions, "Inlining simple functions"),
// BUG: Enabling this mem2reg causes an integration test failure in aztec-package; see:
// https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2622809518
//SsaPass::new(Ssa::mem2reg, "Mem2Reg (1st)"),
SsaPass::new(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs"),
SsaPass::new(
move |ssa| ssa.preprocess_functions(options.inliner_aggressiveness),
"Preprocessing Functions",
),
SsaPass::new(move |ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining"),
// Run mem2reg with the CFG separated into blocks
SsaPass::new(Ssa::mem2reg, "Mem2Reg"),
SsaPass::new(Ssa::simplify_cfg, "Simplifying"),
SsaPass::new(Ssa::as_slice_optimization, "`as_slice` optimization"),
SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"),
SsaPass::new_try(
Ssa::evaluate_static_assert_and_assert_constant,
"`static_assert` and `assert_constant`",
),
SsaPass::new(Ssa::purity_analysis, "Purity Analysis"),
SsaPass::new(Ssa::loop_invariant_code_motion, "Loop Invariant Code Motion"),
SsaPass::new_try(
move |ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent),
"Unrolling",
),
SsaPass::new(Ssa::simplify_cfg, "Simplifying"),
SsaPass::new(Ssa::mem2reg, "Mem2Reg"),
SsaPass::new(Ssa::flatten_cfg, "Flattening"),
SsaPass::new(Ssa::remove_bit_shifts, "Removing Bit Shifts"),
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
SsaPass::new(Ssa::mem2reg, "Mem2Reg"),
// Run the inlining pass again to handle functions with `InlineType::NoPredicates`.
// Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point.
// This pass must come immediately following `mem2reg` as the succeeding passes
// may create an SSA which inlining fails to handle.
SsaPass::new(
move |ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness),
"Inlining",
),
SsaPass::new(Ssa::remove_if_else, "Remove IfElse"),
SsaPass::new(Ssa::purity_analysis, "Purity Analysis"),
SsaPass::new(Ssa::fold_constants, "Constant Folding"),
SsaPass::new(Ssa::flatten_basic_conditionals, "Simplify conditionals for unconstrained"),
SsaPass::new(Ssa::remove_enable_side_effects, "EnableSideEffectsIf removal"),
SsaPass::new(Ssa::fold_constants_using_constraints, "Constraint Folding"),
SsaPass::new(Ssa::make_constrain_not_equal_instructions, "Adding constrain not equal"),
SsaPass::new(Ssa::check_u128_mul_overflow, "Check u128 mul overflow"),
SsaPass::new(Ssa::dead_instruction_elimination, "Dead Instruction Elimination"),
SsaPass::new(Ssa::simplify_cfg, "Simplifying"),
SsaPass::new(Ssa::array_set_optimization, "Array Set Optimizations"),
// The Brillig globals pass expected that we have the used globals map set for each function.
// The used globals map is determined during DIE, so we should duplicate entry points before a DIE pass run.
SsaPass::new(Ssa::brillig_entry_point_analysis, "Brillig Entry Point Analysis"),
// Remove any potentially unnecessary duplication from the Brillig entry point analysis.
SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"),
SsaPass::new(Ssa::remove_truncate_after_range_check, "Removing Truncate after RangeCheck"),
// This pass makes transformations specific to Brillig generation.
// It must be the last pass to either alter or add new instructions before Brillig generation,
// as other semantics in the compiler can potentially break (e.g. inserting instructions).
// We can safely place the pass before DIE as that pass only removes instructions.
// We also need DIE's tracking of used globals in case the array get transformations
// end up using an existing constant from the globals space.
SsaPass::new(Ssa::brillig_array_gets, "Brillig Array Get Optimizations"),
SsaPass::new(Ssa::dead_instruction_elimination, "Dead Instruction Elimination"),
]
}

/// The second SSA pipeline, in which we take the Brillig functions compiled after
/// the primary pipeline, and execute the ones with all-constant arguments,
/// to replace the calls with the return value.
pub fn secondary_passes(brillig: &Brillig) -> Vec<SsaPass> {
vec![
SsaPass::new(move |ssa| ssa.fold_constants_with_brillig(brillig), "Inlining Brillig Calls"),
// It could happen that we inlined all calls to a given brillig function.
// In that case it's unused so we can remove it. This is what we check next.
SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"),
SsaPass::new(Ssa::dead_instruction_elimination_acir, "Dead Instruction Elimination"),
]
}

/// Optimize the given program by converting it into SSA
/// form and performing optimizations there. When finished,
/// convert the final SSA into an ACIR program and return it.
/// An ACIR program is made up of both ACIR functions
/// and Brillig functions for unconstrained execution.
pub(crate) fn optimize_into_acir(
///
/// The `primary` SSA passes are applied on the initial SSA.
/// Then we compile the Brillig functions, and use the output
/// to run a `secondary` pass, which can use the Brillig
/// artifacts to do constant folding.
///
/// See the [primary_passes] and [secondary_passes] for
/// the default implementations.
fn optimize_into_acir<S>(
program: Program,
options: &SsaEvaluatorOptions,
) -> Result<ArtifactsAndWarnings, RuntimeError> {
primary: &[SsaPass],
secondary: S,
) -> Result<ArtifactsAndWarnings, RuntimeError>
where
S: for<'b> Fn(&'b Brillig) -> Vec<SsaPass<'b>>,
{
let ssa_gen_span = span!(Level::TRACE, "ssa_generation");
let ssa_gen_span_guard = ssa_gen_span.enter();
let builder = SsaBuilder::new(
Expand All @@ -111,7 +236,9 @@ pub(crate) fn optimize_into_acir(
&options.emit_ssa,
)?;

let mut ssa = optimize_all(builder, options)?;
let mut builder = builder.run_passes(primary)?;
let passed = std::mem::take(&mut builder.passed);
let mut ssa = builder.finish();

let mut ssa_level_warnings = vec![];

Expand All @@ -129,12 +256,9 @@ pub(crate) fn optimize_into_acir(
ssa,
ssa_logging: options.ssa_logging.clone(),
print_codegen_timings: options.print_codegen_timings,
passed,
}
.run_pass(|ssa| ssa.fold_constants_with_brillig(&brillig), "Inlining Brillig Calls Inlining")
// It could happen that we inlined all calls to a given brillig function.
// In that case it's unused so we can remove it. This is what we check next.
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (4th)")
.run_pass(Ssa::dead_instruction_elimination_acir, "Dead Instruction Elimination (3rd)")
.run_passes(&secondary(&brillig))?
.finish();

if !options.skip_underconstrained_check {
Expand Down Expand Up @@ -166,78 +290,6 @@ pub(crate) fn optimize_into_acir(
Ok(ArtifactsAndWarnings(artifacts, ssa_level_warnings))
}

/// Run all SSA passes.
fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ssa, RuntimeError> {
Ok(builder
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)")
.run_pass(Ssa::defunctionalize, "Defunctionalization")
.run_pass(Ssa::inline_simple_functions, "Inlining simple functions")
// BUG: Enabling this mem2reg causes an integration test failure in aztec-package; see:
// https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2622809518
//.run_pass(Ssa::mem2reg, "Mem2Reg (1st)")
.run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs")
.run_pass(
|ssa| ssa.preprocess_functions(options.inliner_aggressiveness),
"Preprocessing Functions",
)
.run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "Mem2Reg (2nd)")
.run_pass(Ssa::simplify_cfg, "Simplifying (1st)")
.run_pass(Ssa::as_slice_optimization, "`as_slice` optimization")
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)")
.try_run_pass(
Ssa::evaluate_static_assert_and_assert_constant,
"`static_assert` and `assert_constant`",
)?
.run_pass(Ssa::purity_analysis, "Purity Analysis")
.run_pass(Ssa::loop_invariant_code_motion, "Loop Invariant Code Motion")
.try_run_pass(
|ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent),
"Unrolling",
)?
.run_pass(Ssa::simplify_cfg, "Simplifying (2nd)")
.run_pass(Ssa::mem2reg, "Mem2Reg (3rd)")
.run_pass(Ssa::flatten_cfg, "Flattening")
.run_pass(Ssa::remove_bit_shifts, "Removing Bit Shifts")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
.run_pass(Ssa::mem2reg, "Mem2Reg (4th)")
// Run the inlining pass again to handle functions with `InlineType::NoPredicates`.
// Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point.
// This pass must come immediately following `mem2reg` as the succeeding passes
// may create an SSA which inlining fails to handle.
.run_pass(
|ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness),
"Inlining (2nd)",
)
.run_pass(Ssa::remove_if_else, "Remove IfElse")
.run_pass(Ssa::purity_analysis, "Purity Analysis (2nd)")
.run_pass(Ssa::fold_constants, "Constant Folding")
.run_pass(Ssa::flatten_basic_conditionals, "Simplify conditionals for unconstrained")
.run_pass(Ssa::remove_enable_side_effects, "EnableSideEffectsIf removal")
.run_pass(Ssa::fold_constants_using_constraints, "Constraint Folding")
.run_pass(Ssa::make_constrain_not_equal_instructions, "Adding constrain not equal")
.run_pass(Ssa::check_u128_mul_overflow, "Check u128 mul overflow")
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (1st)")
.run_pass(Ssa::simplify_cfg, "Simplifying (3rd):")
.run_pass(Ssa::array_set_optimization, "Array Set Optimizations")
// The Brillig globals pass expected that we have the used globals map set for each function.
// The used globals map is determined during DIE, so we should duplicate entry points before a DIE pass run.
.run_pass(Ssa::brillig_entry_point_analysis, "Brillig Entry Point Analysis")
// Remove any potentially unnecessary duplication from the Brillig entry point analysis.
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (3rd)")
.run_pass(Ssa::remove_truncate_after_range_check, "Removing Truncate after RangeCheck")
// This pass makes transformations specific to Brillig generation.
// It must be the last pass to either alter or add new instructions before Brillig generation,
// as other semantics in the compiler can potentially break (e.g. inserting instructions).
// We can safely place the pass before DIE as that pass only removes instructions.
// We also need DIE's tracking of used globals in case the array get transformations
// end up using an existing constant from the globals space.
.run_pass(Ssa::brillig_array_gets, "Brillig Array Get Optimizations")
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)")
.finish())
}

// Helper to time SSA passes
fn time<T>(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T {
let start_time = chrono::Utc::now().time();
Expand Down Expand Up @@ -308,6 +360,21 @@ pub fn create_program(
program: Program,
options: &SsaEvaluatorOptions,
) -> Result<SsaProgramArtifact, RuntimeError> {
create_program_with_passes(program, options, &primary_passes(options), secondary_passes)
}

/// Compiles the [`Program`] into [`ACIR`][acvm::acir::circuit::Program] by running it through
/// `primary` and `secondary` SSA passes.
#[tracing::instrument(level = "trace", skip_all)]
pub fn create_program_with_passes<S>(
program: Program,
options: &SsaEvaluatorOptions,
primary: &[SsaPass],
secondary: S,
) -> Result<SsaProgramArtifact, RuntimeError>
where
S: for<'b> Fn(&'b Brillig) -> Vec<SsaPass<'b>>,
{
let debug_variables = program.debug_variables.clone();
let debug_types = program.debug_types.clone();
let debug_functions = program.debug_functions.clone();
Expand All @@ -317,7 +384,7 @@ pub fn create_program(
let ArtifactsAndWarnings(
(generated_acirs, generated_brillig, brillig_function_names, error_types),
ssa_level_warnings,
) = optimize_into_acir(program, options)?;
) = optimize_into_acir(program, options, primary, secondary)?;

assert_eq!(
generated_acirs.len(),
Expand Down Expand Up @@ -480,6 +547,7 @@ struct SsaBuilder {
ssa: Ssa,
ssa_logging: SsaLogging,
print_codegen_timings: bool,
passed: HashMap<String, usize>,
}

impl SsaBuilder {
Expand All @@ -499,14 +567,26 @@ impl SsaBuilder {
let ssa_path = emit_ssa.with_extension("ssa.json");
write_to_file(&serde_json::to_vec(&ssa).unwrap(), &ssa_path);
}
Ok(SsaBuilder { ssa_logging, print_codegen_timings, ssa }.print("Initial SSA"))
let builder =
SsaBuilder { ssa_logging, print_codegen_timings, ssa, passed: Default::default() };
let builder = builder.print("Initial SSA");
Ok(builder)
}

fn finish(self) -> Ssa {
self.ssa.generate_entry_point_index()
}

/// Run a list of SSA passes.
fn run_passes(mut self, passes: &[SsaPass]) -> Result<Self, RuntimeError> {
for pass in passes {
self = self.try_run_pass(|ssa| (pass.run)(ssa), pass.msg)?;
}
Ok(self)
}

/// Runs the given SSA pass and prints the SSA afterward if `print_ssa_passes` is true.
#[allow(dead_code)]
fn run_pass<F>(mut self, pass: F, msg: &str) -> Self
where
F: FnOnce(Ssa) -> Ssa,
Expand All @@ -525,6 +605,9 @@ impl SsaBuilder {
}

fn print(mut self, msg: &str) -> Self {
// Count the number of times we have seen this message.
let cnt = self.passed.entry(msg.to_string()).and_modify(|cnt| *cnt += 1).or_insert(1);

// Always normalize if we are going to print at least one of the passes
if !matches!(self.ssa_logging, SsaLogging::None) {
self.ssa.normalize_ids();
Expand All @@ -541,7 +624,7 @@ impl SsaBuilder {
}
};
if print_ssa_pass {
println!("After {msg}:\n{}", self.ssa);
println!("After {msg} ({cnt}):\n{}", self.ssa);
}
self
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod tests {
assert_ssa_snapshot,
brillig::BrilligOptions,
errors::RuntimeError,
ssa::{Ssa, SsaBuilder, SsaEvaluatorOptions, SsaLogging, optimize_all},
ssa::{Ssa, SsaBuilder, SsaEvaluatorOptions, SsaLogging, primary_passes},
};

fn run_all_passes(ssa: Ssa) -> Result<Ssa, RuntimeError> {
Expand All @@ -27,9 +27,10 @@ mod tests {
ssa,
ssa_logging: options.ssa_logging.clone(),
print_codegen_timings: false,
passed: Default::default(),
};

optimize_all(builder, options)
Ok(builder.run_passes(&primary_passes(options))?.finish())
}

/// Test that the `std::hint::black_box` function prevents some of the optimizations.
Expand Down
Loading
Loading