Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ struct TransformVisitor<'tcx> {
old_yield_ty: Ty<'tcx>,

old_ret_ty: Ty<'tcx>,

// Fields in the coroutine struct that reference other fields.
self_referential_fields: DenseBitSet<Local>,
}

impl<'tcx> TransformVisitor<'tcx> {
Expand Down Expand Up @@ -356,6 +359,45 @@ impl<'tcx> TransformVisitor<'tcx> {
Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
}

// Create a `Place` referencing a self-referential coroutine struct field.
// Self-referential coroutine struct fields are wrapped in `UnsafePinned`. This method
// creates the projections to get the `Place` behind `UnsafePinned`.
#[tracing::instrument(level = "trace", skip(self), ret)]
fn make_self_referential_field(
&self,
variant_index: VariantIdx,
idx: FieldIdx,
ty: Ty<'tcx>,
) -> Place<'tcx> {
let self_place = Place::from(SELF_ARG);
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);

let ty::Adt(adt_def, args) = ty.kind() else {
bug!("expected self-referential field to be an ADT, but it is {:?}", ty);
};
assert_eq!(adt_def.did(), self.tcx.require_lang_item(LangItem::UnsafePinned, DUMMY_SP));
let original_ty = args.type_at(0);

let unsafe_cell_def_id = self.tcx.require_lang_item(LangItem::UnsafeCell, DUMMY_SP);
let unsafe_cell_ty = Ty::new_adt(
self.tcx,
self.tcx.adt_def(unsafe_cell_def_id),
self.tcx.mk_args(&[original_ty.into()]),
);

let mut projection = base.projection.to_vec();
// self.field (UnsafePinned<T>)
projection.push(ProjectionElem::Field(idx, ty));
// value (UnsafeCell<T>)
projection.push(ProjectionElem::Field(FieldIdx::from_u32(0), unsafe_cell_ty));
// value (T)
projection.push(ProjectionElem::Field(FieldIdx::from_u32(0), original_ty));

let place = Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) };
debug!(?place);
place
}

// Create a statement which changes the discriminant
#[tracing::instrument(level = "trace", skip(self))]
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
Expand Down Expand Up @@ -412,8 +454,17 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
#[tracing::instrument(level = "trace", skip(self), ret)]
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
// Replace an Local in the remap with a coroutine struct access
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
let local = place.local;
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(local) {
if self.self_referential_fields.contains(local) {
replace_base(
place,
self.make_self_referential_field(variant_index, idx, ty),
self.tcx,
);
} else {
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
}
}
}

Expand Down Expand Up @@ -964,10 +1015,12 @@ impl StorageConflictVisitor<'_, '_> {
}
}

#[tracing::instrument(level = "trace", skip(liveness, body))]
#[tracing::instrument(level = "trace", skip(tcx, liveness, body))]
fn compute_layout<'tcx>(
tcx: TyCtxt<'tcx>,
liveness: LivenessInfo,
body: &Body<'tcx>,
self_referential_fields: &DenseBitSet<Local>,
) -> (
IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
CoroutineLayout<'tcx>,
Expand Down Expand Up @@ -1006,8 +1059,19 @@ fn compute_layout<'tcx>(
ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true,
_ => false,
};

// Use `UnsafePinned` for self-referential fields.
let local_ty = if self_referential_fields.contains(local) {
let unsafe_pinned_did = tcx.require_lang_item(LangItem::UnsafePinned, body.span);
let unsafe_pinned_adt_def = tcx.adt_def(unsafe_pinned_did);
let args = tcx.mk_args(&[decl.ty.into()]);
Ty::new_adt(tcx, unsafe_pinned_adt_def, args)
} else {
decl.ty
};

let decl =
CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
CoroutineSavedTy { ty: local_ty, source_info: decl.source_info, ignore_for_traits };
debug!(?decl);

tys.push(decl);
Expand Down Expand Up @@ -1411,10 +1475,16 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
let always_live_locals = always_storage_live_locals(body);
let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);

let mut self_referential_fields_finder =
SelfReferentialFieldsFinder::new(&liveness_info.saved_locals);
self_referential_fields_finder.visit_body(body);
let self_referential_fields = self_referential_fields_finder.locals_requiring_unsafe_pinned;

// Extract locals which are live across suspension point into `layout`
// `remap` gives a mapping from local indices onto coroutine struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
let (_, coroutine_layout, _) =
compute_layout(tcx, liveness_info, body, &self_referential_fields);

check_suspend_tys(tcx, &coroutine_layout, body);
check_field_tys_sized(tcx, &coroutine_layout, def_id);
Expand Down Expand Up @@ -1557,10 +1627,17 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
vis.visit_body(body);
}

let mut self_referential_fields_finder =
SelfReferentialFieldsFinder::new(&liveness_info.saved_locals);
self_referential_fields_finder.visit_body(body);
let self_referential_fields = self_referential_fields_finder.locals_requiring_unsafe_pinned;
debug!("self referential fields: {:?}", self_referential_fields,);

// Extract locals which are live across suspension point into `layout`
// `remap` gives a mapping from local indices onto coroutine struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
let (remap, layout, storage_liveness) =
compute_layout(tcx, liveness_info, body, &self_referential_fields);

let can_return = can_return(tcx, body, body.typing_env(tcx));

Expand All @@ -1585,6 +1662,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
new_ret_local,
old_ret_ty,
old_yield_ty,
self_referential_fields,
};
transform.visit_body(body);

Expand All @@ -1599,7 +1677,13 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
0..0,
args_iter.filter_map(|local| {
let (ty, variant_index, idx) = transform.remap[local]?;
let lhs = transform.make_field(variant_index, idx, ty);

let lhs = if transform.self_referential_fields.contains(local) {
transform.make_self_referential_field(variant_index, idx, ty)
} else {
transform.make_field(variant_index, idx, ty)
};

let rhs = Rvalue::Use(Operand::Move(local.into()));
let assign = StatementKind::Assign(Box::new((lhs, rhs)));
Some(Statement::new(source_info, assign))
Expand Down Expand Up @@ -1812,6 +1896,36 @@ impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
}
}

/// Visitor to find all fields in the coroutine struct that are self-referential.
struct SelfReferentialFieldsFinder<'a> {
saved_locals: &'a CoroutineSavedLocals,
locals_requiring_unsafe_pinned: DenseBitSet<Local>,
}

impl<'a> SelfReferentialFieldsFinder<'a> {
fn new(saved_locals: &'a CoroutineSavedLocals) -> Self {
SelfReferentialFieldsFinder {
saved_locals,
locals_requiring_unsafe_pinned: DenseBitSet::new_empty(saved_locals.domain_size()),
}
}
}

impl<'tcx> Visitor<'tcx> for SelfReferentialFieldsFinder<'_> {
fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, _loc: Location) {
if let Some(_) = self.saved_locals.get(place.local) {
match rvalue {
Rvalue::Ref(_, _, borrowed_place) => {
if let Some(_) = self.saved_locals.get(borrowed_place.local) {
self.locals_requiring_unsafe_pinned.insert(borrowed_place.local);
}
}
_ => {}
}
}
}
}

fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) {
let mut linted_tys = FxHashSet::default();

Expand Down
30 changes: 30 additions & 0 deletions tests/codegen-llvm/non-self-ref-coroutine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Tests that the coroutine struct is not noalias if it has no self-referential
// fields.
// NOTE: this should eventually use noalias

//@ compile-flags: -C opt-level=3
//@ edition: 2021

#![crate_type = "lib"]

use std::future::Future;
use std::pin::Pin;

async fn inner() {}

// CHECK-LABEL: ; non_self_ref_coroutine::my_async_fn::{closure#0}
// CHECK-LABEL: my_async_fn
// CHECK-NOT: noalias
// CHECK-SAME: %_1
async fn my_async_fn(b: bool) -> i32 {
let x = Box::new(5);
if b {
inner().await;
}
*x + 1
}

#[no_mangle]
pub fn create_future_as_trait(b: bool) -> Pin<Box<dyn Future<Output = i32>>> {
Box::pin(my_async_fn(b))
}
31 changes: 31 additions & 0 deletions tests/codegen-llvm/self-ref-coroutine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Tests that the coroutine struct is not noalias if it has self-referential
// fields.

//@ compile-flags: -C opt-level=3
//@ edition: 2021

#![crate_type = "lib"]

use std::future::Future;
use std::pin::Pin;

async fn inner() {}

// CHECK-LABEL: ; self_ref_coroutine::my_async_fn::{closure#0}
// CHECK-LABEL: my_async_fn
// CHECK-NOT: noalias
// CHECK-SAME: %_1
async fn my_async_fn(b: bool) -> i32 {
let x = Box::new(5);
let y = &x;
if b {
inner().await;
std::hint::black_box(y);
}
*x + 1
}

#[no_mangle]
pub fn create_future_as_trait(b: bool) -> Pin<Box<dyn Future<Output = i32>>> {
Box::pin(my_async_fn(b))
}
Loading
Loading