diff --git a/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs b/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs index c241b30cdaf..623564a4578 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs @@ -1,3 +1,14 @@ +//! An SSA pass that operates on Brillig functions +//! This optimization pass identifies simple conditional control flow patterns in unconstrained code +//! and flattens them to reduce the number of basic blocks and improve performance. +//! +//! e.g: `if c {a} else {b}` would be flattened to `c*(a-b)+b` +//! A simple conditional pattern is defined as a conditional sub-graph of the form `jmpif c: A, else B`, where A and B are basic blocks which join +//! on the same successor. This exclude the graph from having any nested conditional or loop statements. +//! Performance improvement is based on a simple execution cost metric +//! +//! This pass does not have any pre/post conditions. + use std::collections::HashSet; use acvm::AcirField; @@ -30,13 +41,8 @@ struct BasicConditional { impl Ssa { #[tracing::instrument(level = "trace", skip(self))] - /// This pass flatten simple IF-THEN-ELSE statements - /// This optimization pass identifies simple conditional control flow patterns in unconstrained code - /// and flattens them to reduce the number of basic blocks and improve performance. - /// - /// e.g: if c {a} else {b} would be flattened to c*(a-b)+b - /// A simple conditional pattern is defined as an IF-THEN (with optional ELSE) statement, with no nested conditional nor loop statements - /// Performance improvement is based on a simple execution cost metric + /// Apply the basic_conditional pass to all functions of the program. + /// It first retrieve the `no_predicates` attribute of each function which will be used during the flattening. pub(crate) fn flatten_basic_conditionals(mut self) -> Ssa { // Retrieve the 'no_predicates' attribute of the functions in a map, to avoid problems with borrowing let mut no_predicates = HashMap::default(); @@ -50,109 +56,112 @@ impl Ssa { } } -/// Returns the blocks of the simple conditional sub-graph whose input block is the entry. +/// Returns the blocks of the simple conditional sub-graph of the CFG whose input block is the entry. /// Returns None if the input block is not the entry block of a simple conditional. +/// A simple conditional is an if-then(-else) statement where branches are 'small' basic blocks. +/// 'Small' basic blocks means that we expect their execution cost to be small. +/// +/// In case the block is the entry of a 'simple conditional', the function returns a BasicConditional which +/// consist of the list of the conditional blocks: +/// block_entry +/// / \ +/// block_then block_else +/// \ / +/// block_exit +/// block_then and block_else are optionals, in order to account for the case when there is no 'then' or no 'else' branch +/// Only structured CFG with this shape are considered: +/// - block_entry has exactly 2 successors +/// - block_then and block_else have exactly 1 successor, which is block_exit, or one of them is block_exit +/// - block_exit has exactly 2 predecessors (block_then and block_else) +/// +/// Furthermore, cost of block_then + cost of block_else must be less than their average cost + jump overhead cost fn is_conditional( block: BasicBlockId, cfg: &ControlFlowGraph, function: &Function, ) -> Option { - // jump overhead is the cost for doing the conditional and jump around the blocks + // jump overhead is the cost for doing the conditional and jumping around the blocks // We use 10 as a rough estimate, the real cost is less. let jump_overhead = 10; - let mut successors = cfg.successors(block); let mut result = None; - // a conditional must have 2 branches - if successors.len() != 2 { - return None; - } - let left = successors.next().unwrap(); - let right = successors.next().unwrap(); - let mut left_successors = cfg.successors(left); - let mut right_successors = cfg.successors(right); - let left_successors_len = left_successors.len(); - let right_successors_len = right_successors.len(); - let next_left = left_successors.next(); - let next_right = right_successors.next(); - if next_left == Some(block) || next_right == Some(block) { - // this is a loop, not a conditional - return None; - } - if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right { - // The branches join on one block so it is a non-nested conditional - let cost_left = block_cost(left, &function.dfg); - let cost_right = block_cost(right, &function.dfg); - // For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches, - // including an overhead to take into account the jumps between the blocks. - let cost = cost_right.saturating_add(cost_left); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { + + if let Some(TerminatorInstruction::JmpIf { + condition: _, + then_destination, + else_destination, + call_stack: _, + }) = function.dfg[block].terminator() + { + // A conditional must end with a JmpIf + let mut then_successors = cfg.successors(*then_destination); + let mut else_successors = cfg.successors(*else_destination); + let then_successors_len = then_successors.len(); + let else_successors_len = else_successors.len(); + let next_then = then_successors.next(); + let next_else = else_successors.next(); + if next_then == Some(block) || next_else == Some(block) { + // this is a loop, not a conditional + return None; + } + + if then_successors_len == 1 && else_successors_len == 1 && next_then == next_else { + // The branches join on one block so it is a non-nested conditional with a classical diamond shape: + // block + // / \ + // then else + // \ / + // next_then + // We check that the cost of the flattened code is lower than the cost of the branches + let cost_left = block_cost(*then_destination, &function.dfg); + let cost_right = block_cost(*else_destination, &function.dfg); + // For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches, + // including an overhead to take into account the jumps between the blocks. + // We use the average cost of the 2 branches, assuming that both branches are equally likely to be executed. + let cost = cost_right.saturating_add(cost_left); + if cost < cost / 2 + jump_overhead { result = Some(BasicConditional { block_entry: block, block_then: Some(*then_destination), block_else: Some(*else_destination), - block_exit: next_left.unwrap(), + block_exit: next_then.unwrap(), }); } - } - } else if left_successors_len == 1 && next_left == Some(right) { - // Left branch joins the right branch, e.g if/then statement with no else - // This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations) - let cost = block_cost(left, &function.dfg); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { - let (block_then, block_else) = if left == *then_destination { - (Some(left), None) - } else if left == *else_destination { - (None, Some(left)) - } else { - return None; - }; - + } else if then_successors_len == 1 && next_then == Some(*else_destination) { + // Left branch joins the right branch, e.g if/then statement with no else: + // block + // / \ + // then \ + // \ | + // -> else + // This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations) + let cost = block_cost(*then_destination, &function.dfg); + if cost < cost / 2 + jump_overhead { + // Use the terminator of the entry block to identify the 'then/else' branches + // Indeed, the left/right namings are arbitrary, and we now map them + // to the then/else naming of JmpIf. result = Some(BasicConditional { block_entry: block, - block_then, - block_else, - block_exit: right, + block_then: Some(*then_destination), + block_else: None, + block_exit: *else_destination, }); } - } - } else if right_successors_len == 1 && next_right == Some(left) { - // Right branch joins the left branch, e.g if/else statement with no then - // This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations) - let cost = block_cost(right, &function.dfg); - if cost < cost / 2 + jump_overhead { - if let Some(TerminatorInstruction::JmpIf { - condition: _, - then_destination, - else_destination, - call_stack: _, - }) = function.dfg[block].terminator() - { - let (block_then, block_else) = if right == *then_destination { - (Some(right), None) - } else if right == *else_destination { - (None, Some(right)) - } else { - return None; - }; + } else if else_successors_len == 1 && next_else == Some(*then_destination) { + // Right branch joins the left branch, e.g if/else statement with no then + // This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations) + // block + // / \ + // | else + // | | + // \ / + // then + let cost = block_cost(*else_destination, &function.dfg); + if cost < cost / 2 + jump_overhead { result = Some(BasicConditional { block_entry: block, - block_then, - block_else, - block_exit: right, + block_then: None, + block_else: Some(*else_destination), + block_exit: *else_destination, }); } } @@ -161,7 +170,7 @@ fn is_conditional( result.filter(|result| cfg.predecessors(result.block_exit).len() == 2) } -/// Computes a cost estimate of a basic block +/// Computes a cost estimate for the execution of a basic block /// returns u32::MAX if the block has side-effect instructions /// WARNING: these are estimates of the runtime cost of each instruction, /// 1 being the cost of the simplest instruction. These numbers can be improved. @@ -241,6 +250,7 @@ fn flatten_function(function: &mut Function, no_predicates: &mut HashMap, function: &mut Function, @@ -296,6 +321,24 @@ fn flatten_multiple( } impl Context<'_> { + /// Flattens a single basic conditional by inlining the 2 branches into the entry block. + /// + /// This method transforms a conditional control flow pattern (if-then-else) into straight-line code + /// by merging the entry, then, else, and exit blocks. The conditional logic is converted into + /// predicated operations using cast and multiplication operations to select between branch values. + /// The method is adapted from `flatten_cfg`, tailored to do flatten only the input conditional. + /// + /// # Parameters + /// * `conditional` - The basic conditional structure to flatten + /// * `no_predicates` - Map of function IDs to their no_predicates attribute + /// + /// # Implementation Details + /// - Sets up context state (target_block, no_predicate) to enable proper inlining + /// - Inlines each block's instructions into the entry block + /// - Handles terminators to manage control flow during inlining + /// - Uses a WorkList to track which blocks need processing + /// - Copies the exit block's terminator to the entry block after inlining + /// - Restores original context state after completion fn flatten_single_conditional( &mut self, conditional: &BasicConditional, @@ -369,6 +412,16 @@ impl Context<'_> { self.no_predicate = old_no_predicate; } + /// Applies value mappings to all instructions and terminators in a block. + /// + /// This method rewrites a block by replacing old value IDs with their mapped equivalents + /// according to the provided mapping. This is used to propagate value simplifications + /// from conditional flattening throughout the rest of the function. + /// + /// # Parameters + /// * `mapping` - HashMap mapping old ValueIds to their simplified/replaced ValueIds + /// * `func` - The function containing the block to update + /// * `block` - The BasicBlockId of the block to remap fn map_block_with_mapping( mapping: HashMap, func: &mut Function,