diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs index 24ca8241f35..5a72aceb5f5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs @@ -1,3 +1,22 @@ +//! Removes any unreachable functions from the code. These can result from +//! optimizations making existing functions unreachable, e.g. `if false { foo() }`, +//! or even from monomorphizing an unconstrained version of a constrained function +//! where the original constrained version ends up never being used. +//! +//! This pass identifies all unreachable functions and prunes them from the +//! function set. Reachability is defined as: +//! - A function is reachable if it is an entry point (e.g., `main`) +//! - A function is reachable if it is called from another reachable function +//! - A function is reachable if it is stored in a reference (e.g., in a `Store` instruction) from another reachable function. +//! Even if not immediately called, it may later be dynamically loaded and invoked. +//! This marking is conservative but ensures correctness. We should instead rely on [mem2reg][crate::ssa::opt::mem2reg] +//! for resolving loads/stores. +//! +//! The pass performs a recursive traversal starting from all entry points and marks +//! any transitively reachable functions. It then discards the rest. +//! +//! This pass helps shrink the SSA before compilation stages like inlining and dead code elimination. + use std::collections::BTreeSet; use fxhash::FxHashSet as HashSet; @@ -12,36 +31,49 @@ use crate::ssa::{ }; impl Ssa { - /// Removes any unreachable functions from the code. These can result from - /// optimizations making existing functions unreachable, e.g. `if false { foo() }`, - /// or even from monomorphizing an unconstrained version of a constrained function - /// where the original constrained version ends up never being used. + /// See [`remove_unreachable`][self] module for more information. pub(crate) fn remove_unreachable_functions(mut self) -> Self { - let mut used_functions = HashSet::default(); + let mut reachable_functions = HashSet::default(); + // Go through all the functions, and if we have an entry point, extend the set of all + // functions which are reachable. for (id, function) in self.functions.iter() { // XXX: `self.is_entry_point(*id)` could leave Brillig functions that nobody calls in the SSA. let is_entry_point = function.id() == self.main_id || function.runtime().is_acir() && function.runtime().is_entry_point(); if is_entry_point { - collect_reachable_functions(&self, *id, &mut used_functions); + collect_reachable_functions(&self, *id, &mut reachable_functions); } } - self.functions.retain(|id, _| used_functions.contains(id)); + // Discard all functions not marked as reachable + self.functions.retain(|id, _| reachable_functions.contains(id)); self } } +/// Recursively determine the reachable functions from a given function. +/// This function is only intended to be called on functions that are already known +/// to be entry points or transitively reachable from one. +/// +/// # Arguments +/// - `ssa`: The full [Ssa] structure containing all functions. +/// - `current_func_id`: The [FunctionId] from which to begin a traversal. +/// - `reachable_functions`: A mutable set used to collect all reachable functions. +/// It serves both as the final output of this traversal and as a visited set +/// to prevent cycles and redundant recursion. fn collect_reachable_functions( ssa: &Ssa, current_func_id: FunctionId, reachable_functions: &mut HashSet, ) { + // If this function has already been determine as reachable, then we have already + // processed the given function and we can simply return. if reachable_functions.contains(¤t_func_id) { return; } + // Mark the given function as reachable reachable_functions.insert(current_func_id); // If the debugger is used, its possible for function inlining @@ -50,13 +82,26 @@ fn collect_reachable_functions( return; }; + // Get the set of reachable functions from the given function let used_functions = used_functions(func); + // For each reachable function within the given function recursively collect + // any more reachable functions. for called_func_id in used_functions.iter() { collect_reachable_functions(ssa, *called_func_id, reachable_functions); } } +/// Identifies all reachable function IDs within a given function. +/// This includes: +/// - Function calls (functions used via `Call` instructions) +/// - Function references (functions stored via `Store` instructions) +/// +/// # Arguments +/// - `func`: The [Function] to analyze for usage +/// +/// # Returns +/// A sorted set of [`FunctionId`]s that are reachable from the function. fn used_functions(func: &Function) -> BTreeSet { let mut used_function_ids = BTreeSet::default();