Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -485,5 +485,7 @@ def i[T: (int, str)](x: T) -> T:
case _:
assert_never(x)

# TODO: no error here
# error: [invalid-return-type] "Return type does not match returned value: expected `T@i`, found `str | int`"
return x
```
59 changes: 22 additions & 37 deletions crates/ty_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::member::MemberExprBuilder;
use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId};
use crate::semantic_index::predicate::{
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
PredicateNode, PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
};
use crate::semantic_index::re_exports::exported_names;
use crate::semantic_index::reachability_constraints::{
Expand Down Expand Up @@ -2784,44 +2784,29 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
// We also only add these inside function scopes, since considering module-level
// constraints can affect the type of imported symbols, leading to a lot more
// work in third-party code.
let call_info = match value.as_ref() {
ast::Expr::Call(ast::ExprCall { func, .. }) => {
Some((func.as_ref(), value.as_ref(), false))
}
ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => match inner.as_ref() {
ast::Expr::Call(ast::ExprCall { func, .. }) => {
Some((func.as_ref(), value.as_ref(), true))
}
_ => None,
},
_ => None,
let is_call = match value.as_ref() {
ast::Expr::Call(_) => true,
ast::Expr::Await(ast::ExprAwait { value: inner, .. }) => inner.is_call_expr(),
_ => false,
};

if let Some((func, expr, is_await)) = call_info {
if !self.source_type.is_stub() && self.in_function_scope() {
let callable = self.add_standalone_expression(func);
let call_expr = self.add_standalone_expression(expr);

let predicate = Predicate {
node: PredicateNode::ReturnsNever(CallableAndCallExpr {
callable,
call_expr,
is_await,
}),
is_positive: false,
};
let constraint = self.record_reachability_constraint(
PredicateOrLiteral::Predicate(predicate),
);
if is_call && !self.source_type.is_stub() && self.in_function_scope() {
let call_expr = self.add_standalone_expression(value.as_ref());

// Also gate narrowing by this constraint: if the call returns
// `Never`, any narrowing in the current branch should be
// invalidated (since this path is unreachable). This enables
// narrowing to be preserved after if-statements where one branch
// calls a `NoReturn` function like `sys.exit()`.
self.current_use_def_map_mut()
.record_narrowing_constraint_for_all_places(constraint);
}
let predicate = Predicate {
node: PredicateNode::ReturnsNever(call_expr),
is_positive: false,
};
let constraint = self
.record_reachability_constraint(PredicateOrLiteral::Predicate(predicate));

// Also gate narrowing by this constraint: if the call returns
// `Never`, any narrowing in the current branch should be
// invalidated (since this path is unreachable). This enables
// narrowing to be preserved after if-statements where one branch
// calls a `NoReturn` function like `sys.exit()`.
self.current_use_def_map_mut()
.record_narrowing_constraint_for_all_places(constraint);
}
}
_ => {
Expand Down
12 changes: 1 addition & 11 deletions crates/ty_python_semantic/src/semantic_index/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,10 @@ impl PredicateOrLiteral<'_> {
}
}

#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) struct CallableAndCallExpr<'db> {
pub(crate) callable: Expression<'db>,
pub(crate) call_expr: Expression<'db>,
/// Whether the call is wrapped in an `await` expression. If `true`, `call_expr` refers to the
/// `await` expression rather than the call itself. This is used to detect terminal `await`s of
/// async functions that return `Never`.
pub(crate) is_await: bool,
}

#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) enum PredicateNode<'db> {
Expression(Expression<'db>),
ReturnsNever(CallableAndCallExpr<'db>),
ReturnsNever(Expression<'db>),
Pattern(PatternPredicate<'db>),
StarImportPlaceholder(StarImportPlaceholderPredicate<'db>),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,11 @@ use crate::rank::RankBitBox;
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
Predicates, ScopedPredicateId,
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
};
use crate::types::{
CallableTypes, IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type,
TypeContext, UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint,
IntersectionBuilder, KnownClass, NarrowingConstraint, Truthiness, Type, TypeContext,
UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint,
};

/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
Expand Down Expand Up @@ -1090,62 +1089,12 @@ impl ReachabilityConstraints {
.bool(db)
.negate_if(!predicate.is_positive)
}
PredicateNode::ReturnsNever(CallableAndCallExpr {
callable,
call_expr,
is_await,
}) => {
// We first infer just the type of the callable. In the most likely case that the
// function is not marked with `NoReturn`, or that it always returns `NoReturn`,
// doing so allows us to avoid the more expensive work of inferring the entire call
// expression (which could involve inferring argument types to possibly run the overload
// selection algorithm).
// Avoiding this on the happy-path is important because these constraints can be
// very large in number, since we add them on all statement level function calls.
let ty = infer_expression_type(db, callable, TypeContext::default());

// Short-circuit for well known types that are known not to return `Never` when called.
// Without the short-circuit, we've seen that threads keep blocking each other
// because they all try to acquire Salsa's `CallableType` lock that ensures each type
// is only interned once. The lock is so heavily congested because there are only
// very few dynamic types, in which case Salsa's sharding the locks by value
// doesn't help much.
// See <https://github.com/astral-sh/ty/issues/968>.
if matches!(ty, Type::Dynamic(_)) {
return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive);
}

let overloads_iterator = if let Some(callable) = ty
.try_upcast_to_callable(db)
.and_then(CallableTypes::exactly_one)
{
callable.signatures(db).overloads.iter()
} else {
return Truthiness::AlwaysFalse.negate_if(!predicate.is_positive);
};

let mut no_overloads_return_never = true;
let mut all_overloads_return_never = true;
let mut any_overload_is_generic = false;

for overload in overloads_iterator {
let returns_never = overload.return_ty.is_equivalent_to(db, Type::Never);
no_overloads_return_never &= !returns_never;
all_overloads_return_never &= returns_never;
any_overload_is_generic |= overload.return_ty.has_typevar(db);
}

if no_overloads_return_never && !any_overload_is_generic && !is_await {
Truthiness::AlwaysFalse
} else if all_overloads_return_never {
PredicateNode::ReturnsNever(call_expr) => {
let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
if call_expr_ty.is_equivalent_to(db, Type::Never) {
Truthiness::AlwaysTrue
} else {
let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
if call_expr_ty.is_equivalent_to(db, Type::Never) {
Truthiness::AlwaysTrue
} else {
Truthiness::AlwaysFalse
}
Truthiness::AlwaysFalse
}
.negate_if(!predicate.is_positive)
}
Expand Down
7 changes: 2 additions & 5 deletions crates/ty_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::place::{PlaceExpr, PlaceTable, PlaceTableBuilder, ScopedPlaceId};
use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
PredicateNode,
ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
};
use crate::semantic_index::scope::ScopeId;
use crate::subscript::PyIndex;
Expand Down Expand Up @@ -762,9 +761,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
match self.predicate {
PredicateNode::Expression(expression) => expression.scope(self.db),
PredicateNode::Pattern(pattern) => pattern.scope(self.db),
PredicateNode::ReturnsNever(CallableAndCallExpr { callable, .. }) => {
callable.scope(self.db)
}
PredicateNode::ReturnsNever(call_expr) => call_expr.scope(self.db),
PredicateNode::StarImportPlaceholder(definition) => definition.scope(self.db),
}
}
Expand Down