diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index c3e27f6a070..0f2c2a14abf 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -101,12 +101,17 @@ impl Ssa { // Run defunctionalization over all functions in the SSA context.defunctionalize_all(&mut self); + + // Check that we have established the properties expected from this pass. + #[cfg(debug_assertions)] + self.functions.values().for_each(defunctionalize_post_check); + self } } impl DefunctionalizationContext { - /// Defunctionalize all functions in the Ssa + /// Defunctionalize all functions in the SSA fn defunctionalize_all(mut self, ssa: &mut Ssa) { for function in ssa.functions.values_mut() { self.defunctionalize(function); @@ -534,6 +539,31 @@ fn build_return_block(builder: &mut FunctionBuilder, passed_types: &[Type]) -> B return_block } +/// Check post-execution properties: +/// * All blocks which took function parameters should receive a discriminator instead +#[cfg(debug_assertions)] +fn defunctionalize_post_check(func: &Function) { + fn is_function(typ: &Type) -> bool { + match typ { + Type::Function => true, + Type::Reference(typ) => is_function(typ), + _ => false, + } + } + for block_id in func.reachable_blocks() { + for param in func.dfg[block_id].parameters() { + let value = &func.dfg[*param]; + let Value::Param { typ, .. } = value else { + panic!("unexpected parameter value: {value:?}"); + }; + assert!( + !is_function(typ), + "Blocks are not expected to take function parameters any more." + ); + } + } +} + #[cfg(test)] mod tests { use crate::assert_ssa_snapshot; @@ -689,4 +719,89 @@ mod tests { assert!(applies.iter().any(|f| f.runtime().is_acir())); assert!(applies.iter().any(|f| f.runtime().is_brillig())); } + + #[test] + fn apply_created_for_stored_functions() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + v1 = allocate -> &mut function + store f1 at v1 + jmpif v0 then: b1, else: b2 + b1(): + store f2 at v1 + jmp b2() + b2(): + v4 = load v1 -> function + v6 = call f3(v4) -> u32 + return v6 + } + acir(inline) fn foo f1 { + b0(): + return u32 1 + } + acir(inline) fn bar f2 { + b0(): + return u32 2 + } + acir(inline) fn caller f3 { + b0(v0: function): + v1 = call v0() -> u32 + v2 = call v0() -> u32 + v3 = add v1, v2 + return v3 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + assert_ssa_snapshot!( + ssa, + @r" + acir(inline) fn main f0 { + b0(v0: u1): + v1 = allocate -> &mut function + store Field 1 at v1 + jmpif v0 then: b1, else: b2 + b1(): + store Field 2 at v1 + jmp b2() + b2(): + v4 = load v1 -> Field + v6 = call f3(v4) -> u32 + return v6 + } + acir(inline) fn foo f1 { + b0(): + return u32 1 + } + acir(inline) fn bar f2 { + b0(): + return u32 2 + } + acir(inline) fn caller f3 { + b0(v0: Field): + v2 = call f4(v0) -> u32 + v3 = call f4(v0) -> u32 + v4 = add v2, v3 + return v4 + } + acir(inline_always) fn apply f4 { + b0(v0: Field): + v3 = eq v0, Field 1 + jmpif v3 then: b3, else: b2 + b1(v1: u32): + return v1 + b2(): + constrain v0 == Field 2 + v8 = call f2() -> u32 + jmp b1(v8) + b3(): + v5 = call f1() -> u32 + jmp b1(v5) + } + " + ); + } }