Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
//! Removes any unreachable functions from the code. These can result from
//! optimizations making existing functions unreachable, e.g. `if false { foo() }`,
//! or even from monomorphizing an unconstrained version of a constrained function
//! where the original constrained version ends up never being used.
//!
//! This pass identifies all unreachable functions and prunes them from the
//! function set. Reachability is defined as:
//! - A function is reachable if it is an entry point (e.g., `main`)
//! - A function is reachable if it is called from another reachable function
//! - A function is reachable if it is stored in a reference (e.g., in a `Store` instruction) from another reachable function.
//! Even if not immediately called, it may later be dynamically loaded and invoked.
//! This marking is conservative but ensures correctness. We should instead rely on [mem2reg][crate::ssa::opt::mem2reg]
//! for resolving loads/stores.
//!
//! The pass performs a recursive traversal starting from all entry points and marks
//! any transitively reachable functions. It then discards the rest.
//!
//! This pass helps shrink the SSA before compilation stages like inlining and dead code elimination.

use std::collections::BTreeSet;

use fxhash::FxHashSet as HashSet;
Expand All @@ -12,36 +31,49 @@ use crate::ssa::{
};

impl Ssa {
/// Removes any unreachable functions from the code. These can result from
/// optimizations making existing functions unreachable, e.g. `if false { foo() }`,
/// or even from monomorphizing an unconstrained version of a constrained function
/// where the original constrained version ends up never being used.
/// See [`remove_unreachable`][self] module for more information.
pub(crate) fn remove_unreachable_functions(mut self) -> Self {
let mut used_functions = HashSet::default();
let mut reachable_functions = HashSet::default();

// Go through all the functions, and if we have an entry point, extend the set of all
// functions which are reachable.
for (id, function) in self.functions.iter() {
// XXX: `self.is_entry_point(*id)` could leave Brillig functions that nobody calls in the SSA.
let is_entry_point = function.id() == self.main_id
|| function.runtime().is_acir() && function.runtime().is_entry_point();

if is_entry_point {
collect_reachable_functions(&self, *id, &mut used_functions);
collect_reachable_functions(&self, *id, &mut reachable_functions);
}
}

self.functions.retain(|id, _| used_functions.contains(id));
// Discard all functions not marked as reachable
self.functions.retain(|id, _| reachable_functions.contains(id));
self
}
}

/// Recursively determine the reachable functions from a given function.
/// This function is only intended to be called on functions that are already known
/// to be entry points or transitively reachable from one.
///
/// # Arguments
/// - `ssa`: The full [Ssa] structure containing all functions.
/// - `current_func_id`: The [FunctionId] from which to begin a traversal.
/// - `reachable_functions`: A mutable set used to collect all reachable functions.
/// It serves both as the final output of this traversal and as a visited set
/// to prevent cycles and redundant recursion.
fn collect_reachable_functions(
ssa: &Ssa,
current_func_id: FunctionId,
reachable_functions: &mut HashSet<FunctionId>,
) {
// If this function has already been determine as reachable, then we have already
// processed the given function and we can simply return.
if reachable_functions.contains(&current_func_id) {
return;
}
// Mark the given function as reachable
reachable_functions.insert(current_func_id);

// If the debugger is used, its possible for function inlining
Expand All @@ -50,13 +82,26 @@ fn collect_reachable_functions(
return;
};

// Get the set of reachable functions from the given function
let used_functions = used_functions(func);

// For each reachable function within the given function recursively collect
// any more reachable functions.
for called_func_id in used_functions.iter() {
collect_reachable_functions(ssa, *called_func_id, reachable_functions);
}
}

/// Identifies all reachable function IDs within a given function.
/// This includes:
/// - Function calls (functions used via `Call` instructions)
/// - Function references (functions stored via `Store` instructions)
///
/// # Arguments
/// - `func`: The [Function] to analyze for usage
///
/// # Returns
/// A sorted set of [`FunctionId`]s that are reachable from the function.
fn used_functions(func: &Function) -> BTreeSet<FunctionId> {
let mut used_function_ids = BTreeSet::default();

Expand Down
Loading