Skip to content

Commit

Permalink
Implement async gen blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Nov 28, 2023
1 parent e21ab89 commit 654e0ba
Show file tree
Hide file tree
Showing 31 changed files with 549 additions and 50 deletions.
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,7 @@ pub enum ExprKind {
pub enum GenBlockKind {
Async,
Gen,
AsyncGen,
}

impl fmt::Display for GenBlockKind {
Expand All @@ -1514,6 +1515,7 @@ impl GenBlockKind {
match self {
GenBlockKind::Async => "async",
GenBlockKind::Gen => "gen",
GenBlockKind::AsyncGen => "async gen",
}
}
}
Expand Down
138 changes: 125 additions & 13 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self
.make_async_gen_expr(
*capture_clause,
e.id,
None,
e.span,
hir::CoroutineSource::Block,
|this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
),
ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()),
ExprKind::Err => hir::ExprKind::Err(
self.tcx.sess.delay_span_bug(e.span, "lowered ExprKind::Err"),
Expand Down Expand Up @@ -726,6 +735,84 @@ impl<'hir> LoweringContext<'_, 'hir> {
}))
}

/// Lower a `async gen` construct to a generator that implements `AsyncIterator`.
///
/// This results in:
///
/// ```text
/// static move? |_task_context| -> () {
/// <body>
/// }
/// ```
pub(super) fn make_async_gen_expr(
&mut self,
capture_clause: CaptureBy,
closure_node_id: NodeId,
_yield_ty: Option<hir::FnRetTy<'hir>>,
span: Span,
async_coroutine_source: hir::CoroutineSource,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));

// Resume argument type: `ResumeTy`
let unstable_span =
self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone());
let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span);
let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
};

// The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
let fn_decl = self.arena.alloc(hir::FnDecl {
inputs: arena_vec![self; input_ty],
output,
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
lifetime_elision_allowed: false,
});

// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
let (pat, task_context_hid) = self.pat_ident_binding_mode(
span,
Ident::with_dummy_span(sym::_task_context),
hir::BindingAnnotation::MUT,
);
let param = hir::Param {
hir_id: self.next_id(),
pat,
ty_span: self.lower_span(span),
span: self.lower_span(span),
};
let params = arena_vec![self; param];

let body = self.lower_body(move |this| {
this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source));

let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
let res = body(this);
this.task_context = old_ctx;
(params, res)
});

// `static |_task_context| -> <ret_ty> { body }`:
hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
def_id: self.local_def_id(closure_node_id),
binder: hir::ClosureBinder::Default,
capture_clause,
bound_generic_params: &[],
fn_decl,
body,
fn_decl_span: self.lower_span(span),
fn_arg_span: None,
movability: Some(hir::Movability::Static),
constness: hir::Constness::NotConst,
}))
}

/// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to
/// `inner_hir_id` in case the `async_fn_track_caller` feature is enabled.
pub(super) fn maybe_forward_track_caller(
Expand Down Expand Up @@ -775,15 +862,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// ```
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
let full_span = expr.span.to(await_kw_span);
match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => {}

let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
await_kw_span,
item_span: self.current_item,
}));
}
}
};

let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await,
Expand Down Expand Up @@ -872,12 +962,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.stmt_expr(span, match_expr)
};

// task_context = yield ();
// Depending on `async` of `async gen`:
// async - task_context = yield ();
// async gen - task_context = yield async_gen_pending();
let yield_stmt = {
let unit = self.expr_unit(span);
let yielded = if is_async_gen {
self.expr_call_lang_item_fn(span, hir::LangItem::AsyncGenPending, &[])
} else {
self.expr_unit(span)
};

let yield_expr = self.expr(
span,
hir::ExprKind::Yield(unit, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
hir::ExprKind::Yield(yielded, hir::YieldSource::Await { expr: Some(expr_hir_id) }),
);
let yield_expr = self.arena.alloc(yield_expr);

Expand Down Expand Up @@ -987,7 +1084,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
Some(movability)
}
Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => {
Some(
hir::CoroutineKind::Gen(_)
| hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_),
) => {
panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
}
None => {
Expand Down Expand Up @@ -1494,8 +1595,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
}

fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => {}
let is_async_gen = match self.coroutine_kind {
Some(hir::CoroutineKind::Gen(_)) => false,
Some(hir::CoroutineKind::AsyncGen(_)) => true,
Some(hir::CoroutineKind::Async(_)) => {
return hir::ExprKind::Err(
self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }),
Expand All @@ -1511,14 +1613,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
)
.emit();
}
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine)
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine);
false
}
}
};

let expr =
let mut yielded =
opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span));

hir::ExprKind::Yield(expr, hir::YieldSource::Yield)
if is_async_gen {
// yield async_gen_ready($expr);
yielded = self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncGenReady,
std::slice::from_ref(yielded),
);
}

hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
}

/// Desugar `ExprForLoop` from: `[opt_ident]: for <pat> in <head> <body>` into:
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
CoroutineSource::Closure => "gen closure",
_ => bug!("gen block/closure expected, but gen function found."),
},
CoroutineKind::AsyncGen(kind) => match kind {
CoroutineSource::Block => "async gen block",
CoroutineSource::Closure => "async gen closure",
_ => bug!("gen block/closure expected, but gen function found."),
},
CoroutineKind::Async(async_kind) => match async_kind {
CoroutineSource::Block => "async block",
CoroutineSource::Closure => "async closure",
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_borrowck/src/diagnostics/region_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,21 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
" of gen function"
}
},

Some(hir::CoroutineKind::AsyncGen(gen)) => match gen {
hir::CoroutineSource::Block => " of async gen block",
hir::CoroutineSource::Closure => " of async gen closure",
hir::CoroutineSource::Fn => {
let parent_item =
hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from async gen fn should be in fn")
.output;
span = output.span();
" of async gen function"
}
},
Some(hir::CoroutineKind::Coroutine) => " of coroutine",
None => " of closure",
};
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,9 @@ fn coroutine_kind_label(coroutine_kind: Option<CoroutineKind>) -> &'static str {
Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block",
Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure",
Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn",
Some(CoroutineKind::AsyncGen(CoroutineSource::Block)) => "async_gen_block",
Some(CoroutineKind::AsyncGen(CoroutineSource::Closure)) => "async_gen_closure",
Some(CoroutineKind::AsyncGen(CoroutineSource::Fn)) => "async_gen_fn",
Some(CoroutineKind::Coroutine) => "coroutine",
None => "closure",
}
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let local = mir::Local::from_usize(local);
let expected_ty = self.monomorphize(self.mir.local_decls[local].ty);
if expected_ty != op.layout.ty {
warn!("Unexpected initial operand type. See the issues/114858");
warn!(
"Unexpected initial operand type: expected {expected_ty:?}, found {:?}.\
See <https://github.com/rust-lang/rust/issues/114858>.",
op.layout.ty
);
}
}
}
Expand Down
25 changes: 13 additions & 12 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1484,12 +1484,16 @@ impl<'hir> Body<'hir> {
/// The type of source expression that caused this coroutine to be created.
#[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)]
pub enum CoroutineKind {
/// An explicit `async` block or the body of an async function.
/// An explicit `async` block or the body of an `async` function.
Async(CoroutineSource),

/// An explicit `gen` block or the body of a `gen` function.
Gen(CoroutineSource),

/// An explicit `async gen` block or the body of an `async gen` function,
/// which is able to both `yield` and `.await`.
AsyncGen(CoroutineSource),

/// A coroutine literal created via a `yield` inside a closure.
Coroutine,
}
Expand All @@ -1514,6 +1518,14 @@ impl fmt::Display for CoroutineKind {
}
k.fmt(f)
}
CoroutineKind::AsyncGen(k) => {
if f.alternate() {
f.write_str("`async gen` ")?;
} else {
f.write_str("async gen ")?
}
k.fmt(f)
}
}
}
}
Expand Down Expand Up @@ -2209,17 +2221,6 @@ impl fmt::Display for YieldSource {
}
}

impl From<CoroutineKind> for YieldSource {
fn from(kind: CoroutineKind) -> Self {
match kind {
// Guess based on the kind of the current coroutine.
CoroutineKind::Coroutine => Self::Yield,
CoroutineKind::Async(_) => Self::Await { expr: None },
CoroutineKind::Gen(_) => Self::Yield,
}
}
}

// N.B., if you change this, you'll probably want to change the corresponding
// type structure in middle/ty.rs as well.
#[derive(Debug, Clone, Copy, HashStable_Generic)]
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ language_item_table! {

Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0);
CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None;
Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1);
Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None;
Expand Down Expand Up @@ -294,6 +295,10 @@ language_item_table! {
PollReady, sym::Ready, poll_ready_variant, Target::Variant, GenericRequirement::None;
PollPending, sym::Pending, poll_pending_variant, Target::Variant, GenericRequirement::None;

AsyncGenReady, sym::AsyncGenReady, async_gen_ready, Target::Method(MethodKind::Inherent), GenericRequirement::Exact(1);
AsyncGenPending, sym::AsyncGenPending, async_gen_pending, Target::Method(MethodKind::Inherent), GenericRequirement::Exact(1);
AsyncGenFinished, sym::AsyncGenFinished, async_gen_finished, Target::AssocConst, GenericRequirement::Exact(1);

// FIXME(swatinem): the following lang items are used for async lowering and
// should become obsolete eventually.
ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None;
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ pub(super) fn check_fn<'a, 'tcx>(
&& can_be_coroutine.is_some()
{
let yield_ty = match kind {
hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => {
hir::CoroutineKind::Gen(..)
| hir::CoroutineKind::AsyncGen(..)
| hir::CoroutineKind::Coroutine => {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ impl<O> AssertKind<O> {
RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero",
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => {
"`async gen fn` resumed after completion"
}
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
"`gen fn` should just keep returning `None` after completion"
}
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => {
"`async gen fn` resumed after panicking"
}
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
"`gen fn` should just keep returning `None` after panicking"
}
Expand Down Expand Up @@ -245,13 +251,15 @@ impl<O> AssertKind<O> {
DivisionByZero(_) => middle_assert_divide_by_zero,
RemainderByZero(_) => middle_assert_remainder_by_zero,
ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return,
ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => todo!(),
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
bug!("gen blocks can be resumed after they return and will keep returning `None`")
}
ResumedAfterReturn(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_return
}
ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic,
ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => todo!(),
ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_gen_resume_after_panic,
ResumedAfterPanic(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_panic
Expand Down
Loading

0 comments on commit 654e0ba

Please sign in to comment.