Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,19 @@ impl DataFlowGraph {
if let Some(existing) = self.functions.get(&function) {
return *existing;
}
self.values.insert(Value::Function(function))
let result = self.values.insert(Value::Function(function));
self.functions.insert(function, result);
result
}

/// Gets or creates a ValueId for the given FunctionId.
pub(crate) fn import_foreign_function(&mut self, function: &str) -> ValueId {
if let Some(existing) = self.foreign_functions.get(function) {
return *existing;
}
self.values.insert(Value::ForeignFunction(function.to_owned()))
let result = self.values.insert(Value::ForeignFunction(function.to_owned()));
self.foreign_functions.insert(function.to_owned(), result);
result
}

/// Gets or creates a ValueId for the given Intrinsic.
Expand Down
157 changes: 111 additions & 46 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! with a non-literal target can be replaced with a call to an apply function.
//! The apply function is a dispatch function that takes the function id as a parameter
//! and dispatches to the correct target.
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::collections::{BTreeMap, BTreeSet};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -80,21 +80,54 @@ impl DefunctionalizationContext {

/// Defunctionalize a single function
fn defunctionalize(&mut self, func: &mut Function) {
let mut call_target_values = HashSet::new();

for block_id in func.reachable_blocks() {
let block = &func.dfg[block_id];
let instructions = block.instructions().to_vec();
let block = &mut func.dfg[block_id];

// Temporarily take the parameters here just to avoid cloning them
let parameters = block.take_parameters();
for parameter in &parameters {
if func.dfg.type_of_value(*parameter) == Type::Function {
func.dfg.set_type_of_value(*parameter, Type::field());
}
}

let block = &mut func.dfg[block_id];
block.set_parameters(parameters);

// Do the same for the terminator
let mut terminator = block.take_terminator();
terminator.map_values_mut(|value| map_function_to_field(func, value).unwrap_or(value));

let block = &mut func.dfg[block_id];
block.set_terminator(terminator);

for instruction_id in instructions {
let instruction = func.dfg[instruction_id].clone();
// Now we can finally change each instruction, replacing
// each first class function with a field value and replacing calls
// to a first class function to a call to the relevant `apply` function.
#[allow(clippy::unnecessary_to_owned)] // clippy is wrong here
for instruction_id in block.instructions().to_vec() {
let mut instruction = func.dfg[instruction_id].clone();
let mut replacement_instruction = None;

if remove_first_class_functions_in_instruction(func, &mut instruction) {
func.dfg[instruction_id] = instruction.clone();
}

#[allow(clippy::unnecessary_to_owned)] // clippy is wrong here
for result in func.dfg.instruction_results(instruction_id).to_vec() {
if func.dfg.type_of_value(result) == Type::Function {
func.dfg.set_type_of_value(result, Type::field());
}
}

// Operate on call instructions
let (target_func_id, arguments) = match &instruction {
Instruction::Call { func: target_func_id, arguments } => {
(*target_func_id, arguments)
}
_ => continue,
_ => {
continue;
}
};

match func.dfg[target_func_id] {
Expand All @@ -116,43 +149,15 @@ impl DefunctionalizationContext {
arguments.insert(0, target_func_id);
}
let func = apply_function_value_id;
call_target_values.insert(func);

replacement_instruction = Some(Instruction::Call { func, arguments });
}
Value::Function(..) => {
call_target_values.insert(target_func_id);
}
_ => {}
}
if let Some(new_instruction) = replacement_instruction {
func.dfg[instruction_id] = new_instruction;
}
}
}

// Change the type of all the values that are not call targets to NativeField
let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id);
for value_id in value_ids {
if let Type::Function = func.dfg[value_id].get_type().as_ref() {
match &func.dfg[value_id] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
if !call_target_values.contains(&value_id) {
let field = NumericType::NativeField;
let new_value =
func.dfg.make_constant(function_id_to_field(*id), field);
func.dfg.set_value_from_id(value_id, new_value);
}
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value_id, Type::field());
}
_ => {}
}
}
}
}

/// Returns the apply function for the given signature
Expand All @@ -161,6 +166,54 @@ impl DefunctionalizationContext {
}
}

/// Replace any first class functions used in an instruction with a field value.
/// This applies to any function used anywhere else other than the function position
/// of a call instruction. Returns true if the instruction was modified
fn remove_first_class_functions_in_instruction(
func: &mut Function,
instruction: &mut Instruction,
) -> bool {
let mut modified = false;
let mut map_value = |value: ValueId| {
if let Some(new_value) = map_function_to_field(func, value) {
modified = true;
new_value
} else {
value
}
};

if let Instruction::Call { func: _, arguments } = instruction {
for arg in arguments {
*arg = map_value(*arg);
}
} else {
instruction.map_values_mut(map_value);
}

modified
}

/// Try to map the given function literal to a field, returning Some(field) on success.
/// Returns none if the given value was not a function or doesn't need to be mapped.
fn map_function_to_field(func: &mut Function, value: ValueId) -> Option<ValueId> {
if let Type::Function = func.dfg[value].get_type().as_ref() {
match &func.dfg[value] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
let new_value = function_id_to_field(*id);
return Some(func.dfg.make_constant(new_value, NumericType::NativeField));
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value, Type::field());
}
_ => (),
}
}
None
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
Expand Down Expand Up @@ -252,13 +305,25 @@ fn create_apply_functions(
variants_map: BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>,
) -> ApplyFunctions {
let mut apply_functions = HashMap::default();
for ((signature, runtime), variants) in variants_map.into_iter() {
for ((mut signature, runtime), variants) in variants_map.into_iter() {
assert!(
!variants.is_empty(),
"ICE: at least one variant should exist for a dynamic call {signature:?}"
);
let dispatches_to_multiple_functions = variants.len() > 1;

for param in &mut signature.params {
if *param == Type::Function {
*param = Type::field();
}
}

for ret in &mut signature.returns {
if *ret == Type::Function {
*ret = Type::field();
}
}

let id = if dispatches_to_multiple_functions {
create_apply_function(ssa, signature.clone(), runtime, variants)
} else {
Expand All @@ -282,7 +347,7 @@ fn create_apply_function(
function_ids: Vec<FunctionId>,
) -> FunctionId {
assert!(!function_ids.is_empty());
let globals = ssa.functions[&function_ids[0]].dfg.globals.clone();
let globals = ssa.main().dfg.globals.clone();
ssa.add_fn(|id| {
let mut function_builder = FunctionBuilder::new("apply".to_string(), id);
function_builder.set_globals(globals);
Expand Down Expand Up @@ -386,10 +451,10 @@ mod tests {
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
v8 = call f1(f3, v0) -> u32
v9 = add v0, u32 1
v10 = eq v8, v9
constrain v8 == v9
return
}
brillig(inline) fn wrapper f1 {
Expand Down Expand Up @@ -419,10 +484,10 @@ mod tests {
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(Field 3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
v8 = call f1(Field 3, v0) -> u32
v9 = add v0, u32 1
v10 = eq v8, v9
constrain v8 == v9
return
}
brillig(inline) fn wrapper f1 {
Expand Down
8 changes: 4 additions & 4 deletions compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ mod test {
v21 = add v1, v2
v23 = array_set v19, index v21, value Field 128
call f1(v23)
v25 = add v2, u32 1
jmp b1(v25)
v24 = add v2, u32 1
jmp b1(v24)
}
brillig(inline) fn foo f1 {
b0(v0: [Field; 5]):
Expand Down Expand Up @@ -685,8 +685,8 @@ mod test {
v21 = add v1, v2
v23 = array_set v14, index v21, value Field 128
call f1(v23)
v25 = add v2, u32 1
jmp b1(v25)
v24 = add v2, u32 1
jmp b1(v24)
}
brillig(inline) fn foo f1 {
b0(v0: [Field; 5]):
Expand Down
Loading