Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 134 additions & 20 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ impl DefunctionalizationContext {
// Temporarily take the parameters here just to avoid cloning them
let parameters = block.take_parameters();
for parameter in &parameters {
if func.dfg.type_of_value(*parameter) == Type::Function {
func.dfg.set_type_of_value(*parameter, Type::field());
let typ = &func.dfg.type_of_value(*parameter);
if is_function_type(typ) {
func.dfg.set_type_of_value(*parameter, replacement_type(typ));
}
}

Expand Down Expand Up @@ -158,8 +159,9 @@ impl DefunctionalizationContext {

#[allow(clippy::unnecessary_to_owned)] // clippy is wrong here
for result in func.dfg.instruction_results(instruction_id).to_vec() {
if func.dfg.type_of_value(result) == Type::Function {
func.dfg.set_type_of_value(result, Type::field());
let typ = &func.dfg.type_of_value(result);
if is_function_type(typ) {
func.dfg.set_type_of_value(result, replacement_type(typ));
}
}

Expand Down Expand Up @@ -275,7 +277,8 @@ fn remove_first_class_functions_in_instruction(
/// Try to map the given function literal to a field, returning Some(field) on success.
/// Returns none if the given value was not a function or doesn't need to be mapped.
fn map_function_to_field(func: &mut Function, value: ValueId) -> Option<ValueId> {
if let Type::Function = func.dfg[value].get_type().as_ref() {
let typ = func.dfg[value].get_type();
if is_function_type(typ.as_ref()) {
match &func.dfg[value] {
// If the value is a static function, transform it to the function id
Value::Function(id) => {
Expand All @@ -284,7 +287,7 @@ fn map_function_to_field(func: &mut Function, value: ValueId) -> Option<ValueId>
}
// If the value is a function used as value, just change the type of it
Value::Instruction { .. } | Value::Param { .. } => {
func.dfg.set_type_of_value(value, Type::field());
func.dfg.set_type_of_value(value, replacement_type(typ.as_ref()));
}
_ => (),
}
Expand Down Expand Up @@ -426,15 +429,15 @@ fn create_apply_functions(ssa: &mut Ssa, variants_map: Variants) -> ApplyFunctio
// Update the shared function signature of the higher-order function variants
// to replace any function passed as a value to a numeric field type.
for param in &mut signature.params {
if *param == Type::Function {
*param = Type::field();
if is_function_type(param) {
*param = replacement_type(param);
}
}

// Update the return value types as we did for the signature parameters above.
for ret in &mut signature.returns {
if *ret == Type::Function {
*ret = Type::field();
if is_function_type(ret) {
*ret = replacement_type(ret);
}
}

Expand Down Expand Up @@ -614,27 +617,38 @@ fn create_apply_function(
/// * All blocks which took function parameters should receive a discriminator instead
#[cfg(debug_assertions)]
fn defunctionalize_post_check(func: &Function) {
fn is_function(typ: &Type) -> bool {
match typ {
Type::Function => true,
Type::Reference(typ) => is_function(typ),
_ => false,
}
}
for block_id in func.reachable_blocks() {
for param in func.dfg[block_id].parameters() {
let value = &func.dfg[*param];
let Value::Param { typ, .. } = value else {
panic!("unexpected parameter value: {value:?}");
};
assert!(
!is_function(typ),
"Blocks are not expected to take function parameters any more."
!is_function_type(typ),
"Blocks are not expected to take function parameters any more. Got '{typ}' in param {param} of block {block_id} in function {} {}",
func.name(),
func.id()
);
}
}
}

fn is_function_type(typ: &Type) -> bool {
match typ {
Type::Function => true,
Type::Reference(typ) => is_function_type(typ),
_ => false,
}
}

fn replacement_type(typ: &Type) -> Type {
if matches!(typ, Type::Reference(_)) {
Type::Reference(Arc::new(Type::field()))
} else {
Type::field()
}
}

#[cfg(test)]
mod tests {
use crate::assert_ssa_snapshot;
Expand Down Expand Up @@ -842,7 +856,7 @@ mod tests {
@r"
acir(inline) fn main f0 {
b0(v0: u1):
v1 = allocate -> &mut function
v1 = allocate -> &mut Field
store Field 1 at v1
jmpif v0 then: b1, else: b2
b1():
Expand Down Expand Up @@ -1188,4 +1202,104 @@ mod tests {
}
"#);
}

#[test]
fn mut_ref_function() {
let src = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut function
store f1 at v0
v3 = call f2(v0) -> u1
return v3
}
acir(inline) fn bar f1 {
b0():
return u1 0
}
acir(inline) fn foo f2 {
b0(v0: &mut function):
v1 = load v0 -> function
v2 = call v1() -> u1
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 1 at v0
v3 = call f2(v0) -> u1
return v3
}
acir(inline) fn bar f1 {
b0():
return u1 0
}
acir(inline) fn foo f2 {
b0(v0: &mut Field):
v1 = load v0 -> Field
v3 = call f1() -> u1
return v3
}
");
}

#[test]
fn mut_ref_function_matching() {
let src = "
brillig(inline) fn add_to_tally_public f0 {
b0():
v4 = allocate -> &mut function
store f2 at v4
v10 = call f10(v4, f33) -> Field
return
}
brillig(inline) fn lambda f2 {
b0():
return Field 1
}
brillig(inline) fn at f10 {
b0(v4: &mut function, v6: function):
v10 = call v6(v4) -> Field
return v10
}
brillig(inline) fn lambda f33 {
b0(v4: &mut function):
v10 = call v4() -> Field
return v10
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

assert_ssa_snapshot!(ssa, @r"
brillig(inline) fn add_to_tally_public f0 {
b0():
v0 = allocate -> &mut Field
store Field 1 at v0
v4 = call f2(v0, Field 3) -> Field
return
}
brillig(inline) fn lambda f1 {
b0():
return Field 1
}
brillig(inline) fn at f2 {
b0(v0: &mut Field, v1: Field):
v3 = call f3(v0) -> Field
return v3
}
brillig(inline) fn lambda f3 {
b0(v0: &mut Field):
v2 = call f1() -> Field
return v2
}
");
}
}
7 changes: 7 additions & 0 deletions test_programs/execution_success/regression_8662/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_8662"
version = "0.1.0"
type = "bin"
authors = [""]

[dependencies]
2 changes: 2 additions & 0 deletions test_programs/execution_success/regression_8662/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
c = true
return = true
22 changes: 22 additions & 0 deletions test_programs/execution_success/regression_8662/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
struct Context {
f: fn() -> bool,
}
fn main(c: bool) -> pub bool {
let mut ctx = Context { f: bar };
if c {
ctx.f = qux;
}
foo(&mut ctx)
}

fn foo(ctx: &mut Context) -> bool {
(ctx.f)()
}

fn bar() -> bool {
false
}

fn qux() -> bool {
true
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading