diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 446e5b62ead..0f9d22ca9b5 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -4,14 +4,13 @@ use std::{ rc::Rc, }; -use crate::hir::def_map::CrateDefMap; use crate::{ ast::{ ArrayLiteral, ConstructorExpression, FunctionKind, IfExpression, InfixExpression, Lambda, UnresolvedTraitConstraint, UnresolvedTypeExpression, }, hir::{ - def_collector::dc_crate::CompilationError, + def_collector::{dc_crate::CompilationError, errors::DuplicateType}, resolution::{errors::ResolverError, path_resolver::PathResolver, resolver::LambdaContext}, scope::ScopeForest as GenericScopeForest, type_check::TypeCheckError, @@ -56,6 +55,14 @@ use crate::{ token::FunctionAttribute, Generics, }; +use crate::{ + hir::{ + def_collector::dc_crate::{UnresolvedFunctions, UnresolvedTraitImpl}, + def_map::{CrateDefMap, ModuleData}, + }, + hir_def::traits::TraitImpl, + macros_api::ItemVisibility, +}; mod expressions; mod patterns; @@ -86,7 +93,7 @@ pub struct Elaborator<'context> { interner: &'context mut NodeInterner, - def_maps: &'context BTreeMap, + def_maps: &'context mut BTreeMap, file: FileId, @@ -168,7 +175,7 @@ impl<'context> Elaborator<'context> { scopes: ScopeForest::default(), errors: Vec::new(), interner: &mut context.def_interner, - def_maps: &context.def_maps, + def_maps: &mut context.def_maps, file: FileId::dummy(), in_unconstrained_fn: false, nested_loops: 0, @@ -192,7 +199,7 @@ impl<'context> Elaborator<'context> { pub fn elaborate( context: &'context mut Context, crate_id: CrateId, - items: CollectedItems, + mut items: CollectedItems, ) -> Vec<(CompilationError, FileId)> { let mut this = Self::new(context, crate_id); @@ -205,28 +212,27 @@ impl<'context> Elaborator<'context> { for struct_ in items.types {} - for trait_impl in &items.trait_impls { - // only collect now + for trait_impl in &mut items.trait_impls { + this.collect_trait_impl(trait_impl); } - for impl_ in &items.impls { - // only collect now + for ((typ, module), impls) in &items.impls { + this.collect_impls(typ, *module, impls); } // resolver resolves non-literal globals here for functions in items.functions { - this.file = functions.file_id; - this.trait_id = functions.trait_id; // TODO: Resolve? - for (local_module, id, func) in functions.functions { - this.local_module = local_module; - this.elaborate_function(func, id); - } + this.elaborate_functions(functions); } - for impl_ in items.impls {} + for ((typ, module), impls) in items.impls { + this.elaborate_impls(typ, module, impls); + } - for trait_impl in items.trait_impls {} + for trait_impl in items.trait_impls { + this.elaborate_trait_impl(trait_impl); + } let cycle_errors = this.interner.check_for_dependency_cycles(); this.errors.extend(cycle_errors); @@ -234,6 +240,17 @@ impl<'context> Elaborator<'context> { this.errors } + fn elaborate_functions(&mut self, functions: UnresolvedFunctions) { + self.file = functions.file_id; + self.trait_id = functions.trait_id; // TODO: Resolve? + for (local_module, id, func) in functions.functions { + self.local_module = local_module; + let generics_count = self.generics.len(); + self.elaborate_function(func, id); + self.generics.truncate(generics_count); + } + } + fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) { self.current_function = Some(id); self.resolve_where_clause(&mut function.def.where_clause); @@ -779,4 +796,308 @@ impl<'context> Elaborator<'context> { } } } + + fn elaborate_impls( + &mut self, + typ: UnresolvedType, + module: LocalModuleId, + impls: Vec<(Vec, Span, UnresolvedFunctions)>, + ) { + self.generics.clear(); + + for (generics, _, functions) in impls { + self.file = functions.file_id; + self.add_generics(&generics); + let self_type = self.resolve_type(typ.clone()); + self.self_type = Some(self_type.clone()); + + let function_ids = vecmap(&functions.functions, |(_, id, _)| *id); + self.elaborate_functions(functions); + + if self_type != Type::Error { + for method_id in function_ids { + let method_name = self.interner.function_name(&method_id).to_owned(); + + if let Some(first_fn) = + self.interner.add_method(&self_type, method_name.clone(), method_id, false) + { + let error = ResolverError::DuplicateDefinition { + name: method_name, + first_span: self.interner.function_ident(&first_fn).span(), + second_span: self.interner.function_ident(&method_id).span(), + }; + self.push_err(error); + } + } + } + } + } + + fn elaborate_trait_impl(&mut self, trait_impl: UnresolvedTraitImpl) { + self.file = trait_impl.file_id; + self.local_module = trait_impl.module_id; + + let unresolved_type = trait_impl.object_type; + let self_type_span = unresolved_type.span; + self.add_generics(&trait_impl.generics); + + let trait_generics = + vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())); + + let self_type = self.resolve_type(unresolved_type.clone()); + let impl_id = self.interner.next_trait_impl_id(); + + self.self_type = Some(self_type.clone()); + self.current_trait_impl = Some(impl_id); + + let mut methods = trait_impl.methods.function_ids(); + + self.elaborate_functions(trait_impl.methods); + + if matches!(self_type, Type::MutableReference(_)) { + let span = self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()); + self.push_err(DefCollectorErrorKind::MutableReferenceInTraitImpl { span }); + } + + if let Some(trait_id) = trait_impl.trait_id { + for func_id in &methods { + self.interner.set_function_trait(*func_id, self_type.clone(), trait_id); + } + + let where_clause = trait_impl + .where_clause + .into_iter() + .flat_map(|item| self.resolve_trait_constraint(item)) + .collect(); + + let resolved_trait_impl = Shared::new(TraitImpl { + ident: trait_impl.trait_path.last_segment().clone(), + typ: self_type.clone(), + trait_id, + trait_generics: trait_generics.clone(), + file: trait_impl.file_id, + where_clause, + methods, + }); + + let generics = vecmap(&self.generics, |(_, type_variable, _)| type_variable.clone()); + + if let Err((prev_span, prev_file)) = self.interner.add_trait_implementation( + self_type.clone(), + trait_id, + trait_generics, + impl_id, + generics, + resolved_trait_impl, + ) { + self.push_err(DefCollectorErrorKind::OverlappingImpl { + typ: self_type.clone(), + span: self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()), + }); + + // The 'previous impl defined here' note must be a separate error currently + // since it may be in a different file and all errors have the same file id. + self.file = prev_file; + self.push_err(DefCollectorErrorKind::OverlappingImplNote { span: prev_span }); + self.file = trait_impl.file_id; + } + } + + self.self_type = None; + self.current_trait_impl = None; + self.generics.clear(); + } + + fn collect_impls( + &mut self, + self_type: &UnresolvedType, + module: LocalModuleId, + impls: &[(Vec, Span, UnresolvedFunctions)], + ) { + self.local_module = module; + + for (generics, span, unresolved) in impls { + self.file = unresolved.file_id; + self.declare_method_on_struct(self_type, generics, false, unresolved, *span); + } + } + + fn collect_trait_impl(&mut self, trait_impl: &mut UnresolvedTraitImpl) { + self.local_module = trait_impl.module_id; + self.file = trait_impl.file_id; + trait_impl.trait_id = self.resolve_trait_by_path(trait_impl.trait_path.clone()); + + if let Some(trait_id) = trait_impl.trait_id { + self.collect_trait_impl_methods(trait_id, trait_impl); + + let span = trait_impl.object_type.span.expect("All trait self types should have spans"); + let object_type = &trait_impl.object_type; + let generics = &trait_impl.generics; + self.declare_method_on_struct(object_type, generics, true, &trait_impl.methods, span); + } + } + + fn get_module_mut(&mut self, module: ModuleId) -> &mut ModuleData { + let message = "A crate should always be present for a given crate id"; + &mut self.def_maps.get_mut(&module.krate).expect(message).modules[module.local_id.0] + } + + fn declare_method_on_struct( + &mut self, + self_type: &UnresolvedType, + generics: &UnresolvedGenerics, + is_trait_impl: bool, + functions: &UnresolvedFunctions, + span: Span, + ) { + let generic_count = self.generics.len(); + self.add_generics(generics); + let typ = self.resolve_type(self_type.clone()); + + if let Type::Struct(struct_type, _generics) = typ { + let struct_type = struct_type.borrow(); + + // `impl`s are only allowed on types defined within the current crate + if !is_trait_impl && struct_type.id.krate() != self.crate_id { + let type_name = struct_type.name.to_string(); + self.push_err(DefCollectorErrorKind::ForeignImpl { span, type_name }); + self.generics.truncate(generic_count); + return; + } + + // Grab the module defined by the struct type. Note that impls are a case + // where the module the methods are added to is not the same as the module + // they are resolved in. + let module = self.get_module_mut(struct_type.id.module_id()); + + for (_, method_id, method) in &functions.functions { + // If this method was already declared, remove it from the module so it cannot + // be accessed with the `TypeName::method` syntax. We'll check later whether the + // object types in each method overlap or not. If they do, we issue an error. + // If not, that is specialization which is allowed. + let name = method.name_ident().clone(); + if module.declare_function(name, ItemVisibility::Public, *method_id).is_err() { + module.remove_function(method.name_ident()); + } + } + // Prohibit defining impls for primitive types if we're not in the stdlib + } else if !is_trait_impl && typ != Type::Error && !self.crate_id.is_stdlib() { + self.push_err(DefCollectorErrorKind::NonStructTypeInImpl { span }); + } + self.generics.truncate(generic_count); + } + + fn collect_trait_impl_methods( + &mut self, + trait_id: TraitId, + trait_impl: &mut UnresolvedTraitImpl, + ) { + self.local_module = trait_impl.module_id; + self.file = trait_impl.file_id; + + // In this Vec methods[i] corresponds to trait.methods[i]. If the impl has no implementation + // for a particular method, the default implementation will be added at that slot. + let mut ordered_methods = Vec::new(); + + // check whether the trait implementation is in the same crate as either the trait or the type + self.check_trait_impl_crate_coherence(trait_id, trait_impl); + + // set of function ids that have a corresponding method in the trait + let mut func_ids_in_trait = HashSet::default(); + + // Temporarily take ownership of the trait's methods so we can iterate over them + // while also mutating the interner + let the_trait = self.interner.get_trait_mut(trait_id); + let methods = std::mem::take(&mut the_trait.methods); + + for method in &methods { + let overrides: Vec<_> = trait_impl + .methods + .functions + .iter() + .filter(|(_, _, f)| f.name() == method.name.0.contents) + .collect(); + + if overrides.is_empty() { + if let Some(default_impl) = &method.default_impl { + // copy 'where' clause from unresolved trait impl + let mut default_impl_clone = default_impl.clone(); + default_impl_clone.def.where_clause.extend(trait_impl.where_clause.clone()); + + let func_id = self.interner.push_empty_fn(); + let module = self.module_id(); + let location = Location::new(default_impl.def.span, trait_impl.file_id); + self.interner.push_function(func_id, &default_impl.def, module, location); + func_ids_in_trait.insert(func_id); + ordered_methods.push(( + method.default_impl_module_id, + func_id, + *default_impl_clone, + )); + } else { + self.push_err(DefCollectorErrorKind::TraitMissingMethod { + trait_name: self.interner.get_trait(trait_id).name.clone(), + method_name: method.name.clone(), + trait_impl_span: trait_impl + .object_type + .span + .expect("type must have a span"), + }); + } + } else { + for (_, func_id, _) in &overrides { + func_ids_in_trait.insert(*func_id); + } + + if overrides.len() > 1 { + self.push_err(DefCollectorErrorKind::Duplicate { + typ: DuplicateType::TraitAssociatedFunction, + first_def: overrides[0].2.name_ident().clone(), + second_def: overrides[1].2.name_ident().clone(), + }); + } + + ordered_methods.push(overrides[0].clone()); + } + } + + // Restore the methods that were taken before the for loop + let the_trait = self.interner.get_trait_mut(trait_id); + the_trait.set_methods(methods); + + // Emit MethodNotInTrait error for methods in the impl block that + // don't have a corresponding method signature defined in the trait + for (_, func_id, func) in &trait_impl.methods.functions { + if !func_ids_in_trait.contains(func_id) { + let trait_name = the_trait.name.clone(); + let impl_method = func.name_ident().clone(); + let error = DefCollectorErrorKind::MethodNotInTrait { trait_name, impl_method }; + self.errors.push((error.into(), self.file)); + } + } + + trait_impl.methods.functions = ordered_methods; + trait_impl.methods.trait_id = Some(trait_id); + } + + fn check_trait_impl_crate_coherence( + &mut self, + trait_id: TraitId, + trait_impl: &UnresolvedTraitImpl, + ) { + self.local_module = trait_impl.module_id; + self.file = trait_impl.file_id; + + let object_crate = match self.resolve_type(trait_impl.object_type.clone()) { + Type::Struct(struct_type, _) => struct_type.borrow().id.krate(), + _ => CrateId::Dummy, + }; + + let the_trait = self.interner.get_trait(trait_id); + if self.crate_id != the_trait.crate_id && self.crate_id != object_crate { + self.push_err(DefCollectorErrorKind::TraitImplOrphaned { + span: trait_impl.object_type.span.expect("object type must have a span"), + }); + } + } } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 4aac0fec9c3..d2eaf79b0f0 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -54,6 +54,10 @@ impl UnresolvedFunctions { self.functions.push((mod_id, func_id, func)); } + pub fn function_ids(&self) -> Vec { + vecmap(&self.functions, |(_, id, _)| *id) + } + pub fn resolve_trait_bounds_trait_ids( &mut self, def_maps: &BTreeMap,