Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uplift some miscellaneous coroutine-specific machinery into check_closure #119417

Merged
merged 3 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 3 additions & 63 deletions compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
pub(super) fn check_fn<'a, 'tcx>(
fcx: &mut FnCtxt<'a, 'tcx>,
fn_sig: ty::FnSig<'tcx>,
coroutine_types: Option<CoroutineTypes<'tcx>>,
decl: &'tcx hir::FnDecl<'tcx>,
fn_def_id: LocalDefId,
body: &'tcx hir::Body<'tcx>,
closure_kind: Option<hir::ClosureKind>,
params_can_be_unsized: bool,
) -> Option<CoroutineTypes<'tcx>> {
let fn_id = fcx.tcx.local_def_id_to_hir_id(fn_def_id);
Expand All @@ -49,54 +49,13 @@ pub(super) fn check_fn<'a, 'tcx>(
fcx.param_env,
));

fcx.coroutine_types = coroutine_types;
fcx.ret_coercion = Some(RefCell::new(CoerceMany::new(ret_ty)));

let span = body.value.span;

forbid_intrinsic_abi(tcx, span, fn_sig.abi);

if let Some(hir::ClosureKind::Coroutine(kind)) = closure_kind {
let yield_ty = match kind {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
| hir::CoroutineKind::Coroutine(_) => {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span,
});
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
yield_ty
}
// HACK(-Ztrait-solver=next): In the *old* trait solver, we must eagerly
// guide inference on the yield type so that we can handle `AsyncIterator`
// in this block in projection correctly. In the new trait solver, it is
// not a problem.
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span,
});
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);

Ty::new_adt(
tcx,
tcx.adt_def(tcx.require_lang_item(hir::LangItem::Poll, Some(span))),
tcx.mk_args(&[Ty::new_adt(
tcx,
tcx.adt_def(tcx.require_lang_item(hir::LangItem::Option, Some(span))),
tcx.mk_args(&[yield_ty.into()]),
)
.into()]),
)
}
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => Ty::new_unit(tcx),
};

// Resume type defaults to `()` if the coroutine has no argument.
let resume_ty = fn_sig.inputs().get(0).copied().unwrap_or_else(|| Ty::new_unit(tcx));

fcx.resume_yield_tys = Some((resume_ty, yield_ty));
}

GatherLocalsVisitor::new(fcx).visit_body(body);

// C-variadic fns also have a `VaList` input that's not listed in `fn_sig`
Expand Down Expand Up @@ -146,25 +105,6 @@ pub(super) fn check_fn<'a, 'tcx>(
fcx.require_type_is_sized(declared_ret_ty, return_or_body_span, traits::SizedReturnType);
fcx.check_return_expr(body.value, false);

// We insert the deferred_coroutine_interiors entry after visiting the body.
// This ensures that all nested coroutines appear before the entry of this coroutine.
// resolve_coroutine_interiors relies on this property.
let coroutine_ty = if let Some(hir::ClosureKind::Coroutine(coroutine_kind)) = closure_kind {
let interior = fcx
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::MiscVariable, span });
fcx.deferred_coroutine_interiors.borrow_mut().push((
fn_def_id,
body.id(),
interior,
coroutine_kind,
));

let (resume_ty, yield_ty) = fcx.resume_yield_tys.unwrap();
Some(CoroutineTypes { resume_ty, yield_ty, interior })
} else {
None
};

// Finalize the return check by taking the LUB of the return types
// we saw and assigning it to the expected return type. This isn't
// really expected to fail, since the coercions would have failed
Expand Down Expand Up @@ -200,7 +140,7 @@ pub(super) fn check_fn<'a, 'tcx>(
check_lang_start_fn(tcx, fn_sig, fn_def_id);
}

coroutine_ty
fcx.coroutine_types
}

fn check_panic_info_fn(tcx: TyCtxt<'_>, fn_id: LocalDefId, fn_sig: ty::FnSig<'_>) {
Expand Down
187 changes: 129 additions & 58 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
opt_kind: Option<ty::ClosureKind>,
expected_sig: Option<ExpectedSig<'tcx>>,
) -> Ty<'tcx> {
let body = self.tcx.hir().body(closure.body);
let tcx = self.tcx;
let body = tcx.hir().body(closure.body);

trace!("decl = {:#?}", closure.fn_decl);
let expr_def_id = closure.def_id;
Expand All @@ -83,81 +84,151 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

debug!(?bound_sig, ?liberated_sig);

// FIXME: We could probably actually just unify this further --
// instead of having a `FnSig` and a `Option<CoroutineTypes>`,
// we can have a `ClosureSignature { Coroutine { .. }, Closure { .. } }`,
// similar to how `ty::GenSig` is a distinct data structure.
let coroutine_types = match closure.kind {
hir::ClosureKind::Closure => None,
hir::ClosureKind::Coroutine(kind) => {
let yield_ty = match kind {
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
| hir::CoroutineKind::Coroutine(_) => {
let yield_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span: expr_span,
});
self.require_type_is_sized(yield_ty, expr_span, traits::SizedYieldType);
yield_ty
}
// HACK(-Ztrait-solver=next): In the *old* trait solver, we must eagerly
// guide inference on the yield type so that we can handle `AsyncIterator`
// in this block in projection correctly. In the new trait solver, it is
// not a problem.
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
let yield_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span: expr_span,
});
self.require_type_is_sized(yield_ty, expr_span, traits::SizedYieldType);

Ty::new_adt(
tcx,
tcx.adt_def(
tcx.require_lang_item(hir::LangItem::Poll, Some(expr_span)),
),
tcx.mk_args(&[Ty::new_adt(
tcx,
tcx.adt_def(
tcx.require_lang_item(hir::LangItem::Option, Some(expr_span)),
),
tcx.mk_args(&[yield_ty.into()]),
)
.into()]),
)
}
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
tcx.types.unit
}
};

// Resume type defaults to `()` if the coroutine has no argument.
let resume_ty = liberated_sig.inputs().get(0).copied().unwrap_or(tcx.types.unit);

Some(CoroutineTypes { resume_ty, yield_ty })
}
};

let mut fcx = FnCtxt::new(self, self.param_env, closure.def_id);
let coroutine_types = check_fn(
check_fn(
&mut fcx,
liberated_sig,
coroutine_types,
closure.fn_decl,
expr_def_id,
body,
Some(closure.kind),
// Closure "rust-call" ABI doesn't support unsized params
false,
);

let parent_args = GenericArgs::identity_for_item(
self.tcx,
self.tcx.typeck_root_def_id(expr_def_id.to_def_id()),
);
let parent_args =
GenericArgs::identity_for_item(tcx, tcx.typeck_root_def_id(expr_def_id.to_def_id()));

let tupled_upvars_ty = self.next_root_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: self.tcx.def_span(expr_def_id),
});

if let Some(CoroutineTypes { resume_ty, yield_ty, interior }) = coroutine_types {
let coroutine_args = ty::CoroutineArgs::new(
self.tcx,
ty::CoroutineArgsParts {
parent_args,
resume_ty,
yield_ty,
return_ty: liberated_sig.output(),
witness: interior,
tupled_upvars_ty,
},
);

return Ty::new_coroutine(self.tcx, expr_def_id.to_def_id(), coroutine_args.args);
}

// Tuple up the arguments and insert the resulting function type into
// the `closures` table.
let sig = bound_sig.map_bound(|sig| {
self.tcx.mk_fn_sig(
[Ty::new_tup(self.tcx, sig.inputs())],
sig.output(),
sig.c_variadic,
sig.unsafety,
sig.abi,
)
span: expr_span,
});

debug!(?sig, ?opt_kind);

let closure_kind_ty = match opt_kind {
Some(kind) => Ty::from_closure_kind(self.tcx, kind),
match closure.kind {
hir::ClosureKind::Closure => {
assert_eq!(coroutine_types, None);
// Tuple up the arguments and insert the resulting function type into
// the `closures` table.
let sig = bound_sig.map_bound(|sig| {
tcx.mk_fn_sig(
[Ty::new_tup(tcx, sig.inputs())],
sig.output(),
sig.c_variadic,
sig.unsafety,
sig.abi,
)
});

// Create a type variable (for now) to represent the closure kind.
// It will be unified during the upvar inference phase (`upvar.rs`)
None => self.next_root_ty_var(TypeVariableOrigin {
// FIXME(eddyb) distinguish closure kind inference variables from the rest.
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
}),
};
debug!(?sig, ?opt_kind);

let closure_kind_ty = match opt_kind {
Some(kind) => Ty::from_closure_kind(tcx, kind),

// Create a type variable (for now) to represent the closure kind.
// It will be unified during the upvar inference phase (`upvar.rs`)
None => self.next_root_ty_var(TypeVariableOrigin {
// FIXME(eddyb) distinguish closure kind inference variables from the rest.
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
}),
};

let closure_args = ty::ClosureArgs::new(
tcx,
ty::ClosureArgsParts {
parent_args,
closure_kind_ty,
closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(tcx, sig),
tupled_upvars_ty,
},
);

let closure_args = ty::ClosureArgs::new(
self.tcx,
ty::ClosureArgsParts {
parent_args,
closure_kind_ty,
closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(self.tcx, sig),
tupled_upvars_ty,
},
);
Ty::new_closure(tcx, expr_def_id.to_def_id(), closure_args.args)
}
hir::ClosureKind::Coroutine(_) => {
let Some(CoroutineTypes { resume_ty, yield_ty }) = coroutine_types else {
bug!("expected coroutine to have yield/resume types");
};
let interior = fcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::MiscVariable,
span: body.value.span,
});
fcx.deferred_coroutine_interiors.borrow_mut().push((
expr_def_id,
body.id(),
interior,
));

let coroutine_args = ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args,
resume_ty,
yield_ty,
return_ty: liberated_sig.output(),
witness: interior,
tupled_upvars_ty,
},
);

Ty::new_closure(self.tcx, expr_def_id.to_def_id(), closure_args.args)
Ty::new_coroutine(tcx, expr_def_id.to_def_id(), coroutine_args.args)
}
}
}

/// Given the expected type, figures out what it can about this closure we
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_hir_typeck/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::errors::{
use crate::fatally_break_rust;
use crate::method::SelfSource;
use crate::type_error_struct;
use crate::CoroutineTypes;
use crate::Expectation::{self, ExpectCastableToType, ExpectHasType, NoExpectation};
use crate::{
report_unexpected_variant_res, BreakableCtxt, Diverges, FnCtxt, Needs,
Expand Down Expand Up @@ -3164,8 +3165,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
expr: &'tcx hir::Expr<'tcx>,
src: &'tcx hir::YieldSource,
) -> Ty<'tcx> {
match self.resume_yield_tys {
Some((resume_ty, yield_ty)) => {
match self.coroutine_types {
Some(CoroutineTypes { resume_ty, yield_ty }) => {
self.check_expr_coercible_to_type(value, yield_ty, None);

resume_ty
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let coroutines = std::mem::take(&mut *self.deferred_coroutine_interiors.borrow_mut());
debug!(?coroutines);

for &(expr_def_id, body_id, interior, _) in coroutines.iter() {
for &(expr_def_id, body_id, interior) in coroutines.iter() {
debug!(?expr_def_id);

// Create the `CoroutineWitness` type that we will unify with `interior`.
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod checks;
mod suggestions;

use crate::coercion::DynamicCoerceMany;
use crate::{Diverges, EnclosingBreakables, Inherited};
use crate::{CoroutineTypes, Diverges, EnclosingBreakables, Inherited};
use rustc_errors::{DiagCtxt, ErrorGuaranteed};
use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
Expand Down Expand Up @@ -68,7 +68,7 @@ pub struct FnCtxt<'a, 'tcx> {
/// First span of a return site that we find. Used in error messages.
pub(super) ret_coercion_span: Cell<Option<Span>>,

pub(super) resume_yield_tys: Option<(Ty<'tcx>, Ty<'tcx>)>,
pub(super) coroutine_types: Option<CoroutineTypes<'tcx>>,

/// Whether the last checked node generates a divergence (e.g.,
/// `return` will set this to `Always`). In general, when entering
Expand Down Expand Up @@ -122,7 +122,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
err_count_on_creation: inh.tcx.dcx().err_count(),
ret_coercion: None,
ret_coercion_span: Cell::new(None),
resume_yield_tys: None,
coroutine_types: None,
diverges: Cell::new(Diverges::Maybe),
enclosing_breakables: RefCell::new(EnclosingBreakables {
stack: Vec::new(),
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir_typeck/src/inherited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ pub struct Inherited<'tcx> {

pub(super) deferred_asm_checks: RefCell<Vec<(&'tcx hir::InlineAsm<'tcx>, hir::HirId)>>,

pub(super) deferred_coroutine_interiors:
RefCell<Vec<(LocalDefId, hir::BodyId, Ty<'tcx>, hir::CoroutineKind)>>,
pub(super) deferred_coroutine_interiors: RefCell<Vec<(LocalDefId, hir::BodyId, Ty<'tcx>)>>,

/// Whenever we introduce an adjustment from `!` into a type variable,
/// we record that type variable here. This is later used to inform
Expand Down
Loading
Loading