diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index eca31b4b00c..c3e27f6a070 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -1,9 +1,35 @@ //! This module defines the defunctionalization pass for the SSA IR. -//! The purpose of this pass is to transforms all functions used as values into -//! constant numbers (fields) that represent the function id. That way all calls -//! with a non-literal target can be replaced with a call to an apply function. -//! The apply function is a dispatch function that takes the function id as a parameter -//! and dispatches to the correct target. +//! Certain IR targets (e.g., Brillig and ACIR) do not support higher-order functions directly. +//! +//! The pass eliminates higher-order functions (a function which accepts function values as arguments or returns functions) +//! by transforming functions used as values (i.e., first-class functions) +//! into constant numbers (fields) that represent their function IDs. +//! +//! Defunctionalization handles higher-order functions functions by lowering function values into +//! constant identifiers and replacing calls of function values with calls to a single +//! dispatch `apply` function. +//! +//! ## How the pass works: +//! - Every function used as a value (e.g., passed as a parameter) is assigned a unique [NumericType::NativeField] value. +//! This value now represents the first-class function's ID. +//! - All call instructions with non-literal targets are replaced by calls to an `apply` function. +//! - The `apply` function is a dispatcher. It takes the function ID as its first argument +//! and calls the appropriate function based on that ID. +//! +//! Pseudocode of an `apply` function is given below: +//! ```text +//! fn apply(function_id: Field, arg1: Field, arg2: Field) -> Field { +//! match function_id { +//! 0 -> function0(arg1, arg2), +//! 1 -> function0(arg1, arg2), +//! ... +//! N -> functionN(arg1, arg2), +//! } +//! } +//! ``` +//! +//! After this pass all first-class functions are replaced with numeric IDs +//! and calls are routed via the newly generated `apply` functions. use std::collections::{BTreeMap, BTreeSet}; use acvm::FieldElement; @@ -44,7 +70,12 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +/// All functions used as a value that share the same signature and runtime type +/// Maps ([Signature], [RuntimeType]) -> Vec<[FunctionId]> type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +/// All generated apply functions for each grouping of function variants. +/// Each apply function is handles a specific ([Signature], [RuntimeType]) group. +/// Maps ([Signature], [RuntimeType]) -> [ApplyFunction] type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; /// Performs defunctionalization on all functions @@ -56,15 +87,19 @@ struct DefunctionalizationContext { } impl Ssa { + /// See [`defunctionalize`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn defunctionalize(mut self) -> Ssa { - // Find all functions used as value that share the same signature + // Find all functions used as value that share the same signature and runtime type let variants = find_variants(&self); + // Generate the apply functions for the provided variants let apply_functions = create_apply_functions(&mut self, variants); + // Setup the pass context let context = DefunctionalizationContext { apply_functions }; + // Run defunctionalization over all functions in the SSA context.defunctionalize_all(&mut self); self } @@ -215,6 +250,15 @@ fn map_function_to_field(func: &mut Function, value: ValueId) -> Option } /// Collects all functions used as values that can be called by their signatures +/// +/// Groups all [FunctionId]s used as values by their [Signature] and [RuntimeType], +/// producing a mapping from these tuples to the list of variant functions to be dynamically dispatched. +/// +/// # Arguments +/// - `ssa`: The full [Ssa] structure +/// +/// # Returns +/// [Variants] that should then be used to generate apply functions for dispatching fn find_variants(ssa: &Ssa) -> Variants { let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); @@ -226,6 +270,7 @@ fn find_variants(ssa: &Ssa) -> Variants { ); } + // Group function variant candidates by their signature let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); for function_id in functions_as_values { @@ -235,16 +280,21 @@ fn find_variants(ssa: &Ssa) -> Variants { let mut variants: Variants = BTreeMap::new(); + // Further group function variant candidates by their caller runtime. for (dispatch_signature, caller_runtime) in dynamic_dispatches { let target_fns = signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); variants.insert((dispatch_signature, caller_runtime), target_fns); } + // We will now have fully constructed our variants map and can return it variants } /// Finds all literal functions used as values in the given function +/// +/// It is assumed that function values will only ever be used in a call instruction +/// or a store instruction. fn find_functions_as_values(func: &Function) -> BTreeSet { let mut functions_as_values: BTreeSet = BTreeSet::new(); @@ -276,6 +326,10 @@ fn find_functions_as_values(func: &Function) -> BTreeSet { } /// Finds all dynamic dispatch signatures in the given function +/// +/// A dynamic dispatch is defined as a call into a function value where that +/// value comes from a parameter (i.e., calling a function passed as a function parameter +/// or another instruction (i.e., calling a function returned from another function call). fn find_dynamic_dispatches(func: &Function) -> BTreeSet { let mut dispatches = BTreeSet::new(); @@ -300,10 +354,21 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { dispatches } -fn create_apply_functions( - ssa: &mut Ssa, - variants_map: BTreeMap<(Signature, RuntimeType), Vec>, -) -> ApplyFunctions { +/// Creates all apply functions needed for dispatch of function values. +/// +/// This function maintains the grouping set in [Variants], meaning an apply +/// function is grouped by functions that share a signature and runtime. +/// An apply function is only created if there are multiple function variants +/// for a specific ([Signature], [RuntimeType]) group. +/// Otherwise, if there is a single variant that function is simply reused. +/// +/// # Arguments +/// - `ssa`: A mutable reference to the full [Ssa] structure containing all functions. +/// - `variants_map`: [Variants] +/// +/// # Returns +/// [ApplyFunctions] +fn create_apply_functions(ssa: &mut Ssa, variants_map: Variants) -> ApplyFunctions { let mut apply_functions = HashMap::default(); for ((mut signature, runtime), variants) in variants_map.into_iter() { assert!( @@ -312,12 +377,15 @@ fn create_apply_functions( ); let dispatches_to_multiple_functions = variants.len() > 1; + // Update the shared function signature of the higher-order function variants + // to replace any function passed as a value to a numeric field type. for param in &mut signature.params { if *param == Type::Function { *param = Type::field(); } } + // Update the return value types as we did for the signature parameters above. for ret in &mut signature.returns { if *ret == Type::Function { *ret = Type::field(); @@ -325,8 +393,11 @@ fn create_apply_functions( } let id = if dispatches_to_multiple_functions { + // If we have multiple variants for this signature and runtime type group + // we need to generate an apply function. create_apply_function(ssa, signature.clone(), runtime, variants) } else { + // If there is only variant, we can use it directly rather than creating a new apply function. variants[0] }; apply_functions @@ -335,11 +406,34 @@ fn create_apply_functions( apply_functions } +/// Transforms a [FunctionId] into a [FieldElement] fn function_id_to_field(function_id: FunctionId) -> FieldElement { (function_id.to_u32() as u128).into() } -/// Creates an apply function for the given signature and variants +/// Creates a single apply function to enable dispatch across multiple function variants +/// that share the same [Signature] and [RuntimeType]. +/// +/// This function is responsible for generating an entry point that dispatches between several +/// concrete functions at runtime based on a target field value. It builds a sequence of +/// conditional checks (if-else chain) to compare the target against each +/// function's ID, and calls the matching function. +/// +/// These apply functions are to be aggressively inlined as it is assumed that they will be optimized +/// away by the constants at the call site. +/// +/// # Arguments +/// - `ssa`: A mutable reference to the full [Ssa] structure containing all functions. +/// - `signature`: The shared [Signature] of all variants. +/// - `caller_runtime`: The runtime in which the apply function will be called, used to update inlining policies. +/// - `function_ids`: A non-empty list of [FunctionId]s representing concrete functions to dispatch between. +/// This method will panic if `function_ids` is empty. +/// +/// # Returns +/// The [FunctionId] of the new apply function +/// +/// # Panics +/// If the `function_ids` argument is empty. fn create_apply_function( ssa: &mut Ssa, signature: Signature, @@ -347,6 +441,10 @@ fn create_apply_function( function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); + // Clone the user-defined globals and the function purities mapping, + // which are shared across all functions. + // We will be borrowing `ssa` mutably so we need to fetch this shared information + // before attempting to add a new function to the SSA. let globals = ssa.main().dfg.globals.clone(); let purities = ssa.main().dfg.function_purities.clone(); ssa.add_fn(|id| { @@ -361,10 +459,18 @@ fn create_apply_function( RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways), }; function_builder.set_runtime(runtime); + // Set up the parameters of the apply function + // The first argument is the target function ID for which are dispatching a call let target_id = function_builder.add_parameter(Type::field()); + // The remaining apply function parameters are the actual parameters of the variants for which we are dispatching calls let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); - let mut previous_target_block = None; + let entry_block = function_builder.current_block(); + + let return_block = build_return_block(&mut function_builder, &signature.returns); + // Switch back to the entry block to build the rest of the dispatch function + function_builder.switch_to_block(entry_block); + for (index, function_id) in function_ids.iter().enumerate() { let is_last = index == function_ids.len() - 1; let mut next_function_block = None; @@ -389,25 +495,15 @@ fn create_apply_function( // Else just constrain the condition function_builder.insert_constrain(target_id, function_id_constant, None); } - // Find the target block or build it if necessary - let current_block = function_builder.current_block(); - - let target_block = build_return_block( - &mut function_builder, - current_block, - &signature.returns, - previous_target_block, - ); - previous_target_block = Some(target_block); - - // Call the function + + // Call the function variant let target_function_value = function_builder.import_function(*function_id); let call_results = function_builder .insert_call(target_function_value, params_ids.clone(), signature.returns.clone()) .to_vec(); - // Jump to the target block for returning - function_builder.terminate_with_jmp(target_block, call_results); + // Jump to the return block + function_builder.terminate_with_jmp(return_block, call_results); if let Some(next_block) = next_function_block { // Switch to the next block for the else branch @@ -418,22 +514,23 @@ fn create_apply_function( }) } -/// If no previous return target exists, it will create a final return, -/// otherwise returns the existing return block to jump to. -fn build_return_block( - builder: &mut FunctionBuilder, - previous_block: BasicBlockId, - passed_types: &[Type], - target: Option, -) -> BasicBlockId { - if let Some(return_block) = target { - return return_block; - } +/// Create the final return block for an apply function. +/// +/// The return block is meant to be shared among all branches of the apply function. +/// The apply function will jump to this block after calling the appropriate +/// target function. +/// +/// # Arguments +/// * `builder` - [FunctionBuilder] used to construct the function's SSA. +/// * `passed_types` - A slice of [Type]s representing the values to be returned from the function. +/// +/// # Returns +/// A [BasicBlockId] representing the newly created return block. +fn build_return_block(builder: &mut FunctionBuilder, passed_types: &[Type]) -> BasicBlockId { let return_block = builder.insert_block(); builder.switch_to_block(return_block); let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone())); builder.terminate_with_return(params); - builder.switch_to_block(previous_block); return_block } @@ -446,6 +543,7 @@ mod tests { #[test] fn apply_inherits_caller_runtime() { // Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig` + // with an additional simple higher-order function let src = " brillig(inline) fn main f0 { b0(v0: u32): @@ -457,6 +555,9 @@ mod tests { v9 = add v0, u32 1 v10 = eq v8, v9 constrain v8 == v9 + v11 = call f1(f4, v0) -> u32 + v12 = add v0, u32 1 + constrain v11 == v12 return } brillig(inline) fn wrapper f1 { @@ -474,6 +575,11 @@ mod tests { v2 = add v0, u32 1 return v2 } + brillig(inline) fn increment_three f4 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } "; let ssa = Ssa::from_str(src).unwrap(); @@ -490,11 +596,14 @@ mod tests { v9 = add v0, u32 1 v10 = eq v8, v9 constrain v8 == v9 + v12 = call f1(Field 4, v0) -> u32 + v13 = add v0, u32 1 + constrain v12 == v13 return } brillig(inline) fn wrapper f1 { b0(v0: Field, v1: u32): - v3 = call f4(v0, v1) -> u32 + v3 = call f5(v0, v1) -> u32 return v3 } brillig(inline) fn increment f2 { @@ -507,19 +616,30 @@ mod tests { v2 = add v0, u32 1 return v2 } - brillig(inline_always) fn apply f4 { + brillig(inline) fn increment_three f4 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline_always) fn apply f5 { b0(v0: Field, v1: u32): v4 = eq v0, Field 2 - jmpif v4 then: b2, else: b1 - b1(): - constrain v0 == Field 3 - v9 = call f3(v1) -> u32 - jmp b3(v9) + jmpif v4 then: b3, else: b2 + b1(v2: u32): + return v2 b2(): + v8 = eq v0, Field 3 + jmpif v8 then: b5, else: b4 + b3(): v6 = call f2(v1) -> u32 - jmp b3(v6) - b3(v2: u32): - return v2 + jmp b1(v6) + b4(): + constrain v0 == Field 4 + v13 = call f4(v1) -> u32 + jmp b1(v13) + b5(): + v10 = call f3(v1) -> u32 + jmp b1(v10) } "); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 7e3113424e0..957ae2787fd 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -3,7 +3,7 @@ //! any `Store` instructions within a block that are no longer needed because no more loads occur in //! between the Store in question and the next Store. //! -//! The pass works as follows: +//! ## How the pass works: //! - Each block in each function is iterated in forward-order. //! - The starting value of each reference in the block is the unification of the same references //! at the end of each direct predecessor block to the current block.