diff --git a/crates/ruff_benchmark/benches/ty.rs b/crates/ruff_benchmark/benches/ty.rs index 9ae6e9c40bd6a..57f6eab5a8dad 100644 --- a/crates/ruff_benchmark/benches/ty.rs +++ b/crates/ruff_benchmark/benches/ty.rs @@ -667,7 +667,7 @@ fn attrs(criterion: &mut Criterion) { max_dep_date: "2025-06-17", python_version: PythonVersion::PY313, }, - 120, + 136, ); bench_project(&benchmark, criterion); diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index d6391b94659cb..20d8856c31cfa 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -143,7 +143,7 @@ static FREQTRADE: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 600, + 626, ); static PANDAS: Benchmark = Benchmark::new( @@ -163,7 +163,7 @@ static PANDAS: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 4000, + 5430, ); static PYDANTIC: Benchmark = Benchmark::new( @@ -194,7 +194,8 @@ static SYMPY: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 13116, + // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. + 58000, ); static TANJUN: Benchmark = Benchmark::new( diff --git a/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py new file mode 100644 index 0000000000000..ce4cd6a795d02 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py @@ -0,0 +1,17 @@ +# Regression test for https://github.com/astral-sh/ruff/issues/17371 +# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5 +# error message: +# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00)) + +try: + class foo[T: bar](object): + pass + bar = foo +except Exception: + bar = lambda: 0 +def bar(): + pass + +@bar() +class bar: + pass diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py new file mode 100644 index 0000000000000..2bacad7f72b9e --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -0,0 +1,93 @@ +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +reveal_type(f(True)) + +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +def f(cond: bool): + result = None + if cond: + result = () + result += (f(cond),) + + return result + +reveal_type(f(True)) + +def f(cond: bool): + result = None + if cond: + result = [f(cond) for _ in range(1)] + + return result + +reveal_type(f(True)) + +class Foo: + def value(self): + return 1 + +def unwrap(value): + if isinstance(value, Foo): + foo = value + return foo.value() + elif type(value) is tuple: + length = len(value) + if length == 0: + return () + elif length == 1: + return (unwrap(value[0]),) + else: + result = [] + for item in value: + result.append(unwrap(item)) + return tuple(result) + else: + raise TypeError() + +def descent(x: int, y: int): + if x > y: + y, x = descent(y, x) + return x, y + if x == 1: + return (1, 0) + if y == 1: + return (0, 1) + else: + return descent(x-1, y-1) + +def count_set_bits(n): + return 1 + count_set_bits(n & n - 1) if n else 0 + +class Literal: + def __invert__(self): + return Literal() + +class OR: + def __invert__(self): + return AND() + +class AND: + def __invert__(self): + return OR() + +def to_NNF(cond): + if cond: + return ~to_NNF(cond) + if cond: + return OR() + if cond: + return AND() + return Literal() diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md index 32626dc5190b0..2f45a9800617b 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md @@ -79,7 +79,9 @@ def outer_sync(): # `yield` from is only valid syntax inside a synchronous func a: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions" ): ... -async def baz(): ... +async def baz(): + yield + async def outer_async(): # avoid unrelated syntax errors on `yield` and `await` def _( a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression" diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index c0f42e204be19..47dd959074225 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -616,6 +616,14 @@ class X[T]: x1: X[int | None] = X() reveal_type(x1) # revealed: X[None] + +class Y[T]: + def __init__(self: Y[None]) -> None: ... + def pop(self) -> T: + raise NotImplementedError + +y1: Y[int | None] = Y() +reveal_type(y1) # revealed: Y[None] ``` ## Declared type preference sees through subtyping diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index 09943c0801765..0b0095a39242b 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -111,7 +111,7 @@ def _(flag: bool): # error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union of binding errors @@ -128,7 +128,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None ``` ## One not-callable, one wrong argument @@ -146,7 +146,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [call-non-callable] "Object of type `C` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union including a special-cased function diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index b4e0b1ae24ef4..5c2c958a15c05 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -125,7 +125,8 @@ match obj: ```py class C: - def __await__(self): ... + def __await__(self): + yield # error: [invalid-syntax] "`return` statement outside of a function" return diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 28bd509e440bd..dd071fa5efc11 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -295,6 +295,331 @@ def f(cond: bool) -> int: return 2 ``` +## Inferred return type + +### Free function + +If a function's return type is not annotated, it is inferred. The inferred type is the union of all +possible return types. + +```py +def f(): + return 1 + +reveal_type(f()) # revealed: Literal[1] +# TODO: should be `def f() -> Literal[1]` +reveal_type(f) # revealed: def f() -> Unknown + +def g(cond: bool): + if cond: + return 1 + else: + return "a" + +reveal_type(g(True)) # revealed: Literal[1, "a"] + +# This function implicitly returns `None`. +def h(x: int, y: str): + if x > 10: + return x + elif x > 5: + return y + +reveal_type(h(1, "a")) # revealed: int | str | None + +lambda_func = lambda: 1 +# TODO: lambda function type inference +# Should be `Literal[1]` +reveal_type(lambda_func()) # revealed: Unknown + +def generator(): + yield 1 + yield 2 + return None + +# TODO: Should be `Generator[Literal[1, 2], Any, None]` +reveal_type(generator()) # revealed: Unknown + +async def async_generator(): + yield + +# TODO: Should be `AsyncGenerator[None, Any]` +reveal_type(async_generator()) # revealed: Unknown + +async def coroutine(): + return + +# TODO: Should be `CoroutineType[Any, Any, None]` +reveal_type(coroutine()) # revealed: Unknown +``` + +The return type of a recursive function is also inferred. When the return type inference would +diverge, it is truncated and replaced with the special dynamic type `Divergent`. + +```toml +[environment] +python-version = "3.12" +``` + +```py +def fibonacci(n: int): + if n == 0: + return 0 + elif n == 1: + return 1 + else: + return fibonacci(n - 1) + fibonacci(n - 2) + +reveal_type(fibonacci(5)) # revealed: int + +def even(n: int): + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int): + if n == 0: + return False + else: + return even(n - 1) + +reveal_type(even(1)) # revealed: bool +reveal_type(odd(1)) # revealed: bool + +def repeat_a(n: int): + if n <= 0: + return "" + else: + return repeat_a(n - 1) + "a" + +reveal_type(repeat_a(3)) # revealed: str + +def divergent(value): + if type(value) is tuple: + return (divergent(value[0]),) + else: + return None + +# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None +reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None + +def call_divergent(x: int): + return (divergent((1, 2, 3)), x) + +reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int] + +def list1[T](x: T) -> list[T]: + return [x] + +def divergent2(value): + if type(value) is tuple: + return (divergent2(value[0]),) + elif type(value) is list: + return list1(divergent2(value[0])) + else: + return None + +reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None + +def list_int(x: int): + if x > 0: + return list1(list_int(x - 1)) + else: + return list1(x) + +# TODO: should be `list[int]` +reveal_type(list_int(1)) # revealed: list[Divergent] | list[Divergent] | list[int] + +def tuple_obj(cond: bool): + if cond: + x = object() + else: + x = tuple_obj(cond) + return (x,) + +reveal_type(tuple_obj(True)) # revealed: tuple[object] + +def get_non_empty(node): + for child in node.children: + node = get_non_empty(child) + if node is not None: + return node + return None + +reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None + +def nested_scope(): + def inner(): + return nested_scope() + return inner() + +reveal_type(nested_scope()) # revealed: Divergent + +def eager_nested_scope(): + class A: + x = eager_nested_scope() + + return A.x + +reveal_type(eager_nested_scope()) # revealed: Unknown + +class C: + def flip(self) -> "D": + return D() + +class D(C): + # error: [invalid-method-override] + def flip(self) -> "C": + return C() + +def c_or_d(n: int): + if n == 0: + return D() + else: + return c_or_d(n - 1).flip() + +# In fixed-point iteration of the return type inference, the return type is monotonically widened. +# For example, once the return type of `c_or_d` is determined to be `C`, +# it will never be determined to be a subtype `D` in the subsequent iterations. +reveal_type(c_or_d(1)) # revealed: C +``` + +### Class method + +If a method's return type is not annotated, it is also inferred, but the inferred type is a union of +all possible return types and `Unknown`. This is because a method of a class may be overridden by +its subtypes. For example, if the return type of a method is inferred to be `int`, the type the +coder really intended might be `int | None`, in which case it would be impossible for the overridden +method to return `None`. + +```py +class C: + def f(self): + return 1 + +class D(C): + def f(self): + return None + +reveal_type(C().f()) # revealed: Literal[1] | Unknown +reveal_type(D().f()) # revealed: None | Literal[1] | Unknown +``` + +However, in the following cases, `Unknown` is not included in the inferred return type because there +is no ambiguity in the subclass. + +- The class or the method is marked as `final`. + +```py +from typing import final + +@final +class C: + def f(self): + return 1 + +class D: + @final + def f(self): + return "a" + +reveal_type(C().f()) # revealed: Literal[1] +reveal_type(D().f()) # revealed: Literal["a"] +``` + +- The method overrides the methods of the base classes, and the return types of the base class + methods are known (In this case, the return type of the method is the intersection of the return + types of the methods in the base classes). + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Literal + +class C: + def f(self) -> int: + return 1 + + def g[T](self, x: T) -> T: + return x + + def h[T: int](self, x: T) -> T: + return x + + def i[T: int](self, x: T) -> list[T]: + return [x] + +class D(C): + def f(self): + return 2 + # TODO: This should be an invalid-override error. + def g(self, x: int): + return 2 + # A strict application of the Liskov Substitution Principle would consider + # this an invalid override because it violates the guarantee that the method returns + # the same type as its input type (any type smaller than int), + # but neither mypy nor pyright will throw an error for this. + def h(self, x: int): + return 2 + + def i(self, x: int): + return [2] + +class E(D): + def f(self): + return 3 + +reveal_type(C().f()) # revealed: int +reveal_type(D().f()) # revealed: int +reveal_type(E().f()) # revealed: int +reveal_type(C().g(1)) # revealed: Literal[1] +reveal_type(D().g(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(1)) # revealed: Literal[1] +reveal_type(D().h(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(True)) # revealed: Literal[True] +reveal_type(D().h(True)) # revealed: Literal[2] | Unknown +reveal_type(C().i(1)) # revealed: list[int] +# TODO: better type for list elements +reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown] + +class F: + def f(self) -> Literal[1, 2]: + return 2 + +class G: + def f(self) -> Literal[2, 3]: + return 2 + +class H(F, G): + # TODO: should be an invalid-override error + def f(self): + raise NotImplementedError + +class I(F, G): + # TODO: should be an invalid-override error + @final + def f(self): + raise NotImplementedError + +# We use a return type of `F.f` according to the MRO. +reveal_type(H().f()) # revealed: Literal[1, 2] +reveal_type(I().f()) # revealed: Never + +class C2[T]: + def f(self, x: T) -> T: + return x + +class D2(C2[int]): + def f(self, x: int): + return x + +reveal_type(D2().f(1)) # revealed: int +``` + ## Invalid return type diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index de962d2075807..06c0744d7aa3f 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -142,20 +142,25 @@ def _(x: A | B): reveal_type(x) # revealed: A | B ``` -## No narrowing for custom `type` callable +## No special narrowing for custom `type` callable ```py +def type(x: object): + return int + class A: ... class B: ... -def type(x): - return int - def _(x: A | B): + # The custom `type` function always returns `int`, + # so any branch other than `type(...) is int` is unreachable. if type(x) is A: + reveal_type(x) # revealed: Never + # And the condition here is always `True` and has no effect on the narrowing of `x`. + elif type(x) is int: reveal_type(x) # revealed: A | B else: - reveal_type(x) # revealed: A | B + reveal_type(x) # revealed: Never ``` ## No narrowing for multiple arguments diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index c7c42241a3487..1ec58a3299ba2 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -1,6 +1,9 @@ use std::ops::Range; -use ruff_db::{files::File, parsed::ParsedModuleRef}; +use ruff_db::{ + files::File, + parsed::{ParsedModuleRef, parsed_module}, +}; use ruff_index::newtype_index; use ruff_python_ast as ast; @@ -27,6 +30,10 @@ pub struct ScopeId<'db> { impl get_size2::GetSize for ScopeId<'_> {} impl<'db> ScopeId<'db> { + pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool { + self.node(db).scope_kind().is_non_lambda_function() + } + pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool { self.node(db).scope_kind().is_annotation() } @@ -64,6 +71,18 @@ impl<'db> ScopeId<'db> { NodeWithScopeKind::GeneratorExpression(_) => "", } } + + pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool { + let module = parsed_module(db, self.file(db)).load(db); + self.node(db) + .as_function() + .is_some_and(|func| func.node(&module).is_async && !self.is_generator_function(db)) + } + + pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool { + let index = semantic_index(db, self.file(db)); + self.file_scope_id(db).is_generator_function(index) + } } /// ID that uniquely identifies a scope inside of a module. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 129af6d7e7340..a72086f60e580 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1419,6 +1419,13 @@ impl<'db> Type<'db> { matches!(self, Type::FunctionLiteral(..)) } + pub(crate) const fn as_bound_method(self) -> Option> { + match self { + Type::BoundMethod(bound_method_type) => Some(bound_method_type), + _ => None, + } + } + /// Detects types which are valid to appear inside a `Literal[…]` type annotation. pub(crate) fn is_literal_or_union_of_literals(&self, db: &'db dyn Db) -> bool { match self { @@ -6624,6 +6631,19 @@ impl<'db> Type<'db> { } } + /// Returns the inferred return type of `self` if it is a function literal / bound method. + fn infer_return_type(self, db: &'db dyn Db) -> Option> { + match self { + Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { + Some(function_type.infer_return_type(db)) + } + Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => { + Some(method_type.infer_return_type(db)) + } + _ => None, + } + } + /// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if /// the arguments are not compatible with the formal parameters. /// @@ -12452,6 +12472,88 @@ impl<'db> BoundMethodType<'db> { ) } + /// Infers this method scope's types and returns the inferred return type. + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let inferred_return_type = self.function(db).infer_return_type(db); + // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. + // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. + // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. + if !self.is_final(db) { + UnionType::from_elements( + db, + [ + inferred_return_type, + self.base_return_type(db).unwrap_or(Type::unknown()), + ], + ) + } else { + inferred_return_type + } + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn class_definition(self, db: &'db dyn Db) -> Option> { + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + Some(index.expect_single_definition(definition_scope.node(db).as_class()?)) + } + + fn is_final(self, db: &'db dyn Db) -> bool { + if self + .function(db) + .has_known_decorator(db, FunctionDecorators::FINAL) + { + return true; + } + let Some(class_ty) = self + .class_definition(db) + .and_then(|class| binding_type(db, class).as_class_literal()) + else { + return false; + }; + class_ty + .known_function_decorators(db) + .any(|deco| deco == KnownFunction::Final) + } + + fn is_init(self, db: &'db dyn Db) -> bool { + self.function(db).name(db) == "__init__" + } + + fn base_return_type(self, db: &'db dyn Db) -> Option> { + let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; + let name = self.function(db).name(db); + + let base = class + .iter_mro(db) + .nth(1) + .and_then(class_base::ClassBase::into_class)?; + let base_member = base.class_member(db, name, MemberLookupPolicy::default()); + if let Place::Defined(Type::FunctionLiteral(base_func), _, _, _) = base_member.place { + if let [signature] = base_func.signature(db).overloads.as_slice() { + let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }); + if let Some(generic_context) = signature.generic_context.as_ref() { + // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. + Some( + unspecialized_return_ty + .apply_specialization(db, generic_context.unknown_specialization(db)), + ) + } else { + Some(unspecialized_return_ty) + } + } else { + // TODO: Handle overloaded base methods. + None + } + } else { + None + } + } + fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { Self::new( db, diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 36a258f20d012..f68d627737ee9 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -659,8 +659,31 @@ impl<'db> UnionBuilder<'db> { types.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(self.db, l, r)); } match types.len() { - 0 => None, - 1 => Some(types[0]), + 0 => { + if self.recursively_defined.is_yes() { + // See the comment below for why this is necessary. + Some(Type::Union(UnionType::new( + self.db, + Box::from([Type::Never]), + self.recursively_defined, + ))) + } else { + None + } + } + 1 => { + if self.recursively_defined.is_yes() { + // We need to mark this type with a "recursively-defined" marker, so build it as a single-element recursively-defined union type. + // This will only happen very early in the fixed-point iteration, and a single-element union type should never appear in the final converged type. + Some(Type::Union(UnionType::new( + self.db, + Box::from([types[0]]), + self.recursively_defined, + ))) + } else { + Some(types[0]) + } + } _ => Some(Type::Union(UnionType::new( self.db, types.into_boxed_slice(), @@ -699,6 +722,14 @@ impl<'db> IntersectionBuilder<'db> { } } + fn extend(&mut self, sub: Self) { + for inner in sub.intersections { + if !self.intersections.contains(&inner) { + self.intersections.push(inner); + } + } + } + pub(crate) fn add_positive(self, ty: Type<'db>) -> Self { self.add_positive_impl(ty, &mut vec![]) } @@ -735,7 +766,7 @@ impl<'db> IntersectionBuilder<'db> { .iter() .map(|elem| self.clone().add_positive_impl(*elem, seen_aliases)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }) } @@ -859,7 +890,7 @@ impl<'db> IntersectionBuilder<'db> { positive_side.chain(negative_side).fold( IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }, ) @@ -957,7 +988,7 @@ impl<'db> IntersectionBuilder<'db> { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq, Eq)] struct InnerIntersectionBuilder<'db> { positive: FxOrderSet>, negative: FxOrderSet>, diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 24c98e94dc09e..8cc0644aedc2d 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3151,9 +3151,18 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return None; } - // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an - // annotated assignment, to closer match the order of any unions written in the type annotation. - builder.infer(return_ty, call_expression_tcx).ok()?; + // For `__init__`, do not the use type context to widen the return type, + // as it can lead to argument assignability errors if the type variable + // is constrained by a narrower parameter type. + if self + .signature_type + .as_bound_method() + .is_none_or(|method| !method.is_init(self.db)) + { + // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an + // annotated assignment, to closer match the order of any unions written in the type annotation. + builder.infer(return_ty, call_expression_tcx).ok()?; + } // Otherwise, build the specialization again after inferring the complete type context. let specialization = builder @@ -3674,10 +3683,15 @@ impl<'db> Binding<'db> { } } } + for (keywords_index, keywords_type) in keywords_arguments { matcher.match_keyword_variadic(db, keywords_index, keywords_type); } - self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); + self.return_ty = self.signature.return_ty.unwrap_or_else(|| { + self.callable_type + .infer_return_type(db) + .unwrap_or(Type::unknown()) + }); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); self.variadic_argument_matched_to_variadic_parameter = matcher.variadic_argument_matched_to_variadic_parameter; diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 19f82ae177bbd..ae63618fc48cc 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -84,7 +84,7 @@ use crate::types::{ HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, UnionBuilder, binding_type, - definition_expression_type, infer_definition_types, walk_signature, + definition_expression_type, infer_definition_types, infer_scope_types, walk_signature, }; use crate::{Db, FxOrderSet}; use ty_module_resolver::{KnownModule, ModuleName, file_to_module, resolve_module}; @@ -1204,6 +1204,32 @@ impl<'db> FunctionType<'db> { updated_last_definition_signature, )) } + + /// Infers this function scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.literal(db).last_definition(db).body_scope(db); + let inference = infer_scope_types(db, scope); + inference.infer_return_type(db, scope) + } +} + +fn return_type_cycle_recover<'db>( + db: &'db dyn Db, + cycle: &salsa::Cycle, + previous_return_type: &Type<'db>, + return_type: Type<'db>, + _self: FunctionType<'db>, +) -> Type<'db> { + return_type.cycle_normalized(db, *previous_return_type, cycle) +} + +fn return_type_cycle_initial<'db>( + _db: &'db dyn Db, + id: salsa::Id, + _function: FunctionType<'db>, +) -> Type<'db> { + Type::divergent(id) } /// Evaluate an `isinstance` call. Return `Truthiness::AlwaysTrue` if we can definitely infer that diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 3a1f4425928c9..a64debc0dfd9d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -47,13 +47,14 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{SemanticIndex, semantic_index}; +use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; +use crate::types::builder::RecursivelyDefined; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::function::FunctionType; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, declaration_type, + ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, UnionBuilder, declaration_type, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -546,6 +547,9 @@ struct ScopeInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, + + /// The returned types, if it is a function body. + return_types: Vec>, } impl<'db> ScopeInference<'db> { @@ -570,6 +574,23 @@ impl<'db> ScopeInference<'db> { *ty = ty.cycle_normalized(db, previous_ty, cycle); } + if let Some(extra) = &mut self.extra { + for (i, return_ty) in extra.return_types.iter_mut().enumerate() { + match previous_inference.extra.as_ref() { + Some(previous_extra) => { + if let Some(previous_return_ty) = previous_extra.return_types.get(i) { + *return_ty = return_ty.cycle_normalized(db, *previous_return_ty, cycle); + } else { + *return_ty = return_ty.recursive_type_normalized(db, cycle); + } + } + None => { + *return_ty = return_ty.recursive_type_normalized(db, cycle); + } + } + } + } + self } @@ -605,6 +626,41 @@ impl<'db> ScopeInference<'db> { extra.string_annotations.contains(&expression.into()) } + + /// Returns the inferred return type of this function body (union of all possible return types), + /// or `None` if the region is not a function body. + /// In the case of methods, the return type of the superclass method is further unioned. + /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. + pub(crate) fn infer_return_type(&self, db: &'db dyn Db, scope: ScopeId<'db>) -> Type<'db> { + // TODO: coroutine function type inference + // TODO: generator function type inference + if scope.is_coroutine_function(db) || scope.is_generator_function(db) { + return Type::unknown(); + } + + let mut union = UnionBuilder::new(db); + // If this method is called early in the query cycle of `infer_scope_types`, `extra.return_types` will be empty. + // To properly propagate divergence, we must add `Divergent` to the union type. + if let Some(divergent) = self.fallback_type() { + union = union.recursively_defined(RecursivelyDefined::Yes); + union = union.add(divergent); + } + + let Some(extra) = &self.extra else { + unreachable!( + "infer_return_type should only be called on a function body scope inference" + ); + }; + for return_ty in &extra.return_types { + union = union.add(*return_ty); + } + let use_def = use_def_map(db, scope); + if use_def.can_implicitly_return_none(db) { + union = union.add(Type::none(db)); + } + + union.build() + } } /// The inferred types for a definition region. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 453eb9a15e1dc..f9dd5256f11f6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -12976,6 +12976,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_scope(mut self) -> ScopeInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, @@ -13002,21 +13003,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { called_functions: _, index: _, region: _, - return_types_and_ranges: _, + return_types_and_ranges, } = self; - let _ = scope; let diagnostics = context.finish(); - let extra = - (!string_annotations.is_empty() || !diagnostics.is_empty() || cycle_recovery.is_some()) - .then(|| { - Box::new(ScopeInferenceExtra { - string_annotations, - cycle_recovery, - diagnostics, - }) - }); + let extra = (!string_annotations.is_empty() + || !diagnostics.is_empty() + || cycle_recovery.is_some() + || scope.is_non_lambda_function(db)) + .then(|| { + let return_types = return_types_and_ranges + .into_iter() + .map(|ty_range| ty_range.ty) + .collect(); + Box::new(ScopeInferenceExtra { + string_annotations, + cycle_recovery, + diagnostics, + return_types, + }) + }); expressions.shrink_to_fit();