diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md index 76d96d746baf1..eb93a96f1cde9 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md @@ -242,3 +242,17 @@ def bar(): # v was reassigned, so any narrowing shouldn't apply reveal_type(v) # revealed: int | None ``` + +## Narrowing preserved when `await`ing a `NoReturn` function in one branch + +```py +from typing import NoReturn + +async def stop() -> NoReturn: + raise NotImplementedError + +async def main(val: int | None): + if val is None: + await stop() + reveal_type(val) # revealed: int +``` diff --git a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md index 762a77746ee1f..ff883ea576f0a 100644 --- a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md +++ b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md @@ -808,6 +808,28 @@ def _() -> NoReturn: C().die() ``` +### Awaiting async `NoReturn` functions + +Awaiting an async function annotated as returning `NoReturn` should be treated as terminal, just +like calling a synchronous `NoReturn` function. + +```py +from typing import NoReturn + +async def stop() -> NoReturn: + raise NotImplementedError + +async def main(flag: bool): + if flag: + x = "terminal" + await stop() + else: + x = "test" + pass + + reveal_type(x) # revealed: Literal["test"] +``` + ## Nested functions Free references inside of a function body refer to variables defined in the containing scope. diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 76cb3c3df1ef3..8cb1948a85a0b 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -2708,8 +2708,9 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_expr(value); - // If the statement is a call, it could possibly be a call to a function - // marked with `NoReturn` (for example, `sys.exit()`). In this case, we use a special + // If the statement is a call (or an `await` wrapping a call), it could + // possibly be a call to a function marked with `NoReturn` (for example, + // `sys.exit()` or `await async_exit()`). In this case, we use a special // kind of constraint to mark the following code as unreachable. // // Ideally, these constraints should be added for every call expression, even those in @@ -2721,15 +2722,29 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // We also only add these inside function scopes, since considering module-level // constraints can affect the type of imported symbols, leading to a lot more // work in third-party code. - if let ast::Expr::Call(ast::ExprCall { func, .. }) = value.as_ref() { + let call_info = match value.as_ref() { + ast::Expr::Call(ast::ExprCall { func, .. }) => { + Some((func.as_ref(), value.as_ref(), false)) + } + ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => match inner.as_ref() { + ast::Expr::Call(ast::ExprCall { func, .. }) => { + Some((func.as_ref(), value.as_ref(), true)) + } + _ => None, + }, + _ => None, + }; + + if let Some((func, expr, is_await)) = call_info { if !self.source_type.is_stub() && self.in_function_scope() { let callable = self.add_standalone_expression(func); - let call_expr = self.add_standalone_expression(value.as_ref()); + let call_expr = self.add_standalone_expression(expr); let predicate = Predicate { node: PredicateNode::ReturnsNever(CallableAndCallExpr { callable, call_expr, + is_await, }), is_positive: false, }; diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index cb0519e6ca674..528e8143385a3 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -102,6 +102,10 @@ impl PredicateOrLiteral<'_> { pub(crate) struct CallableAndCallExpr<'db> { pub(crate) callable: Expression<'db>, pub(crate) call_expr: Expression<'db>, + /// Whether the call is wrapped in an `await` expression. If `true`, `call_expr` refers to the + /// `await` expression rather than the call itself. This is used to detect terminal `await`s of + /// async functions that return `Never`. + pub(crate) is_await: bool, } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 647bb088d1582..20baeaa42df7d 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -1070,6 +1070,7 @@ impl ReachabilityConstraints { PredicateNode::ReturnsNever(CallableAndCallExpr { callable, call_expr, + is_await, }) => { // We first infer just the type of the callable. In the most likely case that the // function is not marked with `NoReturn`, or that it always returns `NoReturn`, @@ -1111,7 +1112,7 @@ impl ReachabilityConstraints { any_overload_is_generic |= overload.return_ty.has_typevar(db); } - if no_overloads_return_never && !any_overload_is_generic { + if no_overloads_return_never && !any_overload_is_generic && !is_await { Truthiness::AlwaysFalse } else if all_overloads_return_never { Truthiness::AlwaysTrue