Skip to content

Commit

Permalink
Use ObligationCtxt in favor of TraitEngine in many places
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed May 1, 2024
1 parent be9bca2 commit 4fc00ec
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 103 deletions.
37 changes: 16 additions & 21 deletions compiler/rustc_hir_analysis/src/autoderef.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::errors::AutoDerefReachedRecursionLimit;
use crate::traits;
use crate::traits::query::evaluate_obligation::InferCtxtExt;
use crate::traits::{self, TraitEngine, TraitEngineExt};
use rustc_infer::infer::InferCtxt;
use rustc_middle::ty::TypeVisitableExt;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_session::Limit;
use rustc_span::def_id::LocalDefId;
use rustc_span::def_id::LOCAL_CRATE;
use rustc_span::Span;
use rustc_trait_selection::traits::StructurallyNormalizeExt;
use rustc_trait_selection::traits::ObligationCtxt;

#[derive(Copy, Clone, Debug)]
pub enum AutoderefKind {
Expand Down Expand Up @@ -167,25 +167,20 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
&self,
ty: Ty<'tcx>,
) -> Option<(Ty<'tcx>, Vec<traits::PredicateObligation<'tcx>>)> {
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self.infcx);

let cause = traits::ObligationCause::misc(self.span, self.body_id);
let normalized_ty = match self
.infcx
.at(&cause, self.param_env)
.structurally_normalize(ty, &mut *fulfill_cx)
{
Ok(normalized_ty) => normalized_ty,
Err(errors) => {
// This shouldn't happen, except for evaluate/fulfill mismatches,
// but that's not a reason for an ICE (`predicate_may_hold` is conservative
// by design).
debug!(?errors, "encountered errors while fulfilling");
return None;
}
};
let ocx = ObligationCtxt::new(self.infcx);
let normalized_ty = ocx
.structurally_normalize(
&traits::ObligationCause::misc(self.span, self.body_id),
self.param_env,
ty,
)
// We shouldn't have errors here, except for evaluate/fulfill mismatches,
// but that's not a reason for an ICE (`predicate_may_hold` is conservative
// by design).
// FIXME(-Znext-solver): This *actually* shouldn't happen then.
.ok()?;

let errors = fulfill_cx.select_where_possible(self.infcx);
let errors = ocx.select_where_possible();
if !errors.is_empty() {
// This shouldn't happen, except for evaluate/fulfill mismatches,
// but that's not a reason for an ICE (`predicate_may_hold` is conservative
Expand All @@ -194,7 +189,7 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
return None;
}

Some((normalized_ty, fulfill_cx.pending_obligations()))
Some((normalized_ty, ocx.pending_obligations()))
}

/// Returns the final type we ended up with, which may be an inference
Expand Down
22 changes: 9 additions & 13 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ use rustc_hir::Expr;
use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
use rustc_infer::infer::type_variable::TypeVariableOrigin;
use rustc_infer::infer::{Coercion, DefineOpaqueTypes, InferOk, InferResult};
use rustc_infer::traits::TraitEngineExt as _;
use rustc_infer::traits::{IfExpressionCause, MatchExpressionArmCause, TraitEngine};
use rustc_infer::traits::{IfExpressionCause, MatchExpressionArmCause};
use rustc_infer::traits::{Obligation, PredicateObligation};
use rustc_middle::lint::in_external_macro;
use rustc_middle::traits::BuiltinImplSource;
Expand All @@ -65,7 +64,6 @@ use rustc_trait_selection::infer::InferCtxtExt as _;
use rustc_trait_selection::traits::error_reporting::suggestions::TypeErrCtxtExt;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::TraitEngineExt as _;
use rustc_trait_selection::traits::{
self, NormalizeExt, ObligationCause, ObligationCauseCode, ObligationCtxt,
};
Expand Down Expand Up @@ -164,11 +162,10 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
// Filter these cases out to make sure our coercion is more accurate.
match res {
Ok(InferOk { value, obligations }) if self.next_trait_solver() => {
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
fulfill_cx.register_predicate_obligations(self, obligations);
let errs = fulfill_cx.select_where_possible(self);
if errs.is_empty() {
Ok(InferOk { value, obligations: fulfill_cx.pending_obligations() })
let ocx = ObligationCtxt::new(self);
ocx.register_obligations(obligations);
if ocx.select_where_possible().is_empty() {
Ok(InferOk { value, obligations: ocx.pending_obligations() })
} else {
Err(TypeError::Mismatch)
}
Expand Down Expand Up @@ -631,13 +628,12 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
// but we need to constrain vars before processing goals mentioning
// them.
Some(ty::PredicateKind::AliasRelate(..)) => {
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
fulfill_cx.register_predicate_obligation(self, obligation);
let errs = fulfill_cx.select_where_possible(self);
if !errs.is_empty() {
let ocx = ObligationCtxt::new(self);
ocx.register_obligation(obligation);
if !ocx.select_where_possible().is_empty() {
return Err(TypeError::Mismatch);
}
coercion.obligations.extend(fulfill_cx.pending_obligations());
coercion.obligations.extend(ocx.pending_obligations());
continue;
}
_ => {
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ impl<'tcx> FulfillmentError<'tcx> {
) -> FulfillmentError<'tcx> {
FulfillmentError { obligation, code, root_obligation }
}

pub fn is_true_error(&self) -> bool {
match self.code {
FulfillmentErrorCode::SelectionError(_)
| FulfillmentErrorCode::ProjectionError(_)
| FulfillmentErrorCode::SubtypeError(_, _)
| FulfillmentErrorCode::ConstEquateError(_, _) => true,
FulfillmentErrorCode::Cycle(_) | FulfillmentErrorCode::Ambiguity { overflow: _ } => {
false
}
}
}
}

impl<'tcx> PolyTraitObligation<'tcx> {
Expand Down
11 changes: 6 additions & 5 deletions compiler/rustc_trait_selection/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::traits::query::evaluate_obligation::InferCtxtExt as _;
use crate::traits::{self, ObligationCtxt, SelectionContext, TraitEngineExt as _};
use crate::traits::{self, ObligationCtxt, SelectionContext};

use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_infer::traits::{Obligation, TraitEngine, TraitEngineExt as _};
use rustc_infer::traits::Obligation;
use rustc_macros::extension;
use rustc_middle::arena::ArenaAllocatable;
use rustc_middle::infer::canonical::{Canonical, CanonicalQueryResponse, QueryResponse};
Expand Down Expand Up @@ -93,9 +94,9 @@ impl<'tcx> InferCtxt<'tcx> {
ty::TraitRef::new(self.tcx, trait_def_id, [ty]),
)) {
Ok(Some(selection)) => {
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self);
fulfill_cx.register_predicate_obligations(self, selection.nested_obligations());
Some(fulfill_cx.select_all_or_error(self))
let ocx = ObligationCtxt::new(self);
ocx.register_obligations(selection.nested_obligations());
Some(ocx.select_all_or_error())
}
Ok(None) | Err(_) => None,
}
Expand Down
79 changes: 37 additions & 42 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@

use crate::infer::outlives::env::OutlivesEnvironment;
use crate::infer::InferOk;
use crate::regions::InferCtxtRegionExt;
use crate::solve::inspect::{InspectGoal, ProofTreeInferCtxtExt, ProofTreeVisitor};
use crate::solve::{deeply_normalize_for_diagnostics, inspect, FulfillmentCtxt};
use crate::traits::engine::TraitEngineExt as _;
use crate::solve::{deeply_normalize_for_diagnostics, inspect};
use crate::traits::select::IntercrateAmbiguityCause;
use crate::traits::structural_normalize::StructurallyNormalizeExt;
use crate::traits::NormalizeExt;
use crate::traits::SkipLeakCheck;
use crate::traits::{
Expand All @@ -22,7 +19,7 @@ use rustc_errors::{Diag, EmissionGuarantee};
use rustc_hir::def::DefKind;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, TyCtxtInferExt};
use rustc_infer::traits::{util, FulfillmentErrorCode, TraitEngine, TraitEngineExt};
use rustc_infer::traits::{util, FulfillmentErrorCode};
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::solve::{CandidateSource, Certainty, Goal};
use rustc_middle::traits::specialization_graph::OverlapMode;
Expand All @@ -35,6 +32,7 @@ use std::fmt::Debug;
use std::ops::ControlFlow;

use super::error_reporting::suggest_new_overflow_limit;
use super::ObligationCtxt;

/// Whether we do the orphan check relative to this crate or to some remote crate.
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -361,23 +359,27 @@ fn impl_intersection_has_impossible_obligation<'a, 'cx, 'tcx>(
let infcx = selcx.infcx;

if infcx.next_trait_solver() {
let mut fulfill_cx = FulfillmentCtxt::new(infcx);
fulfill_cx.register_predicate_obligations(infcx, obligations.iter().cloned());

let ocx = ObligationCtxt::new(infcx);
ocx.register_obligations(obligations.iter().cloned());
let errors_and_ambiguities = ocx.select_all_or_error();
// We only care about the obligations that are *definitely* true errors.
// Ambiguities do not prove the disjointness of two impls.
let errors = fulfill_cx.select_where_possible(infcx);
let (errors, ambiguities): (Vec<_>, Vec<_>) =
errors_and_ambiguities.into_iter().partition(|error| error.is_true_error());

if errors.is_empty() {
let overflow_errors = fulfill_cx.collect_remaining_errors(infcx);
let overflowing_predicates = overflow_errors
.into_iter()
.filter(|e| match e.code {
FulfillmentErrorCode::Ambiguity { overflow: Some(true) } => true,
_ => false,
})
.map(|e| infcx.resolve_vars_if_possible(e.obligation.predicate))
.collect();
IntersectionHasImpossibleObligations::No { overflowing_predicates }
IntersectionHasImpossibleObligations::No {
overflowing_predicates: ambiguities
.into_iter()
.filter(|error| {
matches!(
error.code,
FulfillmentErrorCode::Ambiguity { overflow: Some(true) }
)
})
.map(|e| infcx.resolve_vars_if_possible(e.obligation.predicate))
.collect(),
}
} else {
IntersectionHasImpossibleObligations::Yes
}
Expand Down Expand Up @@ -589,22 +591,22 @@ fn try_prove_negated_where_clause<'tcx>(
// Without this, we over-eagerly register coherence ambiguity candidates when
// impl candidates do exist.
let ref infcx = root_infcx.fork_with_intercrate(false);
let mut fulfill_cx = FulfillmentCtxt::new(infcx);

fulfill_cx.register_predicate_obligation(
infcx,
Obligation::new(infcx.tcx, ObligationCause::dummy(), param_env, negative_predicate),
);
if !fulfill_cx.select_all_or_error(infcx).is_empty() {
let ocx = ObligationCtxt::new(infcx);
ocx.register_obligation(Obligation::new(
infcx.tcx,
ObligationCause::dummy(),
param_env,
negative_predicate,
));
if !ocx.select_all_or_error().is_empty() {
return false;
}

// FIXME: We could use the assumed_wf_types from both impls, I think,
// if that wasn't implemented just for LocalDefId, and we'd need to do
// the normalization ourselves since this is totally fallible...
let outlives_env = OutlivesEnvironment::new(param_env);

let errors = infcx.resolve_regions(&outlives_env);
let errors = ocx.resolve_regions(&outlives_env);
if !errors.is_empty() {
return false;
}
Expand Down Expand Up @@ -1130,22 +1132,15 @@ impl<'a, 'tcx> ProofTreeVisitor<'tcx> for AmbiguityCausesVisitor<'a, 'tcx> {
result: Ok(_),
} = cand.kind()
{
let lazily_normalize_ty = |ty: Ty<'tcx>| {
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(infcx);
let lazily_normalize_ty = |mut ty: Ty<'tcx>| {
let ocx = ObligationCtxt::new(infcx);
ty = infcx.resolve_vars_if_possible(ty);
if matches!(ty.kind(), ty::Alias(..)) {
// FIXME(-Znext-solver=coherence): we currently don't
// normalize opaque types here, resulting in diverging behavior
// for TAITs.
match infcx
.at(&ObligationCause::dummy(), param_env)
.structurally_normalize(ty, &mut *fulfill_cx)
{
Ok(ty) => Ok(ty),
Err(_errs) => Err(()),
}
} else {
Ok(ty)
ty = ocx
.structurally_normalize(&ObligationCause::dummy(), param_env, ty)
.map_err(|_| ())?;
}
if ocx.select_where_possible().is_empty() { Ok(ty) } else { Err(()) }
};

infcx.probe(|_| {
Expand Down
30 changes: 30 additions & 0 deletions compiler/rustc_trait_selection/src/traits/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::regions::InferCtxtRegionExt;
use crate::solve::FulfillmentCtxt as NextFulfillmentCtxt;
use crate::traits::error_reporting::TypeErrCtxtExt;
use crate::traits::NormalizeExt;
use crate::traits::StructurallyNormalizeExt;
use rustc_data_structures::fx::FxIndexSet;
use rustc_errors::ErrorGuaranteed;
use rustc_hir::def_id::{DefId, LocalDefId};
Expand All @@ -15,6 +16,7 @@ use rustc_infer::infer::canonical::{
Canonical, CanonicalQueryResponse, CanonicalVarValues, QueryResponse,
};
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::infer::RegionResolutionError;
use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
use rustc_infer::traits::{
FulfillmentError, Obligation, ObligationCause, PredicateObligation, TraitEngineExt as _,
Expand Down Expand Up @@ -117,6 +119,17 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
self.infcx.at(cause, param_env).deeply_normalize(value, &mut **self.engine.borrow_mut())
}

pub fn structurally_normalize(
&self,
cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
value: Ty<'tcx>,
) -> Result<Ty<'tcx>, Vec<FulfillmentError<'tcx>>> {
self.infcx
.at(cause, param_env)
.structurally_normalize(value, &mut **self.engine.borrow_mut())
}

pub fn eq<T: ToTrace<'tcx>>(
&self,
cause: &ObligationCause<'tcx>,
Expand Down Expand Up @@ -182,6 +195,11 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
self.engine.borrow_mut().select_all_or_error(self.infcx)
}

#[must_use]
pub fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>> {
self.engine.borrow().pending_obligations()
}

/// Resolves regions and reports errors.
///
/// Takes ownership of the context as doing trait solving afterwards
Expand All @@ -199,6 +217,18 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
}
}

/// Resolves regions and reports errors.
///
/// Takes ownership of the context as doing trait solving afterwards
/// will result in region constraints getting ignored.
#[must_use]
pub fn resolve_regions(
self,
outlives_env: &OutlivesEnvironment<'tcx>,
) -> Vec<RegionResolutionError<'tcx>> {
self.infcx.resolve_regions(outlives_env)
}

pub fn assumed_wf_types_and_report_errors(
&self,
param_env: ty::ParamEnv<'tcx>,
Expand Down
Loading

0 comments on commit 4fc00ec

Please sign in to comment.