Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,6 @@ impl DataFlowGraph {
new_block
}

/// Get an iterator over references to each basic block within the dfg, paired with the basic
/// block's id.
///
/// The pairs are order by id, which is not guaranteed to be meaningful.
pub(crate) fn basic_blocks_iter(
&self,
) -> impl DoubleEndedIterator<Item = (BasicBlockId, &BasicBlock)> {
self.blocks.iter()
}

/// Iterate over every Value in this DFG in no particular order, including unused Values
pub(crate) fn values_iter(&self) -> impl DoubleEndedIterator<Item = (ValueId, &Value)> {
self.values.iter()
Expand Down
240 changes: 126 additions & 114 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
//! has when not unrolled.
//! - Unrolling may be reverted for brillig functions if the increase in instruction count is
//! greater than `max_bytecode_increase_percent` (if set).
//! - In ACIR functions, reference count instructions are removed if present (they only have
//! effects in brillig functions so are superfluous in ACIR).
//! - Differing post-conditions (see below).
//!
//! Relevance to other passes:
Expand All @@ -43,10 +41,9 @@
//! used in the loop condition whose value is unknown) will result in an error.
//! - Post-condition (Brillig-only): If `max_bytecode_increase_percent` is set, the instruction count
//! of each function should increase by no more than that percentage compared to before the pass.
use std::collections::BTreeSet;
use std::collections::{BTreeSet, HashSet};

use acvm::acir::AcirField;
use im::HashSet;
use noirc_errors::call_stack::{CallStack, CallStackId};

use crate::{
Expand All @@ -59,7 +56,7 @@ use crate::{
dom::DominatorTree,
function::Function,
function_inserter::FunctionInserter,
instruction::{Binary, BinaryOp, Instruction, InstructionId, TerminatorInstruction},
instruction::{Binary, BinaryOp, Instruction, TerminatorInstruction},
integer::IntegerConstant,
post_order::PostOrder,
value::ValueId,
Expand Down Expand Up @@ -144,15 +141,84 @@ impl Function {
Ok(has_unrolled)
}

// Loop unrolling in brillig can lead to a code explosion currently.
// This can also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code,
// so we only attempt to unroll small loops, which we decide on a case-by-case basis.
/// Unroll all loops within the function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
/// Returns a flag indicating whether any blocks have been modified.
///
/// Loop unrolling in brillig can lead to a code explosion currently.
/// This can also be true for ACIR, but we have no alternative to unrolling in ACIR.
/// Brillig also generally prefers smaller code rather than faster code,
/// so we only attempt to unroll small loops, which we decide on a case-by-case basis.
fn try_unroll_loops(&mut self) -> (bool, Vec<RuntimeError>) {
Loops::find_all(self).unroll_each(self)
// The loops that failed to be unrolled so that we do not try to unroll them again.
// Each loop is identified by its header block id.
let mut failed_to_unroll = HashSet::new();
// The reasons why loops in the above set failed to unroll.
let mut unroll_errors = vec![];
let mut has_unrolled = false;

// Repeatedly find all loops as we unroll outer loops and go towards nested ones.
loop {
let mut loops = Loops::find_all(self);
// Blocks which were part of loops we unrolled. Nested loops are included in the outer loops,
// so if an outer loop is unrolled, we have to restart looking for the nested ones.
let mut modified_blocks = HashSet::new();
// Indicate whether we will have to have another go looking for loops, to deal with nested ones.
let mut needs_refresh = false;

while let Some(next_loop) = loops.yet_to_unroll.pop() {
// Don't try to unroll the loop again if it is known to fail
if failed_to_unroll.contains(&next_loop.header) {
continue;
}

// Only unroll small loops in Brillig.
if self.runtime().is_brillig() && !next_loop.is_small_loop(self, &loops.cfg) {
continue;
}

// Check if we will be able to unroll this loop, before starting to modify the blocks.
if next_loop.has_const_back_edge_induction_value(self) {
// Don't try to unroll this.
failed_to_unroll.insert(next_loop.header);
// If this is Brillig, we can still evaluate this loop at runtime.
if self.runtime().is_acir() {
unroll_errors
.push(RuntimeError::UnknownLoopBound { call_stack: CallStack::new() });
}
continue;
}

// If we've previously modified a block in this loop we need to refresh the context.
// This happens any time we have nested loops.
if next_loop.blocks.iter().any(|block| modified_blocks.contains(block)) {
needs_refresh = true;
// Carry on unrolling the loops which weren't related to the ones we have already done.
continue;
}

// Try to unroll.
match next_loop.unroll(self, &loops.cfg) {
Ok(_) => {
has_unrolled = true;
modified_blocks.extend(next_loop.blocks);
}
Err(call_stack) => {
failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
// Once we have no more nested loops, we are done.
if !needs_refresh {
break;
}
}
(has_unrolled, unroll_errors)
}
}

/// Describe the blocks that constitute up a loop.
#[derive(Debug)]
pub(super) struct Loop {
/// The header block of a loop is the block which dominates all the
Expand All @@ -167,18 +233,16 @@ pub(super) struct Loop {
pub(super) blocks: BTreeSet<BasicBlockId>,
}

/// All the unrolled loops in the SSA.
pub(super) struct Loops {
/// The loops that failed to be unrolled so that we do not try to unroll them again.
/// Each loop is identified by its header block id.
failed_to_unroll: HashSet<BasicBlockId>,

/// Loops that haven't been unrolled yet, which is all the loops currently in the CFG.
pub(super) yet_to_unroll: Vec<Loop>,
modified_blocks: HashSet<BasicBlockId>,
/// The CFG so we can query the predecessors of blocks when needed.
pub(super) cfg: ControlFlowGraph,
}

impl Loops {
/// Find a loop in the program by finding a node that dominates any predecessor node.
/// Find all loops in the program by finding a node that dominates any predecessor node.
/// The edge where this happens will be the back-edge of the loop.
///
/// For example consider the following SSA of a basic loop:
Expand Down Expand Up @@ -206,23 +270,23 @@ impl Loops {
/// loop_end loop_body
/// ```
/// `loop_entry` has two predecessors: `main` and `loop_body`, and it dominates `loop_body`.
///
/// Returns all groups of blocks that look like a loop, even if we might not be able to unroll them,
/// which we can use to check whether we were able to unroll all blocks.
pub(super) fn find_all(function: &Function) -> Self {
let cfg = ControlFlowGraph::with_function(function);
let post_order = PostOrder::with_function(function);
let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);

let mut loops = vec![];

for (block, _) in function.dfg.basic_blocks_iter() {
// These reachable checks wouldn't be needed if we only iterated over reachable blocks
if dom_tree.is_reachable(block) {
for predecessor in cfg.predecessors(block) {
// In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`.
if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor)
{
// predecessor -> block is the back-edge of a loop
loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg));
}
// Iterating over blocks in reverse-post-order, ie. forward order, just because it's already available.
for block in post_order.into_vec().into_iter().rev() {
for predecessor in cfg.predecessors(block) {
// In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`.
if dom_tree.dominates(block, predecessor) {
// predecessor -> block is the back-edge of a loop
loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg));
}
}
}
Expand All @@ -232,60 +296,7 @@ impl Loops {
// their loop range. We will start popping loops from the back.
loops.sort_by_key(|loop_| loop_.blocks.len());

Self {
failed_to_unroll: HashSet::default(),
yet_to_unroll: loops,
modified_blocks: HashSet::default(),
cfg,
}
}

/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
/// Returns whether any blocks have been modified
fn unroll_each(mut self, function: &mut Function) -> (bool, Vec<RuntimeError>) {
let mut unroll_errors = vec![];
let mut has_unrolled = false;
while let Some(next_loop) = self.yet_to_unroll.pop() {
if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) {
continue;
}

if next_loop.has_const_back_edge_induction_value(function) {
// Don't try to unroll this.
self.failed_to_unroll.insert(next_loop.header);
// If this is Brillig, we can still evaluate this loop at runtime.
if function.runtime().is_acir() {
unroll_errors
.push(RuntimeError::UnknownLoopBound { call_stack: CallStack::new() });
}
continue;
}

// If we've previously modified a block in this loop we need to refresh the context.
// This happens any time we have nested loops.
if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) {
let mut new_loops = Self::find_all(function);
new_loops.failed_to_unroll = self.failed_to_unroll;
let (new_unrolled, new_errors) = new_loops.unroll_each(function);
return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat());
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match next_loop.unroll(function, &self.cfg) {
Ok(_) => {
has_unrolled = true;
self.modified_blocks.extend(next_loop.blocks);
}
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
}
(has_unrolled, unroll_errors)
Self { yet_to_unroll: loops, cfg }
}
}

Expand All @@ -298,6 +309,7 @@ impl Loop {
cfg: &ControlFlowGraph,
) -> Self {
let mut blocks = BTreeSet::default();
// Insert the header so we don't go past it when traversing backwards from the back-edge.
blocks.insert(header);

let mut insert = |block, stack: &mut Vec<BasicBlockId>| {
Expand All @@ -307,8 +319,7 @@ impl Loop {
}
};

// Starting from the back edge of the loop, each predecessor of this block until
// the header is within the loop.
// Starting from the back edge of the loop, enqueue each predecessor of this block until we reach the header.
let mut stack = vec![];
insert(back_edge_start, &mut stack);

Expand Down Expand Up @@ -931,6 +942,7 @@ impl<'f> LoopIteration<'f> {
/// for further unrolling. When the loop is finished this will need to be mutated to
/// jump to the end of the loop instead.
fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) {
// Kick off the unrolling from the initial source block.
let mut next_blocks = self.unroll_loop_block();

while let Some(block) = next_blocks.pop() {
Expand All @@ -951,25 +963,18 @@ impl<'f> LoopIteration<'f> {
(end_block, induction_value)
}

/// Unroll a single block in the current iteration of the loop
/// Unroll a single block in the current iteration of the loop.
///
/// Returns the next blocks to unroll, based on whether the jmp terminator has 1 or 2 destinations.
fn unroll_loop_block(&mut self) -> Vec<BasicBlockId> {
let mut next_blocks = self.unroll_loop_block_helper();
// Guarantee that the next blocks we set up to be unrolled, are actually part of the loop,
// which we recorded while inlining the instructions of the blocks already processed.
next_blocks.retain(|block| {
let b = self.get_original_block(*block);
self.loop_.blocks.contains(&b)
});
next_blocks
}
self.visited_blocks.insert(self.source_block);

/// Unroll a single block in the current iteration of the loop
fn unroll_loop_block_helper(&mut self) -> Vec<BasicBlockId> {
// Copy instructions from the loop body to the unroll destination, replacing the terminator.
self.inline_instructions_from_block();
self.visited_blocks.insert(self.source_block);

match self.inserter.function.dfg[self.insert_block].unwrap_terminator() {
let terminator = self.inserter.function.dfg[self.insert_block].unwrap_terminator();

let next_blocks = match terminator {
TerminatorInstruction::JmpIf {
condition,
then_destination,
Expand All @@ -979,15 +984,32 @@ impl<'f> LoopIteration<'f> {
TerminatorInstruction::Jmp { destination, arguments, call_stack: _ } => {
if self.get_original_block(*destination) == self.loop_.header {
// We found the back-edge of the loop.
assert_eq!(arguments.len(), 1);
assert_eq!(arguments.len(), 1, "back-edge should only have 1 argument");
assert!(self.induction_value.is_none(), "there should be only one back-edge");
self.induction_value = Some((self.insert_block, arguments[0]));
}
vec![*destination]
}
TerminatorInstruction::Return { .. } | TerminatorInstruction::Unreachable { .. } => {
vec![]
TerminatorInstruction::Return { .. } => {
// Early returns from loops are not implemented.
unreachable!("unexpected return terminator in loop body");
}
}
TerminatorInstruction::Unreachable { .. } => {
// The SSA pass that adds unreachable terminators must come after unrolling.
unreachable!("unexpected unreachable terminator in loop body");
}
};

// Guarantee that the next blocks we set up to be unrolled, are actually part of the loop,
// which we recorded while inlining the instructions of the blocks already processed.
// Since we only call `unroll_loop_block` from `unroll_loop_iteration`, which we only call
// if the single destination in `unroll_header` is *not* outside the loop, this should hold.
next_blocks.iter().for_each(|block| {
let b = self.get_original_block(*block);
assert!(self.loop_.blocks.contains(&b), "destination not in original loop");
});

next_blocks
}

/// Find the next branch(es) to take from a jmpif terminator and return them.
Expand Down Expand Up @@ -1021,8 +1043,10 @@ impl<'f> LoopIteration<'f> {
}
}

/// Translate a block id to a block id in the unrolled loop. If the given
/// block id is not within the loop, it is returned as-is.
/// Translate a block id to a block id in the unrolled loop.
///
/// If the given block id is not within the loop, it is returned as-is,
/// which is the case for when the header jumps to the block following the loop.
fn get_or_insert_block(&mut self, block: BasicBlockId) -> BasicBlockId {
if let Some(new_block) = self.blocks.get(&block) {
return *new_block;
Expand Down Expand Up @@ -1056,10 +1080,6 @@ impl<'f> LoopIteration<'f> {
// instances of the induction variable or any values that were changed as a result
// of the new induction variable value.
for instruction in instructions {
// Reference counting is only used by Brillig, ACIR doesn't need them.
if self.inserter.function.runtime().is_acir() && self.is_refcount(instruction) {
continue;
}
self.inserter.push_instruction(instruction, self.insert_block);
}
let mut terminator = self.dfg()[self.source_block].unwrap_terminator().clone();
Expand All @@ -1072,14 +1092,6 @@ impl<'f> LoopIteration<'f> {
self.inserter.function.dfg.set_block_terminator(self.insert_block, terminator);
}

/// Is the instruction an `Rc`?
fn is_refcount(&self, instruction: InstructionId) -> bool {
matches!(
self.dfg()[instruction],
Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. }
)
}

fn dfg(&self) -> &DataFlowGraph {
&self.inserter.function.dfg
}
Expand Down
Loading