Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,14 @@ impl<'context> Elaborator<'context> {
self.file = trait_impl.file_id;
self.local_module = trait_impl.module_id;

self.check_parent_traits_are_implemented(&trait_impl);

self.generics = trait_impl.resolved_generics;
self.generics = trait_impl.resolved_generics.clone();
self.current_trait_impl = trait_impl.impl_id;

self.add_trait_impl_assumed_trait_implementations(trait_impl.impl_id);
self.check_trait_impl_where_clause_matches_trait_where_clause(&trait_impl);
self.check_parent_traits_are_implemented(&trait_impl);
self.remove_trait_impl_assumed_trait_implementations(trait_impl.impl_id);

for (module, function, _) in &trait_impl.methods.functions {
self.local_module = *module;
let errors = check_trait_impl_method_matches_declaration(self.interner, *function);
Expand All @@ -1045,6 +1048,95 @@ impl<'context> Elaborator<'context> {
self.generics.clear();
}

fn add_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
let trait_bound = &trait_constrain.trait_bound;
self.interner.add_assumed_trait_implementation(
trait_constrain.typ.clone(),
trait_bound.trait_id,
trait_bound.trait_generics.clone(),
);
}
}
}
}

fn remove_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
self.interner.remove_assumed_trait_implementations_for_trait(
trait_constrain.trait_bound.trait_id,
);
}
}
}
}

fn check_trait_impl_where_clause_matches_trait_where_clause(
&mut self,
trait_impl: &UnresolvedTraitImpl,
) {
let Some(trait_id) = trait_impl.trait_id else {
return;
};

let Some(the_trait) = self.interner.try_get_trait(trait_id) else {
return;
};

if the_trait.where_clause.is_empty() {
return;
}

let impl_trait = the_trait.name.to_string();
let the_trait_file = the_trait.location.file;

let mut bindings = TypeBindings::new();
bind_ordered_generics(
&the_trait.generics,
&trait_impl.resolved_trait_generics,
&mut bindings,
);

// Check that each of the trait's where clause constraints is satisfied
for trait_constraint in the_trait.where_clause.clone() {
let Some(trait_constraint_trait) =
self.interner.try_get_trait(trait_constraint.trait_bound.trait_id)
else {
continue;
};

let trait_constraint_type = trait_constraint.typ.substitute(&bindings);
let trait_bound = &trait_constraint.trait_bound;

if self
.interner
.try_lookup_trait_implementation(
&trait_constraint_type,
trait_bound.trait_id,
&trait_bound.trait_generics.ordered,
&trait_bound.trait_generics.named,
)
.is_err()
{
let missing_trait =
format!("{}{}", trait_constraint_trait.name, trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_constraint_type.to_string(),
span: trait_impl.object_type.span,
missing_trait_location: Location::new(trait_bound.span, the_trait_file),
});
}
}
}

fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) {
let Some(trait_id) = trait_impl.trait_id else {
return;
Expand Down Expand Up @@ -1168,7 +1260,7 @@ impl<'context> Elaborator<'context> {
trait_id,
trait_generics,
file: trait_impl.file_id,
where_clause: where_clause.clone(),
where_clause,
methods,
});

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl<'context> Elaborator<'context> {
&resolved_generics,
);

let where_clause =
this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause);

// Each associated type in this trait is also an implicit generic
for associated_type in &this.interner.get_trait(*trait_id).associated_types {
this.generics.push(associated_type.clone());
Expand All @@ -48,6 +51,7 @@ impl<'context> Elaborator<'context> {
this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub struct Trait {

/// The resolved trait bounds (for example in `trait Foo: Bar + Baz`, this would be `Bar + Baz`)
pub trait_bounds: Vec<ResolvedTraitBound>,

pub where_clause: Vec<TraitConstraint>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -154,6 +156,10 @@ impl Trait {
self.trait_bounds = trait_bounds;
}

pub fn set_where_clause(&mut self, where_clause: Vec<TraitConstraint>) {
self.where_clause = where_clause;
}

pub fn find_method(&self, name: &str) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if &method.name == name {
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ impl NodeInterner {
method_ids: unresolved_trait.method_ids.clone(),
associated_types,
trait_bounds: Vec::new(),
where_clause: Vec::new(),
};

self.traits.insert(type_id, new_trait);
Expand Down
15 changes: 9 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2969,9 +2969,7 @@ fn uses_self_type_in_trait_where_clause() {
}
}

struct Bar {

}
struct Bar {}

impl Foo for Bar {

Expand All @@ -2983,12 +2981,17 @@ fn uses_self_type_in_trait_where_clause() {
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert_eq!(errors.len(), 2);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0
else {
panic!("Expected a trait not implemented error, got {:?}", errors[0].0);
};

let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) =
&errors[0].0
&errors[1].0
else {
panic!("Expected an unresolved method call error, got {:?}", errors[0].0);
panic!("Expected an unresolved method call error, got {:?}", errors[1].0);
};

assert_eq!(method_name, "trait_func");
Expand Down
113 changes: 113 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,116 @@ fn trait_inheritance_missing_parent_implementation() {
assert_eq!(typ, "Struct");
assert_eq!(impl_trait, "Bar");
}

#[test]
fn errors_on_unknown_type_in_trait_where_clause() {
let src = r#"
pub trait Foo<T> where T: Unknown {}

fn main() {
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_concrete_type() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}

pub struct SomeGreeter;
impl Greeter for SomeGreeter {
fn greet(self) {}
}

pub struct Bar;

impl Foo<SomeGreeter> for Bar {}

fn main() {}
"#;
assert_no_errors(src);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_type_variable() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T> where T: Greeter {
fn greet(object: T) {
object.greet();
}
}

pub struct Bar;

impl<T> Foo<T> for Bar where T: Greeter {
}

fn main() {
}
"#;
assert_no_errors(src);
}
#[test]
fn errors_if_impl_trait_constraint_is_not_satisfied() {
let src = r#"
pub trait Greeter {
fn greet(self);
}

pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}

pub struct SomeGreeter;

pub struct Bar;

impl Foo<SomeGreeter> for Bar {}

fn main() {}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented {
impl_trait,
missing_trait: the_trait,
type_missing_trait: typ,
..
}) = &errors[0].0
else {
panic!("Expected a TraitNotImplemented error, got {:?}", &errors[0].0);
};

assert_eq!(the_trait, "Greeter");
assert_eq!(typ, "SomeGreeter");
assert_eq!(impl_trait, "Foo");
}