Skip to content

Commit

Permalink
Auto merge of rust-lang#115582 - compiler-errors:refine-yeet, r=oli-obk
Browse files Browse the repository at this point in the history
Implement refinement lint for RPITIT

Implements a lint that warns against accidentally refining an RPITIT in an implementation. This is not a hard error, and can be suppressed with `#[allow(refining_impl_trait)]`, since this behavior may be desirable -- the lint just serves as an acknowledgement from the impl author that they understand that the types they write in the implementation are an API guarantee.

This compares bounds syntactically, not semantically -- semantic implication is more difficult and essentially relies on adding the ability to keep the RPITIT hidden in the trait system so that things can be proven about the type that shows up in the impl without its own bounds leaking through, either via a new reveal mode or something else. This was experimentally implemented in rust-lang#111931.

Somewhat opinionated choices:
1. Putting the lint behind `refining_impl_trait` rather than a blanket `refine` lint. This could be changed, but I like keeping the lint specialized to RPITITs so the explanation can be tailored to it.
2. This PR does not include the `#[refine]` attribute or the feature gate, since it's kind of orthogonal and can be added in a separate PR.

r? `@oli-obk`
  • Loading branch information
bors committed Sep 7, 2023
2 parents d5573d7 + 7714db8 commit 7f0fa48
Show file tree
Hide file tree
Showing 26 changed files with 577 additions and 45 deletions.
6 changes: 3 additions & 3 deletions compiler/rustc_errors/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use rustc_fluent_macro::fluent_messages;
pub use rustc_lint_defs::{pluralize, Applicability};
use rustc_span::source_map::SourceMap;
pub use rustc_span::ErrorGuaranteed;
use rustc_span::{Loc, Span};
use rustc_span::{Loc, Span, DUMMY_SP};

use std::borrow::Cow;
use std::error::Report;
Expand Down Expand Up @@ -1754,7 +1754,7 @@ impl DelayedDiagnostic {
BacktraceStatus::Captured => {
let inner = &self.inner;
self.inner.subdiagnostic(DelayedAtWithNewline {
span: inner.span.primary_span().unwrap(),
span: inner.span.primary_span().unwrap_or(DUMMY_SP),
emitted_at: inner.emitted_at.clone(),
note: self.note,
});
Expand All @@ -1764,7 +1764,7 @@ impl DelayedDiagnostic {
_ => {
let inner = &self.inner;
self.inner.subdiagnostic(DelayedAtWithoutNewline {
span: inner.span.primary_span().unwrap(),
span: inner.span.primary_span().unwrap_or(DUMMY_SP),
emitted_at: inner.emitted_at.clone(),
note: self.note,
});
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_hir_analysis/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ hir_analysis_return_type_notation_on_non_rpitit =
.note = function returns `{$ty}`, which is not compatible with associated type return bounds
.label = this function must be `async` or return `impl Trait`
hir_analysis_rpitit_refined = impl trait in impl method signature does not match trait method signature
.suggestion = replace the return type so that it matches the trait
.label = return type from trait method defined here
.unmatched_bound_label = this bound is stronger than that defined on the trait
.note = add `#[allow(refining_impl_trait)]` if it is intended for this to be part of the public API of this crate
hir_analysis_self_in_impl_self =
`Self` is not valid in the self type of an impl block
.note = replace `Self` with a different type
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use rustc_trait_selection::traits::{
use std::borrow::Cow;
use std::iter;

mod refine;

/// Checks that a method from an impl conforms to the signature of
/// the same method as declared in the trait.
///
Expand All @@ -53,6 +55,12 @@ pub(super) fn compare_impl_method<'tcx>(
impl_trait_ref,
CheckImpliedWfMode::Check,
)?;
refine::check_refining_return_position_impl_trait_in_trait(
tcx,
impl_m,
trait_m,
impl_trait_ref,
);
};
}

Expand Down
311 changes: 311 additions & 0 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item/refine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
use rustc_data_structures::fx::FxIndexSet;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::{outlives::env::OutlivesEnvironment, TyCtxtInferExt};
use rustc_lint_defs::builtin::REFINING_IMPL_TRAIT;
use rustc_middle::traits::{ObligationCause, Reveal};
use rustc_middle::ty::{
self, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor,
};
use rustc_span::{Span, DUMMY_SP};
use rustc_trait_selection::traits::{
elaborate, normalize_param_env_or_error, outlives_bounds::InferCtxtExt, ObligationCtxt,
};
use std::ops::ControlFlow;

/// Check that an implementation does not refine an RPITIT from a trait method signature.
pub(super) fn check_refining_return_position_impl_trait_in_trait<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: ty::AssocItem,
trait_m: ty::AssocItem,
impl_trait_ref: ty::TraitRef<'tcx>,
) {
if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id) {
return;
}
// crate-private traits don't have any library guarantees, there's no need to do this check.
if !tcx.visibility(trait_m.container_id(tcx)).is_public() {
return;
}

// If a type in the trait ref is private, then there's also no reason to to do this check.
let impl_def_id = impl_m.container_id(tcx);
for arg in impl_trait_ref.args {
if let Some(ty) = arg.as_type()
&& let Some(self_visibility) = type_visibility(tcx, ty)
&& !self_visibility.is_public()
{
return;
}
}

let impl_m_args = ty::GenericArgs::identity_for_item(tcx, impl_m.def_id);
let trait_m_to_impl_m_args = impl_m_args.rebase_onto(tcx, impl_def_id, impl_trait_ref.args);
let bound_trait_m_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_m_to_impl_m_args);
let trait_m_sig = tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
// replace the self type of the trait ref with `Self` so that diagnostics render better.
let trait_m_sig_with_self_for_diag = tcx.liberate_late_bound_regions(
impl_m.def_id,
tcx.fn_sig(trait_m.def_id).instantiate(
tcx,
tcx.mk_args_from_iter(
[tcx.types.self_param.into()]
.into_iter()
.chain(trait_m_to_impl_m_args.iter().skip(1)),
),
),
);

let Ok(hidden_tys) = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id) else {
// Error already emitted, no need to delay another.
return;
};

let mut collector = ImplTraitInTraitCollector { tcx, types: FxIndexSet::default() };
trait_m_sig.visit_with(&mut collector);

// Bound that we find on RPITITs in the trait signature.
let mut trait_bounds = vec![];
// Bounds that we find on the RPITITs in the impl signature.
let mut impl_bounds = vec![];

for trait_projection in collector.types.into_iter().rev() {
let impl_opaque_args = trait_projection.args.rebase_onto(tcx, trait_m.def_id, impl_m_args);
let hidden_ty = hidden_tys[&trait_projection.def_id].instantiate(tcx, impl_opaque_args);

// If the hidden type is not an opaque, then we have "refined" the trait signature.
let ty::Alias(ty::Opaque, impl_opaque) = *hidden_ty.kind() else {
report_mismatched_rpitit_signature(
tcx,
trait_m_sig_with_self_for_diag,
trait_m.def_id,
impl_m.def_id,
None,
);
return;
};

// This opaque also needs to be from the impl method -- otherwise,
// it's a refinement to a TAIT.
if !tcx.hir().get_if_local(impl_opaque.def_id).map_or(false, |node| {
matches!(
node.expect_item().expect_opaque_ty().origin,
hir::OpaqueTyOrigin::AsyncFn(def_id) | hir::OpaqueTyOrigin::FnReturn(def_id)
if def_id == impl_m.def_id.expect_local()
)
}) {
report_mismatched_rpitit_signature(
tcx,
trait_m_sig_with_self_for_diag,
trait_m.def_id,
impl_m.def_id,
None,
);
return;
}

trait_bounds.extend(
tcx.item_bounds(trait_projection.def_id).iter_instantiated(tcx, trait_projection.args),
);
impl_bounds.extend(elaborate(
tcx,
tcx.explicit_item_bounds(impl_opaque.def_id)
.iter_instantiated_copied(tcx, impl_opaque.args),
));
}

let hybrid_preds = tcx
.predicates_of(impl_def_id)
.instantiate_identity(tcx)
.into_iter()
.chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_m_to_impl_m_args))
.map(|(clause, _)| clause);
let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds), Reveal::UserFacing);
let param_env = normalize_param_env_or_error(tcx, param_env, ObligationCause::dummy());

let ref infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(infcx);

// Normalize the bounds. This has two purposes:
//
// 1. Project the RPITIT projections from the trait to the opaques on the impl,
// which means that they don't need to be mapped manually.
//
// 2. Project any other projections that show up in the bound. That makes sure that
// we don't consider `tests/ui/async-await/in-trait/async-associated-types.rs`
// to be refining.
let (trait_bounds, impl_bounds) =
ocx.normalize(&ObligationCause::dummy(), param_env, (trait_bounds, impl_bounds));

// Since we've normalized things, we need to resolve regions, since we'll
// possibly have introduced region vars during projection. We don't expect
// this resolution to have incurred any region errors -- but if we do, then
// just delay a bug.
let mut implied_wf_types = FxIndexSet::default();
implied_wf_types.extend(trait_m_sig.inputs_and_output);
implied_wf_types.extend(ocx.normalize(
&ObligationCause::dummy(),
param_env,
trait_m_sig.inputs_and_output,
));
if !ocx.select_all_or_error().is_empty() {
tcx.sess.delay_span_bug(
DUMMY_SP,
"encountered errors when checking RPITIT refinement (selection)",
);
return;
}
let outlives_env = OutlivesEnvironment::with_bounds(
param_env,
infcx.implied_bounds_tys(param_env, impl_m.def_id.expect_local(), implied_wf_types),
);
let errors = infcx.resolve_regions(&outlives_env);
if !errors.is_empty() {
tcx.sess.delay_span_bug(
DUMMY_SP,
"encountered errors when checking RPITIT refinement (regions)",
);
return;
}
// Resolve any lifetime variables that may have been introduced during normalization.
let Ok((trait_bounds, impl_bounds)) = infcx.fully_resolve((trait_bounds, impl_bounds)) else {
tcx.sess.delay_span_bug(
DUMMY_SP,
"encountered errors when checking RPITIT refinement (resolution)",
);
return;
};

// For quicker lookup, use an `IndexSet`
// (we don't use one earlier because it's not foldable..)
let trait_bounds = FxIndexSet::from_iter(trait_bounds);

// Find any clauses that are present in the impl's RPITITs that are not
// present in the trait's RPITITs. This will trigger on trivial predicates,
// too, since we *do not* use the trait solver to prove that the RPITIT's
// bounds are not stronger -- we're doing a simple, syntactic compatibility
// check between bounds. This is strictly forwards compatible, though.
for (clause, span) in impl_bounds {
if !trait_bounds.contains(&clause) {
report_mismatched_rpitit_signature(
tcx,
trait_m_sig_with_self_for_diag,
trait_m.def_id,
impl_m.def_id,
Some(span),
);
return;
}
}
}

struct ImplTraitInTraitCollector<'tcx> {
tcx: TyCtxt<'tcx>,
types: FxIndexSet<ty::AliasTy<'tcx>>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
type BreakTy = !;

fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let ty::Alias(ty::Projection, proj) = *ty.kind()
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
{
if self.types.insert(proj) {
for (pred, _) in self
.tcx
.explicit_item_bounds(proj.def_id)
.iter_instantiated_copied(self.tcx, proj.args)
{
pred.visit_with(self)?;
}
}
ControlFlow::Continue(())
} else {
ty.super_visit_with(self)
}
}
}

fn report_mismatched_rpitit_signature<'tcx>(
tcx: TyCtxt<'tcx>,
trait_m_sig: ty::FnSig<'tcx>,
trait_m_def_id: DefId,
impl_m_def_id: DefId,
unmatched_bound: Option<Span>,
) {
let mapping = std::iter::zip(
tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
)
.filter_map(|(impl_bv, trait_bv)| {
if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
&& let ty::BoundVariableKind::Region(trait_bv) = trait_bv
{
Some((impl_bv, trait_bv))
} else {
None
}
})
.collect();

let mut return_ty =
trait_m_sig.output().fold_with(&mut super::RemapLateBound { tcx, mapping: &mapping });

if tcx.asyncness(impl_m_def_id).is_async() && tcx.asyncness(trait_m_def_id).is_async() {
let ty::Alias(ty::Projection, future_ty) = return_ty.kind() else {
bug!();
};
let Some(future_output_ty) = tcx
.explicit_item_bounds(future_ty.def_id)
.iter_instantiated_copied(tcx, future_ty.args)
.find_map(|(clause, _)| match clause.kind().no_bound_vars()? {
ty::ClauseKind::Projection(proj) => proj.term.ty(),
_ => None,
})
else {
bug!()
};
return_ty = future_output_ty;
}

let (span, impl_return_span, pre, post) =
match tcx.hir().get_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
hir::FnRetTy::DefaultReturn(span) => (tcx.def_span(impl_m_def_id), span, "-> ", " "),
hir::FnRetTy::Return(ty) => (ty.span, ty.span, "", ""),
};
let trait_return_span =
tcx.hir().get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
hir::FnRetTy::Return(ty) => ty.span,
});

let span = unmatched_bound.unwrap_or(span);
tcx.emit_spanned_lint(
REFINING_IMPL_TRAIT,
tcx.local_def_id_to_hir_id(impl_m_def_id.expect_local()),
span,
crate::errors::ReturnPositionImplTraitInTraitRefined {
impl_return_span,
trait_return_span,
pre,
post,
return_ty,
unmatched_bound,
},
);
}

fn type_visibility<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option<ty::Visibility<DefId>> {
match *ty.kind() {
ty::Ref(_, ty, _) => type_visibility(tcx, ty),
ty::Adt(def, args) => {
if def.is_fundamental() {
type_visibility(tcx, args.type_at(0))
} else {
Some(tcx.visibility(def.did()))
}
}
_ => None,
}
}
16 changes: 16 additions & 0 deletions compiler/rustc_hir_analysis/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,22 @@ pub struct UnusedAssociatedTypeBounds {
pub span: Span,
}

#[derive(LintDiagnostic)]
#[diag(hir_analysis_rpitit_refined)]
#[note]
pub(crate) struct ReturnPositionImplTraitInTraitRefined<'tcx> {
#[suggestion(applicability = "maybe-incorrect", code = "{pre}{return_ty}{post}")]
pub impl_return_span: Span,
#[label]
pub trait_return_span: Option<Span>,
#[label(hir_analysis_unmatched_bound_label)]
pub unmatched_bound: Option<Span>,

pub pre: &'static str,
pub post: &'static str,
pub return_ty: Ty<'tcx>,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_assoc_bound_on_const)]
#[note]
Expand Down
Loading

0 comments on commit 7f0fa48

Please sign in to comment.