Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline check_closure, simplify deduce_sig_from_projection #119900

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 15 additions & 52 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rustc_middle::ty::visit::{TypeVisitable, TypeVisitableExt};
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitor};
use rustc_span::def_id::LocalDefId;
use rustc_span::{sym, Span};
use rustc_span::Span;
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::traits;
use rustc_trait_selection::traits::error_reporting::ArgKind;
Expand Down Expand Up @@ -49,7 +49,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
expr_span: Span,
expected: Expectation<'tcx>,
) -> Ty<'tcx> {
trace!("decl = {:#?}", closure.fn_decl);
let tcx = self.tcx;
let body = tcx.hir().body(closure.body);
let expr_def_id = closure.def_id;

// It's always helpful for inference if we know the kind of
// closure sooner rather than later, so first examine the expected
Expand All @@ -61,24 +63,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
None => (None, None),
};

self.check_closure(closure, expr_span, expected_kind, expected_sig)
}

#[instrument(skip(self, closure), level = "debug", ret)]
fn check_closure(
&self,
closure: &hir::Closure<'tcx>,
expr_span: Span,
opt_kind: Option<ty::ClosureKind>,
expected_sig: Option<ExpectedSig<'tcx>>,
) -> Ty<'tcx> {
let tcx = self.tcx;
let body = tcx.hir().body(closure.body);

trace!("decl = {:#?}", closure.fn_decl);
let expr_def_id = closure.def_id;
debug!(?expr_def_id);

let ClosureSignatures { bound_sig, liberated_sig } =
self.sig_of_closure(expr_def_id, closure.fn_decl, closure.kind, expected_sig);

Expand Down Expand Up @@ -139,9 +123,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}
};

let mut fcx = FnCtxt::new(self, self.param_env, closure.def_id);
check_fn(
&mut fcx,
&mut FnCtxt::new(self, self.param_env, closure.def_id),
liberated_sig,
coroutine_types,
closure.fn_decl,
Expand Down Expand Up @@ -174,9 +157,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
)
});

debug!(?sig, ?opt_kind);
debug!(?sig, ?expected_kind);

let closure_kind_ty = match opt_kind {
let closure_kind_ty = match expected_kind {
Some(kind) => Ty::from_closure_kind(tcx, kind),

// Create a type variable (for now) to represent the closure kind.
Expand Down Expand Up @@ -204,11 +187,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let Some(CoroutineTypes { resume_ty, yield_ty }) = coroutine_types else {
bug!("expected coroutine to have yield/resume types");
};
let interior = fcx.next_ty_var(TypeVariableOrigin {
let interior = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::MiscVariable,
span: body.value.span,
});
fcx.deferred_coroutine_interiors.borrow_mut().push((
self.deferred_coroutine_interiors.borrow_mut().push((
expr_def_id,
body.id(),
interior,
Expand Down Expand Up @@ -364,36 +347,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let tcx = self.tcx;

let trait_def_id = projection.trait_def_id(tcx);

let is_fn = tcx.is_fn_trait(trait_def_id);

let coroutine_trait = tcx.lang_items().coroutine_trait();
let is_gen = coroutine_trait == Some(trait_def_id);

if !is_fn && !is_gen {
debug!("not fn or coroutine");
// For now, we only do signature deduction based off of the `Fn` traits.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect us to extend this to deduce other coroutines' signatures: #109338

Or at least, when we do start doing so, this method should be reworked to be cleaner and more extensible. Right now there's a lot of funky control flow for absolutely no reason 😅

if !tcx.is_fn_trait(trait_def_id) {
return None;
}

// Check that we deduce the signature from the `<_ as std::ops::Coroutine>::Return`
// associated item and not yield.
if is_gen && self.tcx.associated_item(projection.projection_def_id()).name != sym::Return {
debug!("not `Return` assoc item of `Coroutine`");
return None;
}

let input_tys = if is_fn {
let arg_param_ty = projection.skip_binder().projection_ty.args.type_at(1);
let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty);
debug!(?arg_param_ty);
let arg_param_ty = projection.skip_binder().projection_ty.args.type_at(1);
let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty);
debug!(?arg_param_ty);

match arg_param_ty.kind() {
&ty::Tuple(tys) => tys,
_ => return None,
}
} else {
// Coroutines with a `()` resume type may be defined with 0 or 1 explicit arguments,
// else they must have exactly 1 argument. For now though, just give up in this case.
let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
return None;
};

Expand Down
Loading