diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index c83b10a5e583a..7cab3d2792bf0 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -211,6 +211,9 @@ struct TransformVisitor<'tcx> { old_yield_ty: Ty<'tcx>, old_ret_ty: Ty<'tcx>, + + // Fields in the coroutine struct that reference other fields. + self_referential_fields: DenseBitSet, } impl<'tcx> TransformVisitor<'tcx> { @@ -356,6 +359,45 @@ impl<'tcx> TransformVisitor<'tcx> { Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) } } + // Create a `Place` referencing a self-referential coroutine struct field. + // Self-referential coroutine struct fields are wrapped in `UnsafePinned`. This method + // creates the projections to get the `Place` behind `UnsafePinned`. + #[tracing::instrument(level = "trace", skip(self), ret)] + fn make_self_referential_field( + &self, + variant_index: VariantIdx, + idx: FieldIdx, + ty: Ty<'tcx>, + ) -> Place<'tcx> { + let self_place = Place::from(SELF_ARG); + let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index); + + let ty::Adt(adt_def, args) = ty.kind() else { + bug!("expected self-referential field to be an ADT, but it is {:?}", ty); + }; + assert_eq!(adt_def.did(), self.tcx.require_lang_item(LangItem::UnsafePinned, DUMMY_SP)); + let original_ty = args.type_at(0); + + let unsafe_cell_def_id = self.tcx.require_lang_item(LangItem::UnsafeCell, DUMMY_SP); + let unsafe_cell_ty = Ty::new_adt( + self.tcx, + self.tcx.adt_def(unsafe_cell_def_id), + self.tcx.mk_args(&[original_ty.into()]), + ); + + let mut projection = base.projection.to_vec(); + // self.field (UnsafePinned) + projection.push(ProjectionElem::Field(idx, ty)); + // value (UnsafeCell) + projection.push(ProjectionElem::Field(FieldIdx::from_u32(0), unsafe_cell_ty)); + // value (T) + projection.push(ProjectionElem::Field(FieldIdx::from_u32(0), original_ty)); + + let place = Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }; + debug!(?place); + place + } + // Create a statement which changes the discriminant #[tracing::instrument(level = "trace", skip(self))] fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> { @@ -412,8 +454,17 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { #[tracing::instrument(level = "trace", skip(self), ret)] fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) { // Replace an Local in the remap with a coroutine struct access - if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) { - replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); + let local = place.local; + if let Some(&Some((ty, variant_index, idx))) = self.remap.get(local) { + if self.self_referential_fields.contains(local) { + replace_base( + place, + self.make_self_referential_field(variant_index, idx, ty), + self.tcx, + ); + } else { + replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); + } } } @@ -964,10 +1015,12 @@ impl StorageConflictVisitor<'_, '_> { } } -#[tracing::instrument(level = "trace", skip(liveness, body))] +#[tracing::instrument(level = "trace", skip(tcx, liveness, body))] fn compute_layout<'tcx>( + tcx: TyCtxt<'tcx>, liveness: LivenessInfo, body: &Body<'tcx>, + self_referential_fields: &DenseBitSet, ) -> ( IndexVec, VariantIdx, FieldIdx)>>, CoroutineLayout<'tcx>, @@ -1006,8 +1059,19 @@ fn compute_layout<'tcx>( ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true, _ => false, }; + + // Use `UnsafePinned` for self-referential fields. + let local_ty = if self_referential_fields.contains(local) { + let unsafe_pinned_did = tcx.require_lang_item(LangItem::UnsafePinned, body.span); + let unsafe_pinned_adt_def = tcx.adt_def(unsafe_pinned_did); + let args = tcx.mk_args(&[decl.ty.into()]); + Ty::new_adt(tcx, unsafe_pinned_adt_def, args) + } else { + decl.ty + }; + let decl = - CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + CoroutineSavedTy { ty: local_ty, source_info: decl.source_info, ignore_for_traits }; debug!(?decl); tys.push(decl); @@ -1411,10 +1475,16 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>( let always_live_locals = always_storage_live_locals(body); let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + let mut self_referential_fields_finder = + SelfReferentialFieldsFinder::new(&liveness_info.saved_locals); + self_referential_fields_finder.visit_body(body); + let self_referential_fields = self_referential_fields_finder.locals_requiring_unsafe_pinned; + // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto coroutine struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (_, coroutine_layout, _) = compute_layout(liveness_info, body); + let (_, coroutine_layout, _) = + compute_layout(tcx, liveness_info, body, &self_referential_fields); check_suspend_tys(tcx, &coroutine_layout, body); check_field_tys_sized(tcx, &coroutine_layout, def_id); @@ -1557,10 +1627,17 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform { vis.visit_body(body); } + let mut self_referential_fields_finder = + SelfReferentialFieldsFinder::new(&liveness_info.saved_locals); + self_referential_fields_finder.visit_body(body); + let self_referential_fields = self_referential_fields_finder.locals_requiring_unsafe_pinned; + debug!("self referential fields: {:?}", self_referential_fields,); + // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto coroutine struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + let (remap, layout, storage_liveness) = + compute_layout(tcx, liveness_info, body, &self_referential_fields); let can_return = can_return(tcx, body, body.typing_env(tcx)); @@ -1585,6 +1662,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform { new_ret_local, old_ret_ty, old_yield_ty, + self_referential_fields, }; transform.visit_body(body); @@ -1599,7 +1677,13 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform { 0..0, args_iter.filter_map(|local| { let (ty, variant_index, idx) = transform.remap[local]?; - let lhs = transform.make_field(variant_index, idx, ty); + + let lhs = if transform.self_referential_fields.contains(local) { + transform.make_self_referential_field(variant_index, idx, ty) + } else { + transform.make_field(variant_index, idx, ty) + }; + let rhs = Rvalue::Use(Operand::Move(local.into())); let assign = StatementKind::Assign(Box::new((lhs, rhs))); Some(Statement::new(source_info, assign)) @@ -1812,6 +1896,36 @@ impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> { } } +/// Visitor to find all fields in the coroutine struct that are self-referential. +struct SelfReferentialFieldsFinder<'a> { + saved_locals: &'a CoroutineSavedLocals, + locals_requiring_unsafe_pinned: DenseBitSet, +} + +impl<'a> SelfReferentialFieldsFinder<'a> { + fn new(saved_locals: &'a CoroutineSavedLocals) -> Self { + SelfReferentialFieldsFinder { + saved_locals, + locals_requiring_unsafe_pinned: DenseBitSet::new_empty(saved_locals.domain_size()), + } + } +} + +impl<'tcx> Visitor<'tcx> for SelfReferentialFieldsFinder<'_> { + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, _loc: Location) { + if let Some(_) = self.saved_locals.get(place.local) { + match rvalue { + Rvalue::Ref(_, _, borrowed_place) => { + if let Some(_) = self.saved_locals.get(borrowed_place.local) { + self.locals_requiring_unsafe_pinned.insert(borrowed_place.local); + } + } + _ => {} + } + } + } +} + fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) { let mut linted_tys = FxHashSet::default(); diff --git a/tests/codegen-llvm/non-self-ref-coroutine.rs b/tests/codegen-llvm/non-self-ref-coroutine.rs new file mode 100644 index 0000000000000..229c289b663e3 --- /dev/null +++ b/tests/codegen-llvm/non-self-ref-coroutine.rs @@ -0,0 +1,30 @@ +// Tests that the coroutine struct is not noalias if it has no self-referential +// fields. +// NOTE: this should eventually use noalias + +//@ compile-flags: -C opt-level=3 +//@ edition: 2021 + +#![crate_type = "lib"] + +use std::future::Future; +use std::pin::Pin; + +async fn inner() {} + +// CHECK-LABEL: ; non_self_ref_coroutine::my_async_fn::{closure#0} +// CHECK-LABEL: my_async_fn +// CHECK-NOT: noalias +// CHECK-SAME: %_1 +async fn my_async_fn(b: bool) -> i32 { + let x = Box::new(5); + if b { + inner().await; + } + *x + 1 +} + +#[no_mangle] +pub fn create_future_as_trait(b: bool) -> Pin>> { + Box::pin(my_async_fn(b)) +} diff --git a/tests/codegen-llvm/self-ref-coroutine.rs b/tests/codegen-llvm/self-ref-coroutine.rs new file mode 100644 index 0000000000000..f09db99704c7b --- /dev/null +++ b/tests/codegen-llvm/self-ref-coroutine.rs @@ -0,0 +1,31 @@ +// Tests that the coroutine struct is not noalias if it has self-referential +// fields. + +//@ compile-flags: -C opt-level=3 +//@ edition: 2021 + +#![crate_type = "lib"] + +use std::future::Future; +use std::pin::Pin; + +async fn inner() {} + +// CHECK-LABEL: ; self_ref_coroutine::my_async_fn::{closure#0} +// CHECK-LABEL: my_async_fn +// CHECK-NOT: noalias +// CHECK-SAME: %_1 +async fn my_async_fn(b: bool) -> i32 { + let x = Box::new(5); + let y = &x; + if b { + inner().await; + std::hint::black_box(y); + } + *x + 1 +} + +#[no_mangle] +pub fn create_future_as_trait(b: bool) -> Pin>> { + Box::pin(my_async_fn(b)) +} diff --git a/tests/mir-opt/async_self_referential_fields.my_async_fn-{closure#0}.StateTransform.after.mir b/tests/mir-opt/async_self_referential_fields.my_async_fn-{closure#0}.StateTransform.after.mir new file mode 100644 index 0000000000000..40c4c4fde1a7b --- /dev/null +++ b/tests/mir-opt/async_self_referential_fields.my_async_fn-{closure#0}.StateTransform.after.mir @@ -0,0 +1,331 @@ +// MIR for `my_async_fn::{closure#0}` after StateTransform +/* coroutine_layout = CoroutineLayout { + field_tys: { + _s0: CoroutineSavedTy { + ty: std::pin::UnsafePinned>, + source_info: SourceInfo { + span: $DIR/async_self_referential_fields.rs:27:9: 27:10 (#0), + scope: scope[0], + }, + ignore_for_traits: false, + }, + _s1: CoroutineSavedTy { + ty: &'{erased} std::boxed::Box, + source_info: SourceInfo { + span: $DIR/async_self_referential_fields.rs:28:9: 28:10 (#0), + scope: scope[1], + }, + ignore_for_traits: false, + }, + _s2: CoroutineSavedTy { + ty: Coroutine( + DefId(0:15 ~ async_self_referential_fields[ea9b]::inner_async_fn::{closure#0}), + [ + (), + std::future::ResumeTy, + (), + (), + (), + ], + ), + source_info: SourceInfo { + span: $DIR/async_self_referential_fields.rs:29:5: 29:27 (#10), + scope: scope[2], + }, + ignore_for_traits: false, + }, + }, + variant_fields: { + Unresumed(0): [], + Returned (1): [], + Panicked (2): [], + Suspend0 (3): [_s0, _s1, _s2], + }, + storage_conflicts: BitMatrix(3x3) { + (_s0, _s0), + (_s0, _s1), + (_s0, _s2), + (_s1, _s0), + (_s1, _s1), + (_s1, _s2), + (_s2, _s0), + (_s2, _s1), + (_s2, _s2), + }, +} */ + +fn my_async_fn::{closure#0}(_1: Pin<&mut {async fn body of my_async_fn()}>, _2: &mut std::task::Context<'_>) -> std::task::Poll { + debug _task_context => _2; + let mut _0: std::task::Poll; + let _3: std::boxed::Box; + let _5: (); + let mut _6: {async fn body of inner_async_fn()}; + let mut _7: {async fn body of inner_async_fn()}; + let mut _9: (); + let _10: (); + let mut _11: std::task::Poll<()>; + let mut _12: std::pin::Pin<&mut {async fn body of inner_async_fn()}>; + let mut _13: &mut {async fn body of inner_async_fn()}; + let mut _14: &mut {async fn body of inner_async_fn()}; + let mut _15: &mut std::task::Context<'_>; + let mut _16: &mut std::task::Context<'_>; + let mut _17: &mut std::task::Context<'_>; + let mut _18: isize; + let mut _20: !; + let mut _21: &mut std::task::Context<'_>; + let mut _22: (); + let _23: &std::boxed::Box; + let mut _24: &std::boxed::Box; + let mut _25: i32; + let mut _26: *const i32; + let mut _27: i32; + let mut _28: u32; + let mut _29: &mut {async fn body of my_async_fn()}; + scope 1 { + debug x => (((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box); + let _4: &std::boxed::Box; + scope 2 { + debug y => (((*_29) as variant#3).1: &std::boxed::Box); + let mut _8: {async fn body of inner_async_fn()}; + scope 3 { + debug __awaitee => (((*_29) as variant#3).2: {async fn body of inner_async_fn()}); + let _19: (); + scope 4 { + debug result => _19; + } + } + } + } + + bb0: { + _29 = copy (_1.0: &mut {async fn body of my_async_fn()}); + _28 = discriminant((*_29)); + switchInt(move _28) -> [0: bb1, 1: bb33, 2: bb32, 3: bb31, otherwise: bb9]; + } + + bb1: { + nop; + (((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box) = Box::::new(const 5_i32) -> [return: bb2, unwind: bb27]; + } + + bb2: { + nop; + (((*_29) as variant#3).1: &std::boxed::Box) = &(((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box); + StorageLive(_5); + StorageLive(_6); + StorageLive(_7); + _7 = inner_async_fn() -> [return: bb3, unwind: bb24]; + } + + bb3: { + _6 = <{async fn body of inner_async_fn()} as IntoFuture>::into_future(move _7) -> [return: bb4, unwind: bb23]; + } + + bb4: { + StorageDead(_7); + PlaceMention(_6); + nop; + (((*_29) as variant#3).2: {async fn body of inner_async_fn()}) = move _6; + goto -> bb5; + } + + bb5: { + StorageLive(_10); + StorageLive(_11); + StorageLive(_12); + StorageLive(_13); + StorageLive(_14); + _14 = &mut (((*_29) as variant#3).2: {async fn body of inner_async_fn()}); + _13 = &mut (*_14); + _12 = Pin::<&mut {async fn body of inner_async_fn()}>::new_unchecked(move _13) -> [return: bb6, unwind: bb20]; + } + + bb6: { + StorageDead(_13); + StorageLive(_15); + StorageLive(_16); + StorageLive(_17); + _17 = copy _2; + _16 = move _17; + goto -> bb7; + } + + bb7: { + _15 = &mut (*_16); + StorageDead(_17); + _11 = <{async fn body of inner_async_fn()} as Future>::poll(move _12, move _15) -> [return: bb8, unwind: bb19]; + } + + bb8: { + StorageDead(_15); + StorageDead(_12); + PlaceMention(_11); + _18 = discriminant(_11); + switchInt(move _18) -> [0: bb11, 1: bb10, otherwise: bb9]; + } + + bb9: { + unreachable; + } + + bb10: { + _10 = const (); + StorageDead(_16); + StorageDead(_14); + StorageDead(_11); + StorageDead(_10); + StorageLive(_21); + StorageLive(_22); + _22 = (); + _0 = std::task::Poll::::Pending; + StorageDead(_5); + StorageDead(_6); + StorageDead(_21); + StorageDead(_22); + discriminant((*_29)) = 3; + return; + } + + bb11: { + StorageLive(_19); + _19 = copy ((_11 as Ready).0: ()); + _5 = copy _19; + StorageDead(_19); + StorageDead(_16); + StorageDead(_14); + StorageDead(_11); + StorageDead(_10); + drop((((*_29) as variant#3).2: {async fn body of inner_async_fn()})) -> [return: bb13, unwind: bb22]; + } + + bb12: { + StorageDead(_22); + _2 = move _21; + StorageDead(_21); + _9 = const (); + goto -> bb5; + } + + bb13: { + nop; + goto -> bb14; + } + + bb14: { + StorageDead(_6); + StorageDead(_5); + StorageLive(_23); + StorageLive(_24); + _24 = copy (((*_29) as variant#3).1: &std::boxed::Box); + _23 = std::hint::black_box::<&Box>(move _24) -> [return: bb15, unwind: bb18]; + } + + bb15: { + StorageDead(_24); + StorageDead(_23); + StorageLive(_25); + _26 = copy (((((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box).0: std::ptr::Unique).0: std::ptr::NonNull) as *const i32 (Transmute); + _25 = copy (*_26); + _27 = Add(move _25, const 1_i32); + StorageDead(_25); + nop; + drop((((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box)) -> [return: bb16, unwind: bb27]; + } + + bb16: { + nop; + goto -> bb29; + } + + bb17: { + _0 = std::task::Poll::::Ready(move _27); + discriminant((*_29)) = 1; + return; + } + + bb18 (cleanup): { + StorageDead(_24); + StorageDead(_23); + goto -> bb26; + } + + bb19 (cleanup): { + StorageDead(_15); + StorageDead(_12); + StorageDead(_16); + goto -> bb21; + } + + bb20 (cleanup): { + StorageDead(_13); + StorageDead(_12); + goto -> bb21; + } + + bb21 (cleanup): { + StorageDead(_14); + StorageDead(_11); + StorageDead(_10); + drop((((*_29) as variant#3).2: {async fn body of inner_async_fn()})) -> [return: bb22, unwind terminate(cleanup)]; + } + + bb22 (cleanup): { + nop; + goto -> bb25; + } + + bb23 (cleanup): { + goto -> bb24; + } + + bb24 (cleanup): { + StorageDead(_7); + goto -> bb25; + } + + bb25 (cleanup): { + StorageDead(_6); + StorageDead(_5); + goto -> bb26; + } + + bb26 (cleanup): { + nop; + drop((((((*_29) as variant#3).0: std::pin::UnsafePinned>).0: std::cell::UnsafeCell>).0: std::boxed::Box)) -> [return: bb27, unwind terminate(cleanup)]; + } + + bb27 (cleanup): { + nop; + goto -> bb28; + } + + bb28 (cleanup): { + goto -> bb30; + } + + bb29: { + goto -> bb17; + } + + bb30 (cleanup): { + discriminant((*_29)) = 2; + resume; + } + + bb31: { + StorageLive(_5); + StorageLive(_6); + StorageLive(_21); + StorageLive(_22); + _21 = move _2; + goto -> bb12; + } + + bb32: { + assert(const false, "`async fn` resumed after panicking") -> [success: bb32, unwind continue]; + } + + bb33: { + assert(const false, "`async fn` resumed after completion") -> [success: bb33, unwind continue]; + } +} diff --git a/tests/mir-opt/async_self_referential_fields.rs b/tests/mir-opt/async_self_referential_fields.rs new file mode 100644 index 0000000000000..c676f9fd3f290 --- /dev/null +++ b/tests/mir-opt/async_self_referential_fields.rs @@ -0,0 +1,38 @@ +//@ edition:2021 +// skip-filecheck +// EMIT_MIR async_self_referential_fields.my_async_fn-{closure#0}.StateTransform.after.mir + +#![allow(unused)] + +use std::future::Future; +use std::ops::{AsyncFn, AsyncFnMut, AsyncFnOnce}; +use std::pin::pin; +use std::task::*; + +pub fn block_on(fut: impl Future) -> T { + let mut fut = pin!(fut); + let ctx = &mut Context::from_waker(Waker::noop()); + + loop { + match fut.as_mut().poll(ctx) { + Poll::Pending => {} + Poll::Ready(t) => break t, + } + } +} + +async fn inner_async_fn() {} + +async fn my_async_fn() -> i32 { + let x = Box::new(5); + let y = &x; + inner_async_fn().await; + std::hint::black_box(y); + *x + 1 +} + +fn main() { + block_on(async { + my_async_fn().await; + }); +}