Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Trait inheritance with cycles on associated types #79209

Merged
merged 28 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
24dcf6f
Allow to use super trait bounds in where clauses
spastorino Nov 3, 2020
2ca4964
Allow to self reference associated types in where clauses
spastorino Nov 13, 2020
4406805
Update tests
spastorino Nov 17, 2020
4a97c52
Add test to check that we do not get a cycle due to resolving `Self::…
nikomatsakis Nov 14, 2020
5b6f206
Fix super_traits_of to consider where bounds
spastorino Nov 19, 2020
3c8cf6d
Avoid ICEing when associated type bound trait is missing
spastorino Nov 19, 2020
a6d2235
Add test for a still ambiguous scenario
spastorino Nov 20, 2020
f76d6e0
Add description on test
spastorino Nov 20, 2020
aa1cafd
Add `` to variable name in docs
spastorino Nov 20, 2020
30e933c
Extract trait_may_define_assoc_type helper function
spastorino Nov 20, 2020
6631215
Remove unneeded logic
spastorino Nov 20, 2020
b60a214
super_traits_of is now a query
spastorino Nov 21, 2020
c0007a2
Extract function trait_may_define_assoc_type
spastorino Nov 21, 2020
dd267fe
compute_bounds takes &[GenericBound]
spastorino Nov 21, 2020
6ab8fe2
Bless tests
spastorino Nov 21, 2020
b916ac6
adjust super_predicates_that_define_assoc_type query description
spastorino Nov 22, 2020
af38d71
Add super_traits_of docs
spastorino Nov 22, 2020
a175f36
Document compute_bounds_that_match_assoc_type
spastorino Nov 24, 2020
9e0538b
Document elaborate_trait_refs_that_define_assoc_type
spastorino Nov 24, 2020
28446ef
Inline elaborate_trait_refs_that_define_assoc_type into transitive_bo…
spastorino Nov 24, 2020
ac1845a
Fix super_predicates_that_define_assoc_type API doc
spastorino Nov 24, 2020
1895e52
Fix super_traits_of API doc
spastorino Nov 24, 2020
67ea9b2
Make super_traits_of return Lrc for cheaper clone
spastorino Nov 24, 2020
b02a905
Add test to check that we handle predicates that can define assoc typ…
spastorino Nov 25, 2020
a6136d8
Simplify super_traits_of
spastorino Nov 25, 2020
35bf466
Remove super_traits_of query, just leave a helper function
spastorino Nov 25, 2020
504d27c
Make super_traits_of return an iterator
spastorino Nov 25, 2020
ada7c1f
Return FxIndexSet instead of FxHashSet to avoid order errors on diffe…
spastorino Nov 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use smallvec::smallvec;

use crate::traits::{Obligation, ObligationCause, PredicateObligation};
use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_middle::ty::outlives::Component;
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
use rustc_span::symbol::Ident;

pub fn anonymize_predicate<'tcx>(
tcx: TyCtxt<'tcx>,
Expand Down Expand Up @@ -287,6 +288,37 @@ pub fn transitive_bounds<'tcx>(
elaborate_trait_refs(tcx, bounds).filter_to_traits()
}

/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
/// define the given associated type `assoc_name`. It uses the
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
tcx: TyCtxt<'tcx>,
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
assoc_name: Ident,
) -> FxIndexSet<ty::PolyTraitRef<'tcx>> {
let mut stack: Vec<_> = bounds.collect();
let mut trait_refs = FxIndexSet::default();

while let Some(trait_ref) = stack.pop() {
if trait_refs.insert(trait_ref) {
let super_predicates =
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
for (super_predicate, _) in super_predicates.predicates {
let bound_predicate = super_predicate.bound_atom();
let subst_predicate = super_predicate
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
stack.push(binder.value);
}
}
}
}

trait_refs
}

///////////////////////////////////////////////////////////////////////////
// Other
///////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,23 @@ rustc_queries! {
/// full predicates are available (note that supertraits have
/// additional acyclicity requirements).
query super_predicates_of(key: DefId) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
desc { |tcx| "computing the super predicates of `{}`", tcx.def_path_str(key) }
}

/// The `Option<Ident>` is the name of an associated type. If it is `None`, then this query
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
/// subset of super-predicates that reference traits that define the given associated type.
/// This is used to avoid cycles in resolving types like `T::Item`.
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the super traits of `{}`{}",
tcx.def_path_str(key.0),
if let Some(assoc_name) = key.1 { format!(" with associated type name `{}`", assoc_name) } else { "".to_string() },
}
}

/// To avoid cycles within the predicates of a single item we compute
/// per-type-parameter predicates for resolving `T::AssocTy`.
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the bounds for type parameter `{}`", {
let id = tcx.hir().local_def_id_to_hir_id(key.1);
tcx.hir().ty_param_name(id)
Expand Down
38 changes: 37 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use rustc_session::config::{BorrowckMode, CrateType, OutputFilenames};
use rustc_session::lint::{Level, Lint};
use rustc_session::Session;
use rustc_span::source_map::MultiSpan;
use rustc_span::symbol::{kw, sym, Symbol};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{Layout, TargetDataLayout, VariantIdx};
use rustc_target::spec::abi;
Expand Down Expand Up @@ -2085,6 +2085,42 @@ impl<'tcx> TyCtxt<'tcx> {
self.mk_fn_ptr(sig.map_bound(|sig| ty::FnSig { unsafety: hir::Unsafety::Unsafe, ..sig }))
}

/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
self.super_traits_of(trait_def_id).any(|trait_did| {
self.associated_items(trait_did)
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
.is_some()
})
}

/// Computes the def-ids of the transitive super-traits of `trait_def_id`. This (intentionally)
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
/// to identify which traits may define a given associated type to help avoid cycle errors.
/// Returns a `DefId` iterator.
fn super_traits_of(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
let mut set = FxHashSet::default();
let mut stack = vec![trait_def_id];

set.insert(trait_def_id);

iter::from_fn(move || -> Option<DefId> {
let trait_did = stack.pop()?;
let generic_predicates = self.super_predicates_of(trait_did);

for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}

Some(trait_did)
})
}

/// Given a closure signature, returns an equivalent fn signature. Detuples
/// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then
/// you would get a `fn(u32, i32)`.
Expand Down
24 changes: 23 additions & 1 deletion compiler/rustc_middle/src/ty/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
use crate::ty::{self, Ty, TyCtxt};
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
use rustc_query_system::query::DefaultCacheSelector;
use rustc_span::symbol::Symbol;
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};

/// The `Key` trait controls what types can legally be used as the key
Expand Down Expand Up @@ -149,6 +149,28 @@ impl Key for (LocalDefId, DefId) {
}
}

impl Key for (DefId, Option<Ident>) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
tcx.def_span(self.0)
}
}

impl Key for (DefId, LocalDefId, Ident) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.1.default_span(tcx)
}
}

impl Key for (CrateNum, DefId) {
type CacheSelector = DefaultCacheSelector;

Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub use self::util::{
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
SupertraitDefIds, Supertraits,
};

pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;
Expand Down
68 changes: 57 additions & 11 deletions compiler/rustc_typeck/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {

fn default_constness_for_trait_bounds(&self) -> Constness;

/// Returns predicates in scope of the form `X: Foo`, where `X` is
/// a type parameter `X` with the given id `def_id`. This is a
/// subset of the full set of predicates.
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
/// is a type parameter `X` with the given id `def_id` and T
/// matches `assoc_name`. This is a subset of the full set of
/// predicates.
///
/// This is used for one specific purpose: resolving "short-hand"
/// associated type references like `T::Item`. In principle, we
Expand All @@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
/// but this can lead to cycle errors. The problem is that we have
/// to do this resolution *in order to create the predicates in
/// the first place*. Hence, we have this "special pass".
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
fn get_type_parameter_bounds(
&self,
span: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx>;

/// Returns the lifetime to use when a lifetime is omitted (and not elided).
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
Expand Down Expand Up @@ -762,7 +768,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
}

// Returns `true` if a bounds list includes `?Sized`.
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
let tcx = self.tcx();

// Try to find an unbound in bounds.
Expand Down Expand Up @@ -820,7 +826,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
fn add_bounds(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
ast_bounds: &[&hir::GenericBound<'_>],
bounds: &mut Bounds<'tcx>,
) {
let mut trait_bounds = Vec::new();
Expand All @@ -838,7 +844,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
.instantiate_lang_item_trait_ref(
lang_item, span, hir_id, args, param_ty, bounds,
*lang_item, *span, *hir_id, args, param_ty, bounds,
),
hir::GenericBound::Outlives(ref l) => region_bounds.push(l),
}
Expand Down Expand Up @@ -875,6 +881,42 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.compute_bounds_inner(param_ty, &ast_bounds, sized_by_default, span)
}

/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
/// named `assoc_name` into ty::Bounds. Ignore the rest.
pub fn compute_bounds_that_match_assoc_type(
spastorino marked this conversation as resolved.
Show resolved Hide resolved
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
assoc_name: Ident,
) -> Bounds<'tcx> {
let mut result = Vec::new();

for ast_bound in ast_bounds {
if let Some(trait_ref) = ast_bound.trait_ref() {
if let Some(trait_did) = trait_ref.trait_def_id() {
if self.tcx().trait_may_define_assoc_type(trait_did, assoc_name) {
result.push(ast_bound);
}
}
}
}

self.compute_bounds_inner(param_ty, &result, sized_by_default, span)
}

fn compute_bounds_inner(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[&hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let mut bounds = Bounds::default();

Expand Down Expand Up @@ -1044,7 +1086,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
// Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
// parameter to have a skipped binder.
let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
self.add_bounds(param_ty, ast_bounds, bounds);
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.add_bounds(param_ty, &ast_bounds, bounds);
}
}
Ok(())
Expand Down Expand Up @@ -1361,21 +1404,24 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty_param_def_id, assoc_name, span,
);

let predicates =
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
let predicates = &self
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
.predicates;

debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);

let param_hir_id = tcx.hir().local_def_id_to_hir_id(ty_param_def_id);
let param_name = tcx.hir().ty_param_name(param_hir_id);
self.one_bound_for_assoc_type(
|| {
traits::transitive_bounds(
traits::transitive_bounds_that_define_assoc_type(
tcx,
predicates.iter().filter_map(|(p, _)| {
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
}),
assoc_name,
)
.into_iter()
},
|| param_name.to_string(),
assoc_name,
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::GenericArgKind;
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};

Expand Down Expand Up @@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
}
}

fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
fn get_type_parameter_bounds(
&self,
_: Span,
def_id: DefId,
_: Ident,
) -> ty::GenericPredicates<'tcx> {
let tcx = self.tcx;
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
let item_id = tcx.hir().ty_param_owner(hir_id);
Expand Down
Loading