Skip to content

Commit

Permalink
Fix ABI for FnMut/Fn impls for async closures
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Mar 19, 2024
1 parent 05116c5 commit f1fef64
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 18 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ macro_rules! make_mir_visitor {
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id: _def_id,
receiver_by_ref: _,
} |
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } |
ty::InstanceDef::DropGlue(_def_id, None) => {}
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@ pub enum InstanceDef<'tcx> {
/// The body generated here differs significantly from the `ClosureOnceShim`,
/// since we need to generate a distinct coroutine type that will move the
/// closure's upvars *out* of the closure.
ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId },
ConstructCoroutineInClosureShim {
coroutine_closure_def_id: DefId,
// Whether the generated MIR body takes the coroutine by-ref. This is
// because the signature of `<{async fn} as FnMut>::call_mut` is:
// `fn(&mut self, args: A) -> <Self as FnOnce>::Output`, that is to say
// that it returns the `FnOnce`-flavored coroutine but takes the closure
// by ref (and similarly for `Fn::call`).
receiver_by_ref: bool,
},

/// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
/// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
Expand Down Expand Up @@ -188,6 +196,7 @@ impl<'tcx> InstanceDef<'tcx> {
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
| ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id: def_id,
receiver_by_ref: _,
}
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id }
| InstanceDef::DropGlue(def_id, _)
Expand Down
24 changes: 19 additions & 5 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
}

ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}
ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
receiver_by_ref,
} => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref),

ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
Expand Down Expand Up @@ -1015,12 +1016,17 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
fn build_construct_coroutine_by_move_shim<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_closure_def_id: DefId,
receiver_by_ref: bool,
) -> Body<'tcx> {
let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
bug!();
};

if receiver_by_ref {
self_ty = Ty::new_mut_ptr(tcx, self_ty);
}

let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
tcx.mk_fn_sig(
[self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
Expand Down Expand Up @@ -1076,11 +1082,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(

let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
receiver_by_ref,
});

let body =
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span);
dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(()));
dump_mir(
tcx,
false,
if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" },
&0,
&body,
|_, _| Ok(()),
);

body
}
15 changes: 11 additions & 4 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,18 @@ fn fn_sig_for_fn_abi<'tcx>(
// a separate def-id for these bodies.
let mut coroutine_kind = args.as_coroutine_closure().kind();

if let InstanceDef::ConstructCoroutineInClosureShim { .. } = instance.def {
coroutine_kind = ty::ClosureKind::FnOnce;
}
let env_ty =
if let InstanceDef::ConstructCoroutineInClosureShim { receiver_by_ref, .. } =
instance.def
{
coroutine_kind = ty::ClosureKind::FnOnce;

let env_ty = tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region);
// Implementations of `FnMut` and `Fn` for coroutine-closures
// still take their receiver by ref.
if receiver_by_ref { Ty::new_mut_ptr(tcx, coroutine_ty) } else { coroutine_ty }
} else {
tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region)
};

let sig = sig.skip_binder();
ty::Binder::bind_with_vars(
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_ty_utils/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ fn resolve_associated_item<'tcx>(
Some(Instance {
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
receiver_by_ref: target_kind != ty::ClosureKind::FnOnce,
},
args,
})
Expand All @@ -304,6 +305,7 @@ fn resolve_associated_item<'tcx>(
Some(Instance {
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
receiver_by_ref: false,
},
args,
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
return;
}
}
10 changes: 10 additions & 0 deletions tests/mir-opt/async_closure_shims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ async fn call_once(f: impl AsyncFnOnce(i32)) {
f(1).await;
}

async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
f(1).await;
}

// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.mir
// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.mir
// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.mir
pub fn main() {
block_on(async {
let b = 2i32;
Expand All @@ -40,5 +45,10 @@ pub fn main() {
};
call_mut(&mut async_closure).await;
call_once(async_closure).await;

let async_closure = async move |a: i32| {
let a = &a;
};
call_normal(&async_closure).await;
});
}

0 comments on commit f1fef64

Please sign in to comment.