diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 5caefda4e192ac..24b04f68f39174 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -146,7 +146,75 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3} reveal_type(r) # revealed: dict[int | str, int | str] ``` -## Incorrect collection literal assignments are complained aobut +## Optional collection literal annotations are understood + +```toml +[environment] +python-version = "3.12" +``` + +```py +import typing + +a: list[int] | None = [1, 2, 3] +reveal_type(a) # revealed: list[int] + +b: list[int | str] | None = [1, 2, 3] +reveal_type(b) # revealed: list[int | str] + +c: typing.List[int] | None = [1, 2, 3] +reveal_type(c) # revealed: list[int] + +d: list[typing.Any] | None = [] +reveal_type(d) # revealed: list[Any] + +e: set[int] | None = {1, 2, 3} +reveal_type(e) # revealed: set[int] + +f: set[int | str] | None = {1, 2, 3} +reveal_type(f) # revealed: set[int | str] + +g: typing.Set[int] | None = {1, 2, 3} +reveal_type(g) # revealed: set[int] + +h: list[list[int]] | None = [[], [42]] +reveal_type(h) # revealed: list[list[int]] + +i: list[typing.Any] | None = [1, 2, "3", ([4],)] +reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]] + +j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()] +reveal_type(j) # revealed: list[tuple[str | int, ...]] + +k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])] +reveal_type(k) # revealed: list[tuple[list[int], ...]] + +l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"]) +# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]` +reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]] + +type IntList = list[int] + +m: IntList | None = [1, 2, 3] +reveal_type(m) # revealed: list[int] + +n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3] +reveal_type(n) # revealed: list[Literal[1, 2, 3]] + +o: list[typing.LiteralString] | None = ["a", "b", "c"] +reveal_type(o) # revealed: list[LiteralString] + +p: dict[int, int] | None = {} +reveal_type(p) # revealed: dict[int, int] + +q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3} +reveal_type(q) # revealed: dict[int | str, int] + +r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3} +reveal_type(r) # revealed: dict[int | str, int | str] +``` + +## Incorrect collection literal assignments are complained about ```py # error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`" diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md new file mode 100644 index 00000000000000..3485304b6b4224 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -0,0 +1,147 @@ +# Bidirectional type inference + +ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an +expression "from the outside in". Normally, type inference proceeds "from the inside out". That is, +in order to infer the type of an expression, the types of all sub-expressions must first be +inferred. There is no reverse dependency. However, when performing complex type inference, such as +when generics are involved, the type of an outer expression can sometimes be useful in inferring +inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types" +to the inference of inner expressions. + +## Propagating target type annotation + +```toml +[environment] +python-version = "3.12" +``` + +```py +def list1[T](x: T) -> list[T]: + return [x] + +l1 = list1(1) +reveal_type(l1) # revealed: list[Literal[1]] +l2: list[int] = list1(1) +reveal_type(l2) # revealed: list[int] + +# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`. +# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`" +l2 = l1 + +intermediate = list1(1) +# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]` +# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`" +l3: list[int] = intermediate +# TODO: it would be nice if this were `list[int]` +reveal_type(intermediate) # revealed: list[Literal[1]] +reveal_type(l3) # revealed: list[int] + +l4: list[int | str] | None = list1(1) +reveal_type(l4) # revealed: list[int | str] + +def _(l: list[int] | None = None): + l1 = l or list() + reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] + + l2: list[int] = l or list() + # it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136) + reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] + +def f[T](x: T, cond: bool) -> T | list[T]: + return x if cond else [x] + +# TODO: no error +# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`" +l5: int | list[int] = f(1, True) +``` + +`typed_dict.py`: + +```py +from typing import TypedDict + +class TD(TypedDict): + x: int + +d1 = {"x": 1} +d2: TD = {"x": 1} +d3: dict[str, int] = {"x": 1} + +reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int] +reveal_type(d2) # revealed: TD +reveal_type(d3) # revealed: dict[str, int] + +def _() -> TD: + return {"x": 1} + +def _() -> TD: + # error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor" + return {} +``` + +## Propagating return type annotation + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import overload, Callable + +def list1[T](x: T) -> list[T]: + return [x] + +def get_data() -> dict | None: + return {} + +def wrap_data() -> list[dict]: + if not (res := get_data()): + return list1({}) + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + # `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible, + # but the return type check passes here because the type of `list1(res)` is inferred + # by bidirectional type inference using the annotated return type, and the type of `res` is not used. + return list1(res) + +def wrap_data2() -> list[dict] | None: + if not (res := get_data()): + return None + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + return list1(res) + +def deco[T](func: Callable[[], T]) -> Callable[[], T]: + return func + +def outer() -> Callable[[], list[dict]]: + @deco + def inner() -> list[dict]: + if not (res := get_data()): + return list1({}) + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + return list1(res) + return inner + +@overload +def f(x: int) -> list[int]: ... +@overload +def f(x: str) -> list[str]: ... +def f(x: int | str) -> list[int] | list[str]: + # `list[int] | list[str]` is disjoint from `list[int | str]`. + if isinstance(x, int): + return list1(x) + else: + return list1(x) + +reveal_type(f(1)) # revealed: list[int] +reveal_type(f("a")) # revealed: list[str] + +async def g() -> list[int | str]: + return list1(1) + +def h[T](x: T, cond: bool) -> T | list[T]: + return i(x, cond) + +def i[T](x: T, cond: bool) -> T | list[T]: + return x if cond else [x] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 34c4f04ead029f..7da1d19286836d 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -323,6 +323,9 @@ def union_param(x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(None)) # revealed: Unknown + +def _(x: int | None): + reveal_type(union_param(x)) # revealed: int ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index c9ee5359ce68eb..a5e62f686623de 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -286,6 +286,9 @@ def union_param[T](x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(None)) # revealed: Unknown + +def _(x: int | None): + reveal_type(union_param(x)) # revealed: int ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index c77a52b5256761..8fc8b2fcb616e5 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]: reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]] plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None} reveal_type(plot2["y"]) # revealed: list[int] -# TODO: no error -# error: [invalid-argument-type] + plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)} +reveal_type(plot3["y"]) # revealed: list[int] +reveal_type(plot3["x"]) # revealed: list[int] | None Y = "y" X = "x" @@ -362,7 +363,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to all keys are required by default, `total=False` means that all keys are non-required by default): ```py -from typing_extensions import TypedDict, Required, NotRequired +from typing_extensions import TypedDict, Required, NotRequired, Final # total=False by default, but id is explicitly Required class Message(TypedDict, total=False): @@ -376,10 +377,17 @@ class User(TypedDict): email: Required[str] # Explicitly required (redundant here) bio: NotRequired[str] # Optional despite total=True +ID: Final = "id" + # Valid Message constructions msg1 = Message(id=1) # id required, content optional msg2 = Message(id=2, content="Hello") # both provided msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional +msg4: Message = {"id": 4} # id required, content optional +msg5: Message = {ID: 5} # id required, content optional + +def msg() -> Message: + return {ID: 1} # Valid User constructions user1 = User(name="Alice", email="alice@example.com") # required fields diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dc27519dbca7a5..c8c1965bccc4e3 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -977,6 +977,10 @@ impl<'db> Type<'db> { } } + pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool { + any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false) + } + pub(crate) const fn into_class_literal(self) -> Option> { match self { Type::ClassLiteral(class_type) => Some(class_type), @@ -1167,6 +1171,15 @@ impl<'db> Type<'db> { if yes { self.negate(db) } else { *self } } + /// Remove the union elements that are not related to `target`. + pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> { + if let Type::Union(union) = self { + union.filter(db, |elem| !elem.is_disjoint_from(db, target)) + } else { + self + } + } + /// Returns the fallback instance type that a literal is an instance of, or `None` if the type /// is not a literal. pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option> { diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 2f6b5858d75963..89ebac378fd15e 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -341,6 +341,48 @@ impl<'db> OverloadLiteral<'db> { /// a cross-module dependency directly on the full AST which will lead to cache /// over-invalidation. pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { + let mut signature = self.raw_signature(db); + + let scope = self.body_scope(db); + let module = parsed_module(db, self.file(db)).load(db); + let function_node = scope.node(db).expect_function().node(&module); + let index = semantic_index(db, scope.file(db)); + let file_scope_id = scope.file_scope_id(db); + let is_generator = file_scope_id.is_generator_function(index); + + if function_node.is_async && !is_generator { + signature = signature.wrap_coroutine_return_type(db); + } + signature = signature.mark_typevars_inferable(db); + + let pep695_ctx = function_node.type_params.as_ref().map(|type_params| { + GenericContext::from_type_params(db, index, self.definition(db), type_params) + }); + let legacy_ctx = GenericContext::from_function_params( + db, + self.definition(db), + signature.parameters(), + signature.return_ty, + ); + // We need to update `signature.generic_context` here, + // because type variables in `GenericContext::variables` are still non-inferable. + signature.generic_context = + GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx); + + signature + } + + /// Typed internally-visible "raw" signature for this function. + /// That is, type variables in parameter types and the return type remain non-inferable, + /// and the return types of async functions are not wrapped in `CoroutineType[...]`. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the function. This means that if the + /// calling query is not in the same file as this function is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. + fn raw_signature(self, db: &'db dyn Db) -> Signature<'db> { /// `self` or `cls` can be implicitly positional-only if: /// - It is a method AND /// - No parameters in the method use PEP-570 syntax AND @@ -402,11 +444,11 @@ impl<'db> OverloadLiteral<'db> { let function_stmt_node = scope.node(db).expect_function().node(&module); let definition = self.definition(db); let index = semantic_index(db, scope.file(db)); - let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { + let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| { GenericContext::from_type_params(db, index, definition, type_params) }); let file_scope_id = scope.file_scope_id(db); - let is_generator = file_scope_id.is_generator_function(index); + let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( db, self, @@ -417,10 +459,9 @@ impl<'db> OverloadLiteral<'db> { Signature::from_function( db, - generic_context, + pep695_ctx, definition, function_stmt_node, - is_generator, has_implicitly_positional_first_parameter, ) } @@ -599,6 +640,18 @@ impl<'db> FunctionLiteral<'db> { fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> { self.last_definition(db).signature(db) } + + /// Typed externally-visible "raw" signature of the last overload or implementation of this function. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the function. This means that if the + /// calling query is not in the same file as this function is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. + fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> { + self.last_definition(db).raw_signature(db) + } } /// Represents a function type, which might be a non-generic function, or a specialization of a @@ -877,6 +930,17 @@ impl<'db> FunctionType<'db> { .unwrap_or_else(|| self.literal(db).last_definition_signature(db)) } + /// Typed externally-visible "raw" signature of the last overload or implementation of this function. + #[salsa::tracked( + returns(ref), + cycle_fn=last_definition_signature_cycle_recover, + cycle_initial=last_definition_signature_cycle_initial, + heap_size=ruff_memory_usage::heap_size, + )] + pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> { + self.literal(db).last_definition_raw_signature(db) + } + /// Convert the `FunctionType` into a [`CallableType`]. pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> { CallableType::new(db, self.signature(db), false) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index eaf1c73f434cd2..436638ae4032fa 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -291,6 +291,28 @@ impl<'db> GenericContext<'db> { Some(Self::from_typevar_instances(db, variables)) } + pub(crate) fn merge_pep695_and_legacy( + db: &'db dyn Db, + pep695_generic_context: Option, + legacy_generic_context: Option, + ) -> Option { + match (legacy_generic_context, pep695_generic_context) { + (Some(legacy_ctx), Some(ctx)) => { + if legacy_ctx + .variables(db) + .exactly_one() + .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) + { + Some(legacy_ctx.merge(db, ctx)) + } else { + // TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed + Some(ctx) + } + } + (left, right) => left.or(right), + } + } + /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class /// list. pub(crate) fn from_base_classes( @@ -1174,7 +1196,7 @@ impl<'db> SpecializationBuilder<'db> { pub(crate) fn infer( &mut self, formal: Type<'db>, - actual: Type<'db>, + mut actual: Type<'db>, ) -> Result<(), SpecializationError<'db>> { if formal == actual { return Ok(()); @@ -1203,6 +1225,10 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } + // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`. + // So, here we remove the union elements that are not related to `formal`. + actual = actual.filter_disjoint_elements(self.db, formal); + match (formal, actual) { // TODO: We haven't implemented a full unification solver yet. If typevars appear in // multiple union elements, we ideally want to express that _only one_ of them needs to @@ -1228,9 +1254,15 @@ impl<'db> SpecializationBuilder<'db> { // def _(y: str | int | None): // reveal_type(g(x)) # revealed: str | int // ``` - let formal_bound_typevars = - (formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var()); - let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else { + // We do not handle cases where the `formal` types contain other types that contain type variables + // to prevent incorrect specialization: e.g. `T = int | list[int]` for `formal: T | list[T], actual: int | list[int]` + // (the correct specialization is `T = int`). + let types_have_typevars = formal_union + .elements(self.db) + .iter() + .filter(|ty| ty.has_type_var(self.db)); + let Ok(Type::TypeVar(formal_bound_typevar)) = types_have_typevars.exactly_one() + else { return Ok(()); }; if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) { @@ -1241,7 +1273,7 @@ impl<'db> SpecializationBuilder<'db> { if remaining_actual.is_never() { return Ok(()); } - self.add_type_mapping(formal_bound_typevar, remaining_actual); + self.add_type_mapping(*formal_bound_typevar, remaining_actual); } (Type::Union(formal), _) => { // Second, if the formal is a union, and precisely one union element _is_ a typevar (not diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 705731980b7e9c..3d949a395ae311 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::types::diagnostic::TypeCheckDiagnostics; +use crate::types::function::FunctionType; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers}; @@ -389,6 +390,12 @@ impl<'db> TypeContext<'db> { self.annotation .and_then(|ty| ty.known_specialization(db, known_class)) } + + pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self { + Self { + annotation: self.annotation.map(f), + } + } } /// Returns the statically-known truthiness of a given expression. @@ -487,6 +494,30 @@ pub(crate) fn nearest_enclosing_class<'db>( }) } +/// Returns the type of the nearest enclosing function for the given scope. +/// +/// This function walks up the ancestor scopes starting from the given scope, +/// and finds the closest (non-lambda) function definition. +/// +/// Returns `None` if no enclosing function is found. +pub(crate) fn nearest_enclosing_function<'db>( + db: &'db dyn Db, + semantic: &SemanticIndex<'db>, + scope: ScopeId, +) -> Option> { + semantic + .ancestor_scopes(scope.file_scope_id(db)) + .find_map(|(_, ancestor_scope)| { + let func = ancestor_scope.node().as_function()?; + let definition = semantic.expect_single_definition(func); + let inference = infer_definition_types(db, definition); + inference + .undecorated_type() + .unwrap_or_else(|| inference.declaration_type(definition).inner_type()) + .into_function_literal() + }) +} + /// A region within which we can infer types. #[derive(Copy, Clone, Debug)] pub(crate) enum InferenceRegion<'db> { diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index a9fabc92010824..1e93f456ca8865 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -79,6 +79,7 @@ use crate::types::function::{ }; use crate::types::generics::{GenericContext, bind_typevar}; use crate::types::generics::{LegacyGenericBase, SpecializationBuilder}; +use crate::types::infer::nearest_enclosing_function; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::Signature; @@ -5101,9 +5102,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - if let Some(ty) = - self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()) - { + let tcx = if ret.value.is_some() { + nearest_enclosing_function(self.db(), self.index, self.scope()) + .map(|func| { + // When inferring expressions within a function body, + // the expected type passed should be the "raw" type, + // i.e. type variables in the return type are non-inferable, + // and the return types of async functions are not wrapped in `CoroutineType[...]`. + TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) + }) + .unwrap_or_default() + } else { + TypeContext::default() + }; + if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) { let range = ret .value .as_ref() @@ -5900,6 +5912,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return None; }; + let tcx = tcx.map_annotation(|annotation| { + // Remove any union elements of `annotation` that are not related to `collection_ty`. + // e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list` + let collection_ty = collection_class.to_instance(self.db()); + annotation.filter_disjoint_elements(self.db(), collection_ty) + }); + // Extract the annotated type of `T`, if provided. let annotated_elt_tys = tcx .known_specialization(self.db(), collection_class) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 41b955f4ec8d8d..039b89a6ebd174 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -26,9 +26,10 @@ use crate::types::function::FunctionType; use crate::types::generics::{GenericContext, typing_self, walk_generic_context}; use crate::types::infer::nearest_enclosing_class; use crate::types::{ - ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, - NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type, + ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassLiteral, + FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, + KnownClass, MaterializationKind, NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, + VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -419,10 +420,9 @@ impl<'db> Signature<'db> { /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, - generic_context: Option>, + pep695_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, - is_generator: bool, has_implicitly_positional_first_parameter: bool, ) -> Self { let parameters = Parameters::from_parameters( @@ -431,38 +431,17 @@ impl<'db> Signature<'db> { function_node.parameters.as_ref(), has_implicitly_positional_first_parameter, ); - let return_ty = function_node.returns.as_ref().map(|returns| { - let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()) - .apply_type_mapping( - db, - &TypeMapping::MarkTypeVarsInferable(Some(definition.into())), - TypeContext::default(), - ); - if function_node.is_async && !is_generator { - KnownClass::CoroutineType - .to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty]) - } else { - plain_return_ty - } - }); + let return_ty = function_node + .returns + .as_ref() + .map(|returns| definition_expression_type(db, definition, returns.as_ref())); let legacy_generic_context = GenericContext::from_function_params(db, definition, ¶meters, return_ty); - - let full_generic_context = match (legacy_generic_context, generic_context) { - (Some(legacy_ctx), Some(ctx)) => { - if legacy_ctx - .variables(db) - .exactly_one() - .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) - { - Some(legacy_ctx.merge(db, ctx)) - } else { - // TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed - Some(ctx) - } - } - (left, right) => left.or(right), - }; + let full_generic_context = GenericContext::merge_pep695_and_legacy( + db, + pep695_generic_context, + legacy_generic_context, + ); Self { generic_context: full_generic_context, @@ -472,6 +451,27 @@ impl<'db> Signature<'db> { } } + pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self { + if let Some(definition) = self.definition { + self.apply_type_mapping_impl( + db, + &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))), + TypeContext::default(), + &ApplyTypeMappingVisitor::default(), + ) + } else { + self + } + } + + pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self { + let return_ty = self.return_ty.map(|return_ty| { + KnownClass::CoroutineType + .to_specialized_instance(db, [Type::any(), Type::any(), return_ty]) + }); + Self { return_ty, ..self } + } + /// Returns the signature which accepts any parameters and returns an `Unknown` type. pub(crate) fn unknown() -> Self { Self::new(Parameters::unknown(), Some(Type::unknown())) @@ -1728,13 +1728,9 @@ impl<'db> Parameter<'db> { kind: ParameterKind<'db>, ) -> Self { Self { - annotated_type: parameter.annotation().map(|annotation| { - definition_expression_type(db, definition, annotation).apply_type_mapping( - db, - &TypeMapping::MarkTypeVarsInferable(Some(definition.into())), - TypeContext::default(), - ) - }), + annotated_type: parameter + .annotation() + .map(|annotation| definition_expression_type(db, definition, annotation)), kind, form: ParameterForm::Value, inferred_annotation: false,