diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index eab0b91f0f6..b01c4e0d768 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -699,7 +699,7 @@ impl<'context> Elaborator<'context> { } } - pub(super) fn type_check_variable( + pub(crate) fn type_check_variable( &mut self, ident: HirIdent, expr_id: ExprId, diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 6655c8977e2..ff46592f9ed 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -35,7 +35,7 @@ use crate::{ }, hir_def::{ self, - expr::{HirExpression, HirLiteral}, + expr::{HirExpression, HirIdent, HirLiteral}, function::FunctionBody, }, node_interner::{DefinitionKind, NodeInterner, TraitImplKind}, @@ -121,6 +121,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "fmtstr_quoted_contents" => fmtstr_quoted_contents(interner, arguments, location), "fresh_type_variable" => fresh_type_variable(interner), "function_def_add_attribute" => function_def_add_attribute(self, arguments, location), + "function_def_as_typed_expr" => function_def_as_typed_expr(self, arguments, location), "function_def_body" => function_def_body(interner, arguments, location), "function_def_eq" => function_def_eq(arguments, location), "function_def_has_named_attribute" => { @@ -2423,6 +2424,25 @@ fn function_def_add_attribute( Ok(Value::Unit) } +// fn as_typed_expr(self) -> TypedExpr +fn function_def_as_typed_expr( + interpreter: &mut Interpreter, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let func_id = get_function_def(self_argument)?; + let definition_id = interpreter.elaborator.interner.function_definition_id(func_id); + let hir_ident = HirIdent::non_trait_method(definition_id, location); + let generics = None; + let hir_expr = HirExpression::Ident(hir_ident.clone(), generics.clone()); + let expr_id = interpreter.elaborator.interner.push_expr(hir_expr); + interpreter.elaborator.interner.push_expr_location(expr_id, location.span, location.file); + let typ = interpreter.elaborator.type_check_variable(hir_ident, expr_id, generics); + interpreter.elaborator.interner.push_expr_type(expr_id, typ); + Ok(Value::TypedExpr(TypedExpr::ExprId(expr_id))) +} + // fn body(self) -> Expr fn function_def_body( interner: &NodeInterner, diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index c1a831c70a8..d4866836dcd 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -495,6 +495,7 @@ impl Value { Value::UnresolvedType(typ) => { Token::InternedUnresolvedTypeData(interner.push_unresolved_type_data(typ)) } + Value::TypedExpr(TypedExpr::ExprId(expr_id)) => Token::UnquoteMarker(expr_id), Value::U1(bool) => Token::Bool(bool), Value::U8(value) => Token::Int((value as u128).into()), Value::U16(value) => Token::Int((value as u128).into()), diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index d0a6f05e05a..ef90ff4e581 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -162,7 +162,7 @@ pub enum Token { InternedLValue(InternedExpressionKind), /// A reference to an interned `UnresolvedTypeData`. InternedUnresolvedTypeData(InternedUnresolvedTypeData), - /// A reference to an interned `Patter`. + /// A reference to an interned `Pattern`. InternedPattern(InternedPattern), /// < Less, diff --git a/docs/docs/noir/standard_library/meta/function_def.md b/docs/docs/noir/standard_library/meta/function_def.md index 583771aeb73..ad6212b8e36 100644 --- a/docs/docs/noir/standard_library/meta/function_def.md +++ b/docs/docs/noir/standard_library/meta/function_def.md @@ -15,6 +15,17 @@ Adds an attribute to the function. This is only valid on functions in the current crate which have not yet been resolved. This means any functions called at compile-time are invalid targets for this method. +### as_typed_expr + +#include_code as_typed_expr noir_stdlib/src/meta/function_def.nr rust + +Returns this function as a `TypedExpr`, which can be unquoted. For example: + +```rust +let typed_expr = some_function.as_typed_expr(); +let _ = quote { $typed_expr(1, 2, 3); }; +``` + ### body #include_code body noir_stdlib/src/meta/function_def.nr rust diff --git a/noir_stdlib/src/meta/function_def.nr b/noir_stdlib/src/meta/function_def.nr index 010110d678e..e68e17cc1ac 100644 --- a/noir_stdlib/src/meta/function_def.nr +++ b/noir_stdlib/src/meta/function_def.nr @@ -4,6 +4,11 @@ impl FunctionDefinition { pub comptime fn add_attribute(self, attribute: str) {} // docs:end:add_attribute + #[builtin(function_def_as_typed_expr)] + // docs:start:as_typed_expr + pub comptime fn as_typed_expr(self) -> TypedExpr {} + // docs:end:as_typed_expr + #[builtin(function_def_body)] // docs:start:body pub comptime fn body(self) -> Expr {} diff --git a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr index 105c1e8d577..f912ecb7bfc 100644 --- a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr +++ b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr @@ -105,3 +105,106 @@ mod foo { assert(!f.is_unconstrained()); } } + +mod test_as_typed_expr_1 { + #![foo] + + pub fn method(x: T) -> T { + x + } + + comptime fn foo(module: Module) -> Quoted { + let method = module.functions().filter(|f| f.name() == quote { method })[0]; + let func = method.as_typed_expr(); + quote { + pub fn bar() -> i32 { + $func(1) + } + } + } + + pub fn test() { + comptime { + assert_eq(bar(), 1); + } + } +} + +mod test_as_typed_expr_2 { + #![foo] + + unconstrained fn method(xs: [T; N]) -> u32 { + xs.len() + } + + comptime fn foo(module: Module) -> Quoted { + let method = module.functions().filter(|f| f.name() == quote { method })[0]; + let func = method.as_typed_expr(); + quote { + pub fn bar() -> u32 { + /// Safety: test program + unsafe { $func([1, 2, 3, 0]) } + } + } + } + + pub fn test() { + comptime { + assert_eq(bar(), 4); + } + } +} + +mod test_as_typed_expr_3 { + #![foo] + + pub comptime fn method(xs_ys: ([T; N], U)) -> u32 { + let (xs, _ys) = xs_ys; + xs.len() + } + + comptime fn foo(module: Module) -> Quoted { + let method = module.functions().filter(|f| f.name() == quote { method })[0]; + let func = method.as_typed_expr(); + quote { + pub fn bar() -> u32 { + /// Safety: test program + comptime { $func(([1, 2, 3, 0], "a")) } + } + } + } + + pub fn test() { + comptime { + assert_eq(bar(), 4); + } + } +} + +mod test_as_typed_expr_4 { + comptime fn foo(f: TypedExpr) -> Quoted { + quote { + $f() + } + } + + fn bar() -> Field { + 1 + } + + fn baz() -> Field { + let x: Field = comptime { + let bar_q = quote { + bar + }; + foo(bar_q.as_expr().unwrap().resolve(Option::none())) + }; + x + } + + pub fn test() { + comptime { + assert_eq(baz(), 1); + } + } +}