diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 15597bf6ef5..0c1cca8d47b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -383,6 +383,30 @@ impl DataFlowGraph { self.instructions[id] = instruction; } + /// Replaces values in the given block according to the given HashMap. + pub(crate) fn replace_values_in_block( + &mut self, + block: BasicBlockId, + values_to_replace: &HashMap, + ) { + self.replace_values_in_block_instructions(block, values_to_replace); + self.replace_values_in_block_terminator(block, values_to_replace); + } + + /// Replaces values in the given block instructions according to the given HashMap. + pub(crate) fn replace_values_in_block_instructions( + &mut self, + block: BasicBlockId, + values_to_replace: &HashMap, + ) { + let instruction_ids = self.blocks[block].take_instructions(); + for instruction_id in &instruction_ids { + let instruction = &mut self[*instruction_id]; + instruction.replace_values(values_to_replace); + } + *self[block].instructions_mut() = instruction_ids; + } + /// Replaces values in the given block terminator (if it has any) according to the given HashMap. pub(crate) fn replace_values_in_block_terminator( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index b042cf9a41e..580d20e4f5b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -9,6 +9,7 @@ //! only 1 successor then (2) also will be applied. //! //! Currently, 1 and 4 are unimplemented. +use fxhash::FxHashMap as HashMap; use std::collections::HashSet; use acvm::acir::AcirField; @@ -19,7 +20,7 @@ use crate::ssa::{ cfg::ControlFlowGraph, function::{Function, RuntimeType}, instruction::{Instruction, TerminatorInstruction}, - value::Value, + value::{Value, ValueId}, }, ssa_gen::Ssa, }; @@ -49,6 +50,7 @@ impl Function { /// be inlined into their predecessor. pub(crate) fn simplify_function(&mut self) { let mut cfg = ControlFlowGraph::with_function(self); + let mut values_to_replace = HashMap::default(); let mut stack = vec![self.entry_block()]; let mut visited = HashSet::new(); @@ -57,6 +59,10 @@ impl Function { stack.extend(self.dfg[block].successors().filter(|block| !visited.contains(block))); } + if !values_to_replace.is_empty() { + self.dfg.replace_values_in_block_instructions(block, &values_to_replace); + } + check_for_negated_jmpif_condition(self, block, &mut cfg); // This call is before try_inline_into_predecessor so that if it succeeds in changing a @@ -70,7 +76,7 @@ impl Function { drop(predecessors); // If the block has only 1 predecessor, we can safely remove its block parameters - remove_block_parameters(self, block, predecessor); + remove_block_parameters(self, block, predecessor, &mut values_to_replace); // Note: this function relies on `remove_block_parameters` being called first. // Otherwise the inlined block will refer to parameters that no longer exist. @@ -84,6 +90,18 @@ impl Function { check_for_double_jmp(self, block, &mut cfg); } + + if !values_to_replace.is_empty() { + self.dfg.replace_values_in_block_terminator(block, &values_to_replace); + } + } + + if !values_to_replace.is_empty() { + // Values from previous blocks might need to be replaced + for block in self.reachable_blocks() { + self.dfg.replace_values_in_block(block, &values_to_replace); + } + self.dfg.data_bus.replace_values(&values_to_replace); } } } @@ -246,6 +264,7 @@ fn remove_block_parameters( function: &mut Function, block: BasicBlockId, predecessor: BasicBlockId, + values_to_replace: &mut HashMap, ) { let block = &mut function.dfg[block]; @@ -264,7 +283,7 @@ fn remove_block_parameters( assert_eq!(block_params.len(), jump_args.len()); for (param, arg) in block_params.iter().zip(jump_args) { - function.dfg.set_value_from_id(*param, arg); + values_to_replace.insert(*param, arg); } } } @@ -296,128 +315,53 @@ fn try_inline_into_predecessor( mod test { use crate::{ assert_ssa_snapshot, - ssa::{ - Ssa, - function_builder::FunctionBuilder, - ir::{ - instruction::{BinaryOp, TerminatorInstruction}, - map::Id, - types::Type, - }, - opt::assert_normalized_ssa_equals, - }, + ssa::{Ssa, opt::assert_normalized_ssa_equals}, }; - use acvm::acir::AcirField; #[test] fn inline_blocks() { - // fn main { - // b0(): - // jmp b1(Field 7) - // b1(v0: Field): - // jmp b2(v0) - // b2(v1: Field): - // return v1 - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - - let v0 = builder.add_block_parameter(b1, Type::field()); - let v1 = builder.add_block_parameter(b2, Type::field()); - - let expected_return = 7u128; - let seven = builder.field_constant(expected_return); - builder.terminate_with_jmp(b1, vec![seven]); - - builder.switch_to_block(b1); - builder.terminate_with_jmp(b2, vec![v0]); - - builder.switch_to_block(b2); - builder.terminate_with_return(vec![v1]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 3); + let src = " + acir(inline) fn main f0 { + b0(): + jmp b1(Field 7) + b1(v0: Field): + jmp b2(v0) + b2(v1: Field): + return v1 + } + "; + let ssa = Ssa::from_str(src).unwrap(); - // Expected output: - // fn main { - // b0(): - // return Field 7 - // } let ssa = ssa.simplify_cfg(); - let main = ssa.main(); - assert_eq!(main.reachable_blocks().len(), 1); - - match main.dfg[main.entry_block()].terminator() { - Some(TerminatorInstruction::Return { return_values, .. }) => { - assert_eq!(return_values.len(), 1); - let return_value = main - .dfg - .get_numeric_constant(return_values[0]) - .expect("Expected return value to be constant") - .to_u128(); - assert_eq!(return_value, expected_return); - } - other => panic!("Unexpected terminator {other:?}"), + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(): + return Field 7 } + "); } #[test] fn remove_known_jmpif() { - // fn main { - // b0(v0: u1): - // v1 = eq v0, v0 - // jmpif v1, then: b1, else: b2 - // b1(): - // return Field 1 - // b2(): - // return Field 2 - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - let v0 = builder.add_parameter(Type::bool()); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - - let one = builder.field_constant(1u128); - let two = builder.field_constant(2u128); - - let v1 = builder.insert_binary(v0, BinaryOp::Eq, v0); - builder.terminate_with_jmpif(v1, b1, b2); - - builder.switch_to_block(b1); - builder.terminate_with_return(vec![one]); - - builder.switch_to_block(b2); - builder.terminate_with_return(vec![two]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 3); + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif u1 1 then: b1, else: b2 + b1(): + return Field 1 + b2(): + return Field 2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); - // Expected output: - // fn main { - // b0(): - // return Field 1 - // } let ssa = ssa.simplify_cfg(); - let main = ssa.main(); - assert_eq!(main.reachable_blocks().len(), 1); - - match main.dfg[main.entry_block()].terminator() { - Some(TerminatorInstruction::Return { return_values, .. }) => { - assert_eq!(return_values.len(), 1); - let return_value = main - .dfg - .get_numeric_constant(return_values[0]) - .expect("Expected return value to be constant") - .to_u128(); - assert_eq!(return_value, 1u128); - } - other => panic!("Unexpected terminator {other:?}"), + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1): + return Field 1 } + "); } #[test]