Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,15 @@ impl Display for TraitImplItemKind {
}
}

/// Does both `desugar_generic_trait_bounds` and `reorder_where_clause`.
pub(crate) fn desugar_generic_trait_bounds_and_reorder_where_clause(
generics: &mut Vec<UnresolvedGeneric>,
where_clause: &mut Vec<UnresolvedTraitConstraint>,
) {
desugar_generic_trait_bounds(generics, where_clause);
reorder_where_clause(where_clause);
}

/// Moves trait bounds from generics into where clauses. For example:
///
/// ```noir
Expand All @@ -300,7 +309,7 @@ impl Display for TraitImplItemKind {
/// ```noir
/// fn foo<T>(x: T) -> T where T: Trait {}
/// ```
pub(crate) fn desugar_generic_trait_bounds(
fn desugar_generic_trait_bounds(
generics: &mut Vec<UnresolvedGeneric>,
where_clause: &mut Vec<UnresolvedTraitConstraint>,
) {
Expand All @@ -323,3 +332,29 @@ pub(crate) fn desugar_generic_trait_bounds(
}
}
}

/// Reorders a where clause in-place so that simpler constraints come before more complex ones.
/// The resulting where clause will have constraints in this order:
/// 1. Paths without generics
/// 2. Paths with generics
/// 3. Everything else
fn reorder_where_clause(where_clause: &mut Vec<UnresolvedTraitConstraint>) {
let mut paths_without_generics = Vec::new();
let mut paths_with_generics = Vec::new();
let mut others = Vec::new();

for clause in std::mem::take(where_clause) {
if let UnresolvedTypeData::Named(_, generics, _) = &clause.typ.typ {
if generics.is_empty() {
paths_without_generics.push(clause);
} else {
paths_with_generics.push(clause);
}
} else {
others.push(clause);
}
}
where_clause.extend(paths_without_generics);
where_clause.extend(paths_with_generics);
where_clause.extend(others);
}
60 changes: 38 additions & 22 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,11 @@ impl<'context> Elaborator<'context> {
};
let trait_constraint_trait_name = trait_constraint_trait.name.to_string();

let trait_constraint_type = trait_constraint.typ.substitute(&bindings);
let trait_bound = &trait_constraint.trait_bound;
let mut trait_constraint = trait_constraint.clone();
trait_constraint.apply_bindings(&bindings);

let trait_constraint_type = trait_constraint.typ;
let trait_bound = trait_constraint.trait_bound;

let mut named_generics = trait_bound.trait_generics.named.clone();

Expand All @@ -1407,43 +1410,56 @@ impl<'context> Elaborator<'context> {
// so they'll unify (the bindings aren't applied here so this is fine).
// If they are bound though, we won't replace them as we want to ensure the binding
// matches.
//
// `bindings` is passed here because these implicitly added named generics might
// have a constraint on them later on and we want to remember what type they ended
// up being.
self.replace_implicitly_added_unbound_named_generics_with_fresh_type_variables(
&mut named_generics,
&mut bindings,
);

if self
.interner
.try_lookup_trait_implementation(
&trait_constraint_type,
trait_bound.trait_id,
&trait_bound.trait_generics.ordered,
&named_generics,
)
.is_err()
{
let missing_trait =
format!("{}{}", trait_constraint_trait_name, trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_constraint_type.to_string(),
location: trait_impl.object_type.location,
missing_trait_location: trait_bound.location,
});
match self.interner.try_lookup_trait_implementation(
&trait_constraint_type,
trait_bound.trait_id,
&trait_bound.trait_generics.ordered,
&named_generics,
) {
Ok((_, impl_bindings, impl_instantiation_bindings)) => {
bindings.extend(impl_bindings);
bindings.extend(impl_instantiation_bindings);
}
Err(_) => {
let missing_trait =
format!("{}{}", trait_constraint_trait_name, trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_constraint_type.to_string(),
location: trait_impl.object_type.location,
missing_trait_location: trait_bound.location,
});
}
}
}
}

fn replace_implicitly_added_unbound_named_generics_with_fresh_type_variables(
&mut self,
named_generics: &mut [NamedType],
bindings: &mut TypeBindings,
) {
for named_type in named_generics.iter_mut() {
match &named_type.typ {
Type::NamedGeneric(NamedGeneric { type_var, implicit: true, .. })
if type_var.borrow().is_unbound() =>
{
named_type.typ = self.interner.next_type_variable();
let type_var_id = type_var.id();
let new_type_var_id = self.interner.next_type_variable_id();
let kind = type_var.kind();
let new_type_var = TypeVariable::unbound(new_type_var_id, kind.clone());
named_type.typ = Type::TypeVariable(new_type_var.clone());
bindings.insert(type_var_id, (new_type_var, kind, named_type.typ.clone()));
}
_ => (),
};
Expand Down
23 changes: 15 additions & 8 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::ast::{
Documented, Expression, FunctionDefinition, Ident, ItemVisibility, LetStatement,
ModuleDeclaration, NoirEnumeration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl,
NoirTypeAlias, Pattern, TraitImplItemKind, TraitItem, TypeImpl, UnresolvedType,
UnresolvedTypeData, desugar_generic_trait_bounds,
UnresolvedTypeData, desugar_generic_trait_bounds_and_reorder_where_clause,
};
use crate::elaborator::PrimitiveType;
use crate::hir::resolution::errors::ResolverError;
Expand Down Expand Up @@ -177,7 +177,10 @@ impl ModCollector<'_> {
let module_id = ModuleId { krate, local_id: self.module_id };

for mut r#impl in impls {
desugar_generic_trait_bounds(&mut r#impl.generics, &mut r#impl.where_clause);
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut r#impl.generics,
&mut r#impl.where_clause,
);

collect_impl(
&mut context.def_interner,
Expand All @@ -201,7 +204,7 @@ impl ModCollector<'_> {
let mut errors = Vec::new();

for mut trait_impl in impls {
desugar_generic_trait_bounds(
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut trait_impl.impl_generics,
&mut trait_impl.where_clause,
);
Expand Down Expand Up @@ -298,7 +301,7 @@ impl ModCollector<'_> {
// With this method we iterate each function in the Crate and not each module
// This may not be great because we have to pull the module_data for each function
let mut noir_function = function.item;
desugar_generic_trait_bounds(
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut noir_function.def.generics,
&mut noir_function.def.where_clause,
);
Expand Down Expand Up @@ -449,7 +452,7 @@ impl ModCollector<'_> {
let has_allow_dead_code =
trait_definition.attributes.iter().any(|attr| attr.kind.is_allow("dead_code"));

desugar_generic_trait_bounds(
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut trait_definition.generics,
&mut trait_definition.where_clause,
);
Expand Down Expand Up @@ -516,7 +519,7 @@ impl ModCollector<'_> {

for item in &mut trait_definition.items {
if let TraitItem::Function { generics, where_clause, .. } = &mut item.item {
desugar_generic_trait_bounds(generics, where_clause);
desugar_generic_trait_bounds_and_reorder_where_clause(generics, where_clause);
}
}

Expand Down Expand Up @@ -1329,7 +1332,11 @@ pub fn collect_impl(

let func_id = interner.push_empty_fn();
method.def.where_clause.extend(r#impl.where_clause.clone());
desugar_generic_trait_bounds(&mut method.def.generics, &mut method.def.where_clause);
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut method.def.generics,
&mut method.def.where_clause,
);

let location = method.location();
interner.push_function(func_id, &method.def, module_id, location);
unresolved_functions.push_fn(module_id.local_id, func_id, method);
Expand Down Expand Up @@ -1454,7 +1461,7 @@ pub(crate) fn collect_trait_impl_items(
let location = impl_method.location();
interner.push_function(func_id, &impl_method.def, module, location);
interner.set_doc_comments(ReferenceId::Function(func_id), item.doc_comments);
desugar_generic_trait_bounds(
desugar_generic_trait_bounds_and_reorder_where_clause(
&mut impl_method.def.generics,
&mut impl_method.def.where_clause,
);
Expand Down
86 changes: 86 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2088,3 +2088,89 @@ fn trait_method_call_when_it_has_bounds_on_generic() {
"#;
assert_no_errors!(src);
}

#[named]
#[test]
fn trait_bound_constraining_two_generics() {
let src = r#"
pub trait Foo<U> {}

pub trait Baz<T, U>
where
T: Foo<U>,
{}

pub struct HasFoo1 {}
impl Foo<()> for HasFoo1 {}

pub struct HasBaz1 {}
impl Baz<HasFoo1, ()> for HasBaz1 {}

fn main() {}
"#;
assert_no_errors!(src);
}

#[named]
#[test]
fn trait_where_clause_associated_type_constraint_expected_order() {
let src = r#"
pub trait BarTrait {}

pub trait Foo {
type Bar;
}

pub trait Baz<T>
where
T: Foo,
<T as Foo>::Bar: BarTrait,
{}

pub struct HasBarTrait1 {}
impl BarTrait for HasBarTrait1 {}

pub struct HasFoo1 {}
impl Foo for HasFoo1 {
type Bar = HasBarTrait1;
}

pub struct HasBaz1 {}
impl Baz<HasFoo1> for HasBaz1 {}

fn main() {}
"#;
assert_no_errors!(src);
}

#[named]
#[test]
fn trait_where_clause_associated_type_constraint_unexpected_order() {
let src = r#"
pub trait BarTrait {}

pub trait Foo {
type Bar;
}

pub trait Baz<T>
where
<T as Foo>::Bar: BarTrait,
T: Foo,
{}

pub struct HasBarTrait1 {}
impl BarTrait for HasBarTrait1 {}

pub struct HasFoo1 {}
impl Foo for HasFoo1 {
type Bar = HasBarTrait1;
}

pub struct HasBaz1 {}
impl Baz<HasFoo1> for HasBaz1 {}

fn main() {}
"#;
assert_no_errors!(src);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

[package]
name = "noirc_frontend_tests_traits_trait_bound_constraining_two_generics"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

pub trait Foo<U> {}

pub trait Baz<T, U>
where
T: Foo<U>,
{}

pub struct HasFoo1 {}
impl Foo<()> for HasFoo1 {}

pub struct HasBaz1 {}
impl Baz<HasFoo1, ()> for HasBaz1 {}

fn main() {}

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
10102510944933110504
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

[package]
name = "noirc_frontend_tests_traits_trait_where_clause_associated_type_constraint_expected_order"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

pub trait BarTrait {}

pub trait Foo {
type Bar;
}

pub trait Baz<T>
where
T: Foo,
<T as Foo>::Bar: BarTrait,
{}

pub struct HasBarTrait1 {}
impl BarTrait for HasBarTrait1 {}

pub struct HasFoo1 {}
impl Foo for HasFoo1 {
type Bar = HasBarTrait1;
}

pub struct HasBaz1 {}
impl Baz<HasFoo1> for HasBaz1 {}

fn main() {}

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
17365504020734663832
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

[package]
name = "noirc_frontend_tests_traits_trait_where_clause_associated_type_constraint_unexpected_order"
type = "bin"
authors = [""]

[dependencies]
Loading
Loading