diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index a0af3979271..6ba1629b5f6 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -191,16 +191,6 @@ impl DataFlowGraph { new_block } - /// Get an iterator over references to each basic block within the dfg, paired with the basic - /// block's id. - /// - /// The pairs are order by id, which is not guaranteed to be meaningful. - pub(crate) fn basic_blocks_iter( - &self, - ) -> impl DoubleEndedIterator { - self.blocks.iter() - } - /// Iterate over every Value in this DFG in no particular order, including unused Values pub(crate) fn values_iter(&self) -> impl DoubleEndedIterator { self.values.iter() diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index cb4e6aefca0..5f46c4330d1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -21,8 +21,6 @@ //! has when not unrolled. //! - Unrolling may be reverted for brillig functions if the increase in instruction count is //! greater than `max_bytecode_increase_percent` (if set). -//! - In ACIR functions, reference count instructions are removed if present (they only have -//! effects in brillig functions so are superfluous in ACIR). //! - Differing post-conditions (see below). //! //! Relevance to other passes: @@ -43,10 +41,9 @@ //! used in the loop condition whose value is unknown) will result in an error. //! - Post-condition (Brillig-only): If `max_bytecode_increase_percent` is set, the instruction count //! of each function should increase by no more than that percentage compared to before the pass. -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashSet}; use acvm::acir::AcirField; -use im::HashSet; use noirc_errors::call_stack::{CallStack, CallStackId}; use crate::{ @@ -59,7 +56,7 @@ use crate::{ dom::DominatorTree, function::Function, function_inserter::FunctionInserter, - instruction::{Binary, BinaryOp, Instruction, InstructionId, TerminatorInstruction}, + instruction::{Binary, BinaryOp, Instruction, TerminatorInstruction}, integer::IntegerConstant, post_order::PostOrder, value::ValueId, @@ -144,15 +141,84 @@ impl Function { Ok(has_unrolled) } - // Loop unrolling in brillig can lead to a code explosion currently. - // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. - // Brillig also generally prefers smaller code rather than faster code, - // so we only attempt to unroll small loops, which we decide on a case-by-case basis. + /// Unroll all loops within the function. + /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. + /// Returns a flag indicating whether any blocks have been modified. + /// + /// Loop unrolling in brillig can lead to a code explosion currently. + /// This can also be true for ACIR, but we have no alternative to unrolling in ACIR. + /// Brillig also generally prefers smaller code rather than faster code, + /// so we only attempt to unroll small loops, which we decide on a case-by-case basis. fn try_unroll_loops(&mut self) -> (bool, Vec) { - Loops::find_all(self).unroll_each(self) + // The loops that failed to be unrolled so that we do not try to unroll them again. + // Each loop is identified by its header block id. + let mut failed_to_unroll = HashSet::new(); + // The reasons why loops in the above set failed to unroll. + let mut unroll_errors = vec![]; + let mut has_unrolled = false; + + // Repeatedly find all loops as we unroll outer loops and go towards nested ones. + loop { + let mut loops = Loops::find_all(self); + // Blocks which were part of loops we unrolled. Nested loops are included in the outer loops, + // so if an outer loop is unrolled, we have to restart looking for the nested ones. + let mut modified_blocks = HashSet::new(); + // Indicate whether we will have to have another go looking for loops, to deal with nested ones. + let mut needs_refresh = false; + + while let Some(next_loop) = loops.yet_to_unroll.pop() { + // Don't try to unroll the loop again if it is known to fail + if failed_to_unroll.contains(&next_loop.header) { + continue; + } + + // Only unroll small loops in Brillig. + if self.runtime().is_brillig() && !next_loop.is_small_loop(self, &loops.cfg) { + continue; + } + + // Check if we will be able to unroll this loop, before starting to modify the blocks. + if next_loop.has_const_back_edge_induction_value(self) { + // Don't try to unroll this. + failed_to_unroll.insert(next_loop.header); + // If this is Brillig, we can still evaluate this loop at runtime. + if self.runtime().is_acir() { + unroll_errors + .push(RuntimeError::UnknownLoopBound { call_stack: CallStack::new() }); + } + continue; + } + + // If we've previously modified a block in this loop we need to refresh the context. + // This happens any time we have nested loops. + if next_loop.blocks.iter().any(|block| modified_blocks.contains(block)) { + needs_refresh = true; + // Carry on unrolling the loops which weren't related to the ones we have already done. + continue; + } + + // Try to unroll. + match next_loop.unroll(self, &loops.cfg) { + Ok(_) => { + has_unrolled = true; + modified_blocks.extend(next_loop.blocks); + } + Err(call_stack) => { + failed_to_unroll.insert(next_loop.header); + unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack }); + } + } + } + // Once we have no more nested loops, we are done. + if !needs_refresh { + break; + } + } + (has_unrolled, unroll_errors) } } +/// Describe the blocks that constitute up a loop. #[derive(Debug)] pub(super) struct Loop { /// The header block of a loop is the block which dominates all the @@ -167,18 +233,16 @@ pub(super) struct Loop { pub(super) blocks: BTreeSet, } +/// All the unrolled loops in the SSA. pub(super) struct Loops { - /// The loops that failed to be unrolled so that we do not try to unroll them again. - /// Each loop is identified by its header block id. - failed_to_unroll: HashSet, - + /// Loops that haven't been unrolled yet, which is all the loops currently in the CFG. pub(super) yet_to_unroll: Vec, - modified_blocks: HashSet, + /// The CFG so we can query the predecessors of blocks when needed. pub(super) cfg: ControlFlowGraph, } impl Loops { - /// Find a loop in the program by finding a node that dominates any predecessor node. + /// Find all loops in the program by finding a node that dominates any predecessor node. /// The edge where this happens will be the back-edge of the loop. /// /// For example consider the following SSA of a basic loop: @@ -206,6 +270,9 @@ impl Loops { /// loop_end loop_body /// ``` /// `loop_entry` has two predecessors: `main` and `loop_body`, and it dominates `loop_body`. + /// + /// Returns all groups of blocks that look like a loop, even if we might not be able to unroll them, + /// which we can use to check whether we were able to unroll all blocks. pub(super) fn find_all(function: &Function) -> Self { let cfg = ControlFlowGraph::with_function(function); let post_order = PostOrder::with_function(function); @@ -213,16 +280,13 @@ impl Loops { let mut loops = vec![]; - for (block, _) in function.dfg.basic_blocks_iter() { - // These reachable checks wouldn't be needed if we only iterated over reachable blocks - if dom_tree.is_reachable(block) { - for predecessor in cfg.predecessors(block) { - // In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`. - if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) - { - // predecessor -> block is the back-edge of a loop - loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg)); - } + // Iterating over blocks in reverse-post-order, ie. forward order, just because it's already available. + for block in post_order.into_vec().into_iter().rev() { + for predecessor in cfg.predecessors(block) { + // In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`. + if dom_tree.dominates(block, predecessor) { + // predecessor -> block is the back-edge of a loop + loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg)); } } } @@ -232,60 +296,7 @@ impl Loops { // their loop range. We will start popping loops from the back. loops.sort_by_key(|loop_| loop_.blocks.len()); - Self { - failed_to_unroll: HashSet::default(), - yet_to_unroll: loops, - modified_blocks: HashSet::default(), - cfg, - } - } - - /// Unroll all loops within a given function. - /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - /// Returns whether any blocks have been modified - fn unroll_each(mut self, function: &mut Function) -> (bool, Vec) { - let mut unroll_errors = vec![]; - let mut has_unrolled = false; - while let Some(next_loop) = self.yet_to_unroll.pop() { - if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) { - continue; - } - - if next_loop.has_const_back_edge_induction_value(function) { - // Don't try to unroll this. - self.failed_to_unroll.insert(next_loop.header); - // If this is Brillig, we can still evaluate this loop at runtime. - if function.runtime().is_acir() { - unroll_errors - .push(RuntimeError::UnknownLoopBound { call_stack: CallStack::new() }); - } - continue; - } - - // If we've previously modified a block in this loop we need to refresh the context. - // This happens any time we have nested loops. - if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { - let mut new_loops = Self::find_all(function); - new_loops.failed_to_unroll = self.failed_to_unroll; - let (new_unrolled, new_errors) = new_loops.unroll_each(function); - return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat()); - } - - // Don't try to unroll the loop again if it is known to fail - if !self.failed_to_unroll.contains(&next_loop.header) { - match next_loop.unroll(function, &self.cfg) { - Ok(_) => { - has_unrolled = true; - self.modified_blocks.extend(next_loop.blocks); - } - Err(call_stack) => { - self.failed_to_unroll.insert(next_loop.header); - unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack }); - } - } - } - } - (has_unrolled, unroll_errors) + Self { yet_to_unroll: loops, cfg } } } @@ -298,6 +309,7 @@ impl Loop { cfg: &ControlFlowGraph, ) -> Self { let mut blocks = BTreeSet::default(); + // Insert the header so we don't go past it when traversing backwards from the back-edge. blocks.insert(header); let mut insert = |block, stack: &mut Vec| { @@ -307,8 +319,7 @@ impl Loop { } }; - // Starting from the back edge of the loop, each predecessor of this block until - // the header is within the loop. + // Starting from the back edge of the loop, enqueue each predecessor of this block until we reach the header. let mut stack = vec![]; insert(back_edge_start, &mut stack); @@ -931,6 +942,7 @@ impl<'f> LoopIteration<'f> { /// for further unrolling. When the loop is finished this will need to be mutated to /// jump to the end of the loop instead. fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) { + // Kick off the unrolling from the initial source block. let mut next_blocks = self.unroll_loop_block(); while let Some(block) = next_blocks.pop() { @@ -951,25 +963,18 @@ impl<'f> LoopIteration<'f> { (end_block, induction_value) } - /// Unroll a single block in the current iteration of the loop + /// Unroll a single block in the current iteration of the loop. + /// + /// Returns the next blocks to unroll, based on whether the jmp terminator has 1 or 2 destinations. fn unroll_loop_block(&mut self) -> Vec { - let mut next_blocks = self.unroll_loop_block_helper(); - // Guarantee that the next blocks we set up to be unrolled, are actually part of the loop, - // which we recorded while inlining the instructions of the blocks already processed. - next_blocks.retain(|block| { - let b = self.get_original_block(*block); - self.loop_.blocks.contains(&b) - }); - next_blocks - } + self.visited_blocks.insert(self.source_block); - /// Unroll a single block in the current iteration of the loop - fn unroll_loop_block_helper(&mut self) -> Vec { // Copy instructions from the loop body to the unroll destination, replacing the terminator. self.inline_instructions_from_block(); - self.visited_blocks.insert(self.source_block); - match self.inserter.function.dfg[self.insert_block].unwrap_terminator() { + let terminator = self.inserter.function.dfg[self.insert_block].unwrap_terminator(); + + let next_blocks = match terminator { TerminatorInstruction::JmpIf { condition, then_destination, @@ -979,15 +984,32 @@ impl<'f> LoopIteration<'f> { TerminatorInstruction::Jmp { destination, arguments, call_stack: _ } => { if self.get_original_block(*destination) == self.loop_.header { // We found the back-edge of the loop. - assert_eq!(arguments.len(), 1); + assert_eq!(arguments.len(), 1, "back-edge should only have 1 argument"); + assert!(self.induction_value.is_none(), "there should be only one back-edge"); self.induction_value = Some((self.insert_block, arguments[0])); } vec![*destination] } - TerminatorInstruction::Return { .. } | TerminatorInstruction::Unreachable { .. } => { - vec![] + TerminatorInstruction::Return { .. } => { + // Early returns from loops are not implemented. + unreachable!("unexpected return terminator in loop body"); } - } + TerminatorInstruction::Unreachable { .. } => { + // The SSA pass that adds unreachable terminators must come after unrolling. + unreachable!("unexpected unreachable terminator in loop body"); + } + }; + + // Guarantee that the next blocks we set up to be unrolled, are actually part of the loop, + // which we recorded while inlining the instructions of the blocks already processed. + // Since we only call `unroll_loop_block` from `unroll_loop_iteration`, which we only call + // if the single destination in `unroll_header` is *not* outside the loop, this should hold. + next_blocks.iter().for_each(|block| { + let b = self.get_original_block(*block); + assert!(self.loop_.blocks.contains(&b), "destination not in original loop"); + }); + + next_blocks } /// Find the next branch(es) to take from a jmpif terminator and return them. @@ -1021,8 +1043,10 @@ impl<'f> LoopIteration<'f> { } } - /// Translate a block id to a block id in the unrolled loop. If the given - /// block id is not within the loop, it is returned as-is. + /// Translate a block id to a block id in the unrolled loop. + /// + /// If the given block id is not within the loop, it is returned as-is, + /// which is the case for when the header jumps to the block following the loop. fn get_or_insert_block(&mut self, block: BasicBlockId) -> BasicBlockId { if let Some(new_block) = self.blocks.get(&block) { return *new_block; @@ -1056,10 +1080,6 @@ impl<'f> LoopIteration<'f> { // instances of the induction variable or any values that were changed as a result // of the new induction variable value. for instruction in instructions { - // Reference counting is only used by Brillig, ACIR doesn't need them. - if self.inserter.function.runtime().is_acir() && self.is_refcount(instruction) { - continue; - } self.inserter.push_instruction(instruction, self.insert_block); } let mut terminator = self.dfg()[self.source_block].unwrap_terminator().clone(); @@ -1072,14 +1092,6 @@ impl<'f> LoopIteration<'f> { self.inserter.function.dfg.set_block_terminator(self.insert_block, terminator); } - /// Is the instruction an `Rc`? - fn is_refcount(&self, instruction: InstructionId) -> bool { - matches!( - self.dfg()[instruction], - Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } - ) - } - fn dfg(&self) -> &DataFlowGraph { &self.inserter.function.dfg }