diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 464622e79419b..2cc55132621af 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -8,13 +8,13 @@ when generics are involved, the type of an outer expression can sometimes be use inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types" to the inference of inner expressions. -## Propagating target type annotation - ```toml [environment] python-version = "3.12" ``` +## Propagating target type annotation + ```py from typing import Literal @@ -80,11 +80,6 @@ def _() -> TD: ## Propagating return type annotation -```toml -[environment] -python-version = "3.12" -``` - ```py from typing import overload, Callable @@ -192,11 +187,6 @@ def f() -> list[Literal[1]]: ## Instance attributes -```toml -[environment] -python-version = "3.12" -``` - Both meta and class/instance attribute annotations are used as type context: ```py @@ -240,13 +230,110 @@ def _(xy: X | Y): xy.x = reveal_type([1]) # revealed: list[int] ``` -## Class constructor parameters +## Overload evaluation -```toml -[environment] -python-version = "3.12" +The type context of all matching overloads are considered during argument inference: + +```py +from typing import overload, TypedDict + +def int_or_str() -> int | str: + raise NotImplementedError + +@overload +def f1(x: list[int | None], y: int) -> int: ... +@overload +def f1(x: list[int | str], y: str) -> str: ... +def f1(x, y) -> int | str: + raise NotImplementedError + +# TODO: We should reveal `list[int]` here. +x1 = f1(reveal_type([1]), 1) # revealed: list[int] +reveal_type(x1) # revealed: int + +x2 = f1(reveal_type([1]), int_or_str()) # revealed: list[int] +reveal_type(x2) # revealed: int | str + +@overload +def f2[T](x: T, y: int) -> T: ... +@overload +def f2(x: list[int | str], y: str) -> object: ... +def f2(x, y) -> object: ... + +x3 = f2(reveal_type([1]), 1) # revealed: list[int] +reveal_type(x3) # revealed: list[int] + +class TD(TypedDict): + x: list[int | str] + +class TD2(TypedDict): + x: list[int | None] + +@overload +def f3(x: TD, y: int) -> int: ... +@overload +def f3(x: TD2, y: str) -> str: ... +def f3(x, y) -> object: ... + +# TODO: We should reveal `TD2` here. +x4 = f3(reveal_type({"x": [1]}), "1") # revealed: dict[str, list[int]] +reveal_type(x4) # revealed: str + +x5 = f3(reveal_type({"x": [1]}), int_or_str()) # revealed: dict[str, list[int]] +reveal_type(x5) # revealed: int | str + +@overload +def f4[T](_: list[T]) -> list[T]: ... +@overload +def f4(_: list[str]) -> list[str]: ... +def f4(_: object): ... + +x6 = f4(reveal_type([])) # revealed: list[Unknown] +reveal_type(x6) # revealed: list[Unknown] + +@overload +def f5(_: list[int | str]) -> int: ... +@overload +def f5(_: set[int | str]) -> str: ... +def f5(_) -> object: + raise NotImplementedError + +def list_or_set[T](x: T) -> list[T] | set[T]: + raise NotImplementedError + +# TODO: We should reveal `list[int | str] | set[int | str]` here. +x7 = f5(reveal_type(list_or_set(1))) # revealed: list[int] | set[int] +reveal_type(x7) # revealed: int | str + +@overload +def f6(_: list[int | None]) -> int: ... +@overload +def f6(_: set[int | str]) -> str: ... +def f6(_) -> object: + raise NotImplementedError + +def list_or_set2[T, U](x: T, y: U) -> list[T] | set[U]: + raise NotImplementedError + +# TODO: We should not error here. +# error: [no-matching-overload] +x8 = f6(reveal_type(list_or_set2(1, 1))) # revealed: list[int] | set[int] +reveal_type(x8) # revealed: Unknown + +@overload +def f7(y: list[int | str]) -> list[int | str]: ... +@overload +def f7[T](y: list[T]) -> list[T]: ... +def f7(y: object) -> object: + raise NotImplementedError + +# TODO: We should reveal `list[int | str]` here. +x9 = f7(reveal_type(["Sheet1"])) # revealed: list[str] +reveal_type(x9) # revealed: list[int | str] ``` +## Class constructor parameters + The parameters of both `__init__` and `__new__` are used as type context sources for constructor calls: @@ -269,11 +356,6 @@ A(f([])) ## Conditional expressions -```toml -[environment] -python-version = "3.12" -``` - The type context is propagated through both branches of conditional expressions: ```py @@ -290,11 +372,6 @@ def _(flag: bool): ## Dunder Calls -```toml -[environment] -python-version = "3.12" -``` - The key and value parameters types are used as type context for `__setitem__` dunder calls: ```py @@ -387,11 +464,6 @@ def _(x: Intersection[X, Y]): ## Multi-inference diagnostics -```toml -[environment] -python-version = "3.12" -``` - Diagnostics unrelated to the type-context are only reported once: ```py diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 7340dcc006fa5..6034ccf428735 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -1566,10 +1566,14 @@ config["port"] = 80 from typing import TypedDict from typing_extensions import NotRequired +class Inner(TypedDict): + inner: int + class Person(TypedDict): name: str age: int | None extra: NotRequired[str] + inner: NotRequired[Inner] def _(p: Person) -> None: reveal_type(p.keys()) # revealed: dict_keys[str, object] @@ -1590,6 +1594,10 @@ def _(p: Person) -> None: # The type of the default parameter can be anything: reveal_type(p.get("extra", 0)) # revealed: str | Literal[0] + # Even another typed dict: + # TODO: This should evaluate to `Inner`. + reveal_type(p.get("inner", {"inner": 0})) # revealed: Inner | dict[str, int] + # We allow access to unknown keys (they could be set for a subtype of Person) reveal_type(p.get("unknown")) # revealed: Unknown | None reveal_type(p.get("unknown", "default")) # revealed: Unknown | Literal["default"] diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index c944bdc3aa29e..bf03f68058280 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -3,11 +3,12 @@ use std::fmt::Display; use itertools::{Either, Itertools}; use ruff_python_ast as ast; +use rustc_hash::FxHashMap; use crate::Db; use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::tuple::Tuple; -use crate::types::{KnownClass, Type}; +use crate::types::{KnownClass, Type, TypeContext}; /// Maximum number of expanded types that can be generated from a single tuple's /// Cartesian product in [`expand_type`]. @@ -43,11 +44,72 @@ pub(crate) enum Argument<'a> { #[derive(Clone, Debug, Default)] pub(crate) struct CallArguments<'a, 'db> { arguments: Vec>, - types: Vec>>, + types: Vec>, +} + +/// Inferred types for a given argument. +/// +/// Note that a single argument may produce multiple distinct inferred types when inferred +/// with type context across multiple bindings. +#[derive(Clone, Debug, Default)] +pub(crate) struct CallArgumentTypes<'db> { + fallback_type: Option>, + types: FxHashMap, Type<'db>>, +} + +impl<'db> CallArgumentTypes<'db> { + pub(crate) fn new(fallback_ty: Option>) -> Self { + Self { + fallback_type: fallback_ty, + types: FxHashMap::default(), + } + } + + /// Returns the most appropriate type of this argument when there is no specific declared type. + pub(crate) fn get_default(&self) -> Option> { + // If this type was inferred against exactly one declared type, or was inferred against + // multiple, but resulted in a single inferred type, we have an exact type to return. + if let Ok(exact_ty) = self + .types + .values() + .exactly_one() + .or(self.types.values().all_equal_value()) + { + return Some(*exact_ty); + } + + self.fallback_type + } + + /// Returns the type of this argument when inferred against the provided declared type. + pub(crate) fn get_for_declared_type(&self, tcx: Type<'db>) -> Type<'db> { + self.types + .get(&tcx) + .copied() + .or(self.get_default()) + .unwrap_or(Type::unknown()) + } + + /// Insert the type of this argument when inferred with the provided type context. + pub(crate) fn insert(&mut self, tcx: impl Into>, ty: Type<'db>) { + match tcx.into().annotation { + None => self.fallback_type = Some(ty), + Some(tcx) => { + self.types.insert(tcx, ty); + } + } + } + + pub(crate) fn iter(&self) -> impl Iterator, Type<'db>)> { + self.types + .iter() + .map(|(tcx, ty)| (TypeContext::new(Some(*tcx)), *ty)) + .chain(self.fallback_type.map(|ty| (TypeContext::default(), ty))) + } } impl<'a, 'db> CallArguments<'a, 'db> { - fn new(arguments: Vec>, types: Vec>>) -> Self { + fn new(arguments: Vec>, types: Vec>) -> Self { debug_assert!(arguments.len() == types.len()); Self { arguments, types } } @@ -121,7 +183,11 @@ impl<'a, 'db> CallArguments<'a, 'db> { /// Create a [`CallArguments`] from an iterator over non-variadic positional argument types. pub(crate) fn positional(positional_tys: impl IntoIterator>) -> Self { - let types: Vec<_> = positional_tys.into_iter().map(Some).collect(); + let types: Vec<_> = positional_tys + .into_iter() + .map(Some) + .map(CallArgumentTypes::new) + .collect(); let arguments = vec![Argument::Positional; types.len()]; Self { arguments, types } } @@ -130,14 +196,10 @@ impl<'a, 'db> CallArguments<'a, 'db> { self.arguments.len() } - pub(crate) fn types(&self) -> &[Option>] { + pub(crate) fn types(&self) -> &[CallArgumentTypes<'db>] { &self.types } - pub(crate) fn iter_types(&self) -> impl Iterator> { - self.types.iter().map(|ty| ty.unwrap_or_else(Type::unknown)) - } - /// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front /// of this argument list. (If `bound_self` is none, we return the argument list /// unmodified.) @@ -146,8 +208,8 @@ impl<'a, 'db> CallArguments<'a, 'db> { let arguments = std::iter::once(Argument::Synthetic) .chain(self.arguments.iter().copied()) .collect(); - let types = std::iter::once(bound_self) - .chain(self.types.iter().copied()) + let types = std::iter::once(CallArgumentTypes::new(bound_self)) + .chain(self.types.iter().cloned()) .collect(); Cow::Owned(CallArguments { arguments, types }) } else { @@ -155,13 +217,15 @@ impl<'a, 'db> CallArguments<'a, 'db> { } } - pub(crate) fn iter(&self) -> impl Iterator, Option>)> + '_ { - (self.arguments.iter().copied()).zip(self.types.iter().copied()) + pub(crate) fn iter( + &self, + ) -> impl Iterator, &CallArgumentTypes<'db>)> + '_ { + (self.arguments.iter().copied()).zip(self.types.iter()) } pub(crate) fn iter_mut( &mut self, - ) -> impl Iterator, &mut Option>)> + '_ { + ) -> impl Iterator, &mut CallArgumentTypes<'db>)> + '_ { (self.arguments.iter().copied()).zip(self.types.iter_mut()) } @@ -191,7 +255,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { /// This is useful to avoid cloning the initial types vector if none of the types can be /// expanded. enum ExpandingState<'a, 'b, 'db> { - Initial(&'b Vec>>), + Initial(&'b Vec>), Expanded(Vec>), } @@ -203,7 +267,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { } } - fn iter(&self) -> impl Iterator>]> + '_ { + fn iter(&self) -> impl Iterator]> + '_ { match self { ExpandingState::Initial(types) => { Either::Left(std::iter::once(types.as_slice())) @@ -228,10 +292,17 @@ impl<'a, 'db> CallArguments<'a, 'db> { // Find the next type that can be expanded. let expanded_types = loop { let arg_type = self.types.get(index)?; - if let Some(arg_type) = arg_type { - if let Some(expanded_types) = expand_type(db, *arg_type) { - break expanded_types; - } + // TODO: For types inferred multiple times with distinct type context, we currently only + // expand the default inference. Note that direct expansion of a type inferred against a + // given declared type would not likely be assignable to other declared types without + // re-inference, and so a more complete implementation would likely have to re-infer the + // argument type against the union a given subset of type contexts before expansion. However, + // this only shows up in very convoluted instances of generic call inference across multiple + // overloads, and is unlikely to happen in practice. + if let Some(arg_type) = arg_type.get_default() + && let Some(expanded_types) = expand_type(db, arg_type) + { + break expanded_types; } index += 1; }; @@ -250,7 +321,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { for pre_expanded_types in state.iter() { for subtype in &expanded_types { let mut new_expanded_types = pre_expanded_types.to_vec(); - new_expanded_types[index] = Some(*subtype); + new_expanded_types[index] = CallArgumentTypes::new(Some(*subtype)); expanded_arguments.push(CallArguments::new( self.arguments.clone(), new_expanded_types, @@ -277,6 +348,24 @@ impl<'a, 'db> CallArguments<'a, 'db> { } pub(super) fn display(&self, db: &'db dyn Db) -> impl Display { + struct DisplayCallArgumentTypes<'a, 'db> { + types: &'a CallArgumentTypes<'db>, + db: &'db dyn Db, + } + + impl std::fmt::Display for DisplayCallArgumentTypes<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map() + .entries(self.types.iter().map(|(tcx, ty)| { + ( + tcx.annotation.as_ref().map(|ty| ty.display(self.db)), + ty.display(self.db), + ) + })) + .finish() + } + } + struct DisplayCallArguments<'a, 'db> { call_arguments: &'a CallArguments<'a, 'db>, db: &'db dyn Db, @@ -285,30 +374,32 @@ impl<'a, 'db> CallArguments<'a, 'db> { impl std::fmt::Display for DisplayCallArguments<'_, '_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("(")?; - for (index, (argument, ty)) in self.call_arguments.iter().enumerate() { + for (index, (argument, types)) in self.call_arguments.iter().enumerate() { if index > 0 { write!(f, ", ")?; } match argument { - Argument::Synthetic => write!( - f, - "self: {}", - ty.unwrap_or_else(Type::unknown).display(self.db) - )?, + Argument::Synthetic => { + write!( + f, + "self: {}", + DisplayCallArgumentTypes { types, db: self.db } + )?; + } Argument::Positional => { - write!(f, "{}", ty.unwrap_or_else(Type::unknown).display(self.db))?; + write!(f, "{}", DisplayCallArgumentTypes { types, db: self.db })?; } Argument::Variadic => { - write!(f, "*{}", ty.unwrap_or_else(Type::unknown).display(self.db))?; + write!(f, "*{}", DisplayCallArgumentTypes { types, db: self.db })?; } Argument::Keyword(name) => write!( f, "{}={}", name, - ty.unwrap_or_else(Type::unknown).display(self.db) + DisplayCallArgumentTypes { types, db: self.db } )?, Argument::Keywords => { - write!(f, "**{}", ty.unwrap_or_else(Type::unknown).display(self.db))?; + write!(f, "**{}", DisplayCallArgumentTypes { types, db: self.db })?; } } } @@ -344,7 +435,10 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments< where T: IntoIterator, Option>)>, { - let (arguments, types) = iter.into_iter().unzip(); + let (arguments, types) = iter + .into_iter() + .map(|(arg, ty)| (arg, CallArgumentTypes::new(ty))) + .unzip(); Self { arguments, types } } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 4eb2418b89ce0..6ed0ac4401bbc 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -25,7 +25,7 @@ use crate::db::Db; use crate::dunder_all::dunder_all_names; use crate::place::{DefinedPlace, Definedness, Place, known_module_symbol}; use crate::subscript::PyIndex; -use crate::types::call::arguments::{Expansion, is_expandable_type}; +use crate::types::call::arguments::{CallArgumentTypes, Expansion, is_expandable_type}; use crate::types::callable::CallableTypeKind; use crate::types::constraints::{ConstraintSet, ConstraintSetBuilder}; use crate::types::diagnostic::{ @@ -94,14 +94,14 @@ impl<'db> BindingsElement<'db> { &mut self, db: &'db dyn Db, constraints: &ConstraintSetBuilder<'db>, - argument_types: &CallArguments<'_, 'db>, + call_arguments: &CallArguments<'_, 'db>, call_expression_tcx: TypeContext<'db>, ) -> Option { let mut result = ArgumentForms::default(); let mut any_forms = false; for binding in &mut self.bindings { if let Some(forms) = - binding.check_types(db, constraints, argument_types, call_expression_tcx) + binding.check_types(db, constraints, call_arguments, call_expression_tcx) { result.merge(&forms); any_forms = true; @@ -421,7 +421,7 @@ impl<'db> Bindings<'db> { /// Verify that the type of each argument is assignable to type of the parameter that it was /// matched to. /// - /// You must provide an `argument_types` that was created from the same `arguments` that you + /// You must provide an `call_arguments` that was created from the same `arguments` that you /// provided to [`match_parameters`][Self::match_parameters]. /// /// The type context of the call expression is also used to infer the specialization of generic @@ -434,14 +434,14 @@ impl<'db> Bindings<'db> { mut self, db: &'db dyn Db, constraints: &ConstraintSetBuilder<'db>, - argument_types: &CallArguments<'_, 'db>, + call_arguments: &CallArguments<'_, 'db>, call_expression_tcx: TypeContext<'db>, dataclass_field_specifiers: &[Type<'db>], ) -> Result> { match self.check_types_impl( db, constraints, - argument_types, + call_arguments, call_expression_tcx, dataclass_field_specifiers, ) { @@ -454,14 +454,14 @@ impl<'db> Bindings<'db> { &mut self, db: &'db dyn Db, constraints: &ConstraintSetBuilder<'db>, - argument_types: &CallArguments<'_, 'db>, + call_arguments: &CallArguments<'_, 'db>, call_expression_tcx: TypeContext<'db>, dataclass_field_specifiers: &[Type<'db>], ) -> Result<(), CallErrorKind> { // Check types for each element (union variant) for element in &mut self.elements { if let Some(updated_argument_forms) = - element.check_types(db, constraints, argument_types, call_expression_tcx) + element.check_types(db, constraints, call_arguments, call_expression_tcx) { // If this element returned a new set of argument forms (indicating successful // argument type expansion), merge them into the existing forms. @@ -470,7 +470,7 @@ impl<'db> Bindings<'db> { } self.argument_forms.shrink_to_fit(); - self.evaluate_known_cases(db, argument_types, dataclass_field_specifiers); + self.evaluate_known_cases(db, call_arguments, dataclass_field_specifiers); // For intersection elements with at least one successful binding, // filter out the failing bindings. @@ -611,6 +611,28 @@ impl<'db> Bindings<'db> { UnionType::from_elements(db, element_return_types) } + /// Returns the inferred type for the argument at the specified index. + pub(crate) fn type_for_argument<'a>( + &'a self, + call_arguments: &'a CallArguments<'a, 'db>, + argument_index: usize, + ) -> Type<'db> { + let argument_types = &call_arguments.types()[argument_index]; + + // If there is a single matching parameter, return the argument type inferred against + // its declared type. + if let Some(binding) = self.single_element() + && let Ok((_, overload)) = binding.matching_overloads().exactly_one() + && let [parameter_index] = *overload.argument_matches[argument_index].parameters + { + let declared_type = overload.signature.parameters()[parameter_index].annotated_type(); + return argument_types.get_for_declared_type(declared_type); + } + + // Otherwise, return the default type. + argument_types.get_default().unwrap_or(Type::unknown()) + } + /// Report diagnostics for all of the errors that occurred when trying to match actual /// arguments to formal parameters. If the callable is a union, or has multiple overloads, we /// report a single diagnostic if we couldn't match any union element or overload. @@ -721,7 +743,7 @@ impl<'db> Bindings<'db> { fn evaluate_known_cases( &mut self, db: &'db dyn Db, - argument_types: &CallArguments<'_, 'db>, + call_arguments: &CallArguments<'_, 'db>, dataclass_field_specifiers: &[Type<'db>], ) { let to_bool = |ty: &Option>, default: bool| -> bool { @@ -1089,9 +1111,9 @@ impl<'db> Bindings<'db> { { return ty; } - argument_types.iter().find_map(|(arg, ty)| { + call_arguments.iter().find_map(|(arg, types)| { if matches!(arg, Argument::Keyword(arg_name) if arg_name == name) { - ty + types.get_default() } else { None } @@ -2145,16 +2167,16 @@ impl<'db> CallableBinding<'db> { &mut self, db: &'db dyn Db, constraints: &ConstraintSetBuilder<'db>, - argument_types: &CallArguments<'_, 'db>, + call_arguments: &CallArguments<'_, 'db>, call_expression_tcx: TypeContext<'db>, ) -> Option { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. - let argument_types = argument_types.with_self(self.bound_type); + let call_arguments = call_arguments.with_self(self.bound_type); let _span = tracing::trace_span!( "CallableBinding::check_types", - arguments = %argument_types.display(db), + arguments = %call_arguments.display(db), signature = %self.signature_type.display(db), ) .entered(); @@ -2175,7 +2197,7 @@ impl<'db> CallableBinding<'db> { overload.check_types( db, constraints, - argument_types.as_ref(), + call_arguments.as_ref(), call_expression_tcx, ); } @@ -2188,7 +2210,7 @@ impl<'db> CallableBinding<'db> { self.overloads[index].check_types( db, constraints, - argument_types.as_ref(), + call_arguments.as_ref(), call_expression_tcx, ); return None; @@ -2205,7 +2227,7 @@ impl<'db> CallableBinding<'db> { overload.check_types( db, constraints, - argument_types.as_ref(), + call_arguments.as_ref(), call_expression_tcx, ); } @@ -2250,7 +2272,7 @@ impl<'db> CallableBinding<'db> { self.filter_overloads_using_any_or_unknown( db, constraints, - argument_types.as_ref(), + call_arguments.as_ref(), &indexes, ); @@ -2269,7 +2291,7 @@ impl<'db> CallableBinding<'db> { // Step 3: Perform "argument type expansion". Reference: // https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion - let mut expansions = argument_types.expand(db).peekable(); + let mut expansions = call_arguments.expand(db).peekable(); // Return early if there are no argument types to expand. expansions.peek()?; @@ -2279,12 +2301,14 @@ impl<'db> CallableBinding<'db> { // This heuristic tries to detect if there's any need to perform argument type expansion or // not by checking whether there are any non-expandable argument type that cannot be // assigned to any of the overloads. - for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() { + for (argument_index, (argument, argument_types)) in call_arguments.iter().enumerate() { // TODO: Remove `Keywords` once `**kwargs` support is added if matches!(argument, Argument::Synthetic | Argument::Keywords) { continue; } - let Some(argument_type) = argument_type else { + // TODO: For types inferred multiple times with distinct type context, we currently only + // expand the default inference. + let Some(argument_type) = argument_types.get_default() else { continue; }; if is_expandable_type(db, argument_type) { @@ -2295,6 +2319,7 @@ impl<'db> CallableBinding<'db> { for parameter_index in &overload.argument_matches[argument_index].parameters { let parameter_type = overload.signature.parameters()[*parameter_index].annotated_type(); + let argument_type = argument_types.get_for_declared_type(parameter_type); if argument_type .when_assignable_to( db, @@ -2565,68 +2590,6 @@ impl<'db> CallableBinding<'db> { } } - let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db)) - .take(max_parameter_count) - .collect::>(); - - // The following loop is trying to construct a tuple of argument types that correspond to - // the participating parameter indexes. Considering the following example: - // - // ```python - // @overload - // def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ... - // @overload - // def f(*args: Any) -> tuple[Any, ...]: ... - // - // f(1, 2) - // ``` - // - // Here, only the first parameter participates in the filtering process because only one - // overload has the second parameter. So, while going through the argument types, the - // second argument needs to be skipped but for the second overload both arguments map to - // the first parameter and that parameter is considered for the filtering process. This - // flag is to handle that special case of many-to-one mapping from arguments to parameters. - let mut variadic_parameter_handled = false; - - for (argument_index, argument_type) in arguments.iter_types().enumerate() { - if variadic_parameter_handled { - continue; - } - for overload_index in matching_overload_indexes { - let overload = &self.overloads[*overload_index]; - for (parameter_index, variadic_argument_type) in - overload.argument_matches[argument_index].iter() - { - if overload.signature.parameters()[parameter_index].is_variadic() { - variadic_parameter_handled = true; - } - if !participating_parameter_indexes.contains(¶meter_index) { - continue; - } - union_argument_type_builders[parameter_index].add_in_place( - variadic_argument_type - .unwrap_or(argument_type) - .top_materialization(db), - ); - } - } - } - - // These only contain the top materialized argument types for the corresponding - // participating parameter indexes. - let top_materialized_argument_type = Type::heterogeneous_tuple( - db, - union_argument_type_builders - .into_iter() - .filter_map(|builder| { - if builder.is_empty() { - None - } else { - Some(builder.build()) - } - }), - ); - // A flag to indicate whether we've found the overload that makes the remaining overloads // unmatched for the given argument types. let mut filter_remaining_overloads = false; @@ -2637,6 +2600,84 @@ impl<'db> CallableBinding<'db> { continue; } + let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db)) + .take(max_parameter_count) + .collect::>(); + + // The following loop is trying to construct a tuple of argument types that correspond to + // the participating parameter indexes. Considering the following example: + // + // ```python + // @overload + // def f(x: Literal[1], y: Literal[2]) -> tuple[int, int]: ... + // @overload + // def f(*args: Any) -> tuple[Any, ...]: ... + // + // f(1, 2) + // ``` + // + // Here, only the first parameter participates in the filtering process because only one + // overload has the second parameter. So, while going through the argument types, the + // second argument needs to be skipped but for the second overload both arguments map to + // the first parameter and that parameter is considered for the filtering process. This + // flag is to handle that special case of many-to-one mapping from arguments to parameters. + let mut variadic_parameter_handled = false; + + for (argument_index, argument_types) in arguments.types().iter().enumerate() { + if variadic_parameter_handled { + continue; + } + + // Get the argument type as inferred against the target overload. + let current_overload = &self.overloads[*current_index]; + let argument_type = + match *current_overload.argument_matches[argument_index].parameters { + [parameter_index] => { + let declared_type = current_overload.signature.parameters() + [parameter_index] + .annotated_type(); + argument_types.get_for_declared_type(declared_type) + } + // Splatted arguments are inferred without type context. + _ => argument_types.get_default().unwrap_or(Type::unknown()), + }; + + for overload_index in matching_overload_indexes { + let overload = &self.overloads[*overload_index]; + for (parameter_index, variadic_argument_type) in + overload.argument_matches[argument_index].iter() + { + let parameter = &overload.signature.parameters()[parameter_index]; + if parameter.is_variadic() { + variadic_parameter_handled = true; + } + if !participating_parameter_indexes.contains(¶meter_index) { + continue; + } + union_argument_type_builders[parameter_index].add_in_place( + variadic_argument_type + .unwrap_or(argument_type) + .top_materialization(db), + ); + } + } + } + + // These only contain the top materialized argument types for the corresponding + // participating parameter indexes. + let top_materialized_argument_type = Type::heterogeneous_tuple( + db, + union_argument_type_builders + .into_iter() + .filter_map(|builder| { + if builder.is_empty() { + None + } else { + Some(builder.build()) + } + }), + ); + let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db)) .take(max_parameter_count) .collect::>(); @@ -3676,11 +3717,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { fn enumerate_argument_types( &self, - ) -> impl Iterator, Argument<'a>, Type<'db>)> + 'a { + ) -> impl Iterator, Argument<'a>, &CallArgumentTypes<'db>)> + 'a + { let mut iter = self.arguments.iter().enumerate(); let mut num_synthetic_args = 0; std::iter::from_fn(move || { - let (argument_index, (argument, argument_type)) = iter.next()?; + let (argument_index, (argument, argument_types)) = iter.next()?; let adjusted_argument_index = if matches!(argument, Argument::Synthetic) { // If we are erroring on a synthetic argument, we'll just emit the // diagnostic on the entire Call node, since there's no argument node for @@ -3696,7 +3738,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_index, adjusted_argument_index, argument, - argument_type.unwrap_or_else(Type::unknown), + argument_types, )) }) } @@ -3767,7 +3809,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { .unwrap_or_default(); let mut specialization_errors = Vec::new(); - let assignable_to_declared_type = self.infer_argument_types( + let assignable_to_declared_type = self.infer_argument_constraints( constraints, &mut builder, &preferred_type_mappings, @@ -3784,7 +3826,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { builder = SpecializationBuilder::new(self.db, self.inferable_typevars); specialization_errors.clear(); - self.infer_argument_types( + self.infer_argument_constraints( constraints, &mut builder, &FxHashMap::default(), @@ -3850,7 +3892,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { self.specialization = Some(specialization); } - fn infer_argument_types( + fn infer_argument_constraints( &mut self, constraints: &ConstraintSetBuilder<'db>, builder: &mut SpecializationBuilder<'db>, @@ -3861,15 +3903,18 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { let mut assignable_to_declared_type = true; let parameters = self.signature.parameters(); - for (argument_index, adjusted_argument_index, _, argument_type) in + for (argument_index, adjusted_argument_index, _, argument_types) in self.enumerate_argument_types() { for (parameter_index, variadic_argument_type) in self.argument_matches[argument_index].iter() { + let declared_type = parameters[parameter_index].annotated_type(); + let argument_type = argument_types.get_for_declared_type(declared_type); + let specialization_result = builder.infer_map( constraints, - parameters[parameter_index].annotated_type(), + declared_type, variadic_argument_type.unwrap_or(argument_type), |(identity, _, inferred_ty)| { // Avoid widening the inferred type if it is already assignable to the @@ -3964,7 +4009,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { .parameters() .find_paramspec_from_args_kwargs(self.db); - for (argument_index, adjusted_argument_index, argument, argument_type) in + for (argument_index, adjusted_argument_index, argument, argument_types) in self.enumerate_argument_types() { if let Some((_, paramspec)) = paramspec { @@ -3988,11 +4033,16 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_index, adjusted_argument_index, argument, - argument_type, + // Splatted arguments are inferred without type context. + argument_types.get_default().unwrap_or(Type::unknown()), ), _ => { // If the argument isn't splatted, just check its type directly. for parameter_index in &self.argument_matches[argument_index].parameters { + let declared_type = + self.signature.parameters()[*parameter_index].annotated_type(); + let argument_type = argument_types.get_for_declared_type(declared_type); + self.check_argument_type( constraints, argument_index, @@ -4434,7 +4484,7 @@ impl<'db> Binding<'db> { let mut matcher = ArgumentMatcher::new(arguments, parameters, argument_forms, &mut self.errors); let mut keywords_arguments = vec![]; - for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() { + for (argument_index, (argument, argument_types)) in arguments.iter().enumerate() { match argument { Argument::Positional | Argument::Synthetic => { let _ = matcher.match_positional(argument_index, argument, None, false); @@ -4443,15 +4493,26 @@ impl<'db> Binding<'db> { let _ = matcher.match_keyword(argument_index, argument, None, name); } Argument::Variadic => { - let _ = matcher.match_variadic(db, argument_index, argument, argument_type); + let _ = matcher.match_variadic( + db, + argument_index, + argument, + // Splatted arguments are inferred without type context. + argument_types.get_default(), + ); } Argument::Keywords => { - keywords_arguments.push((argument_index, argument_type)); + keywords_arguments.push((argument_index, argument_types)); } } } for (keywords_index, keywords_type) in keywords_arguments { - matcher.match_keyword_variadic(db, keywords_index, keywords_type); + matcher.match_keyword_variadic( + db, + keywords_index, + // Splatted arguments are inferred without type context. + keywords_type.get_default(), + ); } // For constructor calls, return the constructed instance type (not `__init__`'s `None`). self.return_ty = self @@ -4539,17 +4600,21 @@ impl<'db> Binding<'db> { pub(crate) fn arguments_for_parameter<'a>( &'a self, - argument_types: &'a CallArguments<'a, 'db>, + call_arguments: &'a CallArguments<'a, 'db>, parameter_index: usize, ) -> impl Iterator, Type<'db>)> + 'a { - argument_types + call_arguments .iter() .zip(&self.argument_matches) .filter(move |(_, argument_matches)| { argument_matches.parameters.contains(¶meter_index) }) - .map(|((argument, argument_type), _)| { - (argument, argument_type.unwrap_or_else(Type::unknown)) + .map(move |((argument, argument_types), _)| { + let declared_type = self.signature.parameters()[parameter_index].annotated_type(); + ( + argument, + argument_types.get_for_declared_type(declared_type), + ) }) } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 185b83d98ebcf..27f2cf53c9e56 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -429,6 +429,12 @@ impl<'db> TypeContext<'db> { } } +impl<'db> From> for TypeContext<'db> { + fn from(annotation: Type<'db>) -> Self { + Self::new(Some(annotation)) + } +} + /// Infer the types for an [`Unpack`] operation. /// /// This infers the expression type and performs structural match against the target expression diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 6688c173fea8f..33677f801ef01 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -50,7 +50,6 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; -use crate::types::CallableTypes; use crate::types::call::bind::MatchingOverloadIndex; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::callable::CallableTypeKind; @@ -100,14 +99,14 @@ use crate::types::type_alias::{ManualPEP695TypeAliasType, PEP695TypeAliasType}; use crate::types::typed_dict::{validate_typed_dict_constructor, validate_typed_dict_dict_literal}; use crate::types::typevar::{BoundTypeVarIdentity, TypeVarConstraints, TypeVarIdentity}; use crate::types::{ - CallDunderError, CallableBinding, CallableType, ClassType, DynamicType, EvaluationMode, - InferenceFlags, InternedConstraintSet, InternedType, IntersectionBuilder, IntersectionType, - KnownClass, KnownInstanceType, KnownUnion, LiteralValueTypeKind, MemberLookupPolicy, - ParamSpecAttrKind, Parameter, ParameterForm, Parameters, Signature, SpecialFormType, - SubclassOfType, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, - TypeQualifiers, TypeVarBoundOrConstraints, TypeVarKind, TypeVarVariance, TypedDictType, - UnionBuilder, UnionType, binding_type, definition_expression_type, infer_complete_scope_types, - infer_scope_types, todo_type, + CallDunderError, CallableBinding, CallableType, CallableTypes, ClassType, DynamicType, + EvaluationMode, InferenceFlags, InternedConstraintSet, InternedType, IntersectionBuilder, + IntersectionType, KnownClass, KnownInstanceType, KnownUnion, LiteralValueTypeKind, + MemberLookupPolicy, ParamSpecAttrKind, Parameter, ParameterForm, Parameters, Signature, + SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, TypeAndQualifiers, + TypeContext, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarKind, TypeVarVariance, + TypedDictType, UnionBuilder, UnionType, binding_type, definition_expression_type, + infer_complete_scope_types, infer_scope_types, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::UnpackPosition; @@ -4371,27 +4370,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { TypeContext::default(), ); - let [Some(value_ty)] = call_arguments.types() else { - unreachable!(); - }; - match call { Ok(outcome) => outcome.return_type(db), Err(CallDunderError::MethodNotAvailable) => { let value_ty = infer_value_ty(self, TypeContext::default()); binary_return_ty(self, value_ty) } - Err(CallDunderError::PossiblyUnbound(outcome)) => UnionType::from_two_elements( - db, - outcome.return_type(db), - binary_return_ty(self, *value_ty), - ), + Err(CallDunderError::PossiblyUnbound(outcome)) => { + let value_ty = outcome.type_for_argument(&call_arguments, 0); + UnionType::from_two_elements( + db, + outcome.return_type(db), + binary_return_ty(self, value_ty), + ) + } Err(CallDunderError::CallError(_, bindings)) => { + let value_ty = bindings.type_for_argument(&call_arguments, 0); report_unsupported_augmented_assignment( &self.context, assignment, target_type, - *value_ty, + value_ty, ); bindings.return_type(db) } @@ -4890,7 +4889,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast_arguments.arguments_source_order() ); - for ((_, argument_type), argument_form, ast_argument) in iter { + for ((_, argument_types), argument_form, ast_argument) in iter { let argument = match ast_argument { // Splatted arguments are inferred before parameter matching to // determine their length. @@ -4902,7 +4901,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; let ty = self.infer_argument_type(argument, argument_form, TypeContext::default()); - *argument_type = Some(ty); + argument_types.insert(TypeContext::default(), ty); } } @@ -5077,12 +5076,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// Note that this method may infer the type of a given argument expression multiple times with /// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference /// behavior. - fn infer_all_argument_types( + fn infer_all_argument_types<'bindings>( &mut self, ast_arguments: ArgumentsIter<'_>, arguments_types: &mut CallArguments<'_, 'db>, infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>, - bindings: &Bindings<'db>, + bindings: &'bindings Bindings<'db>, call_expression_tcx: TypeContext<'db>, ) { debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len()); @@ -5119,14 +5118,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .flatten() .collect::>(); - // Each type is a valid independent inference of the given argument, and we may require - // different permutations of argument types to correctly perform argument expansion during - // overload evaluation, so we take the intersection of all the types we inferred for each - // argument. - let old_multi_inference_state = - self.set_multi_inference_state(MultiInferenceState::Intersect); - - for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { + for (argument_index, (_, argument_types), argument_form, ast_argument) in iter { let ast_argument = match ast_argument { // Splatted arguments are inferred before parameter matching to // determine their length. @@ -5141,12 +5133,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Type-form arguments are inferred without type context, so we can infer the argument type directly. if let Some(ParameterForm::Type) = argument_form { - *argument_type = Some(self.infer_type_expression(ast_argument)); + argument_types.insert( + TypeContext::default(), + self.infer_type_expression(ast_argument), + ); + continue; } - // Retrieve the parameter type for the current argument in a given overload and its binding. - let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| { + // Retrieve the parameter type context for the current argument in a given overload and its binding. + let parameter_tcx = |overload: &'bindings Binding<'db>, + binding: &CallableBinding<'db>| { let argument_index = if binding.bound_type.is_some() { argument_index + 1 } else { @@ -5158,8 +5155,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return None; }; - let mut parameter_type = - overload.signature.parameters()[*parameter_index].annotated_type(); + let parameter = &overload.signature.parameters()[*parameter_index]; + let mut parameter_type = parameter.annotated_type(); // If the parameter is a single type variable with an upper bound, e.g., `typing.Self`, // use the upper bound as type context. @@ -5167,7 +5164,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { && let Some(TypeVarBoundOrConstraints::UpperBound(bound)) = typevar.typevar(db).bound_or_constraints(db) { - return Some(bound); + return Some((parameter, TypeContext::new(Some(bound)))); } // If this is a generic call, attempt to specialize the parameter type using the @@ -5204,67 +5201,63 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { parameter_type = parameter_type.apply_specialization(db, specialization); } - Some(parameter_type) + Some((parameter, TypeContext::new(Some(parameter_type)))) }; // If there is only a single binding and overload, we can infer the argument directly with // the unique parameter type annotation. if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() { - let tcx = TypeContext::new(parameter_type(overload, binding)); - *argument_type = Some(infer_argument_ty(self, (argument_index, ast_argument, tcx))); + if let Some((parameter, parameter_tcx)) = parameter_tcx(overload, binding) { + argument_types.insert( + parameter.annotated_type(), + infer_argument_ty(self, (argument_index, ast_argument, parameter_tcx)), + ); + } else { + argument_types.insert( + TypeContext::default(), + infer_argument_ty( + self, + (argument_index, ast_argument, TypeContext::default()), + ), + ); + } } else { // We perform inference once without any type context, emitting any diagnostics that are unrelated // to bidirectional type inference. - *argument_type = Some(infer_argument_ty( - self, - (argument_index, ast_argument, TypeContext::default()), - )); + argument_types.insert( + TypeContext::default(), + infer_argument_ty(self, (argument_index, ast_argument, TypeContext::default())), + ); // We then silence any diagnostics emitted during multi-inference, as the type context is only // used as a hint to infer a more assignable argument type, and should not lead to diagnostics // for non-matching overloads. let was_in_multi_inference = self.context.set_multi_inference(true); + let prev_multi_inference_state = + self.set_multi_inference_state(MultiInferenceState::Ignore); // Infer the type of each argument once with each distinct parameter type as type context. let parameter_types = overloads_with_binding .iter() - .filter_map(|(overload, binding)| parameter_type(overload, binding)); + .filter_map(|(overload, binding)| parameter_tcx(overload, binding)); let mut seen = FxHashSet::default(); - for parameter_type in parameter_types { - if !seen.insert(parameter_type) { - continue; - } - - let tcx = TypeContext::new(Some(parameter_type)); - let inferred_ty = infer_argument_ty(self, (argument_index, ast_argument, tcx)); - - // Ensure the inferred type is assignable to the declared type. - // - // If not, we want to avoid storing the "failed" inference attempt. - if !inferred_ty.is_assignable_to(db, parameter_type) { + for (parameter, parameter_tcx) in parameter_types { + if !seen.insert(parameter.annotated_type()) { continue; } - // TODO: Intersecting the inferred argument types is correct for unions of - // callables, since the argument must satisfy each callable, but it's not clear - // that it's correct for an intersection of callables, or for a case where - // different overloads provide different type context; unioning may be more - // correct in those cases. - *argument_type = argument_type - .map(|current| { - IntersectionType::from_two_elements(db, inferred_ty, current) - }) - .or(Some(inferred_ty)); + let inferred_ty = + infer_argument_ty(self, (argument_index, ast_argument, parameter_tcx)); + argument_types.insert(parameter.annotated_type(), inferred_ty); } // Re-enable diagnostics. self.context.set_multi_inference(was_in_multi_inference); + self.set_multi_inference_state(prev_multi_inference_state); } } - - self.set_multi_inference_state(old_multi_inference_state); } fn infer_argument_type( @@ -5423,30 +5416,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ty = Type::LiteralValue(literal.to_unpromotable()); } - self.store_expression_type_impl(expression, ty, tcx); + self.store_expression_type(expression, ty); ty } #[track_caller] fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { - self.store_expression_type_impl(expression, ty, TypeContext::default()); - } - - #[track_caller] - fn store_expression_type_impl( - &mut self, - expression: &ast::Expr, - ty: Type<'db>, - tcx: TypeContext<'db>, - ) { if self.inner_expression_inference_state.is_get() { // If `inner_expression_inference_state` is `Get`, the expression type has already been stored. return; } - let db = self.db(); - match self.multi_inference_state { MultiInferenceState::Ignore => {} @@ -5454,22 +5435,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let previous = self.expressions.insert(expression.into(), ty); assert_eq!(previous, None); } - - MultiInferenceState::Intersect => { - self.expressions - .entry(expression.into()) - .and_modify(|current| { - // Avoid storing "failed" multi-inference attempts, which can lead to - // unnecessary union simplification overhead. - if tcx - .annotation - .is_none_or(|tcx| ty.is_assignable_to(db, tcx)) - { - *current = IntersectionType::from_two_elements(db, *current, ty); - } - }) - .or_insert(ty); - } } } @@ -5816,52 +5781,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return ty; } } else if let Type::Union(tcx) = tcx { - // Otherwise, disable diagnostics as we attempt to narrow to specific elements of the union. + // Otherwise, we have to narrow to specific elements of the union. + // + // Infer all expressions with diagnostics enabled before starting multi-inference. + for item in items { + if let Some(key) = item.key.as_ref() { + let key_ty = self.infer_expression(key, TypeContext::default()); + item_types.insert(key.node_index().load(), key_ty); + } + + let value_ty = self.infer_expression(&item.value, TypeContext::default()); + item_types.insert(item.value.node_index().load(), value_ty); + } + + // Disable diagnostics as we attempt to narrow to specific elements of the union. let old_multi_inference = self.context.set_multi_inference(true); let old_multi_inference_state = self.set_multi_inference_state(MultiInferenceState::Ignore); - let mut narrowed_typed_dicts = Vec::new(); + let mut narrowed_tys = Vec::new(); + let mut item_types = FxHashMap::default(); for element in tcx.elements(self.db()) { let typed_dict = element .as_typed_dict() .expect("filtered out non-typed-dict types above"); - if self - .infer_typed_dict_expression(dict, typed_dict, &mut item_types) - .is_some() + if let Some(inferred_ty) = + self.infer_typed_dict_expression(dict, typed_dict, &mut item_types) { - narrowed_typed_dicts.push(typed_dict); + narrowed_tys.push(inferred_ty); } item_types.clear(); } - if !narrowed_typed_dicts.is_empty() { - // Now that we know which typed dict annotations are valid, re-infer with diagnostics enabled, - self.context.set_multi_inference(old_multi_inference); - - // We may have to infer the same expression multiple times with distinct type context, - // so we take the intersection of all valid inferences for a given expression. - self.set_multi_inference_state(MultiInferenceState::Intersect); - - let mut narrowed_tys = Vec::new(); - for typed_dict in narrowed_typed_dicts { - let mut item_types = FxHashMap::default(); - - let ty = self - .infer_typed_dict_expression(dict, typed_dict, &mut item_types) - .expect("ensured the typed dict is valid above"); - - narrowed_tys.push(ty); - } + self.context.set_multi_inference(old_multi_inference); + self.set_multi_inference_state(old_multi_inference_state); - self.set_multi_inference_state(old_multi_inference_state); + // Successfully narrowed to a subset of typed dicts. + if !narrowed_tys.is_empty() { return UnionType::from_elements(self.db(), narrowed_tys); } - - self.context.set_multi_inference(old_multi_inference); - self.set_multi_inference_state(old_multi_inference_state); } } @@ -9114,9 +9074,6 @@ enum MultiInferenceState { /// Ignore the newly inferred value. Ignore, - - /// Store the intersection of all types inferred for the expression. - Intersect, } impl MultiInferenceState { diff --git a/crates/ty_python_semantic/src/types/infer/builder/subscript.rs b/crates/ty_python_semantic/src/types/infer/builder/subscript.rs index 6a42c296bf965..d07f4ef240fae 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/subscript.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/subscript.rs @@ -1370,10 +1370,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return true; }; - let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else { - unreachable!(); - }; - match call_dunder_err { CallDunderError::PossiblyUnbound { .. } => { if emit_diagnostic @@ -1390,6 +1386,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { false } CallDunderError::CallError(call_error_kind, bindings) => { + let slice_ty = bindings.type_for_argument(&call_arguments, 0); + let rhs_value_ty = bindings.type_for_argument(&call_arguments, 1); + match call_error_kind { CallErrorKind::NotCallable => { if emit_diagnostic @@ -1414,7 +1413,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { typed_dict, full_object_ty, key, - value_ty: *rhs_value_ty, + value_ty: rhs_value_ty, typed_dict_node: target.value.as_ref().into(), key_node: target.slice.as_ref().into(), value_node: rhs_value_node.into(),