Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion compiler/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,19 @@ pub struct CompileOptions {
#[arg(long)]
pub skip_underconstrained_check: bool,

/// Setting to decide on an inlining strategy for brillig functions.
/// Setting to decide on an inlining strategy for Brillig functions.
/// A more aggressive inliner should generate larger programs but more optimized
/// A less aggressive inliner should generate smaller programs
#[arg(long, hide = true, allow_hyphen_values = true, default_value_t = i64::MAX)]
pub inliner_aggressiveness: i64,

/// Setting the maximum acceptable increase in Brillig bytecode size due to
/// unrolling small loops. When left empty, any change is accepted as long
/// as it required fewer SSA instructions.
/// A higher value results in fewer jumps but a larger program.
/// A lower value keeps the original program if it was smaller, even if it has more jumps.
#[arg(long, hide = true, allow_hyphen_values = true)]
pub max_bytecode_increase_percent: Option<i32>,
}

pub fn parse_expression_width(input: &str) -> Result<ExpressionWidth, std::io::Error> {
Expand Down Expand Up @@ -589,6 +597,7 @@ pub fn compile_no_check(
emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None },
skip_underconstrained_check: options.skip_underconstrained_check,
inliner_aggressiveness: options.inliner_aggressiveness,
max_bytecode_increase_percent: options.max_bytecode_increase_percent,
};

let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } =
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cfg-if.workspace = true
proptest.workspace = true
similar-asserts.workspace = true
num-traits.workspace = true
test-case.workspace = true

[features]
bn254 = ["noirc_frontend/bn254"]
Expand Down
8 changes: 3 additions & 5 deletions compiler/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use self::{
},
};
use crate::ssa::{
ir::function::{Function, FunctionId, RuntimeType},
ir::function::{Function, FunctionId},
ssa_gen::Ssa,
};
use fxhash::FxHashMap as HashMap;
Expand Down Expand Up @@ -59,17 +59,15 @@ impl std::ops::Index<FunctionId> for Brillig {
}

impl Ssa {
/// Compile to brillig brillig functions and ACIR functions reachable from them
/// Compile Brillig functions and ACIR functions reachable from them
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig {
// Collect all the function ids that are reachable from brillig
// That means all the functions marked as brillig and ACIR functions called by them
let brillig_reachable_function_ids = self
.functions
.iter()
.filter_map(|(id, func)| {
matches!(func.runtime(), RuntimeType::Brillig(_)).then_some(*id)
})
.filter_map(|(id, func)| func.runtime().is_brillig().then_some(*id))
.collect::<BTreeSet<_>>();

let mut brillig = Brillig::default();
Expand Down
19 changes: 13 additions & 6 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub struct SsaEvaluatorOptions {

/// The higher the value, the more inlined brillig functions will be.
pub inliner_aggressiveness: i64,

/// Maximum accepted percentage increase in the Brillig bytecode size after unrolling loops.
/// When `None` the size increase check is skipped altogether and any decrease in the SSA
/// instruction count is accepted.
pub max_bytecode_increase_percent: Option<i32>,
}

pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec<SsaReport>);
Expand Down Expand Up @@ -104,7 +109,10 @@ pub(crate) fn optimize_into_acir(
"After `static_assert` and `assert_constant`:",
)?
.run_pass(Ssa::loop_invariant_code_motion, "After Loop Invariant Code Motion:")
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.try_run_pass(
|ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent),
"After Unrolling:",
)?
.run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
Expand Down Expand Up @@ -450,11 +458,10 @@ impl SsaBuilder {
}

/// The same as `run_pass` but for passes that may fail
fn try_run_pass(
mut self,
pass: fn(Ssa) -> Result<Ssa, RuntimeError>,
msg: &str,
) -> Result<Self, RuntimeError> {
fn try_run_pass<F>(mut self, pass: F, msg: &str) -> Result<Self, RuntimeError>
where
F: FnOnce(Ssa) -> Result<Ssa, RuntimeError>,
{
self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?;
Ok(self.print(msg))
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ impl Function {
}
}

impl Clone for Function {
fn clone(&self) -> Self {
Function::clone_with_id(self.id(), self)
}
}

impl std::fmt::Display for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
137 changes: 116 additions & 21 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
//! When unrolling ACIR code, we remove reference count instructions because they are
//! only used by Brillig bytecode.
use acvm::{acir::AcirField, FieldElement};
use im::HashSet;

use crate::{
brillig::brillig_gen::convert_ssa_function,
errors::RuntimeError,
ssa::{
ir::{
Expand All @@ -37,38 +39,60 @@ use crate::{
ssa_gen::Ssa,
},
};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use fxhash::FxHashMap as HashMap;

impl Ssa {
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
#[tracing::instrument(level = "trace", skip(ssa))]
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
for (_, function) in ssa.functions.iter_mut() {
///
/// The `max_bytecode_incr_pct`, when given, is used to limit the growth of the Brillig bytecode size
/// after unrolling small loops to some percentage of the original loop. For example a value of 150 would
/// mean the new loop can be 150% (ie. 2.5 times) larger than the original loop. It will still contain
/// fewer SSA instructions, but that can still result in more Brillig opcodes.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn unroll_loops_iteratively(
mut self: Ssa,
max_bytecode_increase_percent: Option<i32>,
) -> Result<Ssa, RuntimeError> {
for (_, function) in self.functions.iter_mut() {
// Take a snapshot of the function to compare byte size increase,
// but only if the setting indicates we have to, otherwise skip it.
let orig_func_and_max_incr_pct = max_bytecode_increase_percent
.filter(|_| function.runtime().is_brillig())
.map(|max_incr_pct| (function.clone(), max_incr_pct));

// Try to unroll loops first:
let mut unroll_errors = function.try_unroll_loops();
let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();
simplify_between_unrolls(function);

// Unroll again
unroll_errors = function.try_unroll_loops();
let (new_unrolled, new_errors) = function.try_unroll_loops();
unroll_errors = new_errors;
has_unrolled |= new_unrolled;

// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
}
}

if has_unrolled {
if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct {
let new_size = brillig_bytecode_size(function);
let orig_size = brillig_bytecode_size(&orig_function);
if !is_new_size_ok(orig_size, new_size, max_incr_pct) {
*function = orig_function;
}
}
}
}
Ok(ssa)
Ok(self)
}
}

Expand All @@ -77,7 +101,7 @@ impl Function {
// This can also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code,
// so we only attempt to unroll small loops, which we decide on a case-by-case basis.
fn try_unroll_loops(&mut self) -> Vec<RuntimeError> {
fn try_unroll_loops(&mut self) -> (bool, Vec<RuntimeError>) {
Loops::find_all(self).unroll_each(self)
}
}
Expand Down Expand Up @@ -170,8 +194,10 @@ impl Loops {

/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
fn unroll_each(mut self, function: &mut Function) -> Vec<RuntimeError> {
/// Returns whether any blocks have been modified
fn unroll_each(mut self, function: &mut Function) -> (bool, Vec<RuntimeError>) {
let mut unroll_errors = vec![];
let mut has_unrolled = false;
while let Some(next_loop) = self.yet_to_unroll.pop() {
if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) {
continue;
Expand All @@ -181,21 +207,25 @@ impl Loops {
if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) {
let mut new_loops = Self::find_all(function);
new_loops.failed_to_unroll = self.failed_to_unroll;
return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect();
let (new_unrolled, new_errors) = new_loops.unroll_each(function);
return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat());
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match next_loop.unroll(function, &self.cfg) {
Ok(_) => self.modified_blocks.extend(next_loop.blocks),
Ok(_) => {
has_unrolled = true;
self.modified_blocks.extend(next_loop.blocks);
}
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
}
unroll_errors
(has_unrolled, unroll_errors)
}
}

Expand Down Expand Up @@ -947,21 +977,59 @@ impl<'f> LoopIteration<'f> {
}
}

/// Unrolling leaves some duplicate instructions which can potentially be removed.
fn simplify_between_unrolls(function: &mut Function) {
// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();
}

/// Convert the function to Brillig bytecode and return the resulting size.
fn brillig_bytecode_size(function: &Function) -> usize {
// We need to do some SSA passes in order for the conversion to be able to go ahead,
// otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`.
// Creating a clone so as not to modify the originals.
let mut temp = function.clone();

// Might as well give it the best chance.
simplify_between_unrolls(&mut temp);

// This is to try to prevent hitting ICE.
temp.dead_instruction_elimination(false);

convert_ssa_function(&temp, false).byte_code.len()
}

/// Decide if the new bytecode size is acceptable, compared to the original.
///
/// The maximum increase can be expressed as a negative value if we demand a decrease.
/// (Values -100 and under mean the new size should be 0).
fn is_new_size_ok(orig_size: usize, new_size: usize, max_incr_pct: i32) -> bool {
let max_size_pct = 100i32.saturating_add(max_incr_pct).max(0) as usize;
let max_size = orig_size.saturating_mul(max_size_pct);
new_size.saturating_mul(100) <= max_size
}

#[cfg(test)]
mod tests {
use acvm::FieldElement;
use test_case::test_case;

use crate::errors::RuntimeError;
use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa};

use super::{BoilerplateStats, Loops};
use super::{is_new_size_ok, BoilerplateStats, Loops};

/// Tries to unroll all loops in each SSA function.
/// Tries to unroll all loops in each SSA function once, calling the `Function` directly,
/// bypassing the iterative loop done by the SSA which does further optimisations.
///
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in ssa.functions.values_mut() {
errors.extend(function.try_unroll_loops());
errors.extend(function.try_unroll_loops().1);
}
(ssa, errors)
}
Expand Down Expand Up @@ -1221,9 +1289,26 @@ mod tests {

let (ssa, errors) = try_unroll_loops(ssa);
assert_eq!(errors.len(), 0, "Unroll should have no errors");
// Check that it's still the original
assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str());
}

#[test]
fn test_brillig_unroll_iteratively_respects_max_increase() {
let ssa = brillig_unroll_test_case();
let ssa = ssa.unroll_loops_iteratively(Some(-90)).unwrap();
// Check that it's still the original
assert_normalized_ssa_equals(ssa, brillig_unroll_test_case().to_string().as_str());
}

#[test]
fn test_brillig_unroll_iteratively_with_large_max_increase() {
let ssa = brillig_unroll_test_case();
let ssa = ssa.unroll_loops_iteratively(Some(50)).unwrap();
// Check that it did the unroll
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");
}

/// Test that `break` and `continue` stop unrolling without any panic.
#[test]
fn test_brillig_unroll_break_and_continue() {
Expand Down Expand Up @@ -1377,4 +1462,14 @@ mod tests {
let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop");
loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats")
}

#[test_case(1000, 700, 50, true; "size decreased")]
#[test_case(1000, 1500, 50, true; "size increased just by the max")]
#[test_case(1000, 1501, 50, false; "size increased over the max")]
#[test_case(1000, 700, -50, false; "size decreased but not enough")]
#[test_case(1000, 250, -50, true; "size decreased over expectations")]
#[test_case(1000, 250, -1250, false; "demanding more than minus 100 is handled")]
fn test_is_new_size_ok(old: usize, new: usize, max: i32, ok: bool) {
assert_eq!(is_new_size_ok(old, new, max), ok);
}
}
5 changes: 5 additions & 0 deletions tooling/nargo_cli/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ fn test_{test_name}(force_brillig: ForceBrillig, inliner_aggressiveness: Inliner
nargo.arg("--program-dir").arg(test_program_dir);
nargo.arg("{test_command}").arg("--force");
nargo.arg("--inliner-aggressiveness").arg(inliner_aggressiveness.0.to_string());

if force_brillig.0 {{
nargo.arg("--force-brillig");

// Set the maximum increase so that part of the optimization is exercised (it might fail).
nargo.arg("--max-bytecode-increase-percent");
nargo.arg("50");
}}

{test_content}
Expand Down