diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 18f185f351fb1..afaa06b017e9b 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -580,7 +580,7 @@ def _(): reveal_type(x4) # revealed: X ``` -## Prefer the declared type of generic classes +## Prefer the declared type of generic classes and callables ```toml [environment] @@ -682,6 +682,38 @@ x1: X[int | None] = X() reveal_type(x1) # revealed: X[None] ``` +We also prefer the declared type of `Callable` parameters, which are in contravariant position: + +```py +from typing import Callable + +type AnyToBool = Callable[[Any], bool] + +def wrap[**P, T](f: Callable[P, T]) -> Callable[P, T]: + return f + +def make_callable[T](x: T) -> Callable[[T], bool]: + raise NotImplementedError + +def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None: + raise NotImplementedError + +x1: Callable[[Any], bool] = make_callable(0) +reveal_type(x1) # revealed: (Any, /) -> bool + +x2: AnyToBool = make_callable(0) +reveal_type(x2) # revealed: (Any, /) -> bool + +x3: Callable[[list[Any]], bool] = make_callable([0]) +reveal_type(x3) # revealed: (list[Any], /) -> bool + +x4: Callable[[Any], bool] = wrap(make_callable(0)) +reveal_type(x4) # revealed: (Any, /) -> bool + +x5: Callable[[Any], bool] | None = maybe_make_callable(0) +reveal_type(x5) # revealed: ((Any, /) -> bool) | None +``` + ## Declared type preference sees through subtyping ```toml @@ -775,33 +807,48 @@ python-version = "3.12" ``` ```py -from typing import reveal_type, TypedDict +from typing import reveal_type, Any, Callable, TypedDict def identity[T](x: T) -> T: return x -def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None): +type Target = Any | list[str] | dict[str, str] | Callable[[str], None] | None + +def _(narrow: dict[str, str], target: Target): target = identity(narrow) reveal_type(target) # revealed: dict[str, str] -def _(narrow: list[str], target: list[str] | dict[str, str] | None): +def _(narrow: list[str], target: Target): target = identity(narrow) reveal_type(target) # revealed: list[str] -def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None): +def _(narrow: Callable[[str], None], target: Target): + target = identity(narrow) + reveal_type(target) # revealed: (str, /) -> None + +def _(narrow: list[str] | dict[str, str], target: Target): target = identity(narrow) reveal_type(target) # revealed: list[str] | dict[str, str] class TD(TypedDict): x: int -def _(target: list[TD] | dict[str, TD] | None): +type TargetWithTD = Any | list[TD] | dict[str, TD] | Callable[[TD], None] | None + +def _(target: TargetWithTD): target = identity([{"x": 1}]) reveal_type(target) # revealed: list[TD] -def _(target: list[TD] | dict[str, TD] | None): +def _(target: TargetWithTD): target = identity({"x": {"x": 1}}) reveal_type(target) # revealed: dict[str, TD] + +def _(target: TargetWithTD): + def make_callable[T](x: T) -> Callable[[T], None]: + raise NotImplementedError + + target = identity(make_callable({"x": 1})) + reveal_type(target) # revealed: (TD, /) -> None ``` ## Prefer the inferred type of non-generic classes @@ -886,7 +933,7 @@ def _(a: int, b: str, c: int | str): reveal_type(x10) # revealed: int | str | None ``` -## Assignability diagnostics ignore declared type of generic classes +## Assignability diagnostics ignore declared type ```toml [environment] @@ -912,19 +959,27 @@ class A(TypedDict): x2: list[A | bool] = [{"bar": 1}, 1] ``` -However, the declared type of generic classes should be ignored if the specialization is not -solvable: +However, the declared type should be ignored if the specialization is not solvable: ```py +from typing import Any, Callable + def g[T](x: list[T]) -> T: return x[0] def _(a: int | None): # error: [invalid-assignment] "Object of type `list[int | None]` is not assignable to `list[str]`" - y1: list[str] = f(a) + x1: list[str] = f(a) # error: [invalid-assignment] "Object of type `int | None` is not assignable to `str`" - y2: str = g(f(a)) + x2: str = g(f(a)) + +def make_callable[T](x: T) -> Callable[[T], bool]: + raise NotImplementedError + +def _(a: int | None): + # error: [invalid-assignment] "Object of type `(int | None, /) -> bool` is not assignable to `(str, /) -> bool`" + x1: Callable[[str], bool] = make_callable(a) ``` ## Forward annotation with unclosed string literal diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md index da958ee4ee149..5730ac232bc0c 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md @@ -181,7 +181,7 @@ def outside_callable(t: T) -> Callable[[T], T]: # revealed: ty_extensions.GenericContext[T@outside_callable] reveal_type(generic_context(outside_callable)) -# revealed: (Literal[1], /) -> Literal[1] +# revealed: (int, /) -> int reveal_type(outside_callable(1)) # revealed: None reveal_type(generic_context(outside_callable(1))) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/callables.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/callables.md index 42300e67b0ed5..3338ba3e5054d 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/callables.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/callables.md @@ -181,7 +181,7 @@ def outside_callable[T](t: T) -> Callable[[T], T]: # revealed: ty_extensions.GenericContext[T@outside_callable] reveal_type(generic_context(outside_callable)) -# revealed: (Literal[1], /) -> Literal[1] +# revealed: (int, /) -> int reveal_type(outside_callable(1)) # revealed: None reveal_type(generic_context(outside_callable(1))) diff --git a/crates/ty_python_semantic/resources/mdtest/promotion.md b/crates/ty_python_semantic/resources/mdtest/promotion.md index 1ddedb2c5b832..5de65428207cd 100644 --- a/crates/ty_python_semantic/resources/mdtest/promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/promotion.md @@ -96,6 +96,8 @@ We promote in non-covariant position in the return type of a generic function, o generic class: ```py +from typing import Callable, Literal + class Bivariant[T]: def __init__(self, value: T): ... @@ -124,6 +126,8 @@ def f8[T](x: T) -> Invariant[T] | Covariant[T] | None: ... def f9[T](x: T) -> tuple[Invariant[T], Invariant[T]] | None: ... def f10[T, U](x: T, y: U) -> tuple[Invariant[T], Covariant[U]] | None: ... def f11[T, U](x: T, y: U) -> tuple[Invariant[Covariant[T] | None], Covariant[U]] | None: ... +def f12[T](x: T) -> Callable[[T], bool] | None: ... +def f13[T](x: T) -> Callable[[bool], Invariant[T]] | None: ... reveal_type(Bivariant(1)) # revealed: Bivariant[Literal[1]] reveal_type(Covariant(1)) # revealed: Covariant[Literal[1]] @@ -144,6 +148,9 @@ reveal_type(f9(1)) # revealed: tuple[Invariant[int], Invariant[int]] | None reveal_type(f10(1, 1)) # revealed: tuple[Invariant[int], Covariant[Literal[1]]] | None reveal_type(f11(1, 1)) # revealed: tuple[Invariant[Covariant[int] | None], Covariant[Literal[1]]] | None + +reveal_type(f12(1)) # revealed: ((int, /) -> bool) | None +reveal_type(f13(1)) # revealed: ((bool, /) -> Invariant[int]) | None ``` ## Promotion is recursive @@ -190,6 +197,7 @@ declared in a promotable position: ```py from enum import Enum from typing import Sequence, Literal, LiteralString +from typing import Callable class Color(Enum): RED = "red" @@ -274,6 +282,18 @@ reveal_type(x21) # revealed: X[Literal[1]] x22: X[Literal[1]] | None = x([1]) reveal_type(x22) # revealed: X[Literal[1]] + +def make_callable[T](x: T) -> Callable[[T], bool]: + raise NotImplementedError + +def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None: + raise NotImplementedError + +x23: Callable[[Literal[1]], bool] = make_callable(1) +reveal_type(x23) # revealed: (Literal[1], /) -> bool + +x24: Callable[[Literal[1]], bool] | None = maybe_make_callable(1) +reveal_type(x24) # revealed: ((Literal[1], /) -> bool) | None ``` ## Literal annotations see through subtyping @@ -403,7 +423,7 @@ later used in a promotable position: ```py from enum import Enum -from typing import Literal +from typing import Callable, Literal def promote[T](x: T) -> list[T]: return [x] @@ -449,6 +469,16 @@ class MyEnum(Enum): def _(x: Literal[MyEnum.A, MyEnum.B]): reveal_type(x) # revealed: Literal[MyEnum.A, MyEnum.B] reveal_type([x]) # revealed: list[Literal[MyEnum.A, MyEnum.B]] + +def make_callable[T](x: T) -> Callable[[T], bool]: + raise NotImplementedError + +def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None: + raise NotImplementedError + +def _(x: Literal[1]): + reveal_type(make_callable(x)) # revealed: (Literal[1], /) -> bool + reveal_type(maybe_make_callable(x)) # revealed: ((Literal[1], /) -> bool) | None ``` Literal promotability is respected by unions: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 32db314e75cf9..3aa65f83ecbb3 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4,6 +4,7 @@ use ruff_diagnostics::{Edit, Fix}; use rustc_hash::FxHashMap; use std::borrow::Cow; +use std::iter; use std::time::Duration; use bitflags::bitflags; @@ -50,7 +51,7 @@ use crate::types::bound_super::BoundSuperType; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::callable::{CallableType, CallableTypes}; pub(crate) use crate::types::class_base::ClassBase; -use crate::types::constraints::{ConstraintSetBuilder, Solutions}; +use crate::types::constraints::ConstraintSetBuilder; use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder}; use crate::types::diagnostic::{INVALID_AWAIT, INVALID_TYPE_FORM}; pub use crate::types::display::{DisplaySettings, TypeDetail, TypeDisplayDetails}; @@ -1085,6 +1086,15 @@ impl<'db> Type<'db> { specialization } + /// Returns `true` if this type may contain preferred type mappings when provided as type context + /// during generic call inference. + /// + /// This is the case for any type which may contain types in non-covariant position within it, + /// e.g., nominal instances of a generic class, or callables. + pub(crate) fn may_prefer_declared_type(self, db: &'db dyn Db) -> bool { + self.class_specialization(db).is_some() || self.expand_eagerly(db).is_callable_type() + } + /// Returns the top materialization (or upper bound materialization) of this type, which is the /// most general form of the type that is fully static. #[must_use] @@ -1831,22 +1841,16 @@ impl<'db> Type<'db> { /// Recursively visit the specialization of a generic class instance. /// - /// The provided closure will be called with each assignment of a type variable present in this - /// type, along with the variance of the outermost type with respect to the type variable. - /// - /// If a `TypeContext` is provided, it will be narrowed as nested types are visited, if the - /// type is a specialized instance of the same class. - pub(crate) fn visit_specialization(self, db: &'db dyn Db, tcx: TypeContext<'db>, mut f: F) + /// The provided closure will be called on any nested types, along with their variance with + /// respect to the outermost type. + pub(crate) fn visit_specialization(self, db: &'db dyn Db, mut f: F) where - F: FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>), + F: FnMut(Type<'db>, TypeVarVariance), { - let constraints = ConstraintSetBuilder::new(); self.visit_specialization_impl( db, - tcx, TypeVarVariance::Covariant, &mut f, - &constraints, &SpecializationVisitor::default(), ); } @@ -1854,103 +1858,64 @@ impl<'db> Type<'db> { fn visit_specialization_impl( self, db: &'db dyn Db, - tcx: TypeContext<'db>, polarity: TypeVarVariance, - f: &mut dyn FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>), - constraints: &ConstraintSetBuilder<'db>, + f: &mut dyn FnMut(Type<'db>, TypeVarVariance), visitor: &SpecializationVisitor<'db>, ) { - let Type::NominalInstance(instance) = self else { + let Some(specialization) = self.class_specialization(db) else { match self { Type::Union(union) => { for element in union.elements(db) { - element.visit_specialization_impl( - db, - tcx, - polarity, - f, - constraints, - visitor, - ); + element.visit_specialization_impl(db, polarity, f, visitor); } } Type::Intersection(intersection) => { for element in intersection.positive(db) { - element.visit_specialization_impl( - db, - tcx, - polarity, - f, - constraints, - visitor, - ); + element.visit_specialization_impl(db, polarity, f, visitor); } } Type::TypeAlias(alias) => visitor.visit(self, || { - alias.value_type(db).visit_specialization_impl( - db, - tcx, - polarity, - f, - constraints, - visitor, - ); + alias + .value_type(db) + .visit_specialization_impl(db, polarity, f, visitor); }), + Type::Callable(callable) => { + for signature in callable.signatures(db) { + for parameter in signature.parameters() { + let variance = TypeVarVariance::Contravariant.compose(polarity); + + f(parameter.annotated_type(), variance); + + visitor.visit(parameter.annotated_type(), || { + parameter + .annotated_type() + .visit_specialization_impl(db, variance, f, visitor); + }); + } + + visitor.visit(signature.return_ty, || { + signature + .return_ty + .visit_specialization_impl(db, polarity, f, visitor); + }); + } + } _ => {} } return; }; - let Some((class_literal, Some(specialization))) = - instance.class(db).static_class_literal(db) - else { - return; - }; - let generic_context = specialization.generic_context(db); - - // Collect the type mappings used to narrow the type context. - // - // We use a forward CSA check (`alias_instance ≤ tcx`) to infer what each typevar - // in the identity specialization maps to in the type context. For example, if - // `tcx = list[int]` and `alias_instance = list[T]`, the CSA produces `T = int`. - let tcx_mappings: FxHashMap<_, _> = tcx - .annotation - .and_then(|tcx| { - let alias_instance = Type::instance(db, class_literal.identity_specialization(db)); - let set = alias_instance.when_constraint_set_assignable_to(db, tcx, constraints); - match set.solutions(db, constraints) { - Solutions::Constrained(solutions) => { - let mut mappings = FxHashMap::default(); - for solution in solutions.iter() { - for binding in solution { - mappings - .entry(binding.bound_typevar.identity(db)) - .and_modify(|existing| { - *existing = UnionType::from_two_elements( - db, - *existing, - binding.solution, - ); - }) - .or_insert(binding.solution); - } - } - Some(mappings) - } - _ => None, - } - }) - .unwrap_or_default(); - - for (type_var, ty) in generic_context.variables(db).zip(specialization.types(db)) { - let variance = type_var.variance_with_polarity(db, polarity); - let narrowed_tcx = TypeContext::new(tcx_mappings.get(&type_var.identity(db)).copied()); + for (typevar, ty) in iter::zip( + specialization.generic_context(db).variables(db), + specialization.types(db), + ) { + let variance = typevar.variance_with_polarity(db, polarity); - f(type_var, *ty, variance, narrowed_tcx); + f(*ty, variance); visitor.visit(*ty, || { - ty.visit_specialization_impl(db, narrowed_tcx, variance, f, constraints, visitor); + ty.visit_specialization_impl(db, variance, f, visitor); }); } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index f120dfb33ce4a..72d542c2f7238 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3930,7 +3930,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { FxHashSet::default(); // Attempt to solve the specialization while preferring the declared type of non-covariant - // type parameters from generic classes. + // type parameters from generic classes, or callable types. // // We use an assignability check (`return_ty ≤ tcx`) to infer what each typevar in the // function's return type maps to in the type context. (We use _constraint set_ @@ -3949,8 +3949,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // code does via `assignable_to_declared_type`). let preferred_type_mappings = return_with_tcx .and_then(|(return_ty, tcx)| { - tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) - .class_specialization(self.db)?; + if !tcx + .filter_union(self.db, |ty| ty.may_prefer_declared_type(self.db)) + .may_prefer_declared_type(self.db) + { + return None; + } let return_ty = return_ty.filter_disjoint_elements(self.db, tcx, self.inferable_typevars); @@ -4098,18 +4102,16 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { let mut variance_in_return = TypeVarVariance::Bivariant; // Find all occurrences of the type variable in the return type. - let visit_return_ty = |_, ty, variance, _| { + return_ty.visit_specialization(self.db, |ty, variance| { if ty != Type::TypeVar(typevar) { return; } variance_in_return = variance_in_return.join(variance); - }; - - return_ty.visit_specialization(self.db, self.call_expression_tcx, visit_return_ty); + }); - // Promotion is only useful if the type variable is in invariant or contravariant - // position in the return type. + // Promotion is only useful if the type variable is in non-covariant position + // in the return type. if variance_in_return.is_covariant() { return None; } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index f01689d32cc20..8997c25fe3b11 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5131,13 +5131,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { )) }; - // Prefer generic class instances when narrowing. + // Prefer the declared type of generic classes or callables when narrowing. // // Splitting up this loop is not necessary for correctness, but leads to a slight // performance improvement. for narrowed_ty in narrow_targets .iter() - .filter(|ty| ty.class_specialization(db).is_some()) + .filter(|ty| ty.may_prefer_declared_type(db)) { if let Some(result) = try_narrow(*narrowed_ty) { return result; @@ -5145,7 +5145,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } for narrowed_ty in narrow_targets .iter() - .filter(|ty| ty.class_specialization(db).is_none()) + .filter(|ty| !ty.may_prefer_declared_type(db)) { if let Some(result) = try_narrow(*narrowed_ty) { return result;