Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,30 @@ impl DataFlowGraph {
self.instructions[id] = instruction;
}

/// Replaces values in the given block according to the given HashMap.
pub(crate) fn replace_values_in_block(
&mut self,
block: BasicBlockId,
values_to_replace: &HashMap<ValueId, ValueId>,
) {
self.replace_values_in_block_instructions(block, values_to_replace);
self.replace_values_in_block_terminator(block, values_to_replace);
}

/// Replaces values in the given block instructions according to the given HashMap.
pub(crate) fn replace_values_in_block_instructions(
&mut self,
block: BasicBlockId,
values_to_replace: &HashMap<ValueId, ValueId>,
) {
let instruction_ids = self.blocks[block].take_instructions();
for instruction_id in &instruction_ids {
let instruction = &mut self[*instruction_id];
instruction.replace_values(values_to_replace);
}
*self[block].instructions_mut() = instruction_ids;
}

/// Replaces values in the given block terminator (if it has any) according to the given HashMap.
pub(crate) fn replace_values_in_block_terminator(
&mut self,
Expand Down
166 changes: 55 additions & 111 deletions compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//! only 1 successor then (2) also will be applied.
//!
//! Currently, 1 and 4 are unimplemented.
use fxhash::FxHashMap as HashMap;
use std::collections::HashSet;

use acvm::acir::AcirField;
Expand All @@ -19,7 +20,7 @@ use crate::ssa::{
cfg::ControlFlowGraph,
function::{Function, RuntimeType},
instruction::{Instruction, TerminatorInstruction},
value::Value,
value::{Value, ValueId},
},
ssa_gen::Ssa,
};
Expand Down Expand Up @@ -49,6 +50,7 @@ impl Function {
/// be inlined into their predecessor.
pub(crate) fn simplify_function(&mut self) {
let mut cfg = ControlFlowGraph::with_function(self);
let mut values_to_replace = HashMap::default();
let mut stack = vec![self.entry_block()];
let mut visited = HashSet::new();

Expand All @@ -57,6 +59,10 @@ impl Function {
stack.extend(self.dfg[block].successors().filter(|block| !visited.contains(block)));
}

if !values_to_replace.is_empty() {
self.dfg.replace_values_in_block_instructions(block, &values_to_replace);
}

check_for_negated_jmpif_condition(self, block, &mut cfg);

// This call is before try_inline_into_predecessor so that if it succeeds in changing a
Expand All @@ -70,7 +76,7 @@ impl Function {
drop(predecessors);

// If the block has only 1 predecessor, we can safely remove its block parameters
remove_block_parameters(self, block, predecessor);
remove_block_parameters(self, block, predecessor, &mut values_to_replace);

// Note: this function relies on `remove_block_parameters` being called first.
// Otherwise the inlined block will refer to parameters that no longer exist.
Expand All @@ -84,6 +90,18 @@ impl Function {

check_for_double_jmp(self, block, &mut cfg);
}

if !values_to_replace.is_empty() {
self.dfg.replace_values_in_block_terminator(block, &values_to_replace);
}
}

if !values_to_replace.is_empty() {
// Values from previous blocks might need to be replaced
for block in self.reachable_blocks() {
self.dfg.replace_values_in_block(block, &values_to_replace);
}
self.dfg.data_bus.replace_values(&values_to_replace);
}
}
}
Expand Down Expand Up @@ -246,6 +264,7 @@ fn remove_block_parameters(
function: &mut Function,
block: BasicBlockId,
predecessor: BasicBlockId,
values_to_replace: &mut HashMap<ValueId, ValueId>,
) {
let block = &mut function.dfg[block];

Expand All @@ -264,7 +283,7 @@ fn remove_block_parameters(

assert_eq!(block_params.len(), jump_args.len());
for (param, arg) in block_params.iter().zip(jump_args) {
function.dfg.set_value_from_id(*param, arg);
values_to_replace.insert(*param, arg);
}
}
}
Expand Down Expand Up @@ -296,128 +315,53 @@ fn try_inline_into_predecessor(
mod test {
use crate::{
assert_ssa_snapshot,
ssa::{
Ssa,
function_builder::FunctionBuilder,
ir::{
instruction::{BinaryOp, TerminatorInstruction},
map::Id,
types::Type,
},
opt::assert_normalized_ssa_equals,
},
ssa::{Ssa, opt::assert_normalized_ssa_equals},
};
use acvm::acir::AcirField;

#[test]
fn inline_blocks() {
// fn main {
// b0():
// jmp b1(Field 7)
// b1(v0: Field):
// jmp b2(v0)
// b2(v1: Field):
// return v1
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);

let b1 = builder.insert_block();
let b2 = builder.insert_block();

let v0 = builder.add_block_parameter(b1, Type::field());
let v1 = builder.add_block_parameter(b2, Type::field());

let expected_return = 7u128;
let seven = builder.field_constant(expected_return);
builder.terminate_with_jmp(b1, vec![seven]);

builder.switch_to_block(b1);
builder.terminate_with_jmp(b2, vec![v0]);

builder.switch_to_block(b2);
builder.terminate_with_return(vec![v1]);

let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 3);
let src = "
acir(inline) fn main f0 {
b0():
jmp b1(Field 7)
b1(v0: Field):
jmp b2(v0)
b2(v1: Field):
return v1
}
";
let ssa = Ssa::from_str(src).unwrap();

// Expected output:
// fn main {
// b0():
// return Field 7
// }
let ssa = ssa.simplify_cfg();
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

match main.dfg[main.entry_block()].terminator() {
Some(TerminatorInstruction::Return { return_values, .. }) => {
assert_eq!(return_values.len(), 1);
let return_value = main
.dfg
.get_numeric_constant(return_values[0])
.expect("Expected return value to be constant")
.to_u128();
assert_eq!(return_value, expected_return);
}
other => panic!("Unexpected terminator {other:?}"),
assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0():
return Field 7
}
");
}

#[test]
fn remove_known_jmpif() {
// fn main {
// b0(v0: u1):
// v1 = eq v0, v0
// jmpif v1, then: b1, else: b2
// b1():
// return Field 1
// b2():
// return Field 2
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);
let v0 = builder.add_parameter(Type::bool());

let b1 = builder.insert_block();
let b2 = builder.insert_block();

let one = builder.field_constant(1u128);
let two = builder.field_constant(2u128);

let v1 = builder.insert_binary(v0, BinaryOp::Eq, v0);
builder.terminate_with_jmpif(v1, b1, b2);

builder.switch_to_block(b1);
builder.terminate_with_return(vec![one]);

builder.switch_to_block(b2);
builder.terminate_with_return(vec![two]);

let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 3);
let src = "
acir(inline) fn main f0 {
b0(v0: u1):
jmpif u1 1 then: b1, else: b2
b1():
return Field 1
b2():
return Field 2
}
";
let ssa = Ssa::from_str(src).unwrap();

// Expected output:
// fn main {
// b0():
// return Field 1
// }
let ssa = ssa.simplify_cfg();
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

match main.dfg[main.entry_block()].terminator() {
Some(TerminatorInstruction::Return { return_values, .. }) => {
assert_eq!(return_values.len(), 1);
let return_value = main
.dfg
.get_numeric_constant(return_values[0])
.expect("Expected return value to be constant")
.to_u128();
assert_eq!(return_value, 1u128);
}
other => panic!("Unexpected terminator {other:?}"),
assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u1):
return Field 1
}
");
}

#[test]
Expand Down
Loading