diff --git a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md index 8d6f69215f14d..0b95194737ccc 100644 --- a/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md +++ b/crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md @@ -485,7 +485,5 @@ def i[T: (int, str)](x: T) -> T: case _: assert_never(x) - # TODO: no error here - # error: [invalid-return-type] "Return type does not match returned value: expected `T@i`, found `str | int`" return x ``` diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 02e6047f04712..ea24d8594a6f0 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -35,8 +35,8 @@ use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::member::MemberExprBuilder; use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId}; use crate::semantic_index::predicate::{ - ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, - PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate, + CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, + PredicateNode, PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate, }; use crate::semantic_index::re_exports::exported_names; use crate::semantic_index::reachability_constraints::{ @@ -2784,29 +2784,44 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // We also only add these inside function scopes, since considering module-level // constraints can affect the type of imported symbols, leading to a lot more // work in third-party code. - let is_call = match value.as_ref() { - ast::Expr::Call(_) => true, - ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => inner.is_call_expr(), - _ => false, + let call_info = match value.as_ref() { + ast::Expr::Call(ast::ExprCall { func, .. }) => { + Some((func.as_ref(), value.as_ref(), false)) + } + ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => match inner.as_ref() { + ast::Expr::Call(ast::ExprCall { func, .. }) => { + Some((func.as_ref(), value.as_ref(), true)) + } + _ => None, + }, + _ => None, }; - if is_call && !self.source_type.is_stub() && self.in_function_scope() { - let call_expr = self.add_standalone_expression(value.as_ref()); + if let Some((func, expr, is_await)) = call_info { + if !self.source_type.is_stub() && self.in_function_scope() { + let callable = self.add_standalone_expression(func); + let call_expr = self.add_standalone_expression(expr); + + let predicate = Predicate { + node: PredicateNode::ReturnsNever(CallableAndCallExpr { + callable, + call_expr, + is_await, + }), + is_positive: false, + }; + let constraint = self.record_reachability_constraint( + PredicateOrLiteral::Predicate(predicate), + ); - let predicate = Predicate { - node: PredicateNode::ReturnsNever(call_expr), - is_positive: false, - }; - let constraint = self - .record_reachability_constraint(PredicateOrLiteral::Predicate(predicate)); - - // Also gate narrowing by this constraint: if the call returns - // `Never`, any narrowing in the current branch should be - // invalidated (since this path is unreachable). This enables - // narrowing to be preserved after if-statements where one branch - // calls a `NoReturn` function like `sys.exit()`. - self.current_use_def_map_mut() - .record_narrowing_constraint_for_all_places(constraint); + // Also gate narrowing by this constraint: if the call returns + // `Never`, any narrowing in the current branch should be + // invalidated (since this path is unreachable). This enables + // narrowing to be preserved after if-statements where one branch + // calls a `NoReturn` function like `sys.exit()`. + self.current_use_def_map_mut() + .record_narrowing_constraint_for_all_places(constraint); + } } } _ => { diff --git a/crates/ty_python_semantic/src/semantic_index/predicate.rs b/crates/ty_python_semantic/src/semantic_index/predicate.rs index 3aa5374b2721a..39a2b2fbeaa9c 100644 --- a/crates/ty_python_semantic/src/semantic_index/predicate.rs +++ b/crates/ty_python_semantic/src/semantic_index/predicate.rs @@ -98,10 +98,20 @@ impl PredicateOrLiteral<'_> { } } +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub(crate) struct CallableAndCallExpr<'db> { + pub(crate) callable: Expression<'db>, + pub(crate) call_expr: Expression<'db>, + /// Whether the call is wrapped in an `await` expression. If `true`, `call_expr` refers to the + /// `await` expression rather than the call itself. This is used to detect terminal `await`s of + /// async functions that return `Never`. + pub(crate) is_await: bool, +} + #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub(crate) enum PredicateNode<'db> { Expression(Expression<'db>), - ReturnsNever(Expression<'db>), + ReturnsNever(CallableAndCallExpr<'db>), Pattern(PatternPredicate<'db>), StarImportPlaceholder(StarImportPlaceholderPredicate<'db>), } diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index bc03dce8b0228..a678ec1930efc 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -205,11 +205,12 @@ use crate::rank::RankBitBox; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ - PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, + CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, + Predicates, ScopedPredicateId, }; use crate::types::{ - IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type, TypeContext, - UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint, + CallableTypes, IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type, + TypeContext, UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint, }; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -1089,12 +1090,62 @@ impl ReachabilityConstraints { .bool(db) .negate_if(!predicate.is_positive) } - PredicateNode::ReturnsNever(call_expr) => { - let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default()); - if call_expr_ty.is_equivalent_to(db, Type::Never) { - Truthiness::AlwaysTrue + PredicateNode::ReturnsNever(CallableAndCallExpr { + callable, + call_expr, + is_await, + }) => { + // We first infer just the type of the callable. In the most likely case that the + // function is not marked with `NoReturn`, or that it always returns `NoReturn`, + // doing so allows us to avoid the more expensive work of inferring the entire call + // expression (which could involve inferring argument types to possibly run the overload + // selection algorithm). + // Avoiding this on the happy-path is important because these constraints can be + // very large in number, since we add them on all statement level function calls. + let ty = infer_expression_type(db, callable, TypeContext::default()); + + // Short-circuit for well known types that are known not to return `Never` when called. + // Without the short-circuit, we've seen that threads keep blocking each other + // because they all try to acquire Salsa's `CallableType` lock that ensures each type + // is only interned once. The lock is so heavily congested because there are only + // very few dynamic types, in which case Salsa's sharding the locks by value + // doesn't help much. + // See . + if matches!(ty, Type::Dynamic(_)) { + return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive); + } + + let overloads_iterator = if let Some(callable) = ty + .try_upcast_to_callable(db) + .and_then(CallableTypes::exactly_one) + { + callable.signatures(db).overloads.iter() } else { + return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive); + }; + + let mut no_overloads_return_never = true; + let mut all_overloads_return_never = true; + let mut any_overload_is_generic = false; + + for overload in overloads_iterator { + let returns_never = overload.return_ty.is_equivalent_to(db, Type::Never); + no_overloads_return_never &= !returns_never; + all_overloads_return_never &= returns_never; + any_overload_is_generic |= overload.return_ty.has_typevar(db); + } + + if no_overloads_return_never && !any_overload_is_generic && !is_await { Truthiness::AlwaysFalse + } else if all_overloads_return_never { + Truthiness::AlwaysTrue + } else { + let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default()); + if call_expr_ty.is_equivalent_to(db, Type::Never) { + Truthiness::AlwaysTrue + } else { + Truthiness::AlwaysFalse + } } .negate_if(!predicate.is_positive) } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index f621b3b59bcb6..ea783b83363da 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -3,7 +3,8 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::place::{PlaceExpr, PlaceTable, PlaceTableBuilder, ScopedPlaceId}; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ - ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, + CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, + PredicateNode, }; use crate::semantic_index::scope::ScopeId; use crate::subscript::PyIndex; @@ -761,7 +762,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { match self.predicate { PredicateNode::Expression(expression) => expression.scope(self.db), PredicateNode::Pattern(pattern) => pattern.scope(self.db), - PredicateNode::ReturnsNever(call_expr) => call_expr.scope(self.db), + PredicateNode::ReturnsNever(CallableAndCallExpr { callable, .. }) => { + callable.scope(self.db) + } PredicateNode::StarImportPlaceholder(definition) => definition.scope(self.db), } }