diff --git a/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 59e5730d12e..3c19b322a00 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -37,6 +37,9 @@ impl Display for Ssa { Value::Global(_) => { panic!("Value::Global should only be in the function dfg"); } + Value::Function(id) => { + writeln!(f, "{}", id)?; + } _ => panic!("Expected only numeric constant or instruction"), }; } diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index 5384780e5ff..ba978654346 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -297,6 +297,18 @@ impl Type { } } } + + /// True if this is a function type or if it is a composite type which contains a function. + pub(crate) fn contains_function(&self) -> bool { + match self { + Type::Reference(element_type) => element_type.contains_function(), + Type::Function => true, + Type::Numeric(_) => false, + Type::Array(elements, _) | Type::Slice(elements) => { + elements.iter().any(|elem| elem.contains_function()) + } + } + } } /// Composite Types are essentially flattened struct or tuple types. diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index db4f6f24720..18c2e8a0a54 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -454,7 +454,14 @@ impl Translator { .globals_function .dfg .make_constant(constant.value, constant.typ.unwrap_numeric()), - ParsedValue::Variable(identifier) => self.lookup_global(identifier)?, + ParsedValue::Variable(identifier) => { + match self.lookup_global(identifier.clone()) { + Ok(global) => global, + Err(lookup_global_err) => self + .lookup_call_function(identifier) + .map_err(|_| lookup_global_err)?, + } + } }; elements.push_back(element_id); } diff --git a/compiler/noirc_evaluator/src/ssa/parser/tests.rs b/compiler/noirc_evaluator/src/ssa/parser/tests.rs index 0ac080991b2..cf68edd31b5 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/tests.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/tests.rs @@ -823,3 +823,52 @@ fn parses_variable_from_a_syntantically_following_block_but_logically_preceding_ "; assert_ssa_roundtrip(src); } + +#[test] +fn function_pointer_in_global_array() { + let src = " + g2 = make_array [f1, f2] : [function; 2] + acir(inline) fn main f0 { + b0(v3: u32, v4: Field): + v6 = call f1() -> Field + v8 = call f2() -> Field + v10 = lt v3, u32 2 + constrain v10 == u1 1 + v12 = array_get g2, index v3 -> function + v13 = call v12() -> Field + v14 = eq v13, v4 + constrain v13 == v4 + return + } + acir(inline) fn f1 f1 { + b0(): + return Field 1 + } + acir(inline) fn f2 f2 { + b0(): + return Field 2 + } + "; + let _ = Ssa::from_str_no_validation(src).unwrap(); +} + +#[test] +#[should_panic(expected = "Unknown global")] +fn unknown_function_global_function_pointer() { + let src = " + g2 = make_array [f1, f2] : [function; 2] + acir(inline) fn main f0 { + b0(v3: u32, v4: Field): + v6 = call f1() -> Field + v8 = call f2() -> Field + v10 = lt v3, u32 2 + constrain v10 == u1 1 + v12 = array_get g2, index v3 -> function + v13 = call v12() -> Field + v14 = eq v13, v4 + constrain v13 == v4 + return + } + "; + let _ = Ssa::from_str_no_validation(src).unwrap(); +} diff --git a/compiler/noirc_evaluator/src/ssa/validation/mod.rs b/compiler/noirc_evaluator/src/ssa/validation/mod.rs index 99d657d6f82..8299cd2c81a 100644 --- a/compiler/noirc_evaluator/src/ssa/validation/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/validation/mod.rs @@ -327,7 +327,18 @@ impl<'f> Validator<'f> { } } + fn type_check_globals(&self) { + let globals = (*self.function.dfg.globals).clone(); + for (_, global) in globals.values_iter() { + let global_typ = global.get_type(); + if global_typ.contains_function() { + panic!("Globals cannot contain function pointers"); + } + } + } + fn run(&mut self) { + self.type_check_globals(); self.validate_single_return_block(); for block in self.function.reachable_blocks() { @@ -367,7 +378,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -382,7 +393,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -396,7 +407,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -410,7 +421,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -425,7 +436,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -440,7 +451,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -454,7 +465,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -468,7 +479,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -482,7 +493,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -496,7 +507,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -514,7 +525,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -530,7 +541,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -545,7 +556,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -561,7 +572,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -577,7 +588,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -593,7 +604,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -609,7 +620,7 @@ mod tests { return v4 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -623,7 +634,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -638,7 +649,7 @@ mod tests { } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -650,7 +661,7 @@ mod tests { return v1 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -663,7 +674,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -678,7 +689,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -691,7 +702,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -704,7 +715,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -724,7 +735,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -744,7 +755,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -752,11 +763,11 @@ mod tests { let src = " brillig(inline) predicate_pure fn main f0 { b0(v0: Field): - v1 = call to_le_bytes(v0, u32 256) -> [u8; 1] + v1 = call to_le_radix(v0, u32 256) -> [u8; 1] return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -768,7 +779,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[should_panic( @@ -783,7 +794,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[should_panic( @@ -798,7 +809,7 @@ mod tests { return } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -810,7 +821,7 @@ mod tests { return v0 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -823,7 +834,7 @@ mod tests { return v1 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -836,7 +847,7 @@ mod tests { return v1 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -849,7 +860,7 @@ mod tests { return v0 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -863,7 +874,7 @@ mod tests { return v1 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); } #[test] @@ -877,6 +888,36 @@ mod tests { return v1 } "; - let _ = Ssa::from_str(src); + let _ = Ssa::from_str(src).unwrap(); + } + + #[test] + #[should_panic(expected = "Globals cannot contain function pointers")] + fn function_pointer_in_global_array() { + let src = " + g2 = make_array [f1, f2] : [function; 2] + + acir(inline) fn main f0 { + b0(v3: u32, v4: Field): + v6 = call f1() -> Field + v8 = call f2() -> Field + v10 = lt v3, u32 2 + constrain v10 == u1 1 + v12 = array_get g2, index v3 -> function + v13 = call v12() -> Field + v14 = eq v13, v4 + constrain v13 == v4 + return + } + acir(inline) fn f1 f1 { + b0(): + return Field 1 + } + acir(inline) fn f2 f2 { + b0(): + return Field 2 + } + "; + let _ = Ssa::from_str(src).unwrap(); } } diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 05b53a783ff..e033876096a 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -586,8 +586,7 @@ impl Value { Value::Array(values, _) => { values.iter().any(|value| value.contains_function_or_closure()) } - Value::Function(..) => true, - Value::Closure(..) => true, + Value::Function(..) | Value::Closure(..) => true, _ => false, } }