Skip to content
233 changes: 143 additions & 90 deletions compiler/noirc_evaluator/src/ssa/opt/basic_conditional.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
//! An SSA pass that operates on Brillig functions
//! This optimization pass identifies simple conditional control flow patterns in unconstrained code
//! and flattens them to reduce the number of basic blocks and improve performance.
//!
//! e.g: `if c {a} else {b}` would be flattened to `c*(a-b)+b`
//! A simple conditional pattern is defined as a conditional sub-graph of the form `jmpif c: A, else B`, where A and B are basic blocks which join
//! on the same successor. This exclude the graph from having any nested conditional or loop statements.
//! Performance improvement is based on a simple execution cost metric
//!
//! This pass does not have any pre/post conditions.

use std::collections::HashSet;

use acvm::AcirField;
Expand Down Expand Up @@ -30,13 +41,8 @@ struct BasicConditional {

impl Ssa {
#[tracing::instrument(level = "trace", skip(self))]
/// This pass flatten simple IF-THEN-ELSE statements
/// This optimization pass identifies simple conditional control flow patterns in unconstrained code
/// and flattens them to reduce the number of basic blocks and improve performance.
///
/// e.g: if c {a} else {b} would be flattened to c*(a-b)+b
/// A simple conditional pattern is defined as an IF-THEN (with optional ELSE) statement, with no nested conditional nor loop statements
/// Performance improvement is based on a simple execution cost metric
/// Apply the basic_conditional pass to all functions of the program.
/// It first retrieve the `no_predicates` attribute of each function which will be used during the flattening.
pub(crate) fn flatten_basic_conditionals(mut self) -> Ssa {
// Retrieve the 'no_predicates' attribute of the functions in a map, to avoid problems with borrowing
let mut no_predicates = HashMap::default();
Expand All @@ -50,109 +56,112 @@ impl Ssa {
}
}

/// Returns the blocks of the simple conditional sub-graph whose input block is the entry.
/// Returns the blocks of the simple conditional sub-graph of the CFG whose input block is the entry.
/// Returns None if the input block is not the entry block of a simple conditional.
/// A simple conditional is an if-then(-else) statement where branches are 'small' basic blocks.
/// 'Small' basic blocks means that we expect their execution cost to be small.
///
/// In case the block is the entry of a 'simple conditional', the function returns a BasicConditional which
/// consist of the list of the conditional blocks:
/// block_entry
/// / \
/// block_then block_else
/// \ /
/// block_exit
/// block_then and block_else are optionals, in order to account for the case when there is no 'then' or no 'else' branch
/// Only structured CFG with this shape are considered:
/// - block_entry has exactly 2 successors
/// - block_then and block_else have exactly 1 successor, which is block_exit, or one of them is block_exit
/// - block_exit has exactly 2 predecessors (block_then and block_else)
///
/// Furthermore, cost of block_then + cost of block_else must be less than their average cost + jump overhead cost
fn is_conditional(
block: BasicBlockId,
cfg: &ControlFlowGraph,
function: &Function,
) -> Option<BasicConditional> {
// jump overhead is the cost for doing the conditional and jump around the blocks
// jump overhead is the cost for doing the conditional and jumping around the blocks
// We use 10 as a rough estimate, the real cost is less.
let jump_overhead = 10;
let mut successors = cfg.successors(block);
let mut result = None;
// a conditional must have 2 branches
if successors.len() != 2 {
return None;
}
let left = successors.next().unwrap();
let right = successors.next().unwrap();
let mut left_successors = cfg.successors(left);
let mut right_successors = cfg.successors(right);
let left_successors_len = left_successors.len();
let right_successors_len = right_successors.len();
let next_left = left_successors.next();
let next_right = right_successors.next();
if next_left == Some(block) || next_right == Some(block) {
// this is a loop, not a conditional
return None;
}
if left_successors_len == 1 && right_successors_len == 1 && next_left == next_right {
// The branches join on one block so it is a non-nested conditional
let cost_left = block_cost(left, &function.dfg);
let cost_right = block_cost(right, &function.dfg);
// For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches,
// including an overhead to take into account the jumps between the blocks.
let cost = cost_right.saturating_add(cost_left);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{

if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
// A conditional must end with a JmpIf
let mut then_successors = cfg.successors(*then_destination);
let mut else_successors = cfg.successors(*else_destination);
let then_successors_len = then_successors.len();
let else_successors_len = else_successors.len();
let next_then = then_successors.next();
let next_else = else_successors.next();
if next_then == Some(block) || next_else == Some(block) {
// this is a loop, not a conditional
return None;
}

if then_successors_len == 1 && else_successors_len == 1 && next_then == next_else {
// The branches join on one block so it is a non-nested conditional with a classical diamond shape:
// block
// / \
// then else
// \ /
// next_then
// We check that the cost of the flattened code is lower than the cost of the branches
let cost_left = block_cost(*then_destination, &function.dfg);
let cost_right = block_cost(*else_destination, &function.dfg);
// For the flattening to be valuable, we compare the cost of the flattened code with the average cost of the 2 branches,
// including an overhead to take into account the jumps between the blocks.
// We use the average cost of the 2 branches, assuming that both branches are equally likely to be executed.
let cost = cost_right.saturating_add(cost_left);
if cost < cost / 2 + jump_overhead {
result = Some(BasicConditional {
block_entry: block,
block_then: Some(*then_destination),
block_else: Some(*else_destination),
block_exit: next_left.unwrap(),
block_exit: next_then.unwrap(),
});
}
}
} else if left_successors_len == 1 && next_left == Some(right) {
// Left branch joins the right branch, e.g if/then statement with no else
// This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations)
let cost = block_cost(left, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
let (block_then, block_else) = if left == *then_destination {
(Some(left), None)
} else if left == *else_destination {
(None, Some(left))
} else {
return None;
};

} else if then_successors_len == 1 && next_then == Some(*else_destination) {
// Left branch joins the right branch, e.g if/then statement with no else:
// block
// / \
// then \
// \ |
// -> else
// This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations)
let cost = block_cost(*then_destination, &function.dfg);
if cost < cost / 2 + jump_overhead {
// Use the terminator of the entry block to identify the 'then/else' branches
// Indeed, the left/right namings are arbitrary, and we now map them
// to the then/else naming of JmpIf.
result = Some(BasicConditional {
block_entry: block,
block_then,
block_else,
block_exit: right,
block_then: Some(*then_destination),
block_else: None,
block_exit: *else_destination,
});
}
}
} else if right_successors_len == 1 && next_right == Some(left) {
// Right branch joins the left branch, e.g if/else statement with no then
// This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations)
let cost = block_cost(right, &function.dfg);
if cost < cost / 2 + jump_overhead {
if let Some(TerminatorInstruction::JmpIf {
condition: _,
then_destination,
else_destination,
call_stack: _,
}) = function.dfg[block].terminator()
{
let (block_then, block_else) = if right == *then_destination {
(Some(right), None)
} else if right == *else_destination {
(None, Some(right))
} else {
return None;
};
} else if else_successors_len == 1 && next_else == Some(*then_destination) {
// Right branch joins the left branch, e.g if/else statement with no then
// This case may not happen (i.e not generated), but it is safer to handle it (e.g in case it happens due to some optimizations)
// block
// / \
// | else
// | |
// \ /
// then
let cost = block_cost(*else_destination, &function.dfg);
if cost < cost / 2 + jump_overhead {
result = Some(BasicConditional {
block_entry: block,
block_then,
block_else,
block_exit: right,
block_then: None,
block_else: Some(*else_destination),
block_exit: *else_destination,
});
}
}
Expand All @@ -161,7 +170,7 @@ fn is_conditional(
result.filter(|result| cfg.predecessors(result.block_exit).len() == 2)
}

/// Computes a cost estimate of a basic block
/// Computes a cost estimate for the execution of a basic block
/// returns u32::MAX if the block has side-effect instructions
/// WARNING: these are estimates of the runtime cost of each instruction,
/// 1 being the cost of the simplest instruction. These numbers can be improved.
Expand Down Expand Up @@ -241,6 +250,7 @@ fn flatten_function(function: &mut Function, no_predicates: &mut HashMap<Functio
let cfg = ControlFlowGraph::with_function(function);
let mut stack = vec![function.entry_block()];
let mut processed = HashSet::new();
// List of all the simple conditionals that we will identify in the function
let mut conditionals = Vec::new();

// 1. Process all blocks of the cfg, starting from the root and following the successors
Expand Down Expand Up @@ -268,6 +278,21 @@ fn flatten_function(function: &mut Function, no_predicates: &mut HashMap<Functio
flatten_multiple(&conditionals, function, no_predicates);
}

/// Flattens multiple basic conditionals within a function.
///
/// This function processes a collection of basic conditionals identified in the CFG and flattens them
/// to reduce control flow complexity. Each conditional is processed with its own context, and the
/// flattening results are then propagated throughout the entire function.
///
/// # Parameters
/// * `conditionals` - The list of basic conditionals to flatten, assumed in reverse order
/// * `function` - The function being optimized
/// * `no_predicates` - Map of function IDs to their no_predicates attribute for handling function calls
///
/// # Process
/// 1. Each conditional is flattened independently using a fresh context
/// 2. Value mappings from all conditionals are collected into a unified mapping
/// 3. The entire function is remapped only once in post-order to apply all value simplifications
fn flatten_multiple(
conditionals: &Vec<BasicConditional>,
function: &mut Function,
Expand Down Expand Up @@ -296,6 +321,24 @@ fn flatten_multiple(
}

impl Context<'_> {
/// Flattens a single basic conditional by inlining the 2 branches into the entry block.
///
/// This method transforms a conditional control flow pattern (if-then-else) into straight-line code
/// by merging the entry, then, else, and exit blocks. The conditional logic is converted into
/// predicated operations using cast and multiplication operations to select between branch values.
/// The method is adapted from `flatten_cfg`, tailored to do flatten only the input conditional.
///
/// # Parameters
/// * `conditional` - The basic conditional structure to flatten
/// * `no_predicates` - Map of function IDs to their no_predicates attribute
///
/// # Implementation Details
/// - Sets up context state (target_block, no_predicate) to enable proper inlining
/// - Inlines each block's instructions into the entry block
/// - Handles terminators to manage control flow during inlining
/// - Uses a WorkList to track which blocks need processing
/// - Copies the exit block's terminator to the entry block after inlining
/// - Restores original context state after completion
fn flatten_single_conditional(
&mut self,
conditional: &BasicConditional,
Expand Down Expand Up @@ -369,6 +412,16 @@ impl Context<'_> {
self.no_predicate = old_no_predicate;
}

/// Applies value mappings to all instructions and terminators in a block.
///
/// This method rewrites a block by replacing old value IDs with their mapped equivalents
/// according to the provided mapping. This is used to propagate value simplifications
/// from conditional flattening throughout the rest of the function.
///
/// # Parameters
/// * `mapping` - HashMap mapping old ValueIds to their simplified/replaced ValueIds
/// * `func` - The function containing the block to update
/// * `block` - The BasicBlockId of the block to remap
fn map_block_with_mapping(
mapping: HashMap<ValueId, ValueId>,
func: &mut Function,
Expand Down
Loading