From d71f7be21812203f343285e40456975ff5cb8ae9 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 26 Dec 2023 16:42:40 +0000 Subject: [PATCH] Couple of random coroutine pass simplifications --- compiler/rustc_mir_transform/src/coroutine.rs | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 5e434d30b003b..ce1a36cf67021 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -1417,20 +1417,18 @@ fn create_coroutine_resume_function<'tcx>( cases.insert(0, (UNRESUMED, START_BLOCK)); // Panic when resumed on the returned or poisoned state - let coroutine_kind = body.coroutine_kind().unwrap(); - if can_unwind { cases.insert( 1, - (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(coroutine_kind))), + (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind))), ); } if can_return { - let block = match coroutine_kind { + let block = match transform.coroutine_kind { CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine(_) => { - insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) + insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind)) } CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { @@ -1444,7 +1442,7 @@ fn create_coroutine_resume_function<'tcx>( make_coroutine_state_argument_indirect(tcx, body); - match coroutine_kind { + match transform.coroutine_kind { // Iterator::next doesn't accept a pinned argument, // unlike for all other coroutine kinds. CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {} @@ -1614,12 +1612,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } }; - let is_async_kind = - matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Async, _)); - let is_async_gen_kind = - matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)); - let is_gen_kind = - matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)); let new_ret_ty = match coroutine_kind { CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { // Compute Poll @@ -1653,7 +1645,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx); // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. - if is_async_kind || is_async_gen_kind { + if matches!( + coroutine_kind, + CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) + ) { transform_async_context(tcx, body); } @@ -1662,11 +1657,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // case there is no `Assign` to it that the transform can turn into a store to the coroutine // state. After the yield the slot in the coroutine state would then be uninitialized. let resume_local = Local::new(2); - let resume_ty = if is_async_kind { - Ty::new_task_context(tcx) - } else { - body.local_decls[resume_local].ty - }; + let resume_ty = body.local_decls[resume_local].ty; let old_resume_local = replace_local(resume_local, resume_ty, body, tcx); // When first entering the coroutine, move the resume argument into its old local @@ -1709,11 +1700,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Run the transformation which converts Places from Local to coroutine struct // accesses for locals in `remap`. // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning - // either CoroutineState::Complete(x) and CoroutineState::Yielded(y), - // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. + // either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`, + // or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind. let mut transform = TransformVisitor { tcx, - coroutine_kind: body.coroutine_kind().unwrap(), + coroutine_kind, remap, storage_liveness, always_live_locals, @@ -1730,7 +1721,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { body.spread_arg = None; // Remove the context argument within generator bodies. - if is_gen_kind { + if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) { transform_gen_context(tcx, body); }