Skip to content

Commit

Permalink
Prereq6 for async drop - templated coroutine processing and layout
Browse files Browse the repository at this point in the history
  • Loading branch information
azhogin committed Sep 8, 2024
1 parent 62d37f8 commit 64e4cca
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,7 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
_ => unreachable!(),
};

let coroutine_layout =
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.args).unwrap();

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
DIFlags::FlagZero,
),
|cx, coroutine_type_di_node| {
let coroutine_layout = cx
.tcx
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
.unwrap();
let coroutine_layout =
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args).unwrap();

let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
coroutine_type_and_layout.variants
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ macro_rules! arena_types {
($macro:path) => (
$macro!([
[] layout: rustc_target::abi::LayoutS<rustc_target::abi::FieldIdx, rustc_target::abi::VariantIdx>,
[] proxy_coroutine_layout: rustc_middle::mir::CoroutineLayout<'tcx>,
[] fn_abi: rustc_target::abi::call::FnAbi<'tcx, rustc_middle::ty::Ty<'tcx>>,
// AdtDef are interned and compared by address
[decode] adt_def: rustc_middle::ty::AdtDefData,
Expand Down
18 changes: 17 additions & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,13 @@ rustc_queries! {
desc { |tcx| "elaborating drops for `{}`", tcx.def_path_str(key) }
}

query templated_mir_drops_elaborated_and_const_checked(ty: Ty<'tcx>)
-> &'tcx Steal<mir::Body<'tcx>>
{
no_hash
desc { |tcx| "elaborating drops for templated mir `{}`", ty }
}

query mir_for_ctfe(
key: DefId
) -> &'tcx mir::Body<'tcx> {
Expand Down Expand Up @@ -570,6 +577,11 @@ rustc_queries! {
desc { |tcx| "checking for `#[coverage(..)]` on `{}`", tcx.def_path_str(key) }
}

/// MIR for templated coroutine after our optimization passes have run.
query templated_optimized_mir(ty: Ty<'tcx>) -> &'tcx mir::Body<'tcx> {
desc { |tcx| "optimizing templated MIR for `{}`", ty }
}

/// Summarizes coverage IDs inserted by the `InstrumentCoverage` MIR pass
/// (for compiler option `-Cinstrument-coverage`), after MIR optimizations
/// have had a chance to potentially remove some of them.
Expand Down Expand Up @@ -1161,7 +1173,11 @@ rustc_queries! {
/// Generates a MIR body for the shim.
query mir_shims(key: ty::InstanceKind<'tcx>) -> &'tcx mir::Body<'tcx> {
arena_cache
desc { |tcx| "generating MIR shim for `{}`", tcx.def_path_str(key.def_id()) }
desc {
|tcx| "generating MIR shim for `{}`, instance={:?}",
tcx.def_path_str(key.def_id()),
key
}
}

/// The `symbol_name` query provides the symbol name for calling a
Expand Down
63 changes: 49 additions & 14 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -896,22 +896,57 @@ where
i,
),

ty::Coroutine(def_id, args) => match this.variants {
Variants::Single { index } => TyMaybeWithLayout::Ty(
args.as_coroutine()
.state_tys(def_id, tcx)
.nth(index.as_usize())
.unwrap()
.nth(i)
.unwrap(),
),
Variants::Multiple { tag, tag_field, .. } => {
if i == tag_field {
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
ty::Coroutine(def_id, args) => {
// layout of `async_drop_in_place<T>::{closure}` in case,
// when T is a coroutine, contains this internal coroutine's ref
if tcx.is_templated_coroutine(def_id) {
fn find_impl_coroutine<'tcx>(
tcx: TyCtxt<'tcx>,
mut cor_ty: Ty<'tcx>,
) -> Ty<'tcx> {
let mut ty = cor_ty;
loop {
if let ty::Coroutine(def_id, args) = ty.kind() {
cor_ty = ty;
if tcx.is_templated_coroutine(*def_id) {
ty = args.first().unwrap().expect_ty();
continue;
} else {
return cor_ty;
}
} else {
return cor_ty;
}
}
}
let arg_cor_ty = args.first().unwrap().expect_ty();
if arg_cor_ty.is_coroutine() {
assert!(i == 0);
let impl_cor_ty = find_impl_coroutine(tcx, arg_cor_ty);
return TyMaybeWithLayout::Ty(Ty::new_mut_ref(
tcx,
tcx.lifetimes.re_static,
impl_cor_ty,
));
}
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
}
},
match this.variants {
Variants::Single { index } => TyMaybeWithLayout::Ty(
args.as_coroutine()
.state_tys(def_id, tcx)
.nth(index.as_usize())
.unwrap()
.nth(i)
.unwrap(),
),
Variants::Multiple { tag, tag_field, .. } => {
if i == tag_field {
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
}
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
}
}
}

ty::Tuple(tys) => TyMaybeWithLayout::Ty(tys[i]),

Expand Down
70 changes: 66 additions & 4 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use rustc_errors::{Diag, ErrorGuaranteed, StashKey};
use rustc_hir::def::{CtorKind, CtorOf, DefKind, DocLinkResMap, LifetimeRes, Res};
use rustc_hir::def_id::{CrateNum, DefId, DefIdMap, LocalDefId, LocalDefIdMap};
use rustc_hir::LangItem;
use rustc_index::bit_set::BitMatrix;
use rustc_index::IndexVec;
use rustc_macros::{
extension, Decodable, Encodable, HashStable, TyDecodable, TyEncodable, TypeFoldable,
Expand Down Expand Up @@ -110,7 +111,7 @@ pub use self::IntVarValue::*;
use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason};
use crate::metadata::ModChild;
use crate::middle::privacy::EffectiveVisibilities;
use crate::mir::{Body, CoroutineLayout};
use crate::mir::{Body, CoroutineLayout, CoroutineSavedLocal, CoroutineSavedTy, SourceInfo};
use crate::query::Providers;
use crate::traits::{self, Reveal};
use crate::ty;
Expand Down Expand Up @@ -1771,7 +1772,7 @@ impl<'tcx> TyCtxt<'tcx> {
| ty::InstanceKind::FnPtrAddrShim(..)
| ty::InstanceKind::AsyncDropGlueCtorShim(..) => self.mir_shims(instance),
// async drop glue should be processed specifically, as a templated coroutine
ty::InstanceKind::AsyncDropGlue(_, _ty) => todo!(),
ty::InstanceKind::AsyncDropGlue(_, ty) => self.templated_optimized_mir(ty),
}
}

Expand Down Expand Up @@ -1851,16 +1852,17 @@ impl<'tcx> TyCtxt<'tcx> {
self.def_kind(trait_def_id) == DefKind::TraitAlias
}

/// Returns layout of a coroutine. Layout might be unavailable if the
/// Returns layout of a non-templated coroutine. Layout might be unavailable if the
/// coroutine is tainted by errors.
///
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
/// e.g. `args.as_coroutine().kind_ty()`.
pub fn coroutine_layout(
pub fn ordinary_coroutine_layout(
self,
def_id: DefId,
coroutine_kind_ty: Ty<'tcx>,
) -> Option<&'tcx CoroutineLayout<'tcx>> {
debug_assert_ne!(Some(def_id), self.lang_items().async_drop_in_place_poll_fn());
let mir = self.optimized_mir(def_id);
// Regular coroutine
if coroutine_kind_ty.is_unit() {
Expand Down Expand Up @@ -1890,6 +1892,66 @@ impl<'tcx> TyCtxt<'tcx> {
}
}

/// Returns layout of a templated coroutine. Layout might be unavailable if the
/// coroutine is tainted by errors. Atm, the only templated coroutine is
/// `async_drop_in_place<T>::{closure}` returned from `async fn async_drop_in_place<T>(..)`.
pub fn templated_coroutine_layout(self, ty: Ty<'tcx>) -> Option<&'tcx CoroutineLayout<'tcx>> {
self.templated_optimized_mir(ty).coroutine_layout_raw()
}

/// Returns layout of a templated (or not) coroutine. Layout might be unavailable if the
/// coroutine is tainted by errors.
pub fn coroutine_layout(
self,
def_id: DefId,
args: GenericArgsRef<'tcx>,
) -> Option<&'tcx CoroutineLayout<'tcx>> {
if Some(def_id) == self.lang_items().async_drop_in_place_poll_fn() {
fn find_impl_coroutine<'tcx>(tcx: TyCtxt<'tcx>, mut cor_ty: Ty<'tcx>) -> Ty<'tcx> {
let mut ty = cor_ty;
loop {
if let ty::Coroutine(def_id, args) = ty.kind() {
cor_ty = ty;
if tcx.is_templated_coroutine(*def_id) {
ty = args.first().unwrap().expect_ty();
continue;
} else {
return cor_ty;
}
} else {
return cor_ty;
}
}
}
// layout of `async_drop_in_place<T>::{closure}` in case,
// when T is a coroutine, contains this internal coroutine's ref
let arg_cor_ty = args.first().unwrap().expect_ty();
if arg_cor_ty.is_coroutine() {
let impl_cor_ty = find_impl_coroutine(self, arg_cor_ty);
let impl_ref = Ty::new_mut_ref(self, self.lifetimes.re_static, impl_cor_ty);
let span = self.def_span(def_id);
let source_info = SourceInfo::outermost(span);
let proxy_layout = CoroutineLayout {
field_tys: [CoroutineSavedTy {
ty: impl_ref,
source_info,
ignore_for_traits: true,
}]
.into(),
field_names: [None].into(),
variant_fields: [IndexVec::from([CoroutineSavedLocal::ZERO])].into(),
variant_source_info: [source_info].into(),
storage_conflicts: BitMatrix::new(1, 1),
};
return Some(self.arena.alloc(proxy_layout));
} else {
self.templated_coroutine_layout(Ty::new_coroutine(self, def_id, args))
}
} else {
self.ordinary_coroutine_layout(def_id, args.as_coroutine().kind_ty())
}
}

/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
/// If it implements no trait, returns `None`.
pub fn trait_id_of_impl(self, def_id: DefId) -> Option<DefId> {
Expand Down
11 changes: 7 additions & 4 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
#[inline]
fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
// FIXME requires optimized MIR
FIRST_VARIANT
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
FIRST_VARIANT..tcx.coroutine_layout(def_id, self.args).unwrap().variant_fields.next_index()
}

/// The discriminant for the given variant. Panics if the `variant_index` is
Expand Down Expand Up @@ -139,10 +138,14 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
def_id: DefId,
tcx: TyCtxt<'tcx>,
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
let layout = tcx.coroutine_layout(def_id, self.args).unwrap();
layout.variant_fields.iter().map(move |variant| {
variant.iter().map(move |field| {
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
if tcx.is_templated_coroutine(def_id) {
layout.field_tys[*field].ty
} else {
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
}
})
})
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_mir_dataflow/src/value_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,9 @@ impl<'tcx> Map<'tcx> {
if exclude.contains(local) {
continue;
}
if decl.ty.is_templated_coroutine(tcx) {
continue;
}

// Create a place for the local.
debug_assert!(self.locals[local].is_none());
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_mir_transform/src/known_panics_lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,11 @@ impl CanConstProp {
};
for (local, val) in cpv.can_const_prop.iter_enumerated_mut() {
let ty = body.local_decls[local].ty;
if ty.is_union() {
if ty.is_templated_coroutine(tcx) {
// No const propagation for templated coroutine (AsyncDropGlue)
*val = ConstPropMode::NoPropagation;
continue;
} else if ty.is_union() {
// Unions are incompatible with the current implementation of
// const prop because Rust has no concept of an active
// variant of a union
Expand Down
47 changes: 46 additions & 1 deletion compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use rustc_middle::mir::{
MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, SourceInfo,
Statement, StatementKind, TerminatorKind, START_BLOCK,
};
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
use rustc_middle::util::Providers;
use rustc_middle::{bug, query, span_bug};
use rustc_span::source_map::Spanned;
Expand Down Expand Up @@ -121,9 +121,11 @@ pub fn provide(providers: &mut Providers) {
mir_const_qualif,
mir_promoted,
mir_drops_elaborated_and_const_checked,
templated_mir_drops_elaborated_and_const_checked,
mir_for_ctfe,
mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses,
optimized_mir,
templated_optimized_mir,
is_mir_available,
is_ctfe_mir_available: is_mir_available,
mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable,
Expand Down Expand Up @@ -459,6 +461,21 @@ fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &
tcx.alloc_steal_mir(body)
}

/// mir_drops_elaborated_and_const_checked simplified analog for templated coroutine
fn templated_mir_drops_elaborated_and_const_checked<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
) -> &'tcx Steal<Body<'tcx>> {
let ty::Coroutine(def_id, _) = ty.kind() else {
bug!();
};
assert!(ty.is_templated_coroutine(tcx));

let instance = ty::InstanceKind::AsyncDropGlue(*def_id, ty);
let body = tcx.mir_shims(instance).clone();
tcx.alloc_steal_mir(body)
}

// Made public such that `mir_drops_elaborated_and_const_checked` can be overridden
// by custom rustc drivers, running all the steps by themselves.
pub fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
Expand Down Expand Up @@ -623,6 +640,11 @@ fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> {
tcx.arena.alloc(inner_optimized_mir(tcx, did))
}

/// Optimize the templated MIR and prepare it for codegen.
fn templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> &'tcx Body<'tcx> {
tcx.arena.alloc(inner_templated_optimized_mir(tcx, ty))
}

fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
if tcx.is_constructor(did.to_def_id()) {
// There's no reason to run all of the MIR passes on constructors when
Expand Down Expand Up @@ -667,6 +689,29 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
body
}

fn inner_templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Body<'tcx> {
debug!("about to call templated_mir_drops_elaborated...");
let body = tcx.templated_mir_drops_elaborated_and_const_checked(ty).steal();
let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst);

if body.tainted_by_errors.is_some() {
return body;
}

// If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable
// predicates, it will shrink the MIR to a single `unreachable` terminator.
// More generally, if MIR is a lone `unreachable`, there is nothing to optimize.
if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind
&& body.basic_blocks[START_BLOCK].statements.is_empty()
{
return body;
}

run_optimization_passes(tcx, &mut body);

body
}

/// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for
/// constant evaluation once all generic parameters become known.
fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
// Same if this is the by-move body of a coroutine-closure.
self.caller_body.coroutine_layout_raw()
} else {
self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
self.tcx.coroutine_layout(def_id, args)
};

let Some(layout) = layout else {
Expand Down
Loading

0 comments on commit 64e4cca

Please sign in to comment.