diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index aef0771c486..6f6679d14ce 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -1027,11 +1027,14 @@ impl<'context> Elaborator<'context> { self.file = trait_impl.file_id; self.local_module = trait_impl.module_id; - self.check_parent_traits_are_implemented(&trait_impl); - - self.generics = trait_impl.resolved_generics; + self.generics = trait_impl.resolved_generics.clone(); self.current_trait_impl = trait_impl.impl_id; + self.add_trait_impl_assumed_trait_implementations(trait_impl.impl_id); + self.check_trait_impl_where_clause_matches_trait_where_clause(&trait_impl); + self.check_parent_traits_are_implemented(&trait_impl); + self.remove_trait_impl_assumed_trait_implementations(trait_impl.impl_id); + for (module, function, _) in &trait_impl.methods.functions { self.local_module = *module; let errors = check_trait_impl_method_matches_declaration(self.interner, *function); @@ -1045,6 +1048,95 @@ impl<'context> Elaborator<'context> { self.generics.clear(); } + fn add_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option) { + if let Some(impl_id) = impl_id { + if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id) + { + for trait_constrain in &trait_implementation.borrow().where_clause { + let trait_bound = &trait_constrain.trait_bound; + self.interner.add_assumed_trait_implementation( + trait_constrain.typ.clone(), + trait_bound.trait_id, + trait_bound.trait_generics.clone(), + ); + } + } + } + } + + fn remove_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option) { + if let Some(impl_id) = impl_id { + if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id) + { + for trait_constrain in &trait_implementation.borrow().where_clause { + self.interner.remove_assumed_trait_implementations_for_trait( + trait_constrain.trait_bound.trait_id, + ); + } + } + } + } + + fn check_trait_impl_where_clause_matches_trait_where_clause( + &mut self, + trait_impl: &UnresolvedTraitImpl, + ) { + let Some(trait_id) = trait_impl.trait_id else { + return; + }; + + let Some(the_trait) = self.interner.try_get_trait(trait_id) else { + return; + }; + + if the_trait.where_clause.is_empty() { + return; + } + + let impl_trait = the_trait.name.to_string(); + let the_trait_file = the_trait.location.file; + + let mut bindings = TypeBindings::new(); + bind_ordered_generics( + &the_trait.generics, + &trait_impl.resolved_trait_generics, + &mut bindings, + ); + + // Check that each of the trait's where clause constraints is satisfied + for trait_constraint in the_trait.where_clause.clone() { + let Some(trait_constraint_trait) = + self.interner.try_get_trait(trait_constraint.trait_bound.trait_id) + else { + continue; + }; + + let trait_constraint_type = trait_constraint.typ.substitute(&bindings); + let trait_bound = &trait_constraint.trait_bound; + + if self + .interner + .try_lookup_trait_implementation( + &trait_constraint_type, + trait_bound.trait_id, + &trait_bound.trait_generics.ordered, + &trait_bound.trait_generics.named, + ) + .is_err() + { + let missing_trait = + format!("{}{}", trait_constraint_trait.name, trait_bound.trait_generics); + self.push_err(ResolverError::TraitNotImplemented { + impl_trait: impl_trait.clone(), + missing_trait, + type_missing_trait: trait_constraint_type.to_string(), + span: trait_impl.object_type.span, + missing_trait_location: Location::new(trait_bound.span, the_trait_file), + }); + } + } + } + fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) { let Some(trait_id) = trait_impl.trait_id else { return; @@ -1168,7 +1260,7 @@ impl<'context> Elaborator<'context> { trait_id, trait_generics, file: trait_impl.file_id, - where_clause: where_clause.clone(), + where_clause, methods, }); diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index e877682972c..ae278616e03 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -32,6 +32,9 @@ impl<'context> Elaborator<'context> { &resolved_generics, ); + let where_clause = + this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause); + // Each associated type in this trait is also an implicit generic for associated_type in &this.interner.get_trait(*trait_id).associated_types { this.generics.push(associated_type.clone()); @@ -48,6 +51,7 @@ impl<'context> Elaborator<'context> { this.interner.update_trait(*trait_id, |trait_def| { trait_def.set_methods(methods); trait_def.set_trait_bounds(resolved_trait_bounds); + trait_def.set_where_clause(where_clause); }); }); diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 534805c2dad..6fd3c4f7a24 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -75,6 +75,8 @@ pub struct Trait { /// The resolved trait bounds (for example in `trait Foo: Bar + Baz`, this would be `Bar + Baz`) pub trait_bounds: Vec, + + pub where_clause: Vec, } #[derive(Debug)] @@ -154,6 +156,10 @@ impl Trait { self.trait_bounds = trait_bounds; } + pub fn set_where_clause(&mut self, where_clause: Vec) { + self.where_clause = where_clause; + } + pub fn find_method(&self, name: &str) -> Option { for (idx, method) in self.methods.iter().enumerate() { if &method.name == name { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index ca7e0c6aa59..2183cfba0ef 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -735,6 +735,7 @@ impl NodeInterner { method_ids: unresolved_trait.method_ids.clone(), associated_types, trait_bounds: Vec::new(), + where_clause: Vec::new(), }; self.traits.insert(type_id, new_trait); diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 17acd17dcc9..829a68ba3a3 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2969,9 +2969,7 @@ fn uses_self_type_in_trait_where_clause() { } } - struct Bar { - - } + struct Bar {} impl Foo for Bar { @@ -2983,12 +2981,17 @@ fn uses_self_type_in_trait_where_clause() { "#; let errors = get_program_errors(src); - assert_eq!(errors.len(), 1); + assert_eq!(errors.len(), 2); + + let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0 + else { + panic!("Expected a trait not implemented error, got {:?}", errors[0].0); + }; let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) = - &errors[0].0 + &errors[1].0 else { - panic!("Expected an unresolved method call error, got {:?}", errors[0].0); + panic!("Expected an unresolved method call error, got {:?}", errors[1].0); }; assert_eq!(method_name, "trait_func"); diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index ee84cc0e890..88138ecde4d 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -148,3 +148,116 @@ fn trait_inheritance_missing_parent_implementation() { assert_eq!(typ, "Struct"); assert_eq!(impl_trait, "Bar"); } + +#[test] +fn errors_on_unknown_type_in_trait_where_clause() { + let src = r#" + pub trait Foo where T: Unknown {} + + fn main() { + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); +} + +#[test] +fn does_not_error_if_impl_trait_constraint_is_satisfied_for_concrete_type() { + let src = r#" + pub trait Greeter { + fn greet(self); + } + + pub trait Foo + where + T: Greeter, + { + fn greet(object: U) + where + U: Greeter, + { + object.greet(); + } + } + + pub struct SomeGreeter; + impl Greeter for SomeGreeter { + fn greet(self) {} + } + + pub struct Bar; + + impl Foo for Bar {} + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn does_not_error_if_impl_trait_constraint_is_satisfied_for_type_variable() { + let src = r#" + pub trait Greeter { + fn greet(self); + } + + pub trait Foo where T: Greeter { + fn greet(object: T) { + object.greet(); + } + } + + pub struct Bar; + + impl Foo for Bar where T: Greeter { + } + + fn main() { + } + "#; + assert_no_errors(src); +} +#[test] +fn errors_if_impl_trait_constraint_is_not_satisfied() { + let src = r#" + pub trait Greeter { + fn greet(self); + } + + pub trait Foo + where + T: Greeter, + { + fn greet(object: U) + where + U: Greeter, + { + object.greet(); + } + } + + pub struct SomeGreeter; + + pub struct Bar; + + impl Foo for Bar {} + + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::ResolverError(ResolverError::TraitNotImplemented { + impl_trait, + missing_trait: the_trait, + type_missing_trait: typ, + .. + }) = &errors[0].0 + else { + panic!("Expected a TraitNotImplemented error, got {:?}", &errors[0].0); + }; + + assert_eq!(the_trait, "Greeter"); + assert_eq!(typ, "SomeGreeter"); + assert_eq!(impl_trait, "Foo"); +}