Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
reveal_type(k) # revealed: list[tuple[list[int], ...]]

l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]

type IntList = list[int]

Expand Down Expand Up @@ -416,13 +415,14 @@ a = f("a")
reveal_type(a) # revealed: list[Literal["a"]]

b: list[int | Literal["a"]] = f("a")
reveal_type(b) # revealed: list[int | Literal["a"]]
reveal_type(b) # revealed: list[Literal["a"] | int]

c: list[int | str] = f("a")
reveal_type(c) # revealed: list[int | str]
reveal_type(c) # revealed: list[str | int]

d: list[int | tuple[int, int]] = f((1, 2))
reveal_type(d) # revealed: list[int | tuple[int, int]]
# TODO: We could avoid reordering the union elements here.
reveal_type(d) # revealed: list[tuple[int, int] | int]
Copy link
Member Author

@ibraheemdev ibraheemdev Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The churn here could be avoided if we performed a third specialization, inferring the call expression annotation first before the argument types, but I'm not sure it's worth it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a TODO below indicating where we'd make the change if we wanted to fix this, so I'm 👍 to this. Though maybe add a TODO comment here, too, calling out that we realize that we reordered the union, and that we know how we'd fix it?


e: list[int] = f(True)
reveal_type(e) # revealed: list[int]
Expand All @@ -437,8 +437,49 @@ def f2[T: int](x: T) -> T:
return x

i: int = f2(True)
reveal_type(i) # revealed: int
reveal_type(i) # revealed: Literal[True]

j: int | str = f2(True)
reveal_type(j) # revealed: Literal[True]
```

Types are not widened unnecessarily:

```py
def id[T](x: T) -> T:
return x

def lst[T](x: T) -> list[T]:
return [x]

def _(i: int):
a: int | None = i
b: int | None = id(i)
c: int | str | None = id(i)
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
reveal_type(c) # revealed: int

a: list[int | None] | None = [i]
b: list[int | None] | None = id([i])
c: list[int | None] | int | None = id([i])
reveal_type(a) # revealed: list[int | None]
# TODO: these should reveal `list[int | None]`
# we currently do not use the call expression annotation as type context for argument inference
reveal_type(b) # revealed: list[Unknown | int]
reveal_type(c) # revealed: list[Unknown | int]

a: list[int | None] | None = [i]
b: list[int | None] | None = lst(i)
c: list[int | None] | int | None = lst(i)
reveal_type(a) # revealed: list[int | None]
reveal_type(b) # revealed: list[int | None]
reveal_type(c) # revealed: list[int | None]

a: list | None = []
b: list | None = id([])
c: list | int | None = id([])
reveal_type(a) # revealed: list[Unknown]
reveal_type(b) # revealed: list[Unknown]
reveal_type(c) # revealed: list[Unknown]
```
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Member:
role: str = field(default="user")
tag: str | None = field(default=None, init=False)

# revealed: (self: Member, name: str, role: str = str) -> None
# revealed: (self: Member, name: str, role: str = Literal["user"]) -> None
reveal_type(Member.__init__)

alice = Member(name="Alice", role="admin")
Expand All @@ -37,7 +37,7 @@ class Data:
content: list[int] = field(default_factory=list)
timestamp: datetime = field(default_factory=datetime.now, init=False)

# revealed: (self: Data, content: list[int] = list[int]) -> None
# revealed: (self: Data, content: list[int] = Unknown) -> None
reveal_type(Data.__init__)

data = Data([1, 2, 3])
Expand All @@ -64,7 +64,7 @@ class Person:
role: str = field(default="user", kw_only=True)

# TODO: this would ideally show a default value of `None` for `age`
# revealed: (self: Person, name: str, *, age: int | None = int | None, role: str = str) -> None
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
reveal_type(Person.__init__)

alice = Person(role="admin", name="Alice")
Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}

nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}

# TODO: this should be an error (invalid type for `name` in innermost node)
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
```

Expand Down
35 changes: 24 additions & 11 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1233,24 +1233,37 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}

/// Remove the union elements that are not related to `target`.
pub(crate) fn filter_disjoint_elements(
/// If the type is a union, filters union elements based on the provided predicate.
///
/// Otherwise, returns the type unchanged.
pub(crate) fn filter_union(
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: InferableTypeVars<'_, 'db>,
f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| {
!elem
.when_disjoint_from(db, target, inferable)
.is_always_satisfied()
})
union.filter(db, f)
} else {
self
}
}

/// If the type is a union, removes union elements that are disjoint from `target`.
///
/// Otherwise, returns the type unchanged.
pub(crate) fn filter_disjoint_elements(
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> Type<'db> {
self.filter_union(db, |elem| {
!elem
.when_disjoint_from(db, target, inferable)
.is_always_satisfied()
})
}

/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
/// is not a literal.
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
Expand Down Expand Up @@ -11185,9 +11198,9 @@ impl<'db> UnionType<'db> {
pub(crate) fn filter(
self,
db: &'db dyn Db,
filter_fn: impl FnMut(&&Type<'db>) -> bool,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
Comment on lines -11188 to +11203
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the callback take in Type<'db> instead of &Type<'db>? I find that much more ergonomic, and I think it just requires a .copied() before the filter. (Ditto above in Type::filter)

Copy link
Member

@AlexWaygood AlexWaygood Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ibraheem's current signature matches the signature for UnionType::map(). I don't have a strong objection to changing that signature, but I think it's nice to keep the signature of map() consistent with the signature of filter(). (And IIRC, I made the callback passed to UnionType::map() take &Type<'db> because it seemed generally to make most callsites more ergonomic than if it took Type<'db>. Due to the fact that some_union.elements(db).iter() calls return Iterator<Item = &Type<'db>>s rather than Iterator<Item = Type<'db>>s)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I agree with keeping them consistent, but feel strongly enough about avoiding &Type when we can that I might tackle that as a separate follow on. You know, in my copious spare time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some Type methods that take &self, so we should probably address them all at once. Only changing filter would mean .filter(Type::is_typed_dict) no longer works, for example.

Copy link
Member

@AlexWaygood AlexWaygood Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the last time we discussed it, we decided that it was pretty borderline but that it was probably slightly more efficient for methods like that to take &self rather than self, despite the fact that Type is Copy, given how big Type is

}

pub(crate) fn map_with_boundness(
Expand Down
77 changes: 50 additions & 27 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2524,20 +2524,23 @@ struct ArgumentTypeChecker<'a, 'db> {
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,

inferable_typevars: InferableTypeVars<'db, 'db>,
specialization: Option<Specialization<'db>>,
}

impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
#[expect(clippy::too_many_arguments)]
fn new(
db: &'db dyn Db,
signature: &'a Signature<'db>,
arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
Expand All @@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument_matches,
parameter_tys,
call_expression_tcx,
return_ty,
errors,
inferable_typevars: InferableTypeVars::None,
specialization: None,
Expand Down Expand Up @@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// TODO: Use the list of inferable typevars from the generic context of the callable.
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);

// Note that we infer the annotated type _before_ the arguments if this call is part of
// an annotated assignment, to closer match the order of any unions written in the type
// annotation.
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
{
match call_expression_tcx {
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
Type::TypeVar(_) => {}

_ => {
// Ignore any specialization errors here, because the type context is only used as a hint
// to infer a more assignable return type.
let _ = builder.infer(return_ty, call_expression_tcx);
}
}
}

let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
Expand All @@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}

self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
// Build the specialization first without inferring the type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);

let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?;
let call_expression_tcx = self.call_expression_tcx.annotation?;

// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
if call_expression_tcx.is_type_var() {
return None;
}

// If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None;
}

// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
// annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?;

// Otherwise, build the specialization again after inferring the type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization);

Some((Some(specialization), return_ty))
};

(self.specialization, self.return_ty) =
try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty));
}

fn check_argument_type(
Expand Down Expand Up @@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}

fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
(self.inferable_typevars, self.specialization)
fn finish(
self,
) -> (
InferableTypeVars<'db, 'db>,
Option<Specialization<'db>>,
Type<'db>,
) {
(self.inferable_typevars, self.specialization, self.return_ty)
}
}

Expand Down Expand Up @@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> {
&self.argument_matches,
&mut self.parameter_tys,
call_expression_tcx,
self.return_ty,
&mut self.errors,
);

// If this overload is generic, first see if we can infer a specialization of the function
// from the arguments that were passed in.
checker.infer_specialization();

checker.check_argument_types();
(self.inferable_typevars, self.specialization) = checker.finish();
if let Some(specialization) = self.specialization {
self.return_ty = self.return_ty.apply_specialization(db, specialization);
}

(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
}

pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
Expand Down
11 changes: 7 additions & 4 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
let tcx = tcx_specialization.and_then(|specialization| {
specialization.get(self.db, variable.bound_typevar)
});

ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
}

Expand All @@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
mut actual: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
Expand Down Expand Up @@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(());
}

// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
// So, here we remove the union elements that are not related to `formal`.
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
// Remove the union elements that are not related to `formal`.
//
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
// to `int`.
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);

match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
.and_then(|ty| ty.known_specialization(db, known_class))
}

pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
Self {
annotation: self.annotation.map(f),
}
Expand Down
Loading
Loading