Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 165 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
//!
//! After this pass all first-class functions are replaced with numeric IDs
//! and calls are routed via the newly generated `apply` functions.
use std::collections::{BTreeMap, BTreeSet};
use std::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -235,6 +238,33 @@ fn remove_first_class_functions_in_instruction(
for arg in arguments {
*arg = map_value(*arg);
}
} else if let Instruction::MakeArray { typ, .. } = instruction {
match typ {
Type::Array(element_types, len) => {
let new_element_types =
element_types
.iter()
.map(|typ| {
if matches!(typ, Type::Function) { Type::field() } else { typ.clone() }
})
.collect::<Vec<_>>();
*typ = Type::Array(Arc::new(new_element_types), *len);
}
Type::Slice(element_types) => {
let new_element_types =
element_types
.iter()
.map(|typ| {
if matches!(typ, Type::Function) { Type::field() } else { typ.clone() }
})
.collect::<Vec<_>>();
*typ = Type::Slice(Arc::new(new_element_types));
}
_ => {}
}
instruction.map_values_mut(map_value);

modified = true;
} else {
instruction.map_values_mut(map_value);
}
Expand Down Expand Up @@ -328,6 +358,9 @@ fn find_functions_as_values(func: &Function) -> BTreeSet<FunctionId> {
Instruction::Store { value, .. } => {
process_value(*value);
}
Instruction::MakeArray { elements, .. } => {
elements.iter().for_each(|element| process_value(*element));
}
_ => continue,
};
}
Expand Down Expand Up @@ -1024,4 +1057,135 @@ mod tests {
}
");
}

#[test]
fn fn_in_array() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: u32):
v5 = make_array [f1, f2, f3, f4] : [function; 4]
v7 = lt v0, u32 4
constrain v7 == u1 1, "Index out of bounds"
v9 = array_get v5, index v0 -> function
call v9()
return
}
acir(inline) fn lambda f1 {
b0():
return
}
acir(inline) fn lambda f2 {
b0():
return
}
acir(inline) fn lambda f3 {
b0():
return
}
acir(inline) fn lambda f4 {
b0():
return
}
"#;

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u32):
v5 = make_array [Field 1, Field 2, Field 3, Field 4] : [Field; 4]
v7 = lt v0, u32 4
constrain v7 == u1 1, "Index out of bounds"
v9 = array_get v5, index v0 -> Field
call f5(v9)
return
}
acir(inline) fn lambda f1 {
b0():
return
}
acir(inline) fn lambda f2 {
b0():
return
}
acir(inline) fn lambda f3 {
b0():
return
}
acir(inline) fn lambda f4 {
b0():
return
}
acir(inline_always) fn apply f5 {
b0(v0: Field):
v2 = eq v0, Field 1
jmpif v2 then: b3, else: b2
b1():
return
b2():
v5 = eq v0, Field 2
jmpif v5 then: b6, else: b5
b3():
call f1()
jmp b4()
b4():
jmp b14()
b5():
v8 = eq v0, Field 3
jmpif v8 then: b9, else: b8
b6():
call f2()
jmp b7()
b7():
jmp b13()
b8():
constrain v0 == Field 4
call f4()
jmp b11()
b9():
call f3()
jmp b10()
b10():
jmp b12()
b11():
jmp b12()
b12():
jmp b13()
b13():
jmp b14()
b14():
jmp b1()
}
"#);
}

#[test]
fn empty_make_array_updates_type() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: u32):
v1 = make_array [] : [function; 0]
constrain u1 0 == u1 1, "Index out of bounds"
v5 = array_get v1, index u32 0 -> function
call v5()
return
}
"#;

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

// Guarantee that we still accurately modify the make_array instruction type for an empty array
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u32):
v1 = make_array [] : [Field; 0]
constrain u1 0 == u1 1, "Index out of bounds"
v5 = array_get v1, index u32 0 -> Field
call v5()
return
}
"#);
}
}
36 changes: 35 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ fn used_functions(func: &Function) -> BTreeSet<FunctionId> {
for instruction_id in block.instructions() {
let instruction = &func.dfg[*instruction_id];

if matches!(instruction, Instruction::Store { .. } | Instruction::Call { .. }) {
if matches!(
instruction,
Instruction::Store { .. }
| Instruction::Call { .. }
| Instruction::MakeArray { .. }
) {
instruction.for_each_value(&mut find_functions);
}
}
Expand Down Expand Up @@ -177,4 +182,33 @@ mod tests {
// It should not remove anything.
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn keep_functions_used_in_array() {
// f1 and f2 are used within an array. Thus, we do not want to remove them.
let src = r#"
acir(inline) fn main f0 {
b0(v0: u32):
v5 = make_array [f1, f2] : [function; 2]
v7 = lt v0, u32 4
constrain v7 == u1 1, "Index out of bounds"
v9 = array_get v5, index v0 -> function
call v9()
return
}
acir(inline) fn lambda f1 {
b0():
return
}
acir(inline) fn lambda f2 {
b0():
return
}
"#;

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.remove_unreachable_functions();

assert_normalized_ssa_equals(ssa, src);
}
}
6 changes: 6 additions & 0 deletions test_programs/execution_success/lambda_from_array/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "lambda_from_array"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = 1
70 changes: 70 additions & 0 deletions test_programs/execution_success/lambda_from_array/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Many parts of the code here are regressions from issue #5503 (https://github.com/noir-lang/noir/issues/5503)
fn main(x: u32) {
lambdas_in_array_literal(x - 1);
lambdas_in_array_literal(x);
lambdas_in_array_literal(x + 2);
lambdas_in_array_literal(x + 1);

lambdas_in_slice_literal(x - 1);
lambdas_in_slice_literal(x);
lambdas_in_slice_literal(x + 1);
lambdas_in_slice_literal(x + 2);

functions_in_array_literal(x - 1);
functions_in_array_literal(x);
functions_in_slice_literal(x - 1);
functions_in_slice_literal(x);

let example_lambda: fn(u8) -> u8 = |x| x + 1;
let lambdas: [fn(u8) -> u8; 8] = [example_lambda; 8];
println(lambdas[0](5));
// Dynamic dispatch
println(lambdas[x - 1](5));

let lambdas: [fn(()) -> (); 1] = [|_: ()| {}];
lambdas[0](());
lambdas[x - 1](());

// Also check against slices
let lambdas: [fn(()) -> ()] = &[|_: ()| {}];
lambdas[0](());
lambdas[x - 1](());

// Still panics when there are no other lambdas
// This should fail either way as we are attempting to access an empty array at zero
// let lambdas: [fn(()) -> (); 0] = [];
// lambdas[0](());
}

fn lambdas_in_array_literal(x: u32) {
let xs = [|| println("hi"), || println("bye"), || println("wow"), || println("big")];
(xs[x])();
}

fn lambdas_in_slice_literal(x: u32) {
let xs = &[|| println("hi"), || println("bye"), || println("big"), || println("wow")];
(xs[x])();
}

fn functions_in_array_literal(x: u32) {
let xs = [foo, bar];
(xs[x])();
}

fn functions_in_slice_literal(x: u32) {
let xs = &[baz, qux];
(xs[x])();
}

fn foo() {
println("hi");
}
fn bar() {
println("bye");
}
fn baz() {
println("hi");
}
fn qux() {
println("bye");
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading