diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs index 7c91b5f0fe5..419120fc4e8 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/simplify_cfg.rs @@ -5,8 +5,10 @@ //! 2. Inlines a block into its sole predecessor if that predecessor only has one successor. //! 3. Removes any block arguments for blocks with only a single predecessor. //! 4. Removes any blocks which have no instructions other than a single terminating jmp. +//! 5. Replaces any jmpifs with constant conditions with jmps. If this causes the block to have +//! only 1 successor then (2) also will be applied. //! -//! Currently, only 2 and 3 are implemented. +//! Currently, 1 and 4 are unimplemented. use std::collections::HashSet; use crate::ssa_refactor::{ @@ -23,8 +25,10 @@ impl Ssa { /// 2. Inlining a block into its sole predecessor if that predecessor only has one successor. /// 3. Removing any block arguments for blocks with only a single predecessor. /// 4. Removing any blocks which have no instructions other than a single terminating jmp. + /// 5. Replacing any jmpifs with constant conditions with jmps. If this causes the block to have + /// only 1 successor then (2) also will be applied. /// - /// Currently, only 2 and 3 are implemented. + /// Currently, 1 and 4 are unimplemented. pub(crate) fn simplify_cfg(mut self) -> Self { for function in self.functions.values_mut() { simplify_function(function); @@ -45,6 +49,10 @@ fn simplify_function(function: &mut Function) { stack.extend(function.dfg[block].successors().filter(|block| !visited.contains(block))); } + // This call is before try_inline_into_predecessor so that if it succeeds in changing a + // jmpif into a jmp, the block may then be inlined entirely into its predecessor in try_inline_into_predecessor. + check_for_constant_jmpif(function, block, &mut cfg); + let mut predecessors = cfg.predecessors(block); if predecessors.len() == 1 { @@ -65,6 +73,26 @@ fn simplify_function(function: &mut Function) { } } +/// Optimize a jmpif into a jmp if the condition is known +fn check_for_constant_jmpif( + function: &mut Function, + block: BasicBlockId, + cfg: &mut ControlFlowGraph, +) { + if let Some(TerminatorInstruction::JmpIf { condition, then_destination, else_destination }) = + function.dfg[block].terminator() + { + if let Some(constant) = function.dfg.get_numeric_constant(*condition) { + let destination = + if constant.is_zero() { *else_destination } else { *then_destination }; + + let jmp = TerminatorInstruction::Jmp { destination, arguments: Vec::new() }; + function.dfg[block].set_terminator(jmp); + cfg.recompute_block(function, block); + } + } +} + /// If the given block has block parameters, replace them with the jump arguments from the predecessor. /// /// Currently, if this function is needed, `try_inline_into_predecessor` will also always apply, @@ -130,7 +158,11 @@ fn try_inline_into_predecessor( #[cfg(test)] mod test { use crate::ssa_refactor::{ - ir::{instruction::TerminatorInstruction, map::Id, types::Type}, + ir::{ + instruction::{BinaryOp, TerminatorInstruction}, + map::Id, + types::Type, + }, ssa_builder::FunctionBuilder, }; @@ -189,4 +221,61 @@ mod test { other => panic!("Unexpected terminator {other:?}"), } } + + #[test] + fn remove_known_jmpif() { + // fn main { + // b0(v0: u1): + // v1 = eq v0, v0 + // jmpif v1, then: b1, else: b2 + // b1(): + // return Field 1 + // b2(): + // return Field 2 + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + let v0 = builder.add_parameter(Type::bool()); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + + let one = builder.field_constant(1u128); + let two = builder.field_constant(2u128); + + let v1 = builder.insert_binary(v0, BinaryOp::Eq, v0); + builder.terminate_with_jmpif(v1, b1, b2); + + builder.switch_to_block(b1); + builder.terminate_with_return(vec![one]); + + builder.switch_to_block(b2); + builder.terminate_with_return(vec![two]); + + let ssa = builder.finish(); + assert_eq!(ssa.main().reachable_blocks().len(), 3); + + // Expected output: + // fn main { + // b0(): + // return Field 1 + // } + let ssa = ssa.simplify_cfg(); + let main = ssa.main(); + println!("{}", main); + assert_eq!(main.reachable_blocks().len(), 1); + + match main.dfg[main.entry_block()].terminator() { + Some(TerminatorInstruction::Return { return_values }) => { + assert_eq!(return_values.len(), 1); + let return_value = main + .dfg + .get_numeric_constant(return_values[0]) + .expect("Expected return value to be constant") + .to_u128(); + assert_eq!(return_value, 1u128); + } + other => panic!("Unexpected terminator {other:?}"), + } + } }