diff --git a/compiler/noirc_frontend/src/elaborator/lints.rs b/compiler/noirc_frontend/src/elaborator/lints.rs index 532d81ed50a..346d80789b7 100644 --- a/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/compiler/noirc_frontend/src/elaborator/lints.rs @@ -175,27 +175,6 @@ pub(super) fn oracle_returns_multiple_slices( } } -/// Oracle functions may not be called by constrained functions directly. -/// -/// In order for a constrained function to call an oracle it must first call through an unconstrained function. -pub(super) fn oracle_called_from_constrained_function( - interner: &NodeInterner, - called_func: &FuncId, - calling_from_constrained_runtime: bool, - location: Location, -) -> Option { - if !calling_from_constrained_runtime { - return None; - } - - let function_attributes = interner.function_attributes(called_func); - if function_attributes.function()?.kind.is_oracle() { - Some(ResolverError::UnconstrainedOracleReturnToConstrained { location }) - } else { - None - } -} - /// `pub` is required on return types for entry point functions pub(super) fn missing_pub(func: &FuncMeta, modifiers: &FunctionModifiers) -> Option { if func.is_entry_point diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 7f5b9b947ec..d0f79f449e6 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -2334,23 +2334,6 @@ impl Elaborator<'_> { UnsafeBlockStatus::InUnsafeBlockWithUnconstrainedCalls => (), } - // Check whether we are trying to call an oracle directly from ACIR. - // Indirect calls (going through some variable) are okay because we - // already wrap them into proxy functions. Eventually we will wrap - // everything, and then we won't need this lint any more. - if let Some(called_func_id) = self.interner.lookup_function_from_expr(&call.func, false) - { - self.run_lint(|elaborator| { - lints::oracle_called_from_constrained_function( - elaborator.interner, - &called_func_id, - is_current_func_constrained, - location, - ) - .map(Into::into) - }); - } - let errors = lints::unconstrained_function_args(&args); self.push_errors(errors); } @@ -2368,7 +2351,7 @@ impl Elaborator<'_> { /// Check if the callee is an unconstrained function, or a variable referring to one. fn is_unconstrained_call(&self, expr: ExprId) -> bool { - if let Some(func_id) = self.interner.lookup_function_from_expr(&expr, true) { + if let Some(func_id) = self.interner.lookup_function_from_expr(&expr) { let modifiers = self.interner.function_modifiers(&func_id); modifiers.is_unconstrained } else { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 95d8bfc17b3..fb842110a92 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -1408,7 +1408,7 @@ fn typed_expr_as_function_definition( let self_argument = check_one_argument(arguments, location)?; let typed_expr = get_typed_expr(self_argument)?; let option_value = if let TypedExpr::ExprId(expr_id) = typed_expr { - let func_id = interner.lookup_function_from_expr(&expr_id, true); + let func_id = interner.lookup_function_from_expr(&expr_id); func_id.map(Value::FunctionDefinition) } else { None diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 125b5ced631..dd8c3e54da1 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -74,8 +74,6 @@ pub enum ResolverError { OracleMarkedAsConstrained { ident: Ident, location: Location }, #[error("Oracle functions cannot return multiple slices")] OracleReturnsMultipleSlices { location: Location }, - #[error("Oracle functions cannot be called directly from constrained functions")] - UnconstrainedOracleReturnToConstrained { location: Location }, #[error("Dependency cycle found, '{item}' recursively depends on itself: {cycle} ")] DependencyCycle { location: Location, item: String, cycle: String }, #[error("break/continue are only allowed in unconstrained functions")] @@ -224,7 +222,6 @@ impl ResolverError { | ResolverError::InvalidClosureEnvironment { location, .. } | ResolverError::NestedSlices { location } | ResolverError::AbiAttributeOutsideContract { location } - | ResolverError::UnconstrainedOracleReturnToConstrained { location } | ResolverError::DependencyCycle { location, .. } | ResolverError::JumpInConstrainedFn { location, .. } | ResolverError::LoopInConstrainedFn { location } @@ -487,11 +484,6 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *location, ) }, - ResolverError::UnconstrainedOracleReturnToConstrained { location } => Diagnostic::simple_error( - error.to_string(), - "This oracle call must be wrapped in a call to another unconstrained function before being returned to a constrained runtime".into(), - *location, - ), ResolverError::DependencyCycle { location, item, cycle } => { Diagnostic::simple_error( "Dependency cycle found".into(), diff --git a/compiler/noirc_frontend/src/monomorphization/proxies.rs b/compiler/noirc_frontend/src/monomorphization/proxies.rs index 2d116e2d4cf..52cf4719eb6 100644 --- a/compiler/noirc_frontend/src/monomorphization/proxies.rs +++ b/compiler/noirc_frontend/src/monomorphization/proxies.rs @@ -11,6 +11,10 @@ //! without actually being the target of a [`Call`](crate::monomorphization::ast::Expression::Call), //! and replace them with a normal function, which will preserve the information we need create //! dispatch functions for them in the `defunctionalize` pass. +//! +//! The pass also automatically wraps direct calls to oracle functions from constrained functions, +//! which, after creating wrapper for function values, would only present an inconvenience for users +//! if they have to keep creating wrappers themselves. use std::collections::HashMap; @@ -40,6 +44,7 @@ impl Program { // Replace foreign function identifier definitions with proxy function IDs. for function in self.functions.iter_mut() { + context.in_unconstrained = function.unconstrained; context.visit_expr(&mut function.body); } @@ -52,13 +57,19 @@ impl Program { struct ProxyContext { next_func_id: u32, + in_unconstrained: bool, replacements: HashMap<(Definition, /*unconstrained*/ bool), FuncId>, proxies: Vec<(FuncId, (Ident, /*unconstrained*/ bool))>, } impl ProxyContext { fn new(next_func_id: u32) -> Self { - Self { next_func_id, replacements: HashMap::new(), proxies: Vec::new() } + Self { + next_func_id, + in_unconstrained: false, + replacements: HashMap::new(), + proxies: Vec::new(), + } } fn next_func_id(&mut self) -> FuncId { @@ -73,8 +84,24 @@ impl ProxyContext { fn visit_expr(&mut self, expr: &mut Expression) { visit_expr_mut(expr, &mut |expr| { // Note that if we see a function in `Call::func` then it will be an `Ident`, not a `Tuple`, - // even though its `Ident::typ` will be a `Tuple([Function, Function])`, but since we only - // handle tuples, we don't have to skip the `Call::func` to leave it in tact. + // even though its `Ident::typ` will be a `Tuple([Function, Function])`. + + // If this is a direct from ACIR to an Oracle, we want to create a proxy. + if !self.in_unconstrained { + if let Expression::Call(Call { func, arguments, return_type: _, location: _ }) = + expr + { + if let Expression::Ident(ident) = func.as_mut() { + if matches!(ident.definition, Definition::Oracle(_)) { + self.redirect_to_proxy(ident, true); + for arg in arguments { + self.visit_expr(arg); + } + return false; + } + } + } + } // If this is a foreign function value, we want to replace it with proxies. let Some(mut pair) = ForeignFunctionValue::try_from(expr) else { @@ -82,30 +109,36 @@ impl ProxyContext { }; // Create a separate proxy for the constrained and unconstrained version. - pair.for_each(|ident, mut unconstrained| { - // If we are calling an oracle, there is no reason to create an unconstrained proxy, - // since such a call would be rejected by the SSA validation. - unconstrained |= matches!(ident.definition, Definition::Oracle(_)); - - let key = (ident.definition.clone(), unconstrained); - - let proxy_id = match self.replacements.get(&key) { - Some(id) => *id, - None => { - let func_id = self.next_func_id(); - self.replacements.insert(key, func_id); - self.proxies.push((func_id, (ident.clone(), unconstrained))); - func_id - } - }; - - ident.definition = Definition::Function(proxy_id); + pair.for_each(|ident, unconstrained| { + self.redirect_to_proxy(ident, unconstrained); }); true }); } + /// Get or create a replacement proxy for the function definition in the [Ident], + /// and replace the definition with the ID of the new global proxy function. + fn redirect_to_proxy(&mut self, ident: &mut Ident, mut unconstrained: bool) { + // If we are calling an oracle, there is no reason to create an unconstrained proxy, + // since such a call would be rejected by the SSA validation. + unconstrained |= matches!(ident.definition, Definition::Oracle(_)); + + let key = (ident.definition.clone(), unconstrained); + + let proxy_id = match self.replacements.get(&key) { + Some(id) => *id, + None => { + let func_id = self.next_func_id(); + self.replacements.insert(key, func_id); + self.proxies.push((func_id, (ident.clone(), unconstrained))); + func_id + } + }; + + ident.definition = Definition::Function(proxy_id); + } + /// Create proxy functions for the foreign function values we discovered. fn into_proxies(self) -> impl IntoIterator { self.proxies @@ -259,7 +292,36 @@ mod tests { }; #[test] - fn creates_proxies_for_oracle() { + fn creates_proxies_for_acir_to_oracle_calls() { + let src = " + fn main() { + // safety: still needed as the bar_proxy is unconstrained + unsafe { + bar(0); + } + } + + #[oracle(my_oracle)] + unconstrained fn bar(f: Field) { + } + "; + + let program = get_monomorphized_no_emit_test(src).unwrap(); + insta::assert_snapshot!(program, @r" + fn main$f0() -> () { + { + bar$f1(0); + } + } + #[inline_always] + unconstrained fn bar_proxy$f1(p0$l0: Field) -> () { + bar$my_oracle(p0$l0) + } + "); + } + + #[test] + fn creates_proxies_for_oracle_values() { let src = " unconstrained fn main() { foo(bar); @@ -290,7 +352,7 @@ mod tests { } #[test] - fn creates_proxies_for_builtin() { + fn creates_proxies_for_builtin_values() { let src = " unconstrained fn main() { foo(bar); diff --git a/compiler/noirc_frontend/src/node_interner/function.rs b/compiler/noirc_frontend/src/node_interner/function.rs index e102d68c30b..be6251b2e62 100644 --- a/compiler/noirc_frontend/src/node_interner/function.rs +++ b/compiler/noirc_frontend/src/node_interner/function.rs @@ -112,29 +112,24 @@ impl NodeInterner { /// Returns the [`FuncId`] corresponding to the function referred to by `expr_id`, /// _iff_ the expression is an [HirExpression::Ident] with a `Function` definition, - /// or if `follow_redirects` is `true`, then an immutable `Local` or `Global`, - /// ultimately pointing at a `Function`. + /// or an immutable `Local` or `Global` definition which ultimately points at a `Function`. /// /// Returns `None` for all other cases (tuples, array, mutable variables, etc.). - pub(crate) fn lookup_function_from_expr( - &self, - expr: &ExprId, - follow_redirects: bool, - ) -> Option { + pub(crate) fn lookup_function_from_expr(&self, expr: &ExprId) -> Option { if let HirExpression::Ident(HirIdent { id, .. }, _) = self.expression(expr) { match self.try_definition(id).map(|def| &def.kind) { Some(DefinitionKind::Function(func_id)) => Some(*func_id), - Some(DefinitionKind::Local(Some(expr_id))) if follow_redirects => { - self.lookup_function_from_expr(expr_id, follow_redirects) + Some(DefinitionKind::Local(Some(expr_id))) => { + self.lookup_function_from_expr(expr_id) } - Some(DefinitionKind::Global(global_id)) if follow_redirects => { + Some(DefinitionKind::Global(global_id)) => { let info = self.get_global(*global_id); let HirStatement::Let(HirLetStatement { expression, .. }) = self.statement(&info.let_statement) else { unreachable!("global refers to a let statement"); }; - self.lookup_function_from_expr(&expression, follow_redirects) + self.lookup_function_from_expr(&expression) } _ => None, } diff --git a/compiler/noirc_frontend/src/tests/oracles.rs b/compiler/noirc_frontend/src/tests/oracles.rs index ca5e0d4b97a..f8520b976b1 100644 --- a/compiler/noirc_frontend/src/tests/oracles.rs +++ b/compiler/noirc_frontend/src/tests/oracles.rs @@ -36,21 +36,20 @@ fn errors_if_oracle_returns_multiple_vectors() { } #[test] -fn errors_if_oracle_called_from_constrained_directly() { +fn does_not_error_if_oracle_called_from_constrained_directly() { + // Assuming that direct oracle calls will be automatically wrapped in a proxy function. let src = r#" fn main() { // safety: unsafe { oracle_call(); - ^^^^^^^^^^^^^ Oracle functions cannot be called directly from constrained functions - ~~~~~~~~~~~~~ This oracle call must be wrapped in a call to another unconstrained function before being returned to a constrained runtime } } #[oracle(oracle_call)] unconstrained fn oracle_call() {} "#; - check_errors(src); + assert_no_errors(src); } #[test]