diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md new file mode 100644 index 0000000000000..2258035524679 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -0,0 +1,401 @@ +# Overloads + +When ty evaluates the call of an overloaded function, it attempts to "match" the supplied arguments +with one or more overloads. This document describes the algorithm that it uses for overload +matching, which is the same as the one mentioned in the +[spec](https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation). + +## Arity check + +The first step is to perform arity check. The non-overloaded cases are described in the +[function](./function.md) document. + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f() -> None: ... +@overload +def f(x: int) -> int: ... +``` + +```py +from overloaded import f + +# These match a single overload +reveal_type(f()) # revealed: None +reveal_type(f(1)) # revealed: int + +# error: [no-matching-overload] "No overload of function `f` matches arguments" +reveal_type(f("a", "b")) # revealed: Unknown +``` + +## Type checking + +The second step is to perform type checking. This is done for all the overloads that passed the +arity check. + +### Single match + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f(x: int) -> int: ... +@overload +def f(x: str) -> str: ... +@overload +def f(x: bytes) -> bytes: ... +``` + +Here, all of the calls below pass the arity check for all overloads, so we proceed to type checking +which filters out all but the matching overload: + +```py +from overloaded import f + +reveal_type(f(1)) # revealed: int +reveal_type(f("a")) # revealed: str +reveal_type(f(b"b")) # revealed: bytes +``` + +### Single match error + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f() -> None: ... +@overload +def f(x: int) -> int: ... +``` + +If the arity check only matches a single overload, it should be evaluated as a regular +(non-overloaded) function call. This means that any diagnostics resulted during type checking that +call should be reported directly and not as a `no-matching-overload` error. + +```py +from overloaded import f + +reveal_type(f()) # revealed: None + +# TODO: This should be `invalid-argument-type` instead +# error: [no-matching-overload] +reveal_type(f("a")) # revealed: Unknown +``` + +### Multiple matches + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B(A): ... + +@overload +def f(x: A) -> A: ... +@overload +def f(x: B, y: int = 0) -> B: ... +``` + +```py +from overloaded import A, B, f + +# These calls pass the arity check, and type checking matches both overloads: +reveal_type(f(A())) # revealed: A +reveal_type(f(B())) # revealed: A + +# But, in this case, the arity check filters out the first overload, so we only have one match: +reveal_type(f(B(), 1)) # revealed: B +``` + +## Argument type expansion + +This step is performed only if the previous steps resulted in **no matches**. + +In this case, the algorithm would perform +[argument type expansion](https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion) +and loops over from the type checking step, evaluating the argument lists. + +### Expanding the only argument + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: A) -> A: ... +@overload +def f(x: B) -> B: ... +@overload +def f(x: C) -> C: ... +``` + +```py +from overloaded import A, B, C, f + +def _(ab: A | B, ac: A | C, bc: B | C): + reveal_type(f(ab)) # revealed: A | B + reveal_type(f(bc)) # revealed: B | C + reveal_type(f(ac)) # revealed: A | C +``` + +### Expanding first argument + +If the set of argument lists created by expanding the first argument evaluates successfully, the +algorithm shouldn't expand the second argument. + +`overloaded.pyi`: + +```pyi +from typing import Literal, overload + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f(x: A, y: C) -> A: ... +@overload +def f(x: A, y: D) -> B: ... +@overload +def f(x: B, y: C) -> C: ... +@overload +def f(x: B, y: D) -> D: ... +``` + +```py +from overloaded import A, B, C, D, f + +def _(a_b: A | B): + reveal_type(f(a_b, C())) # revealed: A | C + reveal_type(f(a_b, D())) # revealed: B | D + +# But, if it doesn't, it should expand the second argument and try again: +def _(a_b: A | B, c_d: C | D): + reveal_type(f(a_b, c_d)) # revealed: A | B | C | D +``` + +### Expanding second argument + +If the first argument cannot be expanded, the algorithm should move on to the second argument, +keeping the first argument as is. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f(x: A, y: B) -> B: ... +@overload +def f(x: A, y: C) -> C: ... +@overload +def f(x: B, y: D) -> D: ... +``` + +```py +from overloaded import A, B, C, D, f + +def _(a: A, bc: B | C, cd: C | D): + # This also tests that partial matching works correctly as the argument type expansion results + # in matching the first and second overloads, but not the third one. + reveal_type(f(a, bc)) # revealed: B | C + + # error: [no-matching-overload] "No overload of function `f` matches arguments" + reveal_type(f(a, cd)) # revealed: Unknown +``` + +### Generics (legacy) + +`overloaded.pyi`: + +```pyi +from typing import TypeVar, overload + +_T = TypeVar("_T") + +class A: ... +class B: ... + +@overload +def f(x: A) -> A: ... +@overload +def f(x: _T) -> _T: ... +``` + +```py +from overloaded import A, f + +def _(x: int, y: A | int): + reveal_type(f(x)) # revealed: int + reveal_type(f(y)) # revealed: A | int +``` + +### Generics (PEP 695) + +```toml +[environment] +python-version = "3.12" +``` + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... + +@overload +def f(x: B) -> B: ... +@overload +def f[T](x: T) -> T: ... +``` + +```py +from overloaded import B, f + +def _(x: int, y: B | int): + reveal_type(f(x)) # revealed: int + reveal_type(f(y)) # revealed: B | int +``` + +### Expanding `bool` + +`overloaded.pyi`: + +```pyi +from typing import Literal, overload + +class T: ... +class F: ... + +@overload +def f(x: Literal[True]) -> T: ... +@overload +def f(x: Literal[False]) -> F: ... +``` + +```py +from overloaded import f + +def _(flag: bool): + reveal_type(f(True)) # revealed: T + reveal_type(f(False)) # revealed: F + reveal_type(f(flag)) # revealed: T | F +``` + +### Expanding `tuple` + +`overloaded.pyi`: + +```pyi +from typing import Literal, overload + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f(x: tuple[A, int], y: tuple[int, Literal[True]]) -> A: ... +@overload +def f(x: tuple[A, int], y: tuple[int, Literal[False]]) -> B: ... +@overload +def f(x: tuple[B, int], y: tuple[int, Literal[True]]) -> C: ... +@overload +def f(x: tuple[B, int], y: tuple[int, Literal[False]]) -> D: ... +``` + +```py +from overloaded import A, B, f + +def _(x: tuple[A | B, int], y: tuple[int, bool]): + reveal_type(f(x, y)) # revealed: A | B | C | D +``` + +### Expanding `type` + +There's no special handling for expanding `type[A | B]` type because ty stores this type in it's +distributed form, which is `type[A] | type[B]`. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... + +@overload +def f(x: type[A]) -> A: ... +@overload +def f(x: type[B]) -> B: ... +``` + +```py +from overloaded import A, B, f + +def _(x: type[A | B]): + reveal_type(x) # revealed: type[A] | type[B] + reveal_type(f(x)) # revealed: A | B +``` + +### Expanding enums + +`overloaded.pyi`: + +```pyi +from enum import Enum +from typing import Literal, overload + +class SomeEnum(Enum): + A = 1 + B = 2 + C = 3 + + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: Literal[SomeEnum.A]) -> A: ... +@overload +def f(x: Literal[SomeEnum.B]) -> B: ... +@overload +def f(x: Literal[SomeEnum.C]) -> C: ... +``` + +```py +from overloaded import SomeEnum, A, B, C, f + +def _(x: SomeEnum): + reveal_type(f(SomeEnum.A)) # revealed: A + # TODO: This should be `B` once enums are supported and are expanded + reveal_type(f(SomeEnum.B)) # revealed: A + # TODO: This should be `C` once enums are supported and are expanded + reveal_type(f(SomeEnum.C)) # revealed: A + # TODO: This should be `A | B | C` once enums are supported and are expanded + reveal_type(f(x)) # revealed: A +``` diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index 392c4ee514268..3ed581e415180 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -1,6 +1,11 @@ use std::borrow::Cow; use std::ops::{Deref, DerefMut}; +use itertools::{Either, Itertools}; + +use crate::Db; +use crate::types::{KnownClass, TupleType}; + use super::Type; /// Arguments for a single call, in source order. @@ -86,6 +91,10 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { Self { arguments, types } } + pub(crate) fn types(&self) -> &[Type<'db>] { + &self.types + } + /// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front /// of this argument list. (If `bound_self` is none, we return the argument list /// unmodified.) @@ -108,6 +117,72 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> { pub(crate) fn iter(&self) -> impl Iterator, Type<'db>)> + '_ { self.arguments.iter().zip(self.types.iter().copied()) } + + /// Returns an iterator on performing [argument type expansion]. + /// + /// Each element of the iterator represents a set of argument lists, where each argument list + /// contains the same arguments, but with one or more of the argument types expanded. + /// + /// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion + pub(crate) fn expand(&self, db: &'db dyn Db) -> impl Iterator>>> + '_ { + /// Represents the state of the expansion process. + /// + /// This is useful to avoid cloning the initial types vector if none of the types can be + /// expanded. + enum State<'a, 'db> { + Initial(&'a Vec>), + Expanded(Vec>>), + } + + impl<'db> State<'_, 'db> { + fn len(&self) -> usize { + match self { + State::Initial(_) => 1, + State::Expanded(expanded) => expanded.len(), + } + } + + fn iter(&self) -> impl Iterator>> + '_ { + match self { + State::Initial(types) => std::slice::from_ref(*types).iter(), + State::Expanded(expanded) => expanded.iter(), + } + } + } + + let mut index = 0; + + std::iter::successors(Some(State::Initial(&self.types)), move |previous| { + // Find the next type that can be expanded. + let expanded_types = loop { + let arg_type = self.types.get(index)?; + if let Some(expanded_types) = expand_type(db, *arg_type) { + break expanded_types; + } + index += 1; + }; + + let mut expanded_arg_types = Vec::with_capacity(expanded_types.len() * previous.len()); + + for pre_expanded_types in previous.iter() { + for subtype in &expanded_types { + let mut new_expanded_types = pre_expanded_types.clone(); + new_expanded_types[index] = *subtype; + expanded_arg_types.push(new_expanded_types); + } + } + + // Increment the index to move to the next argument type for the next iteration. + index += 1; + + Some(State::Expanded(expanded_arg_types)) + }) + .skip(1) // Skip the initial state, which has no expanded types. + .map(|state| match state { + State::Initial(_) => unreachable!("initial state should be skipped"), + State::Expanded(expanded) => expanded, + }) + } } impl<'a> Deref for CallArgumentTypes<'a, '_> { @@ -122,3 +197,138 @@ impl<'a> DerefMut for CallArgumentTypes<'a, '_> { &mut self.arguments } } + +/// Expands a type into its possible subtypes, if applicable. +/// +/// Returns [`None`] if the type cannot be expanded. +fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option>> { + // TODO: Expand enums to their variants + match ty { + Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => { + Some(vec![ + Type::BooleanLiteral(true), + Type::BooleanLiteral(false), + ]) + } + Type::Tuple(tuple) => { + // Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]` + // should not be expanded here. + let expanded = tuple + .iter(db) + .map(|element| { + if let Some(expanded) = expand_type(db, element) { + Either::Left(expanded.into_iter()) + } else { + Either::Right(std::iter::once(element)) + } + }) + .multi_cartesian_product() + .map(|types| TupleType::from_elements(db, types)) + .collect::>(); + if expanded.len() == 1 { + // There are no elements in the tuple type that can be expanded. + None + } else { + Some(expanded) + } + } + Type::Union(union) => Some(union.iter(db).copied().collect()), + // We don't handle `type[A | B]` here because it's already stored in the expanded form + // i.e., `type[A] | type[B]` which is handled by the `Type::Union` case. + _ => None, + } +} + +#[cfg(test)] +mod tests { + use crate::db::tests::setup_db; + use crate::types::{KnownClass, TupleType, Type, UnionType}; + + use super::expand_type; + + #[test] + fn expand_union_type() { + let db = setup_db(); + let types = [ + KnownClass::Int.to_instance(&db), + KnownClass::Str.to_instance(&db), + KnownClass::Bytes.to_instance(&db), + ]; + let union_type = UnionType::from_elements(&db, types); + let expanded = expand_type(&db, union_type).unwrap(); + assert_eq!(expanded.len(), types.len()); + assert_eq!(expanded, types); + } + + #[test] + fn expand_bool_type() { + let db = setup_db(); + let bool_instance = KnownClass::Bool.to_instance(&db); + let expanded = expand_type(&db, bool_instance).unwrap(); + let expected_types = [Type::BooleanLiteral(true), Type::BooleanLiteral(false)]; + assert_eq!(expanded.len(), expected_types.len()); + assert_eq!(expanded, expected_types); + } + + #[test] + fn expand_tuple_type() { + let db = setup_db(); + + let int_ty = KnownClass::Int.to_instance(&db); + let str_ty = KnownClass::Str.to_instance(&db); + let bytes_ty = KnownClass::Bytes.to_instance(&db); + let bool_ty = KnownClass::Bool.to_instance(&db); + let true_ty = Type::BooleanLiteral(true); + let false_ty = Type::BooleanLiteral(false); + + // Empty tuple + let empty_tuple = TupleType::empty(&db); + let expanded = expand_type(&db, empty_tuple); + assert!(expanded.is_none()); + + // None of the elements can be expanded. + let tuple_type1 = TupleType::from_elements(&db, [int_ty, str_ty]); + let expanded = expand_type(&db, tuple_type1); + assert!(expanded.is_none()); + + // All elements can be expanded. + let tuple_type2 = TupleType::from_elements( + &db, + [ + bool_ty, + UnionType::from_elements(&db, [int_ty, str_ty, bytes_ty]), + ], + ); + let expected_types = [ + TupleType::from_elements(&db, [true_ty, int_ty]), + TupleType::from_elements(&db, [true_ty, str_ty]), + TupleType::from_elements(&db, [true_ty, bytes_ty]), + TupleType::from_elements(&db, [false_ty, int_ty]), + TupleType::from_elements(&db, [false_ty, str_ty]), + TupleType::from_elements(&db, [false_ty, bytes_ty]), + ]; + let expanded = expand_type(&db, tuple_type2).unwrap(); + assert_eq!(expanded.len(), expected_types.len()); + assert_eq!(expanded, expected_types); + + // Mixed set of elements where some can be expanded while others cannot be. + let tuple_type3 = TupleType::from_elements( + &db, + [ + bool_ty, + int_ty, + UnionType::from_elements(&db, [str_ty, bytes_ty]), + str_ty, + ], + ); + let expected_types = [ + TupleType::from_elements(&db, [true_ty, int_ty, str_ty, str_ty]), + TupleType::from_elements(&db, [true_ty, int_ty, bytes_ty, str_ty]), + TupleType::from_elements(&db, [false_ty, int_ty, str_ty, str_ty]), + TupleType::from_elements(&db, [false_ty, int_ty, bytes_ty, str_ty]), + ]; + let expanded = expand_type(&db, tuple_type3).unwrap(); + assert_eq!(expanded.len(), expected_types.len()); + assert_eq!(expanded, expected_types); + } +} diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 51ab005c1b8d5..d6e08cdef5141 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -997,6 +997,7 @@ impl<'db> From> for Bindings<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, + return_type: None, overloads: smallvec![from], }; Bindings { @@ -1015,14 +1016,9 @@ impl<'db> From> for Bindings<'db> { /// If the callable has multiple overloads, the first one that matches is used as the overall /// binding match. /// -/// TODO: Implement the call site evaluation algorithm in the [proposed updated typing -/// spec][overloads], which is much more subtle than “first match wins”. -/// /// If the arguments cannot be matched to formal parameters, we store information about the /// specific errors that occurred when trying to match them up. If the callable has multiple /// overloads, we store this error information for each overload. -/// -/// [overloads]: https://github.com/python/typing/pull/1839 #[derive(Debug)] pub(crate) struct CallableBinding<'db> { /// The type that is (hopefully) callable. @@ -1040,6 +1036,14 @@ pub(crate) struct CallableBinding<'db> { /// The type of the bound `self` or `cls` parameter if this signature is for a bound method. pub(crate) bound_type: Option>, + /// The return type of this callable. + /// + /// This is only `Some` if it's an overloaded callable, "argument type expansion" was + /// performed, and one of the expansion evaluated successfully for all of the argument lists. + /// This type is then the union of all the return types of the matched overloads for the + /// expanded argument lists. + return_type: Option>, + /// The bindings of each overload of this callable. Will be empty if the type is not callable. /// /// By using `SmallVec`, we avoid an extra heap allocation for the common case of a @@ -1061,6 +1065,7 @@ impl<'db> CallableBinding<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, + return_type: None, overloads, } } @@ -1071,6 +1076,7 @@ impl<'db> CallableBinding<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, + return_type: None, overloads: smallvec![], } } @@ -1099,12 +1105,6 @@ impl<'db> CallableBinding<'db> { // before checking. let arguments = arguments.with_self(self.bound_type); - // TODO: This checks every overload. In the proposed more detailed call checking spec [1], - // arguments are checked for arity first, and are only checked for type assignability against - // the matching overloads. Make sure to implement that as part of separating call binding into - // two phases. - // - // [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation for overload in &mut self.overloads { overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms); } @@ -1114,9 +1114,154 @@ impl<'db> CallableBinding<'db> { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. let argument_types = argument_types.with_self(self.bound_type); - for overload in &mut self.overloads { - overload.check_types(db, argument_types.as_ref()); + + // Step 1: Check the result of the arity check which is done by `match_parameters` + let matching_overload_indexes = match self.matching_overload_index() { + MatchingOverloadIndex::None => { + // If no candidate overloads remain from the arity check, we can stop here. We + // still perform type checking for non-overloaded function to provide better user + // experience. + if let [overload] = self.overloads.as_mut_slice() { + overload.check_types(db, argument_types.as_ref(), argument_types.types()); + } + return; + } + MatchingOverloadIndex::Single(index) => { + // If only one candidate overload remains, it is the winning match. + // TODO: Evaluate it as a regular (non-overloaded) call. This means that any + // diagnostics reported in this check should be reported directly instead of + // reporting it as `no-matching-overload`. + self.overloads[index].check_types( + db, + argument_types.as_ref(), + argument_types.types(), + ); + return; + } + MatchingOverloadIndex::Multiple(indexes) => { + // If two or more candidate overloads remain, proceed to step 2. + indexes + } + }; + + let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes); + + // State of the bindings _before_ evaluating (type checking) the matching overloads using + // the non-expanded argument types. + let pre_evaluation_snapshot = snapshotter.take(self); + + // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine + // whether it is compatible with the supplied argument list. + for (_, overload) in self.matching_overloads_mut() { + overload.check_types(db, argument_types.as_ref(), argument_types.types()); + } + + match self.matching_overload_index() { + MatchingOverloadIndex::None => { + // If all overloads result in errors, proceed to step 3. + } + MatchingOverloadIndex::Single(_) => { + // If only one overload evaluates without error, it is the winning match. + return; + } + MatchingOverloadIndex::Multiple(_) => { + // If two or more candidate overloads remain, proceed to step 4. + // TODO: Step 4 and Step 5 goes here... + // We're returning here because this shouldn't lead to argument type expansion. + return; + } } + + // Step 3: Perform "argument type expansion". Reference: + // https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion + let mut expansions = argument_types.expand(db).peekable(); + + if expansions.peek().is_none() { + // Return early if there are no argument types to expand. + return; + } + + // State of the bindings _after_ evaluating (type checking) the matching overloads using + // the non-expanded argument types. + let post_evaluation_snapshot = snapshotter.take(self); + + // Restore the bindings state to the one prior to the type checking step in preparation + // for evaluating the expanded argument lists. + snapshotter.restore(self, pre_evaluation_snapshot); + + for expanded_argument_lists in expansions { + // This is the merged state of the bindings after evaluating all of the expanded + // argument lists. This will be the final state to restore the bindings to if all of + // the expanded argument lists evaluated successfully. + let mut merged_evaluation_state: Option> = None; + + let mut return_types = Vec::new(); + + for expanded_argument_types in &expanded_argument_lists { + let pre_evaluation_snapshot = snapshotter.take(self); + + for (_, overload) in self.matching_overloads_mut() { + overload.check_types(db, argument_types.as_ref(), expanded_argument_types); + } + + let return_type = match self.matching_overload_index() { + MatchingOverloadIndex::None => None, + MatchingOverloadIndex::Single(index) => { + Some(self.overloads[index].return_type()) + } + MatchingOverloadIndex::Multiple(index) => { + // TODO: Step 4 and Step 5 goes here... but for now we just use the return + // type of the first matched overload. + Some(self.overloads[index[0]].return_type()) + } + }; + + // This split between initializing and updating the merged evaluation state is + // required because otherwise it's difficult to differentiate between the + // following: + // 1. An initial unmatched overload becomes a matched overload when evaluating the + // first argument list + // 2. An unmatched overload after evaluating the first argument list becomes a + // matched overload when evaluating the second argument list + if let Some(merged_evaluation_state) = merged_evaluation_state.as_mut() { + merged_evaluation_state.update(self); + } else { + merged_evaluation_state = Some(snapshotter.take(self)); + } + + // Restore the bindings state before evaluating the next argument list. + snapshotter.restore(self, pre_evaluation_snapshot); + + if let Some(return_type) = return_type { + return_types.push(return_type); + } else { + // No need to check the remaining argument lists if the current argument list + // doesn't evaluate successfully. Move on to expanding the next argument type. + break; + } + } + + if return_types.len() == expanded_argument_lists.len() { + // If the number of return types is equal to the number of expanded argument lists, + // they all evaluated successfully. So, we need to combine their return types by + // union to determine the final return type. + self.return_type = Some(UnionType::from_elements(db, return_types)); + + // Restore the bindings state to the one that merges the bindings state evaluating + // each of the expanded argument list. + if let Some(merged_evaluation_state) = merged_evaluation_state { + snapshotter.restore(self, merged_evaluation_state); + } + + return; + } + } + + // If the type expansion didn't yield any successful return type, we need to restore the + // bindings state back to the one after the type checking step using the non-expanded + // argument types. This is necessary because we restore the state to the pre-evaluation + // snapshot when processing the expanded argument lists. + snapshotter.restore(self, post_evaluation_snapshot); } fn as_result(&self) -> Result<(), CallErrorKind> { @@ -1145,6 +1290,25 @@ impl<'db> CallableBinding<'db> { self.matching_overloads().next().is_none() } + /// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`]. + fn matching_overload_index(&self) -> MatchingOverloadIndex { + let mut matching_overloads = self.matching_overloads(); + match matching_overloads.next() { + None => MatchingOverloadIndex::None, + Some((first, _)) => { + if let Some((second, _)) = matching_overloads.next() { + let mut indexes = vec![first, second]; + for (index, _) in matching_overloads { + indexes.push(index); + } + MatchingOverloadIndex::Multiple(indexes) + } else { + MatchingOverloadIndex::Single(first) + } + } + } + } + /// Returns an iterator over all the overloads that matched for this call binding. pub(crate) fn matching_overloads(&self) -> impl Iterator)> { self.overloads @@ -1163,16 +1327,20 @@ impl<'db> CallableBinding<'db> { .filter(|(_, overload)| overload.as_result().is_ok()) } - /// Returns the return type of this call. For a valid call, this is the return type of the - /// first overload that the arguments matched against. For an invalid call to a non-overloaded - /// function, this is the return type of the function. For an invalid call to an overloaded - /// function, we return `Type::unknown`, since we cannot make any useful conclusions about - /// which overload was intended to be called. + /// Returns the return type of this call. + /// + /// For a valid call, this is the return type of either a successful argument type expansion of + /// an overloaded function, or the return type of the first overload that the arguments matched + /// against. + /// + /// For an invalid call to a non-overloaded function, this is the return type of the function. + /// + /// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot + /// make any useful conclusions about which overload was intended to be called. pub(crate) fn return_type(&self) -> Type<'db> { - // TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to - // get the matching overload and use that to get the return type. - // - // [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation + if let Some(return_type) = self.return_type { + return return_type; + } if let Some((_, first_overload)) = self.matching_overloads().next() { return first_overload.return_type(); } @@ -1319,6 +1487,18 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> { } } +#[derive(Debug)] +enum MatchingOverloadIndex { + /// No matching overloads found. + None, + + /// Exactly one matching overload found at the given index. + Single(usize), + + /// Multiple matching overloads found at the given indexes. + Multiple(Vec), +} + /// Binding information for one of the overloads of a callable. #[derive(Debug)] pub(crate) struct Binding<'db> { @@ -1493,7 +1673,12 @@ impl<'db> Binding<'db> { self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); } - fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>) { + fn check_types( + &mut self, + db: &'db dyn Db, + arguments: &CallArguments<'_>, + argument_types: &[Type<'db>], + ) { let mut num_synthetic_args = 0; let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { if argument_index >= num_synthetic_args { @@ -1507,13 +1692,20 @@ impl<'db> Binding<'db> { } }; + let enumerate_argument_types = || { + arguments + .iter() + .zip(argument_types.iter().copied()) + .enumerate() + }; + // If this overload is generic, first see if we can infer a specialization of the function // from the arguments that were passed in. let signature = &self.signature; let parameters = signature.parameters(); if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() { let mut builder = SpecializationBuilder::new(db); - for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() { + for (argument_index, (argument, argument_type)) in enumerate_argument_types() { if matches!(argument, Argument::Synthetic) { num_synthetic_args += 1; } @@ -1545,7 +1737,7 @@ impl<'db> Binding<'db> { } num_synthetic_args = 0; - for (argument_index, (argument, mut argument_type)) in argument_types.iter().enumerate() { + for (argument_index, (argument, mut argument_type)) in enumerate_argument_types() { if matches!(argument, Argument::Synthetic) { num_synthetic_args += 1; } @@ -1648,6 +1840,133 @@ impl<'db> Binding<'db> { } Ok(()) } + + fn snapshot(&self) -> BindingSnapshot<'db> { + BindingSnapshot { + return_ty: self.return_ty, + specialization: self.specialization, + inherited_specialization: self.inherited_specialization, + argument_parameters: self.argument_parameters.clone(), + parameter_tys: self.parameter_tys.clone(), + errors: self.errors.clone(), + } + } + + fn restore(&mut self, snapshot: BindingSnapshot<'db>) { + let BindingSnapshot { + return_ty, + specialization, + inherited_specialization, + argument_parameters, + parameter_tys, + errors, + } = snapshot; + + self.return_ty = return_ty; + self.specialization = specialization; + self.inherited_specialization = inherited_specialization; + self.argument_parameters = argument_parameters; + self.parameter_tys = parameter_tys; + self.errors = errors; + } +} + +#[derive(Clone, Debug)] +struct BindingSnapshot<'db> { + return_ty: Type<'db>, + specialization: Option>, + inherited_specialization: Option>, + argument_parameters: Box<[Option]>, + parameter_tys: Box<[Option>]>, + errors: Vec>, +} + +/// Represents the snapshot of the matched overload bindings. +/// +/// The reason that this only contains the matched overloads are: +/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check +/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the +/// expanded argument lists +#[derive(Clone, Debug)] +struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>); + +impl<'db> MatchingOverloadsSnapshot<'db> { + /// Update the state of the matched overload bindings in this snapshot with the current + /// state in the given `binding`. + fn update(&mut self, binding: &CallableBinding<'db>) { + // Here, the `snapshot` is the state of this binding for the previous argument list and + // `binding` would contain the state after evaluating the current argument list. + for (snapshot, binding) in self + .0 + .iter_mut() + .map(|(index, snapshot)| (snapshot, &binding.overloads[*index])) + { + if binding.errors.is_empty() { + // If the binding has no errors, this means that the current argument list was + // evaluated successfully and this is the matching overload. + // + // Clear the errors from the snapshot of this overload to signal this change ... + snapshot.errors.clear(); + + // ... and update the snapshot with the current state of the binding. + snapshot.return_ty = binding.return_ty; + snapshot.specialization = binding.specialization; + snapshot.inherited_specialization = binding.inherited_specialization; + snapshot + .argument_parameters + .clone_from(&binding.argument_parameters); + snapshot.parameter_tys.clone_from(&binding.parameter_tys); + } + + // If the errors in the snapshot was empty, then this binding is the matching overload + // for a previously evaluated argument list. This means that we don't need to change + // any information for an already matched overload binding. + // + // If it does have errors, we could extend it with the errors from evaluating the + // current argument list. Arguably, this isn't required, since the errors in the + // snapshot should already signal that this is an unmatched overload which is why we + // don't do it. Similarly, due to this being an unmatched overload, there's no point in + // updating the binding state. + } + } +} + +/// A helper to take snapshots of the matched overload bindings for the current state of the +/// bindings. +struct MatchingOverloadsSnapshotter(Vec); + +impl MatchingOverloadsSnapshotter { + /// Creates a new snapshotter for the given indexes of the matched overloads. + fn new(indexes: Vec) -> Self { + debug_assert!(indexes.len() > 1); + MatchingOverloadsSnapshotter(indexes) + } + + /// Takes a snapshot of the current state of the matched overload bindings. + /// + /// # Panics + /// + /// Panics if the indexes of the matched overloads are not valid for the given binding. + fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> { + MatchingOverloadsSnapshot( + self.0 + .iter() + .map(|index| (*index, binding.overloads[*index].snapshot())) + .collect(), + ) + } + + /// Restores the state of the matched overload bindings from the given snapshot. + fn restore<'db>( + &self, + binding: &mut CallableBinding<'db>, + snapshot: MatchingOverloadsSnapshot<'db>, + ) { + debug_assert_eq!(self.0.len(), snapshot.0.len()); + for (index, snapshot) in snapshot.0 { + binding.overloads[index].restore(snapshot); + } + } } /// Describes a callable for the purposes of diagnostics.