Skip to content

Simplify universal impl trait lowering #97598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ impl<'a, 'hir> ItemLowerer<'a, 'hir> {
task_context: None,
current_item: None,
captured_lifetimes: None,
impl_trait_defs: Vec::new(),
impl_trait_bounds: Vec::new(),
allow_try_trait: Some([sym::try_trait_v2, sym::yeet_desugar_details][..].into()),
allow_gen_future: Some([sym::gen_future][..].into()),
allow_into_future: Some([sym::into_future][..].into()),
Expand Down Expand Up @@ -264,16 +266,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
let body_id =
this.lower_maybe_async_body(span, &decl, asyncness, body.as_deref());

let (generics, decl) =
this.add_implicit_generics(generics, id, |this, idty, idpb| {
let ret_id = asyncness.opt_return_id();
this.lower_fn_decl(
&decl,
Some((id, idty, idpb)),
FnDeclKind::Fn,
ret_id,
)
});
let (generics, decl) = this.add_implicit_generics(generics, id, |this| {
let ret_id = asyncness.opt_return_id();
this.lower_fn_decl(&decl, Some(id), FnDeclKind::Fn, ret_id)
});
let sig = hir::FnSig {
decl,
header: this.lower_fn_header(header),
Expand Down Expand Up @@ -387,7 +383,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// lifetime to be added, but rather a reference to a
// parent lifetime.
let (generics, (trait_ref, lowered_ty)) =
self.add_implicit_generics(ast_generics, id, |this, _, _| {
self.add_implicit_generics(ast_generics, id, |this| {
let trait_ref = trait_ref.as_ref().map(|trait_ref| {
this.lower_trait_ref(
trait_ref,
Expand Down Expand Up @@ -652,7 +648,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
ForeignItemKind::Fn(box Fn { ref sig, ref generics, .. }) => {
let fdec = &sig.decl;
let (generics, (fn_dec, fn_args)) =
self.add_implicit_generics(generics, i.id, |this, _, _| {
self.add_implicit_generics(generics, i.id, |this| {
(
// Disallow `impl Trait` in foreign items.
this.lower_fn_decl(fdec, None, FnDeclKind::ExternFn, None),
Expand Down Expand Up @@ -1231,8 +1227,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
is_async: Option<NodeId>,
) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
let header = self.lower_fn_header(sig.header);
let (generics, decl) = self.add_implicit_generics(generics, id, |this, idty, idpb| {
this.lower_fn_decl(&sig.decl, Some((id, idty, idpb)), kind, is_async)
let (generics, decl) = self.add_implicit_generics(generics, id, |this| {
this.lower_fn_decl(&sig.decl, Some(id), kind, is_async)
});
(generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
}
Expand Down Expand Up @@ -1292,7 +1288,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
pub(super) fn lower_generics_mut(
&mut self,
generics: &Generics,
mut itctx: ImplTraitContext<'_, 'hir>,
mut itctx: ImplTraitContext,
) -> GenericsCtor<'hir> {
// Error if `?Trait` bounds in where clauses don't refer directly to type parameters.
// Note: we used to clone these bounds directly onto the type parameter (and avoid lowering
Expand Down Expand Up @@ -1372,7 +1368,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
pub(super) fn lower_generics(
&mut self,
generics: &Generics,
itctx: ImplTraitContext<'_, 'hir>,
itctx: ImplTraitContext,
) -> &'hir hir::Generics<'hir> {
let generics_ctor = self.lower_generics_mut(generics, itctx);
generics_ctor.into_generics(self.arena)
Expand Down
113 changes: 45 additions & 68 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ struct LoweringContext<'a, 'hir: 'a> {
local_id_to_def_id: SortedMap<ItemLocalId, LocalDefId>,
trait_map: FxHashMap<ItemLocalId, Box<[TraitCandidate]>>,

impl_trait_defs: Vec<hir::GenericParam<'hir>>,
impl_trait_bounds: Vec<hir::WherePredicate<'hir>>,

/// NodeIds that are lowered inside the current HIR owner.
node_id_to_local_id: FxHashMap<NodeId, hir::ItemLocalId>,

Expand Down Expand Up @@ -244,13 +247,13 @@ pub trait ResolverAstLowering {
/// Context of `impl Trait` in code, which determines whether it is allowed in an HIR subtree,
/// and if so, what meaning it has.
#[derive(Debug)]
enum ImplTraitContext<'b, 'a> {
enum ImplTraitContext {
/// Treat `impl Trait` as shorthand for a new universal generic parameter.
/// Example: `fn foo(x: impl Debug)`, where `impl Debug` is conceptually
/// equivalent to a fresh universal parameter like `fn foo<T: Debug>(x: T)`.
///
/// Newly generated parameters should be inserted into the given `Vec`.
Universal(&'b mut Vec<hir::GenericParam<'a>>, &'b mut Vec<hir::WherePredicate<'a>>, LocalDefId),
Universal(LocalDefId),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this LocalDefId always equal to self.current_hir_id_owner, or can there be a mismatch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so but maybe worth checking. Do you think I should check that in this PR? or should we leave this for a followup?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for follow-up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've just checked this and yes, it's always equal. We are always creating Universal passing self.current_hir_id_owner to it.


/// Treat `impl Trait` as shorthand for a new opaque type.
/// Example: `fn foo() -> impl Debug`, where `impl Debug` is conceptually
Expand Down Expand Up @@ -290,11 +293,11 @@ enum ImplTraitPosition {
ImplReturn,
}

impl<'a> ImplTraitContext<'_, 'a> {
fn reborrow<'this>(&'this mut self) -> ImplTraitContext<'this, 'a> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

impl ImplTraitContext {
fn reborrow<'this>(&'this mut self) -> ImplTraitContext {
use self::ImplTraitContext::*;
match self {
Universal(params, bounds, parent) => Universal(params, bounds, *parent),
Universal(parent) => Universal(*parent),
ReturnPositionOpaqueTy { origin } => ReturnPositionOpaqueTy { origin: *origin },
TypeAliasesOpaqueTy => TypeAliasesOpaqueTy,
Disallowed(pos) => Disallowed(*pos),
Expand Down Expand Up @@ -701,34 +704,24 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&mut self,
generics: &Generics,
parent_node_id: NodeId,
f: impl FnOnce(
&mut Self,
&mut Vec<hir::GenericParam<'hir>>,
&mut Vec<hir::WherePredicate<'hir>>,
) -> T,
f: impl FnOnce(&mut Self) -> T,
) -> (&'hir hir::Generics<'hir>, T) {
let mut impl_trait_defs = Vec::new();
let mut impl_trait_bounds = Vec::new();
let mut lowered_generics = self.lower_generics_mut(
generics,
ImplTraitContext::Universal(
&mut impl_trait_defs,
&mut impl_trait_bounds,
self.current_hir_id_owner,
),
);
let res = f(self, &mut impl_trait_defs, &mut impl_trait_bounds);
let mut lowered_generics = self
.lower_generics_mut(generics, ImplTraitContext::Universal(self.current_hir_id_owner));
let res = f(self);

let extra_lifetimes = self.resolver.take_extra_lifetime_params(parent_node_id);
let impl_trait_defs = std::mem::take(&mut self.impl_trait_defs);
lowered_generics.params.extend(
extra_lifetimes
.into_iter()
.filter_map(|(ident, node_id, res)| {
self.lifetime_res_to_generic_param(ident, node_id, res)
})
.chain(impl_trait_defs),
.chain(impl_trait_defs.into_iter()),
);
lowered_generics.predicates.extend(impl_trait_bounds);
let impl_trait_bounds = std::mem::take(&mut self.impl_trait_bounds);
lowered_generics.predicates.extend(impl_trait_bounds.into_iter());

let lowered_generics = lowered_generics.into_generics(self.arena);
(lowered_generics, res)
Expand Down Expand Up @@ -898,7 +891,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn lower_assoc_ty_constraint(
&mut self,
constraint: &AssocConstraint,
mut itctx: ImplTraitContext<'_, 'hir>,
mut itctx: ImplTraitContext,
) -> hir::TypeBinding<'hir> {
debug!("lower_assoc_ty_constraint(constraint={:?}, itctx={:?})", constraint, itctx);

Expand Down Expand Up @@ -962,7 +955,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// so desugar to
//
// fn foo(x: dyn Iterator<Item = impl Debug>)
ImplTraitContext::Universal(_, _, parent) if self.is_in_dyn_type => {
ImplTraitContext::Universal(parent) if self.is_in_dyn_type => {
parent_def_id = parent;
(true, itctx)
}
Expand Down Expand Up @@ -1036,7 +1029,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn lower_generic_arg(
&mut self,
arg: &ast::GenericArg,
itctx: ImplTraitContext<'_, 'hir>,
itctx: ImplTraitContext,
) -> hir::GenericArg<'hir> {
match arg {
ast::GenericArg::Lifetime(lt) => GenericArg::Lifetime(self.lower_lifetime(&lt)),
Expand Down Expand Up @@ -1103,7 +1096,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
}

fn lower_ty(&mut self, t: &Ty, itctx: ImplTraitContext<'_, 'hir>) -> &'hir hir::Ty<'hir> {
fn lower_ty(&mut self, t: &Ty, itctx: ImplTraitContext) -> &'hir hir::Ty<'hir> {
self.arena.alloc(self.lower_ty_direct(t, itctx))
}

Expand All @@ -1113,7 +1106,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
qself: &Option<QSelf>,
path: &Path,
param_mode: ParamMode,
itctx: ImplTraitContext<'_, 'hir>,
itctx: ImplTraitContext,
) -> hir::Ty<'hir> {
let id = self.lower_node_id(t.id);
let qpath = self.lower_qpath(t.id, qself, path, param_mode, itctx);
Expand All @@ -1128,7 +1121,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.ty(span, hir::TyKind::Tup(tys))
}

fn lower_ty_direct(&mut self, t: &Ty, mut itctx: ImplTraitContext<'_, 'hir>) -> hir::Ty<'hir> {
fn lower_ty_direct(&mut self, t: &Ty, mut itctx: ImplTraitContext) -> hir::Ty<'hir> {
let kind = match t.kind {
TyKind::Infer => hir::TyKind::Infer,
TyKind::Err => hir::TyKind::Err,
Expand Down Expand Up @@ -1235,40 +1228,32 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
|this| this.lower_param_bounds(bounds, nested_itctx),
)
}
ImplTraitContext::Universal(
in_band_ty_params,
in_band_ty_bounds,
parent_def_id,
) => {
ImplTraitContext::Universal(parent_def_id) => {
// Add a definition for the in-band `Param`.
let def_id = self.resolver.local_def_id(def_node_id);

let hir_bounds = self.lower_param_bounds(
bounds,
ImplTraitContext::Universal(
in_band_ty_params,
in_band_ty_bounds,
parent_def_id,
),
);
let hir_bounds = self
.lower_param_bounds(bounds, ImplTraitContext::Universal(parent_def_id));
// Set the name to `impl Bound1 + Bound2`.
let ident = Ident::from_str_and_span(&pprust::ty_to_string(t), span);
in_band_ty_params.push(hir::GenericParam {
let param = hir::GenericParam {
hir_id: self.lower_node_id(def_node_id),
name: ParamName::Plain(self.lower_ident(ident)),
pure_wrt_drop: false,
span: self.lower_span(span),
kind: hir::GenericParamKind::Type { default: None, synthetic: true },
colon_span: None,
});
};
self.impl_trait_defs.push(param);

if let Some(preds) = self.lower_generic_bound_predicate(
ident,
def_node_id,
&GenericParamKind::Type { default: None },
hir_bounds,
hir::PredicateOrigin::ImplTrait,
) {
in_band_ty_bounds.push(preds)
self.impl_trait_bounds.push(preds)
}

hir::TyKind::Path(hir::QPath::Resolved(
Expand Down Expand Up @@ -1442,21 +1427,17 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn lower_fn_decl(
&mut self,
decl: &FnDecl,
mut in_band_ty_params: Option<(
NodeId,
&mut Vec<hir::GenericParam<'hir>>,
&mut Vec<hir::WherePredicate<'hir>>,
)>,
fn_node_id: Option<NodeId>,
kind: FnDeclKind,
make_ret_async: Option<NodeId>,
) -> &'hir hir::FnDecl<'hir> {
debug!(
"lower_fn_decl(\
fn_decl: {:?}, \
in_band_ty_params: {:?}, \
fn_node_id: {:?}, \
kind: {:?}, \
make_ret_async: {:?})",
decl, in_band_ty_params, kind, make_ret_async,
decl, fn_node_id, kind, make_ret_async,
);

let c_variadic = decl.c_variadic();
Expand All @@ -1469,10 +1450,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
inputs = &inputs[..inputs.len() - 1];
}
let inputs = self.arena.alloc_from_iter(inputs.iter().map(|param| {
if let Some((_, ibty, ibpb)) = &mut in_band_ty_params {
if fn_node_id.is_some() {
self.lower_ty_direct(
&param.ty,
ImplTraitContext::Universal(ibty, ibpb, self.current_hir_id_owner),
ImplTraitContext::Universal(self.current_hir_id_owner),
)
} else {
self.lower_ty_direct(
Expand All @@ -1494,15 +1475,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let output = if let Some(ret_id) = make_ret_async {
self.lower_async_fn_ret_ty(
&decl.output,
in_band_ty_params.expect("`make_ret_async` but no `fn_def_id`").0,
fn_node_id.expect("`make_ret_async` but no `fn_def_id`"),
ret_id,
)
} else {
match decl.output {
FnRetTy::Ty(ref ty) => {
let context = match in_band_ty_params {
Some((node_id, _, _)) if kind.impl_trait_return_allowed() => {
let fn_def_id = self.resolver.local_def_id(node_id);
let context = match fn_node_id {
Some(fn_node_id) if kind.impl_trait_return_allowed() => {
let fn_def_id = self.resolver.local_def_id(fn_node_id);
ImplTraitContext::ReturnPositionOpaqueTy {
origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
}
Expand Down Expand Up @@ -1788,7 +1769,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn lower_param_bound(
&mut self,
tpb: &GenericBound,
itctx: ImplTraitContext<'_, 'hir>,
itctx: ImplTraitContext,
) -> hir::GenericBound<'hir> {
match tpb {
GenericBound::Trait(p, modifier) => hir::GenericBound::Trait(
Expand Down Expand Up @@ -1966,11 +1947,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
}

fn lower_trait_ref(
&mut self,
p: &TraitRef,
itctx: ImplTraitContext<'_, 'hir>,
) -> hir::TraitRef<'hir> {
fn lower_trait_ref(&mut self, p: &TraitRef, itctx: ImplTraitContext) -> hir::TraitRef<'hir> {
let path = match self.lower_qpath(p.ref_id, &None, &p.path, ParamMode::Explicit, itctx) {
hir::QPath::Resolved(None, path) => path,
qpath => panic!("lower_trait_ref: unexpected QPath `{:?}`", qpath),
Expand All @@ -1982,7 +1959,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn lower_poly_trait_ref(
&mut self,
p: &PolyTraitRef,
mut itctx: ImplTraitContext<'_, 'hir>,
mut itctx: ImplTraitContext,
) -> hir::PolyTraitRef<'hir> {
let bound_generic_params = self.lower_generic_params(&p.bound_generic_params);

Expand All @@ -1993,22 +1970,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
hir::PolyTraitRef { bound_generic_params, trait_ref, span: self.lower_span(p.span) }
}

fn lower_mt(&mut self, mt: &MutTy, itctx: ImplTraitContext<'_, 'hir>) -> hir::MutTy<'hir> {
fn lower_mt(&mut self, mt: &MutTy, itctx: ImplTraitContext) -> hir::MutTy<'hir> {
hir::MutTy { ty: self.lower_ty(&mt.ty, itctx), mutbl: mt.mutbl }
}

fn lower_param_bounds(
&mut self,
bounds: &[GenericBound],
itctx: ImplTraitContext<'_, 'hir>,
itctx: ImplTraitContext,
) -> hir::GenericBounds<'hir> {
self.arena.alloc_from_iter(self.lower_param_bounds_mut(bounds, itctx))
}

fn lower_param_bounds_mut<'s>(
&'s mut self,
bounds: &'s [GenericBound],
mut itctx: ImplTraitContext<'s, 'hir>,
mut itctx: ImplTraitContext,
) -> impl Iterator<Item = hir::GenericBound<'hir>> + Captures<'s> + Captures<'a> {
bounds.iter().map(move |bound| self.lower_param_bound(bound, itctx.reborrow()))
}
Expand Down
Loading