diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index fb5d6651cda..0cd0824b6d9 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -52,10 +52,41 @@ impl<'context> Elaborator<'context> { /// Elaborate an expression from the middle of a comptime scope. /// When this happens we require additional information to know /// what variables should be in scope. - pub fn elaborate_item_from_comptime<'a, T>( + pub fn elaborate_item_from_comptime_in_function<'a, T>( &'a mut self, current_function: Option, f: impl FnOnce(&mut Elaborator<'a>) -> T, + ) -> T { + self.elaborate_item_from_comptime(f, |elaborator| { + if let Some(function) = current_function { + let meta = elaborator.interner.function_meta(&function); + elaborator.current_item = Some(DependencyId::Function(function)); + elaborator.crate_id = meta.source_crate; + elaborator.local_module = meta.source_module; + elaborator.file = meta.source_file; + elaborator.introduce_generics_into_scope(meta.all_generics.clone()); + } + }) + } + + pub fn elaborate_item_from_comptime_in_module<'a, T>( + &'a mut self, + module: ModuleId, + file: FileId, + f: impl FnOnce(&mut Elaborator<'a>) -> T, + ) -> T { + self.elaborate_item_from_comptime(f, |elaborator| { + elaborator.current_item = None; + elaborator.crate_id = module.krate; + elaborator.local_module = module.local_id; + elaborator.file = file; + }) + } + + fn elaborate_item_from_comptime<'a, T>( + &'a mut self, + f: impl FnOnce(&mut Elaborator<'a>) -> T, + setup: impl FnOnce(&mut Elaborator<'a>), ) -> T { // Create a fresh elaborator to ensure no state is changed from // this elaborator @@ -70,14 +101,7 @@ impl<'context> Elaborator<'context> { elaborator.function_context.push(FunctionContext::default()); elaborator.scopes.start_function(); - if let Some(function) = current_function { - let meta = elaborator.interner.function_meta(&function); - elaborator.current_item = Some(DependencyId::Function(function)); - elaborator.crate_id = meta.source_crate; - elaborator.local_module = meta.source_module; - elaborator.file = meta.source_file; - elaborator.introduce_generics_into_scope(meta.all_generics.clone()); - } + setup(&mut elaborator); elaborator.populate_scope_from_comptime_scopes(); @@ -351,7 +375,7 @@ impl<'context> Elaborator<'context> { } } - fn add_item( + pub(crate) fn add_item( &mut self, item: TopLevelStatement, generated_items: &mut CollectedItems, diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 9925206003b..d321d04bef9 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -253,7 +253,7 @@ impl<'context> Elaborator<'context> { this } - fn elaborate_items(&mut self, mut items: CollectedItems) { + pub(crate) fn elaborate_items(&mut self, mut items: CollectedItems) { // We must first resolve and intern the globals before we can resolve any stmts inside each function. // Each function uses its own resolver with a newly created ScopeForest, and must be resolved again to be within a function's scope // diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 9f559b7c5e6..5f58c18d66e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -2,6 +2,7 @@ use std::collections::VecDeque; use std::{collections::hash_map::Entry, rc::Rc}; use acvm::{acir::AcirField, FieldElement}; +use fm::FileId; use im::Vector; use iter_extended::try_vecmap; use noirc_errors::Location; @@ -10,6 +11,7 @@ use rustc_hash::FxHashMap as HashMap; use crate::ast::{BinaryOpKind, FunctionKind, IntegerBitSize, Signedness}; use crate::elaborator::Elaborator; use crate::graph::CrateId; +use crate::hir::def_map::ModuleId; use crate::hir_def::expr::ImplKind; use crate::hir_def::function::FunctionBody; use crate::macros_api::UnaryOp; @@ -170,7 +172,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Some(body) => Ok(body), None => { if matches!(&meta.function_body, FunctionBody::Unresolved(..)) { - self.elaborate_item(None, |elaborator| { + self.elaborate_in_function(None, |elaborator| { elaborator.elaborate_function(function); }); @@ -183,13 +185,25 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } } - fn elaborate_item( + fn elaborate_in_function( &mut self, function: Option, f: impl FnOnce(&mut Elaborator) -> T, ) -> T { self.unbind_generics_from_previous_function(); - let result = self.elaborator.elaborate_item_from_comptime(function, f); + let result = self.elaborator.elaborate_item_from_comptime_in_function(function, f); + self.rebind_generics_from_previous_function(); + result + } + + fn elaborate_in_module( + &mut self, + module: ModuleId, + file: FileId, + f: impl FnOnce(&mut Elaborator) -> T, + ) -> T { + self.unbind_generics_from_previous_function(); + let result = self.elaborator.elaborate_item_from_comptime_in_module(module, file, f); self.rebind_generics_from_previous_function(); result } @@ -1244,7 +1258,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let mut result = self.call_function(function_id, arguments, bindings, location)?; if call.is_macro_call { let expr = result.into_expression(self.elaborator.interner, location)?; - let expr = self.elaborate_item(self.current_function, |elaborator| { + let expr = self.elaborate_in_function(self.current_function, |elaborator| { elaborator.elaborate_expression(expr).0 }); result = self.evaluate(expr)?; diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index e6ef685e278..5943e352510 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -25,6 +25,7 @@ use crate::{ FunctionReturnType, IntegerBitSize, LValue, Literal, Statement, StatementKind, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, }, + hir::def_collector::dc_crate::CollectedItems, hir::{ comptime::{ errors::IResult, @@ -117,6 +118,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "function_def_set_return_public" => { function_def_set_return_public(self, arguments, location) } + "module_add_item" => module_add_item(self, arguments, location), "module_functions" => module_functions(self, arguments, location), "module_has_named_attribute" => module_has_named_attribute(self, arguments, location), "module_is_contract" => module_is_contract(self, arguments, location), @@ -607,9 +609,10 @@ fn quoted_as_module( let path = parse(argument, parser::path_no_turbofish(), "a path").ok(); let option_value = path.and_then(|path| { - let module = interpreter.elaborate_item(interpreter.current_function, |elaborator| { - elaborator.resolve_module_by_path(path) - }); + let module = interpreter + .elaborate_in_function(interpreter.current_function, |elaborator| { + elaborator.resolve_module_by_path(path) + }); module.map(Value::ModuleDefinition) }); @@ -625,7 +628,7 @@ fn quoted_as_trait_constraint( let argument = check_one_argument(arguments, location)?; let trait_bound = parse(argument, parser::trait_bound(), "a trait constraint")?; let bound = interpreter - .elaborate_item(interpreter.current_function, |elaborator| { + .elaborate_in_function(interpreter.current_function, |elaborator| { elaborator.resolve_trait_bound(&trait_bound, Type::Unit) }) .ok_or(InterpreterError::FailedToResolveTraitBound { trait_bound, location })?; @@ -641,8 +644,8 @@ fn quoted_as_type( ) -> IResult { let argument = check_one_argument(arguments, location)?; let typ = parse(argument, parser::parse_type(), "a type")?; - let typ = - interpreter.elaborate_item(interpreter.current_function, |elab| elab.resolve_type(typ)); + let typ = interpreter + .elaborate_in_function(interpreter.current_function, |elab| elab.resolve_type(typ)); Ok(Value::Type(typ)) } @@ -1712,23 +1715,25 @@ fn expr_resolve( interpreter.current_function }; - let value = interpreter.elaborate_item(function_to_resolve_in, |elaborator| match expr_value { - ExprValue::Expression(expression_kind) => { - let expr = Expression { kind: expression_kind, span: self_argument_location.span }; - let (expr_id, _) = elaborator.elaborate_expression(expr); - Value::TypedExpr(TypedExpr::ExprId(expr_id)) - } - ExprValue::Statement(statement_kind) => { - let statement = Statement { kind: statement_kind, span: self_argument_location.span }; - let (stmt_id, _) = elaborator.elaborate_statement(statement); - Value::TypedExpr(TypedExpr::StmtId(stmt_id)) - } - ExprValue::LValue(lvalue) => { - let expr = lvalue.as_expression(); - let (expr_id, _) = elaborator.elaborate_expression(expr); - Value::TypedExpr(TypedExpr::ExprId(expr_id)) - } - }); + let value = + interpreter.elaborate_in_function(function_to_resolve_in, |elaborator| match expr_value { + ExprValue::Expression(expression_kind) => { + let expr = Expression { kind: expression_kind, span: self_argument_location.span }; + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + ExprValue::Statement(statement_kind) => { + let statement = + Statement { kind: statement_kind, span: self_argument_location.span }; + let (stmt_id, _) = elaborator.elaborate_statement(statement); + Value::TypedExpr(TypedExpr::StmtId(stmt_id)) + } + ExprValue::LValue(lvalue) => { + let expr = lvalue.as_expression(); + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + }); Ok(value) } @@ -1996,7 +2001,7 @@ fn function_def_set_parameters( "a pattern", )?; - let hir_pattern = interpreter.elaborate_item(Some(func_id), |elaborator| { + let hir_pattern = interpreter.elaborate_in_function(Some(func_id), |elaborator| { elaborator.elaborate_pattern_and_store_ids( parameter_pattern, parameter_type.clone(), @@ -2063,6 +2068,34 @@ fn function_def_set_return_public( Ok(Value::Unit) } +// fn add_item(self, item: Quoted) +fn module_add_item( + interpreter: &mut Interpreter, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let (self_argument, item) = check_two_arguments(arguments, location)?; + let module_id = get_module(self_argument)?; + let module_data = interpreter.elaborator.get_module(module_id); + + let parser = parser::top_level_items(); + let top_level_statements = parse(item, parser, "a top-level item")?; + + interpreter.elaborate_in_module(module_id, module_data.location.file, |elaborator| { + let mut generated_items = CollectedItems::default(); + + for top_level_statement in top_level_statements { + elaborator.add_item(top_level_statement, &mut generated_items, location); + } + + if !generated_items.is_empty() { + elaborator.elaborate_items(generated_items); + } + }); + + Ok(Value::Unit) +} + // fn functions(self) -> [FunctionDefinition] fn module_functions( interpreter: &Interpreter, diff --git a/compiler/noirc_frontend/src/parser/mod.rs b/compiler/noirc_frontend/src/parser/mod.rs index 2995e90ab01..66be0fdced5 100644 --- a/compiler/noirc_frontend/src/parser/mod.rs +++ b/compiler/noirc_frontend/src/parser/mod.rs @@ -26,8 +26,8 @@ use noirc_errors::Span; pub use parser::path::path_no_turbofish; pub use parser::traits::trait_bound; pub use parser::{ - block, expression, fresh_statement, lvalue, parse_program, parse_type, pattern, - top_level_items, visibility, + block, expression, fresh_statement, lvalue, module, parse_program, parse_type, pattern, + top_level_items, top_level_statement, visibility, }; #[derive(Debug, Clone)] diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 1aee697aa88..48d25e7a1d8 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -175,7 +175,7 @@ fn program() -> impl NoirParser { /// module: top_level_statement module /// | %empty -fn module() -> impl NoirParser { +pub fn module() -> impl NoirParser { recursive(|module_parser| { empty() .to(ParsedModule::default()) @@ -202,7 +202,7 @@ pub fn top_level_items() -> impl NoirParser> { /// | module_declaration /// | use_statement /// | global_declaration -fn top_level_statement<'a>( +pub fn top_level_statement<'a>( module_parser: impl NoirParser + 'a, ) -> impl NoirParser + 'a { choice(( diff --git a/docs/docs/noir/standard_library/meta/module.md b/docs/docs/noir/standard_library/meta/module.md index 870e366461c..de042760d51 100644 --- a/docs/docs/noir/standard_library/meta/module.md +++ b/docs/docs/noir/standard_library/meta/module.md @@ -8,6 +8,14 @@ declarations in the source program. ## Methods +### add_item + +#include_code add_item noir_stdlib/src/meta/module.nr rust + +Adds a top-level item (a function, a struct, a global, etc.) to the module. +Adding multiple items in one go is also valid if the `Quoted` value has multiple items in it. +Note that the items are type-checked as if they are inside the module they are being added to. + ### name #include_code name noir_stdlib/src/meta/module.nr rust diff --git a/noir_stdlib/src/meta/module.nr b/noir_stdlib/src/meta/module.nr index b3f76812b8a..bee6612e1bf 100644 --- a/noir_stdlib/src/meta/module.nr +++ b/noir_stdlib/src/meta/module.nr @@ -1,4 +1,9 @@ impl Module { + #[builtin(module_add_item)] + // docs:start:add_item + fn add_item(self, item: Quoted) {} + // docs:end:add_item + #[builtin(module_has_named_attribute)] // docs:start:has_named_attribute fn has_named_attribute(self, name: Quoted) -> bool {} diff --git a/test_programs/compile_success_empty/comptime_module/src/main.nr b/test_programs/compile_success_empty/comptime_module/src/main.nr index 1d1690c4017..baf45c517ed 100644 --- a/test_programs/compile_success_empty/comptime_module/src/main.nr +++ b/test_programs/compile_success_empty/comptime_module/src/main.nr @@ -42,6 +42,22 @@ fn outer_attribute_separate_module(m: Module) { increment_counter(); } +struct Foo {} + +#[add_function] +mod add_to_me { + fn add_to_me_function() {} +} + +fn add_function(m: Module) { + m.add_item( + quote { pub fn added_function() -> super::Foo { + add_to_me_function(); + super::Foo {} + } } + ); +} + fn main() { comptime { @@ -73,6 +89,8 @@ fn main() { yet_another_module::generated_outer_function(); yet_another_module::generated_inner_function(); + + let _ = add_to_me::added_function(); } // docs:start:as_module_example