diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 05ef9448b9c..e069cf4398d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -2,9 +2,9 @@ //! This includes branches in the CFG with non-constant conditions. Flattening these requires //! special handling for operations with side-effects and can lead to a loss of information since //! the jmpif will no longer be in the program. As a result, this pass should usually be towards or -//! at the end of the optimization passes. Note that this pass will also perform unexpectedly if -//! loops are still present in the program. Since the pass sees a normal jmpif, it will attempt to -//! merge both blocks, but no actual looping will occur. +//! at the end of the optimization passes. +//! Furthermore, this pass assumes that no loops are present in the program and will assume +//! that a jmpif is a branch point and will attempt to merge both blocks. No actual looping will occur. //! //! This pass is also known to produce some extra instructions which may go unused (usually 'Not') //! while merging branches. These extra instructions can be cleaned up by a later dead instruction @@ -218,7 +218,10 @@ pub(crate) struct Context<'f> { not_instructions: HashMap, /// Flag to tell the context to not issue 'enable_side_effect' instructions during flattening. - /// This should be set to true only by flatten_single(), when no instruction is known to fail. + /// + /// It is set with an attribute when defining a function that cannot fail whatsoever to avoid + /// the overhead of handling side effects. + /// It can also be set to true by flatten_single(), when no instruction is known to fail. pub(crate) no_predicate: bool, } @@ -247,6 +250,8 @@ struct ConditionalContext { call_stack: CallStackId, } +/// Flattens the control flow graph of the function such that it is left with a +/// single block containing all instructions and no more control-flow. fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap) { // This pass may run forever on a brillig function. // Analyze will check if the predecessors have been processed and push the block to the back of @@ -254,6 +259,10 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap Context<'f> { } } + /// Flatten the CFG by inlining all instructions from the queued blocks + /// until all blocks have been flattened. + /// We follow the terminator of each block to determine which blocks to + /// process next: + /// If the terminator is a 'JumpIf', we assume we are entering a conditional statement and + /// add the start blocks of the 'then_branch', 'else_branch' and the 'exit' block to the queue. + /// Other blocks will have only one successor, so we will process them iteratively, + /// until we reach one block already in the queue, i.e added when entering a conditional statement, + /// i.e the 'else_branch' or the 'exit'. In that case we switch to the next block in the queue, instead + /// of the successor. + /// This process ensure that the blocks are always processed in this order: + /// if_entry -> then_branch -> else_branch -> exit + /// In case of nested if statements, for instance in the 'then_branch', it will be: + /// if_entry -> then_branch -> if_entry_2 -> then_branch_2 -> exit_2 -> else_branch -> exit + /// Information about the nested if statements is stored in the 'condition_stack' which + /// is pop-ed/push-ed when entering/leaving a conditional statement. pub(crate) fn flatten(&mut self, no_predicates: &HashMap) { - // Flatten the CFG by inlining all instructions from the queued blocks - // until all blocks have been flattened. - // We follow the terminator of each block to determine which blocks to - // process next let mut queue = vec![self.target_block]; while let Some(block) = queue.pop() { self.inline_block(block, no_predicates); let to_process = self.handle_terminator(block, &queue); for incoming_block in to_process { + // Do not add blocks already in the queue if !queue.contains(&incoming_block) { queue.push(incoming_block); } @@ -326,6 +348,11 @@ impl<'f> Context<'f> { } /// Returns the current condition + /// + /// The conditions are in a stack, they are added as conditional branches are encountered + /// so the last one is the current condition. + /// When processing a conditional branch, we first follow the 'then' branch and only after we + /// process the 'else' branch. At that point, the ConditionalContext has the 'else_branch' fn get_last_condition(&self) -> Option { self.condition_stack.last().map(|context| match &context.else_branch { Some(else_branch) => else_branch.condition, @@ -348,7 +375,11 @@ impl<'f> Context<'f> { result } - // Inline all instructions from the given block into the target block, and track slice capacities + /// Inline all instructions from the given block into the target block, and track slice capacities + /// This is done by processing every instructions in the block and using the flattening context + /// to push them in the target block + /// + /// - `no_predicates` indicates which functions have no predicates and for which we disable the handling side effects pub(crate) fn inline_block( &mut self, block: BasicBlockId, @@ -388,6 +419,12 @@ impl<'f> Context<'f> { /// For a normal block, it would be its successor /// For blocks related to a conditional statement, we ensure to process /// the 'then-branch', then the 'else-branch' (if it exists), and finally the end block + /// The update of the context is done by the functions 'if_start', 'then_stop' and 'else_stop' + /// which perform the business logic when entering a conditional statement, finishing the 'then-branch' + /// and the 'else-branch, respectively. + /// We know if a block is related to the conditional statement if is referenced by the 'work_list' + /// Indeed, the start blocks of the 'then_branch' and 'else_branch' are added to the 'work_list' when + /// starting to process a conditional statement. pub(crate) fn handle_terminator( &mut self, block: BasicBlockId, @@ -430,7 +467,11 @@ impl<'f> Context<'f> { } } - /// Process a conditional statement + /// Process a conditional statement by creating a 'ConditionalContext' + /// with information about the branch, and storing it in the dedicated stack. + /// Local allocations are moved to the 'then_branch' of the ConditionalContext. + /// Returns the blocks corresponding to the 'then_branch', 'else_branch', and exit block of the conditional statement, + /// so that they will be processed in this order. fn if_start( &mut self, condition: &ValueId, @@ -472,7 +513,11 @@ impl<'f> Context<'f> { vec![self.branch_ends[if_entry], *else_destination, *then_destination] } - /// Switch context to the 'else-branch' + /// Switch context to the 'else-branch': + /// - Negates the condition for the 'else_branch' and set it in the ConditionalContext + /// - Move the local allocations to the 'else_branch' + /// - Issues the 'enable_side_effect' instruction + /// - Returns the exit block of the conditional statement fn then_stop(&mut self, block: &BasicBlockId) -> Vec { let mut cond_context = self.condition_stack.pop().unwrap(); cond_context.then_branch.last_block = *block; @@ -500,6 +545,7 @@ impl<'f> Context<'f> { vec![self.cfg.successors(*block).next().unwrap()] } + /// Negates a boolean value by inserting a Not instruction fn not_instruction(&mut self, condition: ValueId, call_stack: CallStackId) -> ValueId { if let Some(existing) = self.not_instructions.get(&condition) { return *existing; @@ -510,7 +556,10 @@ impl<'f> Context<'f> { not } - /// Process the 'exit' block of a conditional statement + /// Process the 'exit' block of a conditional statement: + /// - Retrieves the local allocations from the Conditional Context + /// - Issues the 'enable_side_effect' instruction + /// - Joins the arguments from both branches fn else_stop(&mut self, block: &BasicBlockId) -> Vec { let mut cond_context = self.condition_stack.pop().unwrap(); if cond_context.else_branch.is_none() { @@ -547,8 +596,9 @@ impl<'f> Context<'f> { /// all of the join point's predecessors, and it must handle any differing side effects from /// each branch. /// - /// Afterwards, continues inlining recursively until it finds the next end block or finds the - /// end of the function. + /// The merge of arguments is done by inserting an 'IfElse' instructions which returns + /// the argument from the then_branch or the else_branch depending the the condition. + /// They are added to the 'arguments_stack' instead of the arguments of the 2 branches. /// /// Returns the final block that was inlined. fn inline_branch_end( @@ -678,8 +728,10 @@ impl<'f> Context<'f> { } } - /// If we are currently in a branch, we need to modify constrain instructions - /// to multiply them by the branch's condition (see optimization #1 in the module comment). + /// If we are currently in a branch, we need to modify instructions that have side effects + /// (e.g. constraints, stores, range checks) to ensure that the side effect is only applied + /// if their branch is taken. + /// For instance we multiply constrain instructions by the branch's condition (see optimization #1 in the module comment). fn handle_instruction_side_effects( &mut self, instruction: Instruction, @@ -703,7 +755,7 @@ impl<'f> Context<'f> { if self.local_allocations.contains(&address) { Instruction::Store { address, value } } else { - // Instead of storing `value`, store `if condition { value } else { previous_value }` + // Instead of storing `value`, we store: `if condition { value } else { previous_value }` let typ = self.inserter.function.dfg.type_of_value(value); let load = Instruction::Load { address }; let previous_value = self @@ -734,6 +786,8 @@ impl<'f> Context<'f> { } Instruction::Call { func, mut arguments } => match self.inserter.function.dfg[func] { + // A ToBits (or ToRadix in general) can fail if the input has more bits than the target. + // We ensure it does not fail by multiplying the input by the condition. Value::Intrinsic(Intrinsic::ToBits(_) | Intrinsic::ToRadix(_)) => { let field = arguments[0]; let casted_condition = @@ -744,13 +798,14 @@ impl<'f> Context<'f> { Instruction::Call { func, arguments } } - //Issue #5045: We set curve points to infinity if condition is false + //Issue #5045: We set curve points to infinity if condition is false, to ensure that they are on the curve, if not the addition may fail. Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::EmbeddedCurveAdd)) => { arguments[2] = self.var_or_one(arguments[2], condition, call_stack); arguments[5] = self.var_or_one(arguments[5], condition, call_stack); Instruction::Call { func, arguments } } + // For MSM, we also ensure the inputs are on the curve if the predicate is false. Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)) => { let points_array_idx = if matches!( self.inserter.function.dfg.type_of_value(arguments[0]), @@ -782,6 +837,11 @@ impl<'f> Context<'f> { } } + /// 'Cast' the 'condition' to 'value' type + /// + /// This needed because we need to multiply the condition with several values + /// in order to 'nullify' side-effects when the 'condition' is false (in 'handle_instruction_side_effects()' function). + /// Since the condition is a boolean, it can be safely casted to any other type. fn cast_condition_to_value_type( &mut self, condition: ValueId, @@ -793,6 +853,7 @@ impl<'f> Context<'f> { self.insert_instruction(cast, call_stack) } + /// Insert a multiplication between 'condition' and 'value' fn mul_by_condition( &mut self, value: ValueId,