Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ f: dict[list[Literal[1]], list[Literal[Color.RED]]] = {[1]: [Color.RED, Color.RE
reveal_type(f) # revealed: dict[list[Literal[1]], list[Color]]

class X[T]:
value: T

def __init__(self, value: T): ...

g: X[Literal[1]] = X(1)
Expand Down Expand Up @@ -366,7 +368,7 @@ class Z[T]:

z1: Z[Any] = Z(1)
# TODO: This should reveal `Z[Any]`.
reveal_type(z1) # revealed: Z[int]
reveal_type(z1) # revealed: Z[Literal[1]]
```

## PEP-604 annotations are supported
Expand Down Expand Up @@ -485,7 +487,7 @@ def f[T](x: T) -> list[T]:
return [x]

a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]
reveal_type(a) # revealed: list[str]

b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[int | Literal["a"]]
Expand All @@ -499,10 +501,10 @@ reveal_type(d) # revealed: list[int | tuple[int, int]]
e: list[int] = f(True)
reveal_type(e) # revealed: list[int]

# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `list[int]`"
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `list[int]`"
g: list[int] = f("a")

# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `tuple[int]`"
h: tuple[int] = f("a")

def f2[T: int](x: T) -> T:
Expand Down Expand Up @@ -607,7 +609,7 @@ def f3[T](x: T) -> list[T] | dict[T, T]:
return [x]

a = f(1)
reveal_type(a) # revealed: list[Literal[1]]
reveal_type(a) # revealed: list[int]

b: list[Any] = f(1)
reveal_type(b) # revealed: list[Any]
Expand Down Expand Up @@ -667,7 +669,7 @@ x4 = invariant(1)
reveal_type(x1) # revealed: Bivariant[Literal[1]]
reveal_type(x2) # revealed: Covariant[Literal[1]]
reveal_type(x3) # revealed: Contravariant[Literal[1]]
reveal_type(x4) # revealed: Invariant[Literal[1]]
reveal_type(x4) # revealed: Invariant[int]

x5: Bivariant[Any] = bivariant(1)
x6: Covariant[Any] = covariant(1)
Expand Down
26 changes: 10 additions & 16 deletions crates/ty_python_semantic/resources/mdtest/bidirectional.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,19 @@ python-version = "3.12"
```

```py
from typing import Literal

def list1[T](x: T) -> list[T]:
return [x]

l1 = list1(1)
l1: list[Literal[1]] = list1(1)
reveal_type(l1) # revealed: list[Literal[1]]
l2: list[int] = list1(1)
reveal_type(l2) # revealed: list[int]

# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l2 = l1

intermediate = list1(1)
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l3: list[int] = intermediate
# TODO: it would be nice if this were `list[int]`
reveal_type(intermediate) # revealed: list[Literal[1]]
reveal_type(l3) # revealed: list[int]
l2 = list1(1)
reveal_type(l2) # revealed: list[int]

l4: list[int | str] | None = list1(1)
reveal_type(l4) # revealed: list[int | str]
l3: list[int | str] | None = list1(1)
reveal_type(l3) # revealed: list[int | str]

def _(l: list[int] | None = None):
l1 = l or list()
Expand Down Expand Up @@ -233,6 +224,9 @@ def _(flag: bool):

def _(c: C):
c.x = lst(1)

# TODO: Use the parameter type of `__set__` as type context to avoid this error.
# error: [invalid-assignment]
C.x = lst(1)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,10 @@ reveal_type(ParentDataclass.__init__)
reveal_type(ChildOfParentDataclass.__init__)

result_int = uses_dataclass(42)
reveal_type(result_int) # revealed: ChildOfParentDataclass[Literal[42]]
reveal_type(result_int) # revealed: ChildOfParentDataclass[int]

result_str = uses_dataclass("hello")
reveal_type(result_str) # revealed: ChildOfParentDataclass[Literal["hello"]]
reveal_type(result_str) # revealed: ChildOfParentDataclass[str]
```

## Descriptor-typed fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class C[T, U]:
class D[V](C[V, int]):
def __init__(self, x: V) -> None: ...

reveal_type(D(1)) # revealed: D[int]
reveal_type(D(1)) # revealed: D[Literal[1]]
```

### Generic class inherits `__init__` from generic base class
Expand All @@ -334,8 +334,8 @@ class C[T, U]:
class D[T, U](C[T, U]):
pass

reveal_type(C(1, "str")) # revealed: C[int, str]
reveal_type(D(1, "str")) # revealed: D[int, str]
reveal_type(C(1, "str")) # revealed: C[Literal[1], Literal["str"]]
reveal_type(D(1, "str")) # revealed: D[Literal[1], Literal["str"]]
```

### Generic class inherits `__init__` from `dict`
Expand All @@ -358,7 +358,7 @@ context. But from the user's point of view, this is another example of the above
```py
class C[T, U](tuple[T, U]): ...

reveal_type(C((1, 2))) # revealed: C[int, int]
reveal_type(C((1, 2))) # revealed: C[Literal[1], Literal[2]]
```

### Upcasting a `tuple` to its `Sequence` supertype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,8 @@ For covariant types, such as `frozenset`, the ideal behaviour would be to not pr
types to their instance supertypes: doing so causes more false positives than it fixes:

```py
# TODO: better here would be `frozenset[Literal[1, 2, 3]]`
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[int]
# TODO: better here would be `frozenset[tuple[Literal[1], Literal[2], Literal[3]]]`
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[int, int, int]]
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[Literal[1, 2, 3]]
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[Literal[1], Literal[2], Literal[3]]]
Comment on lines +517 to +518
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonderful!

```

Literals are always promoted for invariant containers such as `list`, however, even though this can
Expand Down
12 changes: 6 additions & 6 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,21 @@ Also, the value types ​​declared in a `TypedDict` affect generic call infere

```py
class Plot(TypedDict):
y: list[int]
x: list[int] | None
y: list[int | None]
x: list[int | None] | None

plot1: Plot = {"y": [1, 2, 3], "x": None}

def homogeneous_list[T](*args: T) -> list[T]:
return list(args)

reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[int]
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
reveal_type(plot2["y"]) # revealed: list[int]
reveal_type(plot2["y"]) # revealed: list[int | None]

plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
reveal_type(plot3["y"]) # revealed: list[int]
reveal_type(plot3["x"]) # revealed: list[int] | None
reveal_type(plot3["y"]) # revealed: list[int | None]
reveal_type(plot3["x"]) # revealed: list[int | None] | None

Y = "y"
X = "x"
Expand Down
56 changes: 53 additions & 3 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ use crate::types::function::{
OverloadLiteral,
};
use crate::types::generics::{
InferableTypeVars, Specialization, SpecializationBuilder, SpecializationError,
GenericContextTypeVar, InferableTypeVars, Specialization, SpecializationBuilder,
SpecializationError,
};
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
Expand Down Expand Up @@ -2762,6 +2763,51 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
.or(self.signature.return_ty)
.zip(self.call_expression_tcx.annotation);

let tcx_specialization = self
.call_expression_tcx
.annotation
.and_then(|annotation| annotation.class_specialization(self.db));

let promote_literals = |typevar: GenericContextTypeVar<'db>, ty: Type<'db>| -> Type<'db> {
let bound_typevar = typevar.bound_typevar();

if typevar.is_inherited() && bound_typevar.variance(self.db).is_invariant() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does is_inherited mean here? Why does it matter whether a typevar is inherited or not?

(It would be great if is_inherited could have a doc-comment explaining what it does 😃)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I see -- you're using the term to mean that the typevar is associated with a class's generic context rather than the generic context of a function or method? So it doesn't have any thing to do with the generic context of a class's base classes etc.?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, though I didn't invent that terminology.

Ideally we wouldn't need to special case inherited type variables, but we don't currently have a robust way to get access to the bound type of a constructor method (the synthesized_constructor_return_ty).

return ty.promote_literals(
self.db,
TypeContext::new(
tcx_specialization
.and_then(|specialization| specialization.get(self.db, bound_typevar)),
),
);
}

let Some(return_specialization) = self
.signature
.return_ty
.and_then(|return_ty| return_ty.class_specialization(self.db))
else {
return ty;
};

if let Some((typevar, _)) = return_specialization
.generic_context(self.db)
.variables(self.db)
.zip(return_specialization.types(self.db))
.find(|(_, ty)| **ty == Type::TypeVar(bound_typevar))
.filter(|(typevar, _)| typevar.variance(self.db).is_invariant())
{
return ty.promote_literals(
self.db,
TypeContext::new(
tcx_specialization
.and_then(|specialization| specialization.get(self.db, typevar)),
),
);
}

ty
};

self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);

Expand Down Expand Up @@ -2811,7 +2857,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}

// Build the specialization first without inferring the complete type context.
let isolated_specialization = builder.build(generic_context, self.call_expression_tcx);
let isolated_specialization = builder
.mapped(generic_context, promote_literals)
.build(generic_context);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);
Expand All @@ -2836,7 +2884,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
builder.infer(return_ty, call_expression_tcx).ok()?;

// Otherwise, build the specialization again after inferring the complete type context.
let specialization = builder.build(generic_context, self.call_expression_tcx);
let specialization = builder
.mapped(generic_context, promote_literals)
.build(generic_context);
let return_ty = return_ty.apply_specialization(self.db, specialization);

Some((Some(specialization), return_ty))
Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ impl<'db> ClassLiteral<'db> {
/// promote any typevars that are inferred as a literal to the corresponding instance type.
fn inherited_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
self.generic_context(db)
.map(|generic_context| generic_context.promote_literals(db))
.map(|generic_context| generic_context.set_inherited(db))
}

pub(super) fn file(self, db: &dyn Db) -> File {
Expand Down
70 changes: 39 additions & 31 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,31 @@ impl<'a, 'db> InferableTypeVars<'a, 'db> {
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)]
pub struct GenericContextTypeVar<'db> {
bound_typevar: BoundTypeVarInstance<'db>,
should_promote_literals: bool,
is_inherited: bool,
}

impl<'db> GenericContextTypeVar<'db> {
fn new(bound_typevar: BoundTypeVarInstance<'db>) -> Self {
Self {
bound_typevar,
should_promote_literals: false,
is_inherited: false,
}
}

fn promote_literals(mut self) -> Self {
self.should_promote_literals = true;
/// Returns `true` if this type variable is inherited from the generic context of an
/// outer class.
pub fn is_inherited(&self) -> bool {
self.is_inherited
}

fn set_inherited(mut self) -> Self {
self.is_inherited = true;
self
}

pub fn bound_typevar(&self) -> BoundTypeVarInstance<'db> {
self.bound_typevar
}
}

/// A list of formal type variables for a generic function, class, or type alias.
Expand Down Expand Up @@ -262,14 +272,13 @@ impl<'db> GenericContext<'db> {
Self::from_variables(db, type_params.into_iter().map(GenericContextTypeVar::new))
}

/// Returns a copy of this generic context where we will promote literal types in any inferred
/// specializations.
pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Self {
/// Mark the variables in this generic context as inherited from an outer class definition.
pub(crate) fn set_inherited(self, db: &'db dyn Db) -> Self {
Self::from_variables(
db,
self.variables_inner(db)
.values()
.map(|variable| variable.promote_literals()),
.map(|variable| variable.set_inherited()),
)
}

Expand Down Expand Up @@ -1321,31 +1330,30 @@ impl<'db> SpecializationBuilder<'db> {
&self.types
}

pub(crate) fn build(
&mut self,
pub(crate) fn mapped(
&self,
generic_context: GenericContext<'db>,
tcx: TypeContext<'db>,
) -> Specialization<'db> {
let tcx_specialization = tcx
.annotation
.and_then(|annotation| annotation.class_specialization(self.db));

let types =
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
let mut ty = self.types.get(identity).copied();

// When inferring a specialization for a generic class typevar from a constructor call,
// promote any typevars that are inferred as a literal to the corresponding instance type.
if variable.should_promote_literals {
let tcx = tcx_specialization.and_then(|specialization| {
specialization.get(self.db, variable.bound_typevar)
});

ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
}
f: impl Fn(GenericContextTypeVar<'db>, Type<'db>) -> Type<'db>,
) -> Self {
let mut types = self.types.clone();
for (identity, variable) in generic_context.variables_inner(self.db) {
if let Some(ty) = types.get_mut(identity) {
*ty = f(*variable, *ty);
}
}

Self {
db: self.db,
inferable: self.inferable,
types,
}
}

ty
});
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
let types = generic_context
.variables_inner(self.db)
.iter()
.map(|(identity, _)| self.types.get(identity).copied());

// TODO Infer the tuple spec for a tuple type
generic_context.specialize_partial(self.db, types)
Expand Down
Loading
Loading