Skip to content

Commit

Permalink
Auto merge of rust-lang#113308 - compiler-errors:poly-select, r=lcnr
Browse files Browse the repository at this point in the history
Split `SelectionContext::select` into fns that take a binder and don't

*most* usages of `SelectionContext::select` don't need to use a binder, but wrap them in a dummy because of the signature. Let's split this out into `SelectionContext::{select,poly_select}` and limit the usages of the latter.

Right now, we only have 3 places where we're calling `poly_select` -- fulfillment, internally within the old solver, and the auto-trait finder.

r? `@lcnr`
  • Loading branch information
bors committed Jul 7, 2023
2 parents 921f669 + 3f8919c commit 1a449dc
Show file tree
Hide file tree
Showing 21 changed files with 125 additions and 112 deletions.
13 changes: 6 additions & 7 deletions compiler/rustc_const_eval/src/transform/check_consts/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceC
use rustc_middle::mir::*;
use rustc_middle::ty::subst::{GenericArgKind, InternalSubsts};
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, InstanceDef, Ty, TyCtxt};
use rustc_middle::ty::{Binder, TraitRef, TypeVisitableExt};
use rustc_middle::ty::{TraitRef, TypeVisitableExt};
use rustc_mir_dataflow::{self, Analysis};
use rustc_span::{sym, Span, Symbol};
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
Expand Down Expand Up @@ -755,10 +755,9 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
}

let trait_ref = TraitRef::from_method(tcx, trait_id, substs);
let poly_trait_pred =
Binder::dummy(trait_ref).with_constness(ty::BoundConstness::ConstIfConst);
let trait_ref = trait_ref.with_constness(ty::BoundConstness::ConstIfConst);
let obligation =
Obligation::new(tcx, ObligationCause::dummy(), param_env, poly_trait_pred);
Obligation::new(tcx, ObligationCause::dummy(), param_env, trait_ref);

let implsrc = {
let infcx = tcx.infer_ctxt().build();
Expand All @@ -776,11 +775,11 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
}
// Closure: Fn{Once|Mut}
Ok(Some(ImplSource::Builtin(_)))
if poly_trait_pred.self_ty().skip_binder().is_closure()
if trait_ref.self_ty().is_closure()
&& tcx.fn_trait_kind_from_def_id(trait_id).is_some() =>
{
let ty::Closure(closure_def_id, substs) =
*poly_trait_pred.self_ty().no_bound_vars().unwrap().kind()
*trait_ref.self_ty().kind()
else {
unreachable!()
};
Expand Down Expand Up @@ -840,7 +839,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
tcx,
ObligationCause::dummy_with_span(*fn_span),
param_env,
poly_trait_pred,
trait_ref,
);

// improve diagnostics by showing what failed. Our requirements are stricter this time
Expand Down
10 changes: 3 additions & 7 deletions compiler/rustc_const_eval/src/transform/check_consts/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
use rustc_middle::mir::{self, CallSource};
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::subst::{GenericArgKind, SubstsRef};
use rustc_middle::ty::TraitRef;
use rustc_middle::ty::{suggest_constraining_type_param, Adt, Closure, FnDef, FnPtr, Param, Ty};
use rustc_middle::ty::{Binder, TraitRef};
use rustc_middle::util::{call_kind, CallDesugaringKind, CallKind};
use rustc_session::parse::feature_err;
use rustc_span::symbol::sym;
Expand Down Expand Up @@ -137,12 +137,8 @@ impl<'tcx> NonConstOp<'tcx> for FnCallNonConst<'tcx> {
}
}
Adt(..) => {
let obligation = Obligation::new(
tcx,
ObligationCause::dummy(),
param_env,
Binder::dummy(trait_ref),
);
let obligation =
Obligation::new(tcx, ObligationCause::dummy(), param_env, trait_ref);

let infcx = tcx.infer_ctxt().build();
let mut selcx = SelectionContext::new(&infcx);
Expand Down
11 changes: 5 additions & 6 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,8 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
while !queue.is_empty() {
let obligation = queue.remove(0);
debug!("coerce_unsized resolve step: {:?}", obligation);
let bound_predicate = obligation.predicate.kind();
let trait_pred = match bound_predicate.skip_binder() {
ty::PredicateKind::Clause(ty::ClauseKind::Trait(trait_pred))
let trait_pred = match obligation.predicate.kind().no_bound_vars() {
Some(ty::PredicateKind::Clause(ty::ClauseKind::Trait(trait_pred)))
if traits.contains(&trait_pred.def_id()) =>
{
if unsize_did == trait_pred.def_id() {
Expand All @@ -652,7 +651,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
has_unsized_tuple_coercion = true;
}
}
bound_predicate.rebind(trait_pred)
trait_pred
}
_ => {
coercion.obligations.push(obligation);
Expand All @@ -664,8 +663,8 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
Ok(None) => {
if trait_pred.def_id() == unsize_did {
let trait_pred = self.resolve_vars_if_possible(trait_pred);
let self_ty = trait_pred.skip_binder().self_ty();
let unsize_ty = trait_pred.skip_binder().trait_ref.substs[1].expect_ty();
let self_ty = trait_pred.self_ty();
let unsize_ty = trait_pred.trait_ref.substs[1].expect_ty();
debug!("coerce_unsized: ambiguous unsize case for {:?}", trait_pred);
match (self_ty.kind(), unsize_ty.kind()) {
(&ty::Infer(ty::TyVar(v)), ty::Dynamic(..))
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self.tcx,
traits::ObligationCause::dummy(),
self.param_env,
ty::Binder::dummy(trait_ref),
trait_ref,
);
match SelectionContext::new(&self).select(&obligation) {
Ok(Some(traits::ImplSource::UserDefined(impl_source))) => {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir_typeck/src/method/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
trait_ref: ty::TraitRef<'tcx>,
) -> traits::SelectionResult<'tcx, traits::Selection<'tcx>> {
let cause = traits::ObligationCause::misc(self.span, self.body_id);
let predicate = ty::Binder::dummy(trait_ref);
let obligation = traits::Obligation::new(self.tcx, cause, self.param_env, predicate);
let obligation = traits::Obligation::new(self.tcx, cause, self.param_env, trait_ref);
traits::SelectionContext::new(self).select(&obligation)
}

Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ impl<'tcx, P> From<Obligation<'tcx, P>> for solve::Goal<'tcx, P> {
}

pub type PredicateObligation<'tcx> = Obligation<'tcx, ty::Predicate<'tcx>>;
pub type TraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;
pub type TraitObligation<'tcx> = Obligation<'tcx, ty::TraitPredicate<'tcx>>;
pub type PolyTraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;

impl<'tcx> PredicateObligation<'tcx> {
/// Flips the polarity of the inner predicate.
Expand All @@ -86,7 +87,7 @@ impl<'tcx> PredicateObligation<'tcx> {
}
}

impl<'tcx> TraitObligation<'tcx> {
impl<'tcx> PolyTraitObligation<'tcx> {
/// Returns `true` if the trait predicate is considered `const` in its ParamEnv.
pub fn is_const(&self) -> bool {
matches!(
Expand Down Expand Up @@ -193,7 +194,7 @@ impl<'tcx> FulfillmentError<'tcx> {
}
}

impl<'tcx> TraitObligation<'tcx> {
impl<'tcx> PolyTraitObligation<'tcx> {
pub fn polarity(&self) -> ty::ImplPolarity {
self.predicate.skip_binder().polarity
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ impl<'tcx> Key for (LocalDefId, DefId, SubstsRef<'tcx>) {
}
}

impl<'tcx> Key for (ty::ParamEnv<'tcx>, ty::PolyTraitRef<'tcx>) {
impl<'tcx> Key for (ty::ParamEnv<'tcx>, ty::TraitRef<'tcx>) {
type CacheSelector = DefaultCacheSelector<Self>;

fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
tcx.def_span(self.1.def_id())
tcx.def_span(self.1.def_id)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,7 @@ rustc_queries! {
}

query codegen_select_candidate(
key: (ty::ParamEnv<'tcx>, ty::PolyTraitRef<'tcx>)
key: (ty::ParamEnv<'tcx>, ty::TraitRef<'tcx>)
) -> Result<&'tcx ImplSource<'tcx, ()>, CodegenObligationError> {
cache_on_disk_if { true }
desc { |tcx| "computing candidate for `{}`", key.1 }
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,13 @@ impl<'tcx> ToPredicate<'tcx> for TraitRef<'tcx> {
}
}

impl<'tcx> ToPredicate<'tcx, TraitPredicate<'tcx>> for TraitRef<'tcx> {
#[inline(always)]
fn to_predicate(self, _tcx: TyCtxt<'tcx>) -> TraitPredicate<'tcx> {
self.without_const()
}
}

impl<'tcx> ToPredicate<'tcx, Clause<'tcx>> for TraitRef<'tcx> {
#[inline(always)]
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Clause<'tcx> {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_monomorphize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ fn custom_coerce_unsize_info<'tcx>(
source_ty: Ty<'tcx>,
target_ty: Ty<'tcx>,
) -> CustomCoerceUnsized {
let trait_ref = ty::Binder::dummy(ty::TraitRef::from_lang_item(
let trait_ref = ty::TraitRef::from_lang_item(
tcx.tcx,
LangItem::CoerceUnsized,
tcx.span,
[source_ty, target_ty],
));
);

match tcx.codegen_select_candidate((ty::ParamEnv::reveal_all(), trait_ref)) {
Ok(traits::ImplSource::UserDefined(traits::ImplSourceUserDefinedData {
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_hir::def_id::DefId;
use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
use rustc_infer::traits::util::supertraits;
use rustc_infer::traits::{
Obligation, PredicateObligation, Selection, SelectionResult, TraitObligation,
Obligation, PolyTraitObligation, PredicateObligation, Selection, SelectionResult,
};
use rustc_middle::traits::solve::{CanonicalInput, Certainty, Goal};
use rustc_middle::traits::{
Expand All @@ -23,14 +23,14 @@ use crate::traits::vtable::{count_own_vtable_entries, prepare_vtable_segments, V
pub trait InferCtxtSelectExt<'tcx> {
fn select_in_new_trait_solver(
&self,
obligation: &TraitObligation<'tcx>,
obligation: &PolyTraitObligation<'tcx>,
) -> SelectionResult<'tcx, Selection<'tcx>>;
}

impl<'tcx> InferCtxtSelectExt<'tcx> for InferCtxt<'tcx> {
fn select_in_new_trait_solver(
&self,
obligation: &TraitObligation<'tcx>,
obligation: &PolyTraitObligation<'tcx>,
) -> SelectionResult<'tcx, Selection<'tcx>> {
assert!(self.next_trait_solver());

Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_trait_selection/src/traits/auto_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ impl<'tcx> AutoTraitFinder<'tcx> {
tcx,
ObligationCause::dummy(),
orig_env,
ty::Binder::dummy(ty::TraitPredicate {
ty::TraitPredicate {
trait_ref,
constness: ty::BoundConstness::NotConst,
polarity: if polarity {
ImplPolarity::Positive
} else {
ImplPolarity::Negative
},
}),
},
));
if let Ok(Some(ImplSource::UserDefined(_))) = result {
debug!(
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<'tcx> AutoTraitFinder<'tcx> {
new_env,
pred,
));
let result = select.select(&obligation);
let result = select.poly_select(&obligation);

match result {
Ok(Some(ref impl_source)) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use rustc_hir::def_id::DefId;
use rustc_infer::infer::{InferCtxt, LateBoundRegionConversionTime};
use rustc_infer::traits::util::elaborate;
use rustc_infer::traits::{Obligation, ObligationCause, TraitObligation};
use rustc_infer::traits::{Obligation, ObligationCause, PolyTraitObligation};
use rustc_middle::ty;
use rustc_span::{Span, DUMMY_SP};

Expand All @@ -14,7 +14,7 @@ pub enum Ambiguity {

pub fn recompute_applicable_impls<'tcx>(
infcx: &InferCtxt<'tcx>,
obligation: &TraitObligation<'tcx>,
obligation: &PolyTraitObligation<'tcx>,
) -> Vec<Ambiguity> {
let tcx = infcx.tcx;
let param_env = obligation.param_env;
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_trait_selection/src/traits/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_data_structures::obligation_forest::{Error, ForestObligation, Outcome}
use rustc_data_structures::obligation_forest::{ObligationForest, ObligationProcessor};
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::traits::ProjectionCacheKey;
use rustc_infer::traits::{SelectionError, TraitEngine, TraitObligation};
use rustc_infer::traits::{PolyTraitObligation, SelectionError, TraitEngine};
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
Expand Down Expand Up @@ -667,7 +667,7 @@ impl<'a, 'tcx> FulfillProcessor<'a, 'tcx> {
fn process_trait_obligation(
&mut self,
obligation: &PredicateObligation<'tcx>,
trait_obligation: TraitObligation<'tcx>,
trait_obligation: PolyTraitObligation<'tcx>,
stalled_on: &mut Vec<TyOrConstInferVar<'tcx>>,
) -> ProcessResult<PendingPredicateObligation<'tcx>, FulfillmentErrorCode<'tcx>> {
let infcx = self.selcx.infcx;
Expand All @@ -683,7 +683,7 @@ impl<'a, 'tcx> FulfillProcessor<'a, 'tcx> {
}
}

match self.selcx.select(&trait_obligation) {
match self.selcx.poly_select(&trait_obligation) {
Ok(Some(impl_source)) => {
debug!("selecting trait at depth {} yielded Ok(Some)", obligation.recursion_depth);
ProcessResult::Changed(mk_pending(impl_source.nested_obligations()))
Expand Down
17 changes: 8 additions & 9 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,6 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
let trait_def_id = tcx.parent(trait_fn_def_id);
let trait_substs =
obligation.predicate.substs.truncate_to(tcx, tcx.generics_of(trait_def_id));
// FIXME(named-returns): Binders
let trait_predicate = ty::TraitRef::new(tcx, trait_def_id, trait_substs);

let _ = selcx.infcx.commit_if_ok(|_| {
Expand Down Expand Up @@ -1747,8 +1746,8 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(

// If we are resolving `<T as TraitRef<...>>::Item == Type`,
// start out by selecting the predicate `T as TraitRef<...>`:
let poly_trait_ref = ty::Binder::dummy(obligation.predicate.trait_ref(selcx.tcx()));
let trait_obligation = obligation.with(selcx.tcx(), poly_trait_ref);
let trait_ref = obligation.predicate.trait_ref(selcx.tcx());
let trait_obligation = obligation.with(selcx.tcx(), trait_ref);
let _ = selcx.infcx.commit_if_ok(|_| {
let impl_source = match selcx.select(&trait_obligation) {
Ok(Some(impl_source)) => impl_source,
Expand Down Expand Up @@ -1802,7 +1801,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
if obligation.param_env.reveal() == Reveal::All {
// NOTE(eddyb) inference variables can resolve to parameters, so
// assume `poly_trait_ref` isn't monomorphic, if it contains any.
let poly_trait_ref = selcx.infcx.resolve_vars_if_possible(poly_trait_ref);
let poly_trait_ref = selcx.infcx.resolve_vars_if_possible(trait_ref);
!poly_trait_ref.still_further_specializable()
} else {
debug!(
Expand All @@ -1821,11 +1820,11 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());

let lang_items = selcx.tcx().lang_items();
if [lang_items.gen_trait(), lang_items.future_trait()].contains(&Some(poly_trait_ref.def_id()))
|| selcx.tcx().fn_trait_kind_from_def_id(poly_trait_ref.def_id()).is_some()
if [lang_items.gen_trait(), lang_items.future_trait()].contains(&Some(trait_ref.def_id))
|| selcx.tcx().fn_trait_kind_from_def_id(trait_ref.def_id).is_some()
{
true
} else if lang_items.discriminant_kind_trait() == Some(poly_trait_ref.def_id()) {
} else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) {
match self_ty.kind() {
ty::Bool
| ty::Char
Expand Down Expand Up @@ -1860,7 +1859,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
| ty::Infer(..)
| ty::Error(_) => false,
}
} else if lang_items.pointee_trait() == Some(poly_trait_ref.def_id()) {
} else if lang_items.pointee_trait() == Some(trait_ref.def_id) {
let tail = selcx.tcx().struct_tail_with_normalize(
self_ty,
|ty| {
Expand Down Expand Up @@ -1935,7 +1934,7 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
}
}
} else {
bug!("unexpected builtin trait with associated type: {poly_trait_ref:?}")
bug!("unexpected builtin trait with associated type: {trait_ref:?}")
}
}
super::ImplSource::Param(..) => {
Expand Down
Loading

0 comments on commit 1a449dc

Please sign in to comment.