diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md index 920a953622271..76d42b11c4b0b 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md @@ -369,16 +369,14 @@ Using `Concatenate` as the first argument to `Callable`: from typing_extensions import Callable, Concatenate def _(c: Callable[Concatenate[int, str, ...], int]): - # TODO: Should reveal the correct signature - reveal_type(c) # revealed: (...) -> int + reveal_type(c) # revealed: (int, str, /, *args: Any, **kwargs: Any) -> int ``` Other type expressions can be nested inside `Concatenate`: ```py -def _(c: Callable[[Concatenate[int | str, type[str], ...], int], int]): - # TODO: Should reveal the correct signature - reveal_type(c) # revealed: (...) -> int +def _(c: Callable[Concatenate[int | str, type[str], ...], int]): + reveal_type(c) # revealed: (int | str, type[str], /, *args: Any, **kwargs: Any) -> int ``` But providing fewer than 2 arguments to `Concatenate` is an error: diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md index 786c614fb6a56..4164c00c43bf0 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/unsupported_special_forms.md @@ -60,7 +60,7 @@ def _( a: Unpack, # error: [invalid-type-form] "`typing.Unpack` requires exactly one argument when used in a type expression" b: TypeGuard, # error: [invalid-type-form] "`typing.TypeGuard` requires exactly one argument when used in a type expression" c: TypeIs, # error: [invalid-type-form] "`typing.TypeIs` requires exactly one argument when used in a type expression" - d: Concatenate, # error: [invalid-type-form] "`typing.Concatenate` requires at least two arguments when used in a type expression" + d: Concatenate, # error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context in a type expression" e: ParamSpec, f: Generic, # error: [invalid-type-form] "`typing.Generic` is not allowed in type expressions" ) -> None: diff --git a/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md b/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md index 471a4b5b8e5df..65bf247c7a0ab 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md +++ b/crates/ty_python_semantic/resources/mdtest/call/callables_as_descriptors.md @@ -203,6 +203,38 @@ class Calculator: reveal_type(Calculator().square_then_round(3.14)) # revealed: Unknown | int ``` +## Use case: Wrappers with explicit receivers + +`trio` defines multiple functions that takes in a callable with `Concatenate`-prepended receiver +types, and returns a wrapper function with a different receiver type. They should still preserve +descriptor behavior when the returned callable is assigned in the class body. + +```py +from collections.abc import Callable, Iterable +from typing import Concatenate, ParamSpec, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +class RawPath: + def write_bytes(self, data: bytes) -> int: + raise NotImplementedError + +def _wrap_method( + fn: Callable[Concatenate[RawPath, P], T], +) -> Callable[Concatenate["Path", P], T]: + raise NotImplementedError + +class Path: + write_bytes = _wrap_method(RawPath.write_bytes) + +def check(path: Path) -> None: + # TODO: shouldn't be errors, should reveal `int` + # error: [missing-argument] + # error: [invalid-argument-type] + reveal_type(path.write_bytes(b"")) # revealed: Unknown | int +``` + ## Use case: Treating dunder methods as bound-method descriptors pytorch defines a `__pow__` dunder attribute on [`TensorBase`] in a similar way to the following diff --git a/crates/ty_python_semantic/resources/mdtest/final.md b/crates/ty_python_semantic/resources/mdtest/final.md index 78149aabcd1bc..0c02b71604e44 100644 --- a/crates/ty_python_semantic/resources/mdtest/final.md +++ b/crates/ty_python_semantic/resources/mdtest/final.md @@ -1325,7 +1325,7 @@ class Base(ABC): @abstractproperty # error: [deprecated] def value(self) -> int: return 0 - + # error: [invalid-argument-type] @abstractclassmethod # error: [deprecated] def make(cls) -> "Base": raise NotImplementedError diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/concatenate.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/concatenate.md index 2ecc6603cd23f..960d3a038a770 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/concatenate.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/concatenate.md @@ -16,16 +16,14 @@ element. from typing import Callable, Concatenate def foo[**P, R](func: Callable[Concatenate[int, P], R]) -> Callable[Concatenate[int, P], R]: - # TODO: Should reveal `(int, /, *args: P@foo.args, **kwargs: P@foo.kwargs) -> R@foo` - reveal_type(func) # revealed: (...) -> R@foo + reveal_type(func) # revealed: (int, /, *args: P@foo.args, **kwargs: P@foo.kwargs) -> R@foo return func def f(x: int, y: str) -> bool: return True result = foo(f) -# TODO: Should reveal `(int, /, y: str) -> bool` -reveal_type(result) # revealed: (...) -> bool +reveal_type(result) # revealed: (int, /, y: str) -> bool ``` ### With ellipsis @@ -34,8 +32,7 @@ reveal_type(result) # revealed: (...) -> bool from typing import Callable, Concatenate def _(c: Callable[Concatenate[int, str, ...], bool]): - # TODO: Should reveal `(int, str, /, ...) -> bool` - reveal_type(c) # revealed: (...) -> bool + reveal_type(c) # revealed: (int, str, /, *args: Any, **kwargs: Any) -> bool ``` ### Complex types inside `Concatenate` @@ -44,8 +41,7 @@ def _(c: Callable[Concatenate[int, str, ...], bool]): from typing import Callable, Concatenate def _(c: Callable[Concatenate[int | str, list[int], type[str], ...], None]): - # TODO: Should reveal `(int | str, list[int], type[str], ...) -> None` - reveal_type(c) # revealed: (...) -> None + reveal_type(c) # revealed: (int | str, list[int], type[str], /, *args: Any, **kwargs: Any) -> None ``` ### Nested @@ -54,8 +50,20 @@ def _(c: Callable[Concatenate[int | str, list[int], type[str], ...], None]): from typing import Callable, Concatenate def _(c: Callable[Concatenate[int, Callable[Concatenate[str, ...], None], ...], None]): - # TODO: Should reveal `(int, (str, ...) -> None, /, ...) -> None` - reveal_type(c) # revealed: (...) -> None + reveal_type(c) # revealed: (int, (str, /, *args: Any, **kwargs: Any) -> None, /, *args: Any, **kwargs: Any) -> None +``` + +### Both `*args` and `**kwargs` are required + +```py +from typing import Callable, Concatenate + +def decorator[**P](func: Callable[Concatenate[int, P], None]) -> Callable[P, None]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + func(1) # TODO: error: [missing-argument] + func(1, *args) # TODO: error: [missing-argument] + func(1, **kwargs) # TODO: error: [missing-argument] + return wrapper ``` ## Decorator patterns @@ -76,12 +84,12 @@ def add_param[**P, R](func: Callable[P, R]) -> Callable[Concatenate[int, P], R]: def f(x: str, y: bytes) -> int: return 1 -# TODO: Should reveal `(int, /, x: str, y: bytes) -> int` -reveal_type(f) # revealed: (...) -> int +reveal_type(f) # revealed: (int, /, x: str, y: bytes) -> int reveal_type(f(1, "", b"")) # revealed: int -# TODO: This should be an error since `param` is a positional-only parameter +# error: [missing-argument] "No argument provided for required parameter 1" +# error: [unknown-argument] "Argument `param` does not match any known parameter" reveal_type(f(param=1, x="", y=b"")) # revealed: int ``` @@ -95,25 +103,18 @@ from typing import Callable, Concatenate def remove_param[**P, R](func: Callable[Concatenate[int, P], R]) -> Callable[P, R]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper @remove_param def f(x: int, y: str, z: bytes) -> int: return 1 -# TODO: Should reveal `(y: str, z: bytes) -> int` -reveal_type(f) # revealed: [**P'return](**P'return) -> int +reveal_type(f) # revealed: (y: str, z: bytes) -> int -# TODO: Shouldn't be an error -# error: [missing-argument] reveal_type(f("", b"")) # revealed: int -# TODO: Shouldn't be an error -# error: [missing-argument] reveal_type(f(y="", z=b"")) # revealed: int -# TODO: missing-argument is an incorrect error, it should be [unknown-argument] since `x` is removed -# error: [missing-argument] "No argument provided for required parameter `*args`" +# error: [unknown-argument] "Argument `x` does not match any known parameter" reveal_type(f(x=1, y="", z=b"")) # revealed: int ``` @@ -133,13 +134,13 @@ def transform[**P, R](func: Callable[Concatenate[int, P], R]) -> Callable[Concat def f(x: int, y: int) -> int: return 1 -# TODO: Should reveal `(str, /, y: int) -> int` -reveal_type(f) # revealed: (...) -> int +reveal_type(f) # revealed: (str, /, y: int) -> int reveal_type(f("", 1)) # revealed: int reveal_type(f("", y=1)) # revealed: int -# TODO: This should be an error since `param` is a positional-only parameter +# error: [missing-argument] "No argument provided for required parameter 1" +# error: [unknown-argument] "Argument `param` does not match any known parameter" reveal_type(f(param="", y=1)) # revealed: int ``` @@ -157,13 +158,14 @@ def multi[**P, R](func: Callable[P, R]) -> Callable[Concatenate[int, str, P], R] def f(x: int) -> int: return 1 -# TODO: Should reveal `(int, str, /, x: int) -> int` -reveal_type(f) # revealed: (...) -> int +reveal_type(f) # revealed: (int, str, /, x: int) -> int reveal_type(f(1, "", 2)) # revealed: int reveal_type(f(1, "", x=2)) # revealed: int -# TODO: This should be an error since `a` and `b` are positional-only parameters +# error: [missing-argument] "No arguments provided for required parameters 1, 2" +# error: [unknown-argument] "Argument `a` does not match any known parameter" +# error: [unknown-argument] "Argument `b` does not match any known parameter" reveal_type(f(a=1, b="", x=2)) # revealed: int ``` @@ -177,14 +179,17 @@ type argument. ```py from typing import Concatenate -# error: [invalid-type-form] "`typing.Concatenate` requires at least two arguments when used in a type expression" -def _(x: Concatenate): ... +# error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context in a type expression" +def invalid0(x: Concatenate): ... + +# error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context in a type expression" +def invalid1(x: Concatenate[int]): ... -# TODO: Should be an error - Concatenate is not a valid standalone type -def invalid1(x: Concatenate[int, ...]) -> None: ... +# error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context in a type expression" +def invalid2(x: Concatenate[int, ...]) -> None: ... -# TODO: Should be an error - Concatenate is not a valid standalone type -def invalid2() -> Concatenate[int, ...]: ... +# error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context in a type expression" +def invalid3() -> Concatenate[int, ...]: ... ``` ### Too few arguments @@ -212,7 +217,7 @@ The final argument to `Concatenate` must be a `ParamSpec` or `...`. ```py from typing import Callable, Concatenate -# TODO: Should be an error - last arg is not ParamSpec or `...` +# error: [invalid-type-arguments] "The last argument to `typing.Concatenate` must be either `...` or a `ParamSpec` type variable: Got ``" def _(c: Callable[Concatenate[int, str], bool]): ... ``` @@ -224,16 +229,18 @@ If a `ParamSpec` appears in `Concatenate`, it must be the last element. from typing import Callable, Concatenate # error: [invalid-type-form] "Bare ParamSpec `P` is not valid in this context" +# error: [invalid-type-arguments] "The last argument to `typing.Concatenate` must be either `...` or a `ParamSpec` type variable: Got ``" def invalid1[**P](c: Callable[Concatenate[P, int], bool]): reveal_type(c) # revealed: (...) -> bool # error: [invalid-type-form] "Bare ParamSpec `P` is not valid in this context" def invalid2[**P](c: Callable[Concatenate[P, ...], bool]): - reveal_type(c) # revealed: (...) -> bool + # The bare `P` falls back to `Unknown` as a prefix parameter, while `...` is a valid + # gradual tail, resulting in `(Unknown, /, *args: Any, **kwargs: Any) -> bool`. + reveal_type(c) # revealed: (Unknown, /, *args: Any, **kwargs: Any) -> bool def valid[**P](c: Callable[Concatenate[int, P], bool]): - # TODO: Should reveal `(int, /, **P@valid) -> bool` - reveal_type(c) # revealed: (...) -> bool + reveal_type(c) # revealed: (int, /, *args: P@valid.args, **kwargs: P@valid.kwargs) -> bool ``` ### Nested `Concatenate` @@ -241,7 +248,7 @@ def valid[**P](c: Callable[Concatenate[int, P], bool]): ```py from typing import Callable, Concatenate -# TODO: This should be an error +# error: [invalid-type-form] "`typing.Concatenate` is not allowed in this context" def invalid[**P](c: Callable[Concatenate[Concatenate[int, ...], P], None]): pass ``` @@ -257,10 +264,9 @@ from typing import Callable, Concatenate def decorator[**P](func: Callable[Concatenate[int, P], bool]) -> Callable[P, bool]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> bool: return func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper -# TODO: This should be an error because the required `int` parameter is missing +# error: [invalid-argument-type] "Argument to function `decorator` is incorrect: Expected `(int, /, *args: Unknown, **kwargs: Unknown) -> bool`, found `def f0() -> bool`" @decorator def f0() -> bool: return True @@ -273,15 +279,13 @@ def f1(a: int) -> bool: def f2(a: int, b: str) -> bool: return True -# TODO: This call should be an error because the `str` is not assignable to `int` +# error: [invalid-argument-type] "Argument to function `decorator` is incorrect: Expected `(int, /, *args: Unknown, **kwargs: Unknown) -> bool`, found `def f3(a: str, b: int) -> bool`" @decorator def f3(a: str, b: int) -> bool: return True -# TODO: Should reveal `() -> bool` -reveal_type(f1) # revealed: [**P'return](**P'return) -> bool -# TODO: Should reveal `(b: str) -> bool` -reveal_type(f2) # revealed: [**P'return](**P'return) -> bool +reveal_type(f1) # revealed: () -> bool +reveal_type(f2) # revealed: (b: str) -> bool ``` ## Generic classes @@ -301,8 +305,7 @@ def my_handler(env: str, x: int, y: float) -> bool: return True m = Middleware(my_handler) -# TODO: Should reveal `Middleware[((x: int, y: float)), bool]` or similar -reveal_type(m) # revealed: Middleware[(...), bool] +reveal_type(m) # revealed: Middleware[(x: int, y: int | float), bool] ``` ### Specializing `ParamSpec` with `Concatenate` @@ -317,8 +320,7 @@ class Foo[**P1]: attr: Callable[P1, None] def with_paramspec[**P2](f: Foo[Concatenate[int, P2]]) -> None: - # TODO: Should reveal `Callable[Concatenate[int, P2], None]` - reveal_type(f.attr) # revealed: (...) -> None + reveal_type(f.attr) # revealed: (int, /, *args: P2@with_paramspec.args, **kwargs: P2@with_paramspec.kwargs) -> None ``` ## `Concatenate` in type aliases @@ -331,8 +333,7 @@ from typing import Callable, Concatenate type Foo[**P, R] = Callable[Concatenate[int, P], R] def _(f: Foo[[str], bool]) -> None: - # TODO: Should reveal `(int, str, /) -> bool` - reveal_type(f) # revealed: (...) -> bool + reveal_type(f) # revealed: (int, str, /) -> bool ``` ### Using `TypeAlias` @@ -347,8 +348,7 @@ R = TypeVar("R") Foo: TypeAlias = Callable[Concatenate[int, P], R] def _(f: Foo[[str], bool]) -> None: - # TODO: Should reveal `(int, str, /) -> bool` - reveal_type(f) # revealed: Unknown + reveal_type(f) # revealed: (int, str, /) -> bool ``` ## `Concatenate` with different parameter kinds @@ -361,14 +361,12 @@ from typing import Callable, Concatenate def decorator[**P](func: Callable[Concatenate[int, P], None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper @decorator def kwonly(x: int, *, key: str) -> None: ... -# TODO: Should reveal `(*, key: str) -> None` -reveal_type(kwonly) # revealed: [**P'return](**P'return) -> None +reveal_type(kwonly) # revealed: (*, key: str) -> None ``` ### Function with default values @@ -379,14 +377,12 @@ from typing import Callable, Concatenate def decorator[**P](func: Callable[Concatenate[int, P], None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper @decorator def defaults(x: int, y: str = "default", z: int = 0) -> None: ... -# TODO: Should reveal `(y: str = "default", z: int = 0) -> None` -reveal_type(defaults) # revealed: [**P'return](**P'return) -> None +reveal_type(defaults) # revealed: (y: str = "default", z: int = 0) -> None ``` ### Function with `*args` and `**kwargs` @@ -397,26 +393,25 @@ from typing import Callable, Concatenate def decorator[**P](func: Callable[Concatenate[int, P], None]) -> Callable[P, None]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper @decorator def variadic(x: int, *args: str, **kwargs: int) -> None: ... -# TODO: Should reveal `(*args: str, **kwargs: int) -> None` -reveal_type(variadic) # revealed: [**P'return](**P'return) -> None +reveal_type(variadic) # revealed: (*args: str, **kwargs: int) -> None +# error: [invalid-argument-type] "Argument to function `decorator` is incorrect: Expected `(int, /, *args: Unknown, **kwargs: Unknown) -> None`, found `def only_variadic(*args: str, **kwargs: int) -> None`" @decorator def only_variadic(*args: str, **kwargs: int) -> None: ... -# TODO: Should reveal `(*args: str, **kwargs: int) -> None` -reveal_type(only_variadic) # revealed: [**P'return](**P'return) -> None +reveal_type(only_variadic) # revealed: (...) -> None +# TODO: This should accept the callable and reveal `(*args: str, **kwargs: int) -> None`. +# error: [invalid-argument-type] @decorator def unpack_variadic(*args: *tuple[int, *tuple[str, ...]], **kwargs: int) -> None: ... -# TODO: should reveal `(*args: str, **kwargs: int) -> None` -reveal_type(unpack_variadic) # revealed: [**P'return](**P'return) -> None +reveal_type(unpack_variadic) # revealed: (...) -> None ``` ## `Concatenate` with `ParamSpec` in generic function calls @@ -429,15 +424,24 @@ from typing import Callable, Concatenate def foo[**P, R](func: Callable[Concatenate[int, P], R], *args: P.args, **kwargs: P.kwargs) -> R: return func(0, *args, **kwargs) -def test(x: str, y: str) -> bool: +def valid(x: int, y: str) -> bool: return True -reveal_type(foo(test, "", "")) # revealed: bool -reveal_type(foo(test, y="", x="")) # revealed: bool +def invalid(x: str, y: str) -> bool: + return True -# TODO: These calls should raise an error -reveal_type(foo(test, 1, "")) # revealed: bool -reveal_type(foo(test, "")) # revealed: bool +reveal_type(foo(valid, "")) # revealed: bool +reveal_type(foo(valid, y="")) # revealed: bool + +# error: [invalid-argument-type] "Argument to function `foo` is incorrect: Expected `str`, found `Literal[1]`" +# error: [too-many-positional-arguments] "Too many positional arguments to function `foo`: expected 1, got 2" +reveal_type(foo(valid, 1, "")) # revealed: bool + +# TODO: These should reveal `bool` +# error: [invalid-argument-type] "Argument to function `foo` is incorrect: Expected `(int, /, *args: Unknown, **kwargs: Unknown) -> Unknown`, found `def invalid(x: str, y: str) -> bool`" +reveal_type(foo(invalid, "")) # revealed: Unknown +# error: [invalid-argument-type] "Argument to function `foo` is incorrect: Expected `(int, /, *args: Unknown, **kwargs: Unknown) -> Unknown`, found `def invalid(x: str, y: str) -> bool`" +reveal_type(foo(invalid, 1, "")) # revealed: Unknown ``` ### Prepended type variable @@ -450,21 +454,26 @@ def decorator[T, R, **P](func: Callable[Concatenate[T, P], R], *args: P.args, ** return func(arg, *args, **kwargs) return wrapper -@decorator -def test1(x: str, y: str) -> bool: +def test1(x: int, y: str) -> bool: return True -# TODO: should reveal (str, /) -> bool -reveal_type(test1) # revealed: [T'return](T'return, /) -> bool -reveal_type(test1("")) # revealed: bool -# error: [too-many-positional-arguments] -reveal_type(test1("", "")) # revealed: bool +# error: [missing-argument] "No argument provided for required parameter `y` of function `decorator`" +reveal_type(decorator(test1)) # revealed: (int, /) -> bool +reveal_type(decorator(test1, "")) # revealed: (int, /) -> bool + +decorated_test1 = decorator(test1, y="") + +reveal_type(decorated_test1(1)) # revealed: bool +# error: [too-many-positional-arguments] "Too many positional arguments: expected 1, got 2" +reveal_type(decorated_test1(1, "")) # revealed: bool -# TODO: This should be an error since a keyword-only parameter cannot be assigned to positional-only -# parameter `T` +# error: [invalid-argument-type] "Argument to function `decorator` is incorrect: Expected `(Unknown, /, *args: Unknown, **kwargs: Unknown) -> Unknown`, found `def test2(*, x: int) -> bool`" @decorator def test2(*, x: int) -> bool: return True + +# TODO: This could reveal `(T, /, x: int) -> bool` using partial specialization +reveal_type(test2) # revealed: (Unknown, /) -> Unknown ``` ## `Concatenate` with overloaded functions @@ -478,8 +487,7 @@ from typing import Callable, Concatenate, overload def remove_param[**P, R](func: Callable[Concatenate[int, P], R]) -> Callable[P, R]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return func(0, *args, **kwargs) - # TODO: no error expected here - return wrapper # error: [invalid-return-type] + return wrapper @overload def f1(x: int, y: str) -> str: ... @@ -490,7 +498,7 @@ def f1(x: int, y: str | int) -> str | int: return y # TODO: Should reveal `Overloaded[(y: str) -> str, (y: int) -> int]` -reveal_type(f1) # revealed: [**P'return](**P'return) -> str | int +reveal_type(f1) # revealed: (y: str) -> str | int ``` But, it's not possible to _add_ a parameter to an overloaded function using `Concatenate` because @@ -513,7 +521,7 @@ def f2(y: str | int) -> str | int: return y # TODO: Should this reveal `Overloaded[(int, /, y: str) -> str, (int, /, y: int) -> int]` ? -reveal_type(f2) # revealed: (...) -> str | int +reveal_type(f2) # revealed: Overload[(int, /, y: str) -> str | int, (int, /, y: int) -> str | int] ``` But, it's possible to add the additional parameter just to the overload signatures and not the @@ -529,7 +537,7 @@ def f3(y: str | int) -> str | int: return y # TODO: Should reveal `Overloaded[(int, /, y: str) -> str, (int, /, y: int) -> int]` -reveal_type(f3) # revealed: (...) -> str | int +reveal_type(f3) # revealed: Overload[(int, x: int, /, y: str) -> str | int, (int, x: int, /, y: int) -> str | int] ``` ## `Concatenate` with protocol classes @@ -550,13 +558,8 @@ class MyHandler: def __call__(self, value: int, name: str) -> bool: return True -# TODO: P should be inferred as [name: str], R as bool from MyHandler.__call__ -# TODO: These should not be errors -# TODO: Should reveal `bool` -# error: [invalid-argument-type] -reveal_type(process(MyHandler(), "hello")) # revealed: Unknown -# error: [invalid-argument-type] -reveal_type(process(MyHandler(), name="hello")) # revealed: Unknown +reveal_type(process(MyHandler(), "hello")) # revealed: bool +reveal_type(process(MyHandler(), name="hello")) # revealed: bool def use_callable[**P, R](func: Callable[Concatenate[int, P], R], handler: Handler[P, R]) -> None: ... ``` @@ -569,6 +572,99 @@ def use_callable[**P, R](func: Callable[Concatenate[int, P], R], handler: Handle from typing_extensions import Callable, Concatenate def _(c: Callable[Concatenate[int, str, ...], bool]): - # TODO: Should reveal `(int, str, ...) -> bool` - reveal_type(c) # revealed: (...) -> bool + reveal_type(c) # revealed: (int, str, /, *args: Any, **kwargs: Any) -> bool +``` + +## Assignability + +### Implicit concatenate to non-concatenated callable + +As per the [spec](https://typing.python.org/en/latest/spec/generics.html#id5): + +> A function declared as `def inner(a: A, b: B, *args: P.args, **kwargs: P.kwargs) -> R` has type +> `Callable[Concatenate[A, B, P], R]`. + +```py +from typing import Callable, Concatenate + +def decorator[**P1](func: Callable[P1, None]) -> Callable[P1, None]: + def wrapper(*args: P1.args, **kwargs: P1.kwargs) -> None: + func(*args, **kwargs) + + return wrapper + +@decorator +def f1[**P2](fn: Callable[P2, None], x: int, *args: P2.args, **kwargs: P2.kwargs) -> None: + pass + +reveal_type(f1) # revealed: [**P2](fn: (**P2) -> None, x: int, *args: P2.args, **kwargs: P2.kwargs) -> None + +def test(a: str) -> None: ... + +reveal_type(f1(test, 1, "")) # revealed: None + +# error: [missing-argument] "No argument provided for required parameter `x`" +# error: [missing-argument] "No argument provided for required parameter `a`" +reveal_type(f1(test)) # revealed: None + +# TODO: Currently, this is allowed but should probably raise a diagnostic given that +# `x` is now a positional-only parameter because of the Concatenate form but it might +# be too strict. +reveal_type(f1(fn=test, x=1, a="")) # revealed: None +``` + +### Non-concatenated to concatenated callable + +```py +from typing import Callable, Concatenate + +def decorator[**P1](func: Callable[Concatenate[int, P1], None]) -> Callable[P1, None]: + def wrapper(*args: P1.args, **kwargs: P1.kwargs) -> None: + pass + return wrapper + +def foo[**P2](f: Callable[P2, None]) -> None: + reveal_type(f) # revealed: (**P2@foo) -> None + # TODO: This should raise an invalid-argument-type error + reveal_type(decorator(f)) # revealed: (...) -> None +``` + +### Concatenate `ParamSpec` to concatenate `...` + +```py +from typing import Callable, Concatenate + +def gradual_generic[T](func: Callable[..., T]) -> T: + return func() + +def concat_paramspec[**P, T](fn: Callable[Concatenate[int, P], T]): + reveal_type(gradual_generic(fn)) # revealed: T@concat_paramspec +``` + +### Concatenate `...` to concatenate `ParamSpec` + +```py +from typing import Callable, Concatenate + +def concat_paramspec[**P, T](fn: Callable[Concatenate[int, P], T]) -> Callable[Concatenate[int, P], T]: + return fn + +def gradual_generic[T](func: Callable[..., T]): + # revealed: (int, /, *args: Any, **kwargs: Any) -> T@gradual_generic + reveal_type(concat_paramspec(func)) +``` + +### Type alias callable + +```py +from collections.abc import Callable +from typing import Concatenate + +type ConsumerType[**P1] = Callable[Concatenate[Callable[P1, None], P1], None] + +def consumer[**P2](x: Callable[P2, None], /, *args: P2.args, **kwargs: P2.kwargs) -> None: ... +def assign[**P3](x: Callable[P3, None], /, *args: P3.args, **kwargs: P3.kwargs) -> None: + # TODO: This shouldn't be an error + # error: [invalid-assignment] "Object of type `def consumer[**P2](x: (**P2) -> None, /, *args: P2.args, **kwargs: P2.kwargs) -> None` is not assignable to `ConsumerType[P3@assign]`" + wrapped: ConsumerType[P3] = consumer ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md index 166d0d6a2a4ff..d39756eae6501 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md @@ -1181,8 +1181,5 @@ class Factory[**P](Protocol): def call_factory[**P](ctr: Factory[P], *args: P.args, **kwargs: P.kwargs) -> int: return ctr("", *args, **kwargs) -# TODO: This should be OK - P should be inferred as [] since my_factory only has `arg: str` -# which matches the prefix. Currently this is a false positive. -# error: [invalid-argument-type] call_factory(my_factory) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/liskov.md b/crates/ty_python_semantic/resources/mdtest/liskov.md index 913bb8ba57943..133726c6ca68f 100644 --- a/crates/ty_python_semantic/resources/mdtest/liskov.md +++ b/crates/ty_python_semantic/resources/mdtest/liskov.md @@ -129,9 +129,6 @@ class Sub21(Super4): class Sub22(Super4): def method(self, **kwargs): ... # error: [invalid-method-override] - -class Sub23(Super4): - def method(self, x, *args, y, **kwargs): ... # error: [invalid-method-override] ``` ## The entire class hierarchy is checked diff --git a/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md index 03321aeddcdde..31d56a46ca65b 100644 --- a/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md @@ -237,8 +237,7 @@ from typing_extensions import Callable, Concatenate, TypeAliasType MyAlias4: TypeAlias = Callable[Concatenate[dict[str, T], ...], list[U]] def _(c: MyAlias4[int, str]): - # TODO: should be (int, / ...) -> str - reveal_type(c) # revealed: Unknown + reveal_type(c) # revealed: (dict[str, int], /, *args: Any, **kwargs: Any) -> list[str] T = TypeVar("T") @@ -270,8 +269,7 @@ def _(x: ListOrDict[int]): MyAlias7: TypeAlias = Callable[Concatenate[T, ...], None] def _(c: MyAlias7[int]): - # TODO: should be (int, / ...) -> None - reveal_type(c) # revealed: Unknown + reveal_type(c) # revealed: (int, /, *args: Any, **kwargs: Any) -> None ``` ## Imported diff --git "a/crates/ty_python_semantic/resources/mdtest/snapshots/liskov.md_-_The_Liskov_Substitut\342\200\246_-_Method_parameters_(d98059266bcc1e13).snap" "b/crates/ty_python_semantic/resources/mdtest/snapshots/liskov.md_-_The_Liskov_Substitut\342\200\246_-_Method_parameters_(d98059266bcc1e13).snap" index 2d3718c4f93a7..c1e693699df4d 100644 --- "a/crates/ty_python_semantic/resources/mdtest/snapshots/liskov.md_-_The_Liskov_Substitut\342\200\246_-_Method_parameters_(d98059266bcc1e13).snap" +++ "b/crates/ty_python_semantic/resources/mdtest/snapshots/liskov.md_-_The_Liskov_Substitut\342\200\246_-_Method_parameters_(d98059266bcc1e13).snap" @@ -2,7 +2,6 @@ source: crates/ty_test/src/lib.rs expression: snapshot --- - --- mdtest name: liskov.md - The Liskov Substitution Principle - Method parameters mdtest path: crates/ty_python_semantic/resources/mdtest/liskov.md @@ -96,9 +95,6 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/liskov.md 81 | 82 | class Sub22(Super4): 83 | def method(self, **kwargs): ... # error: [invalid-method-override] -84 | -85 | class Sub23(Super4): -86 | def method(self, x, *args, y, **kwargs): ... # error: [invalid-method-override] ``` # Diagnostics @@ -294,29 +290,6 @@ error[invalid-method-override]: Invalid override of method `method` 82 | class Sub22(Super4): 83 | def method(self, **kwargs): ... # error: [invalid-method-override] | ^^^^^^^^^^^^^^^^^^^^^^ Definition is incompatible with `Super4.method` -84 | -85 | class Sub23(Super4): - | - ::: src/mdtest_snippet.pyi:74:9 - | -73 | class Super4: -74 | def method(self, *args: int, **kwargs: str): ... - | --------------------------------------- `Super4.method` defined here -75 | -76 | class Sub20(Super4): - | -info: This violates the Liskov Substitution Principle -info: rule `invalid-method-override` is enabled by default - -``` - -``` -error[invalid-method-override]: Invalid override of method `method` - --> src/mdtest_snippet.pyi:86:9 - | -85 | class Sub23(Super4): -86 | def method(self, x, *args, y, **kwargs): ... # error: [invalid-method-override] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Definition is incompatible with `Super4.method` | ::: src/mdtest_snippet.pyi:74:9 | diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 1fdecd5a72b46..24059a40847f7 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -1499,6 +1499,22 @@ static_assert(is_assignable_to(Callable[..., None], Callable[Concatenate[int, .. static_assert(is_assignable_to(Callable[..., None], Callable[Concatenate[int, str, ...], None])) ``` +### Assignable from bottom callable + +```py +from ty_extensions import static_assert, is_assignable_to, RegularCallableTypeOf +from typing import Callable, Concatenate, Never + +def bottom(*args: object, **kwargs: object) -> Never: + raise NotImplementedError + +static_assert(is_assignable_to(RegularCallableTypeOf[bottom], Callable[Concatenate[int, ...], None])) +static_assert(is_assignable_to(RegularCallableTypeOf[bottom], Callable[Concatenate[int, str, ...], None])) + +static_assert(not is_assignable_to(Callable[Concatenate[int, ...], None], RegularCallableTypeOf[bottom])) +static_assert(not is_assignable_to(Callable[Concatenate[int, str, ...], None], RegularCallableTypeOf[bottom])) +``` + ### Contravariance of parameters Callable parameters are contravariant: a callable accepting a wider type (`A`) is assignable to one @@ -1512,8 +1528,6 @@ class Parent: ... class Child(Parent): ... static_assert(is_assignable_to(Callable[Concatenate[Parent, ...], None], Callable[Concatenate[Child, ...], None])) -# TODO: should not be assignable (`Parent` is not assignable to `Child`) -# error: [static-assert-error] static_assert(not is_assignable_to(Callable[Concatenate[Child, ...], None], Callable[Concatenate[Parent, ...], None])) ``` @@ -1526,11 +1540,7 @@ from typing import Callable, Concatenate, final class A: ... class B: ... -# TODO: should not be assignable (`A` and `B` are disjoint) -# error: [static-assert-error] static_assert(not is_assignable_to(Callable[Concatenate[A, ...], None], Callable[Concatenate[B, ...], None])) -# TODO: should not be assignable -# error: [static-assert-error] static_assert(not is_assignable_to(Callable[Concatenate[B, ...], None], Callable[Concatenate[A, ...], None])) ``` @@ -1572,6 +1582,88 @@ def with_paramspec[**P](_: Callable[P, None]): static_assert(is_assignable_to(Callable[..., None], Callable[Concatenate[int, P], None])) ``` +### Gradual `Concatenate` with regular function + +```py +from ty_extensions import RegularCallableTypeOf, static_assert, is_assignable_to +from typing import Callable, Concatenate + +class A: ... +class B: ... +class C: ... +``` + +A `Concatenate` form that ends with `...` means that all of the parameters before `...` are +positional-only. + +```py +def positional_only(a: A, b: B, /) -> None: ... +def with_default(a: A, b: B = B(), /) -> None: ... + +static_assert(is_assignable_to(RegularCallableTypeOf[positional_only], Callable[Concatenate[A, ...], None])) +static_assert(is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[positional_only])) + +static_assert(is_assignable_to(RegularCallableTypeOf[positional_only], Callable[Concatenate[A, B, ...], None])) +static_assert(is_assignable_to(Callable[Concatenate[A, B, ...], None], RegularCallableTypeOf[positional_only])) + +# Concatenate has an additional required positional-only parameter which isn't present in the +# function definition, so they aren't assignable. +static_assert(not is_assignable_to(RegularCallableTypeOf[positional_only], Callable[Concatenate[A, B, C, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, B, C, ...], None], RegularCallableTypeOf[positional_only])) + +static_assert(is_assignable_to(RegularCallableTypeOf[with_default], Callable[Concatenate[A, ...], None])) +static_assert(is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[with_default])) + +# For an optional parameter (with default value), it is assignable to a non-optional parameter, but +# the reverse is not true. +static_assert(is_assignable_to(RegularCallableTypeOf[with_default], Callable[Concatenate[A, B, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, B, ...], None], RegularCallableTypeOf[with_default])) +``` + +But, a regular callable can contain a positional-or-keyword parameter which is sometimes compatible +with the `Concatenate` with gradual form. + +```py +def positional_or_keyword(a: A, b: B) -> None: ... + +static_assert(is_assignable_to(RegularCallableTypeOf[positional_or_keyword], Callable[Concatenate[A, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[positional_or_keyword])) + +static_assert(is_assignable_to(RegularCallableTypeOf[positional_or_keyword], Callable[Concatenate[A, B, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, B, ...], None], RegularCallableTypeOf[positional_or_keyword])) +``` + +For variadic parameter, it is assignable only when the type of the variadic parameter is compatible +with the type of all the prefix parameters in the `Concatenate` form. + +```py +def variadic_a(*args: A) -> None: ... +def variadic_b(*args: B) -> None: ... + +static_assert(is_assignable_to(RegularCallableTypeOf[variadic_a], Callable[Concatenate[A, ...], None])) +static_assert(is_assignable_to(RegularCallableTypeOf[variadic_a], Callable[Concatenate[A, A, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[variadic_a])) + +static_assert(not is_assignable_to(RegularCallableTypeOf[variadic_a], Callable[Concatenate[A, B, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, B, ...], None], RegularCallableTypeOf[variadic_a])) + +static_assert(not is_assignable_to(RegularCallableTypeOf[variadic_b], Callable[Concatenate[A, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[variadic_b])) +``` + +For all the other parameter kinds, it is not assignable in either direction. + +```py +def keyword_only(*, a: A, b: B) -> None: ... +def keyword_variadic(**kwargs: A) -> None: ... + +static_assert(not is_assignable_to(RegularCallableTypeOf[keyword_only], Callable[Concatenate[A, B, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, B, ...], None], RegularCallableTypeOf[keyword_only])) + +static_assert(not is_assignable_to(RegularCallableTypeOf[keyword_variadic], Callable[Concatenate[A, ...], None])) +static_assert(not is_assignable_to(Callable[Concatenate[A, ...], None], RegularCallableTypeOf[keyword_variadic])) +``` + [gradual form]: https://typing.python.org/en/latest/spec/glossary.html#term-gradual-form [gradual tuple]: https://typing.python.org/en/latest/spec/tuples.html#tuple-type-form [typing documentation]: https://typing.python.org/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 91bfabf8fc5ad..2d5b01ad71143 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6598,6 +6598,9 @@ enum InvalidTypeExpression<'db> { /// Same for `typing.TypeAlias`, anywhere except for as the sole annotation on an annotated /// assignment TypeAlias, + /// Same for `typing.Concatenate`, anywhere except for as the first parameter of a `Callable` + /// type expression + Concatenate, /// Type qualifiers are always invalid in *type expressions*, /// but these ones are okay with 0 arguments in *annotation expressions* TypeQualifier(TypeQualifier), @@ -6695,6 +6698,9 @@ impl<'db> InvalidTypeExpression<'db> { "Bare ParamSpec `{}` is not valid in this context in a type expression", paramspec.name(self.db) ), + InvalidTypeExpression::Concatenate => f.write_str( + "`typing.Concatenate` is not allowed in this context in a type expression", + ), } } } @@ -6767,6 +6773,10 @@ impl<'db> InvalidTypeExpression<'db> { diagnostic.info(" - as the default type for another ParamSpec"); diagnostic.info(" - as part of a type parameter list when defining a generic class"); diagnostic.info(" - or as part of an argument list when specializing a generic class"); + } else if matches!(self, InvalidTypeExpression::Concatenate) { + diagnostic.info("`typing.Concatenate` is only valid:"); + diagnostic.info(" - as the first argument to `typing.Callable`"); + diagnostic.info(" - as a type argument for a `ParamSpec` parameter"); } } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 1926b2eec17f6..e07c5113fc59b 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3635,7 +3635,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { // For ParamSpec parameters, both *args and **kwargs are required since we don't know // what arguments the underlying callable expects. For all other callables, variadic // and keyword_variadic parameters are optional. - let paramspec_parameters = self.parameters.as_paramspec().is_some(); + let paramspec = self.parameters.as_paramspec(); let mut missing = vec![]; for ( @@ -3651,7 +3651,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { continue; } let param = &self.parameters[index]; - if !paramspec_parameters && (param.is_variadic() || param.is_keyword_variadic()) + if paramspec.is_none() && (param.is_variadic() || param.is_keyword_variadic()) || param.default_type().is_some() { // variadic/keywords and defaulted arguments are not required @@ -3664,7 +3664,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { if !missing.is_empty() { self.errors.push(BindingError::MissingArguments { parameters: ParameterContexts(missing), - paramspec: self.parameters.as_paramspec(), + paramspec, }); } @@ -4016,10 +4016,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } fn check_argument_types(&mut self, constraints: &ConstraintSetBuilder<'db>) { - let paramspec = self - .signature - .parameters() - .find_paramspec_from_args_kwargs(self.db); + let paramspec = self.signature.parameters().as_paramspec_with_prefix(); for (argument_index, adjusted_argument_index, argument, argument_types) in self.enumerate_argument_types() diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index e311cf4f49cfa..5d5b21734d811 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -6041,3 +6041,20 @@ pub(super) fn hint_if_stdlib_attribute_exists_on_other_versions( // TODO: determine what platform they need to be on add_inferred_python_version_hint_to_diagnostic(db, &mut diagnostic, action); } + +pub(super) fn report_invalid_concatenate_last_arg<'db>( + context: &InferContext<'db, '_>, + last_arg: &ast::Expr, + last_arg_type: Type<'db>, +) { + if let Some(builder) = context.report_lint(&INVALID_TYPE_ARGUMENTS, last_arg) { + let mut diag = builder.into_diagnostic( + "The last argument to `typing.Concatenate` must be either `...` or a `ParamSpec` \ + type variable", + ); + diag.set_primary_message(format_args!( + "Got `{}`", + last_arg_type.display(context.db()) + )); + } +} diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 5716f1437cf0b..677074d779918 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -2136,78 +2136,112 @@ struct DisplayParameters<'a, 'db> { impl<'db> FmtDetailed<'db> for DisplayParameters<'_, 'db> { fn fmt_detailed(&self, f: &mut TypeWriter<'_, '_, 'db>) -> fmt::Result { - // For `ParamSpec` kind, the parameters still contain `*args` and `**kwargs`, but we - // display them as `**P` instead, so avoid multiline in that case. - // TODO: This might change once we support `Concatenate` - let multiline = self.settings.multiline - && self.parameters.len() > 1 - && !matches!( - self.parameters.kind(), - ParametersKind::Gradual | ParametersKind::ParamSpec(_) - ); - // Opening parenthesis - f.write_char('(')?; - if multiline { - f.write_str("\n ")?; - } - match self.parameters.kind() { - ParametersKind::Standard => { - let mut star_added = false; - let mut needs_slash = false; - let mut first = true; - let arg_separator = if multiline { ",\n " } else { ", " }; - - for parameter in self.parameters.as_slice() { - // Handle special separators - if !star_added && parameter.is_keyword_only() { - if !first { - f.write_str(arg_separator)?; - } - f.write_char('*')?; - star_added = true; - first = false; - } - if parameter.is_positional_only() { - needs_slash = true; - } else if needs_slash { - if !first { - f.write_str(arg_separator)?; - } - f.write_char('/')?; - needs_slash = false; - first = false; - } - - // Add comma before parameter if not first + fn display_parameters<'db>( + display: &DisplayParameters<'_, 'db>, + f: &mut TypeWriter<'_, '_, 'db>, + parameters: &[Parameter<'db>], + arg_separator: &str, + ) -> fmt::Result { + let mut star_added = false; + let mut needs_slash = false; + let mut first = true; + + for parameter in parameters { + // Handle special separators + if !star_added && parameter.is_keyword_only() { if !first { f.write_str(arg_separator)?; } - - // Write parameter with range tracking - let param_name = parameter - .display_name() - .map(|name| name.to_string()) - .unwrap_or_default(); - parameter - .display_with(self.db, self.settings.singleline()) - .fmt_detailed(&mut f.with_detail(TypeDetail::Parameter(param_name)))?; - + f.write_char('*')?; + star_added = true; first = false; } - - if needs_slash { + if parameter.is_positional_only() { + needs_slash = true; + } else if needs_slash { if !first { f.write_str(arg_separator)?; } f.write_char('/')?; + needs_slash = false; + first = false; + } + + // Add comma before parameter if not first + if !first { + f.write_str(arg_separator)?; + } + + // Write parameter with range tracking + let param_name = parameter + .display_name() + .map(|name| name.to_string()) + .unwrap_or_default(); + parameter + .display_with(display.db, display.settings.singleline()) + .fmt_detailed(&mut f.with_detail(TypeDetail::Parameter(param_name)))?; + + first = false; + } + + if needs_slash { + if !first { + f.write_str(arg_separator)?; } + f.write_char('/')?; } - ParametersKind::Gradual | ParametersKind::Top => { - // We represent gradual form as `...` in the signature, internally the parameters still - // contain `(*args, **kwargs)` parameters. (Top parameters are displayed the same - // as gradual parameters, we just wrap the entire signature in `Top[]`.) + + Ok(()) + } + + // For `ParamSpec` kind, the parameters still contain `*args` and `**kwargs`, but we + // display them as `**P` instead, so avoid multiline in that case. + // For `Concatenate` kind, use multiline only if there are more than 1 prefix parameters. + // For `Gradual` kind without prefix params (len <= 2), display as `...`. + let multiline = if self.settings.multiline { + match self.parameters.kind() { + ParametersKind::Standard => self.parameters.len() > 1, + ParametersKind::Gradual | ParametersKind::Top | ParametersKind::ParamSpec(_) => { + false + } + ParametersKind::Concatenate(_) => { + // The tail already represents 2 parameters. Additionally, there should be more + // than 1 prefix parameters to use multiline, so the limit becomes 3. + self.parameters.len() > 3 + } + } + } else { + false + }; + + // Opening parenthesis + f.write_char('(')?; + if multiline { + f.write_str("\n ")?; + } + + let arg_separator = if multiline { ",\n " } else { ", " }; + + match self.parameters.kind() { + ParametersKind::Standard | ParametersKind::Concatenate(_) => { + display_parameters(self, f, self.parameters.as_slice(), arg_separator)?; + } + ParametersKind::Top => { + // TODO: Remove `...`, always display all the parameters + // Top parameters are displayed the same as gradual parameters, we just wrap the + // entire signature in `Top[]` + f.write_str("...")?; + } + ParametersKind::Gradual if self.parameters.len() == 2 => { + // TODO: Remove `...`, always display all the parameters + // For gradual parameters with only `(*args, **kwargs)`, display as `...` for + // simplicity ... f.write_str("...")?; } + ParametersKind::Gradual => { + // ... but otherwise display all the parameters as normal. + display_parameters(self, f, self.parameters.as_slice(), arg_separator)?; + } ParametersKind::ParamSpec(typevar) => { write!(f, "**{}", typevar.name(self.db))?; let binding_context = typevar.binding_context(self.db); @@ -2219,9 +2253,11 @@ impl<'db> FmtDetailed<'db> for DisplayParameters<'_, 'db> { } } } + if multiline { f.write_char('\n')?; } + // Closing parenthesis f.write_char(')') } diff --git a/crates/ty_python_semantic/src/types/infer/builder/subscript.rs b/crates/ty_python_semantic/src/types/infer/builder/subscript.rs index d07f4ef240fae..a3f177160f018 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/subscript.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/subscript.rs @@ -7,10 +7,10 @@ use ty_module_resolver::file_to_module; use super::TypeInferenceBuilder; use crate::place::{DefinedPlace, Definedness, Place}; -use crate::semantic_index::SemanticIndex; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::{PlaceExpr, PlaceExprRef}; use crate::semantic_index::scope::FileScopeId; +use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::types::call::CallErrorKind; use crate::types::call::bind::CallableDescription; use crate::types::constraints::ConstraintSetBuilder; @@ -18,19 +18,20 @@ use crate::types::diagnostic::{ CALL_NON_CALLABLE, INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_KEY, INVALID_TYPE_ARGUMENTS, INVALID_TYPE_FORM, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_IMPLICIT_CALL, TypedDictDeleteErrorKind, report_cannot_delete_typed_dict_key, - report_invalid_arguments_to_annotated, report_invalid_key_on_typed_dict, - report_not_subscriptable, + report_invalid_arguments_to_annotated, report_invalid_concatenate_last_arg, + report_invalid_key_on_typed_dict, report_not_subscriptable, }; use crate::types::generics::{GenericContext, InferableTypeVars, bind_typevar}; use crate::types::infer::InferenceFlags; use crate::types::infer::builder::{ArgExpr, ArgumentsIter, MultiInferenceGuard}; +use crate::types::signatures::ConcatenateTail; use crate::types::special_form::AliasSpec; use crate::types::subscript::{LegacyGenericOrigin, SubscriptError, SubscriptErrorKind}; use crate::types::tuple::{Tuple, TupleType}; use crate::types::typed_dict::{TypedDictAssignmentKind, TypedDictKeyAssignment}; use crate::types::{ - BoundTypeVarInstance, CallArguments, CallDunderError, CallableType, DynamicType, InternedType, - KnownClass, KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, + BoundTypeVarInstance, CallArguments, CallDunderError, DynamicType, InternedType, KnownClass, + KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, StaticClassLiteral, Type, TypeAliasType, TypeContext, TypeVarBoundOrConstraints, UnionType, UnionTypeInstance, any_over_type, todo_type, }; @@ -303,51 +304,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { )); } SpecialFormType::Callable => { - let arguments = if let ast::Expr::Tuple(tuple) = &*subscript.slice { - &*tuple.elts - } else { - std::slice::from_ref(&*subscript.slice) - }; - - // TODO: Remove this once we support Concatenate properly. This is necessary - // to avoid a lot of false positives downstream, because we can't represent the typevar- - // specialized `Callable` types yet. - if let [first_arg, second_arg] = arguments - && first_arg.is_subscript_expr() - { - let first_arg_ty = self.infer_expression(first_arg, TypeContext::default()); - if let Type::Dynamic(DynamicType::UnknownGeneric(generic_context)) = - first_arg_ty - { - let mut variables = - generic_context.variables(db).collect::>(); - - let return_ty = - self.infer_expression(second_arg, TypeContext::default()); - return_ty.bind_and_find_all_legacy_typevars( - db, - self.typevar_binding_context, - &mut variables, - ); - - let generic_context = - GenericContext::from_typevar_instances(db, variables); - return Type::Dynamic(DynamicType::UnknownGeneric(generic_context)); - } - - if let Some(builder) = - self.context.report_lint(&INVALID_TYPE_FORM, subscript) - { - builder.into_diagnostic(format_args!( - "The first argument to `Callable` must be either a list of types, \ - ParamSpec, Concatenate, or `...`", - )); - } - return Type::KnownInstance(KnownInstanceType::Callable( - CallableType::unknown(db), - )); - } - let callable = self .infer_callable_type(subscript) .as_callable() @@ -886,8 +842,99 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return Ok(Type::paramspec_value_callable(db, parameters)); } - ast::Expr::Subscript(_) => { - // TODO: Support `Concatenate[...]` + ast::Expr::Subscript(subscript) => { + let value_ty = self.infer_expression(&subscript.value, TypeContext::default()); + + if matches!(value_ty, Type::SpecialForm(SpecialFormType::Concatenate)) { + let arguments_slice = &*subscript.slice; + let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { + &*tuple.elts + } else { + std::slice::from_ref(arguments_slice) + }; + + let num_arguments = arguments.len(); + if num_arguments < 2 { + for argument in arguments { + self.infer_type_expression(argument); + } + if arguments_slice.is_tuple_expr() { + self.store_expression_type(arguments_slice, Type::unknown()); + } + return Ok(Type::paramspec_value_callable( + db, + Parameters::gradual_form(), + )); + } + + // SAFETY: `arguments` is guaranteed to have at least two elements from the + // length check above. + let (last_arg, prefix_args) = arguments.split_last().unwrap(); + + let prefix_params = prefix_args + .iter() + .map(|arg| { + Parameter::positional_only(None) + .with_annotated_type(self.infer_type_expression(arg)) + }) + .collect(); + + let parameters = match last_arg { + ast::Expr::EllipsisLiteral(_) => Some(Parameters::concatenate( + self.db(), + prefix_params, + ConcatenateTail::Gradual, + )), + ast::Expr::Name(name) if !name.is_invalid() => { + let name_ty = self.infer_name_load(name); + if let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) = + name_ty + && typevar.is_paramspec(self.db()) + { + let index = semantic_index(self.db(), self.scope().file(self.db())); + bind_typevar( + self.db(), + index, + self.scope().file_scope_id(self.db()), + self.typevar_binding_context, + typevar, + ) + .map(|bound_typevar| { + Parameters::concatenate( + self.db(), + prefix_params, + ConcatenateTail::ParamSpec(bound_typevar), + ) + }) + } else { + report_invalid_concatenate_last_arg( + &self.context, + last_arg, + name_ty, + ); + None + } + } + _ => { + let ty = self.infer_type_expression(last_arg); + report_invalid_concatenate_last_arg(&self.context, last_arg, ty); + None + } + }; + + if arguments_slice.is_tuple_expr() { + // TODO: What type to store for the argument slice in `Concatenate` because + // `Parameters` is not a `Type` variant? + self.store_expression_type(arguments_slice, Type::unknown()); + } + + return Ok(Type::paramspec_value_callable( + db, + parameters.unwrap_or_else(Parameters::unknown), + )); + } + + // Non-Concatenate subscript: fall back to todo return Ok(Type::paramspec_value_callable(db, Parameters::todo())); } diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index b3f0484dc79e7..65a15cc4aa0a4 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -3,17 +3,20 @@ use ruff_python_ast::{self as ast, PythonVersion}; use super::{DeferredExpressionState, TypeInferenceBuilder}; use crate::semantic_index::scope::ScopeKind; +use crate::semantic_index::semantic_index; use crate::types::diagnostic::{ self, INVALID_TYPE_FORM, NOT_SUBSCRIPTABLE, UNBOUND_TYPE_VARIABLE, UNSUPPORTED_OPERATOR, note_py_version_too_old_for_pep_604, report_invalid_argument_number_to_special_form, - report_invalid_arguments_to_callable, + report_invalid_arguments_to_callable, report_invalid_concatenate_last_arg, }; +use crate::types::generics::bind_typevar; use crate::types::infer::InferenceFlags; use crate::types::infer::builder::{InnerExpressionInferenceState, MultiInferenceState}; -use crate::types::signatures::Signature; +use crate::types::signatures::{ConcatenateTail, Signature}; use crate::types::special_form::{AliasSpec, LegacyStdlibAlias}; use crate::types::string_annotation::parse_string_annotation; use crate::types::tuple::{TupleSpecBuilder, TupleType}; + use crate::types::{ BindingContext, CallableType, DynamicType, GenericContext, IntersectionBuilder, KnownClass, KnownInstanceType, LintDiagnosticGuard, LiteralValueTypeKind, Parameter, Parameters, @@ -1876,6 +1879,15 @@ impl<'db> TypeInferenceBuilder<'db, '_> { ), }, SpecialFormType::Concatenate => { + if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { + let mut diag = builder.into_diagnostic(format_args!( + "`typing.Concatenate` is not allowed in this context in a type expression", + )); + diag.info("`typing.Concatenate` is only valid:"); + diag.info(" - as the first argument to `typing.Callable`"); + diag.info(" - as a type argument for a `ParamSpec` parameter"); + } + let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { &*tuple.elts } else { @@ -1901,21 +1913,11 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } } - let num_arguments = arguments.len(); - let inferred_type = if num_arguments < 2 { - if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { - builder.into_diagnostic(format_args!( - "Special form `{special_form}` expected at least 2 parameters but got {num_arguments}", - )); - } - Type::unknown() - } else { - todo_type!("`Concatenate[]` special form") - }; if arguments_slice.is_tuple_expr() { - self.store_expression_type(arguments_slice, inferred_type); + self.store_expression_type(arguments_slice, Type::unknown()); } - inferred_type + + Type::unknown() } SpecialFormType::Unpack => { let inner_ty = self.infer_type_expression(arguments_slice); @@ -2184,8 +2186,62 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::Subscript(subscript) => { let value_ty = self.infer_expression(&subscript.value, TypeContext::default()); + + if matches!(value_ty, Type::SpecialForm(SpecialFormType::Concatenate)) { + let arguments_slice = &*subscript.slice; + let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { + &*tuple.elts + } else { + std::slice::from_ref(arguments_slice) + }; + + let num_arguments = arguments.len(); + if num_arguments < 2 { + for argument in arguments { + self.infer_type_expression(argument); + } + if let Some(builder) = + self.context.report_lint(&INVALID_TYPE_FORM, subscript) + { + builder.into_diagnostic(format_args!( + "Special form `typing.Concatenate` expected at least 2 parameters \ + but got {num_arguments}", + )); + } + if arguments_slice.is_tuple_expr() { + self.store_expression_type(arguments_slice, Type::unknown()); + } + return Some(Parameters::gradual_form()); + } + + // SAFETY: `arguments` is guaranteed to have at least two elements from the + // length check above. + let (last_arg, prefix_args) = arguments.split_last().unwrap(); + + let prefix_params = prefix_args + .iter() + .map(|arg| { + Parameter::positional_only(None) + .with_annotated_type(self.infer_type_expression(arg)) + }) + .collect(); + + let parameters = self + .infer_concatenate_tail(last_arg) + .map(|tail| Parameters::concatenate(self.db(), prefix_params, tail)); + + if arguments_slice.is_tuple_expr() { + // TODO: What type to store for the argument slice in `Concatenate` because + // `Parameters` is not a `Type` variant? + self.store_expression_type(arguments_slice, Type::unknown()); + } + + return Some(parameters.unwrap_or_else(Parameters::unknown)); + } + self.infer_subscript_type_expression(subscript, value_ty); - // TODO: Support `Concatenate[...]` + + // Non-Concatenate subscript (e.g. Unpack): fall back to todo return Some(Parameters::todo()); } ast::Expr::Name(_) | ast::Expr::Attribute(_) => { @@ -2246,6 +2302,64 @@ impl<'db> TypeInferenceBuilder<'db, '_> { None } + fn infer_concatenate_tail(&mut self, expr: &ast::Expr) -> Option> { + match expr { + ast::Expr::EllipsisLiteral(_) => Some(ConcatenateTail::Gradual), + ast::Expr::Name(name) if !name.is_invalid() => { + let name_ty = self.infer_name_load(name); + if let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) = name_ty + && typevar.is_paramspec(self.db()) + { + let index = semantic_index(self.db(), self.scope().file(self.db())); + bind_typevar( + self.db(), + index, + self.scope().file_scope_id(self.db()), + self.typevar_binding_context, + typevar, + ) + .map(ConcatenateTail::ParamSpec) + } else { + report_invalid_concatenate_last_arg(&self.context, expr, name_ty); + None + } + } + ast::Expr::StringLiteral(string) => { + if let Some(parsed) = parse_string_annotation(&self.context, string) { + self.string_annotations + .insert(ruff_python_ast::ExprRef::StringLiteral(string).into()); + let node_key = self.enclosing_node_key(string.into()); + + let previous_deferred_state = std::mem::replace( + &mut self.deferred_state, + DeferredExpressionState::InStringAnnotation(node_key), + ); + let result = matches!( + parsed.expr(), + ast::Expr::Name(_) | ast::Expr::EllipsisLiteral(_) + ) + .then(|| self.infer_concatenate_tail(parsed.expr())); + self.deferred_state = previous_deferred_state; + + if let Some(result) = result { + result + } else { + report_invalid_concatenate_last_arg(&self.context, expr, Type::unknown()); + None + } + } else { + report_invalid_concatenate_last_arg(&self.context, expr, Type::unknown()); + None + } + } + _ => { + let ty = self.infer_type_expression(expr); + report_invalid_concatenate_last_arg(&self.context, expr, ty); + None + } + } + } + /// Checks if the inferred type is an unbound type variable and reports a diagnostic if so. /// /// Returns `Unknown` as a fallback if the type variable is unbound, otherwise returns the diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 4b767029ccc67..d3ef8fe8944e9 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -224,8 +224,7 @@ impl<'db> CallableSignature<'db> { type_mapping { Self::from_overloads(self.overloads.iter().flat_map(|signature| { - if let Some((prefix, paramspec)) = - signature.parameters.find_paramspec_from_args_kwargs(db) + if let Some((prefix, paramspec)) = signature.parameters.as_paramspec_with_prefix() && let Some(value) = specialization.get(db, paramspec) && let Some(result) = try_apply_type_mapping_for_paramspec( db, @@ -1022,8 +1021,14 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { ([source_signature], [target_signature]) => { // Base case: both callable types contain a single signature. if self.relation.is_constraint_set_assignability() - && (source_signature.parameters.as_paramspec().is_some() - || target_signature.parameters.as_paramspec().is_some()) + && (source_signature + .parameters + .as_paramspec_with_prefix() + .is_some() + || target_signature + .parameters + .as_paramspec_with_prefix() + .is_some()) { self.check_signature_pair_inner(db, source_signature, target_signature) } else { @@ -1220,15 +1225,17 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { let return_type_checks = check_types(source.return_ty, target.return_ty); if self.relation.is_constraint_set_assignability() { - let source_as_paramspec = source.parameters.as_paramspec(); - let target_as_paramspec = target.parameters.as_paramspec(); + let source_paramspec = source.parameters.as_paramspec_with_prefix(); + let target_paramspec = target.parameters.as_paramspec_with_prefix(); // If either signature is a ParamSpec, the constraint set should bind the ParamSpec to // the other signature before the return-type and gradual/top fast paths can return // early. We also need to compare the return types here so a return-type mismatch still // preserves the inferred ParamSpec binding. - match (source_as_paramspec, target_as_paramspec) { - (Some(source_bound_typevar), Some(target_bound_typevar)) => { + match (source_paramspec, target_paramspec) { + // self: `P` + // other: `P` + (Some(([], source_bound_typevar)), Some(([], target_bound_typevar))) => { let param_spec_matches = ConstraintSet::constrain_typevar( db, self.constraints, @@ -1240,12 +1247,51 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return result; } - (Some(source_bound_typevar), None) => { + // self: `Concatenate[, P]` + // other: `P` + ( + Some((source_prefix_params, source_bound_typevar)), + Some(([], target_bound_typevar)), + ) => { + let lower = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + source.generic_context, + Parameters::concatenate( + db, + source_prefix_params.to_vec(), + ConcatenateTail::ParamSpec(source_bound_typevar), + ), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_prefix_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + target_bound_typevar, + lower, + Type::object(), + ); + result.intersect(db, self.constraints, param_spec_prefix_matches); + return result; + } + + // self: `P` + // other: `Concatenate[, P]` + ( + Some(([], source_bound_typevar)), + Some((target_prefix_params, target_bound_typevar)), + ) => { let upper = Type::Callable(CallableType::new( db, CallableSignature::single(Signature::new_generic( target.generic_context, - target.parameters.clone(), + Parameters::concatenate( + db, + target_prefix_params.to_vec(), + ConcatenateTail::ParamSpec(target_bound_typevar), + ), Type::unknown(), )), CallableTypeKind::ParamSpecValue, @@ -1261,7 +1307,160 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return result; } - (None, Some(target_bound_typevar)) => { + // self: `Concatenate[, P]` + // other: `Concatenate[, P]` + ( + Some((source_prefix_params, source_bound_typevar)), + Some((target_prefix_params, target_bound_typevar)), + ) => { + let mut parameters = ParametersZip { + current_source: None, + current_target: None, + source_iter: source_prefix_params.iter(), + target_iter: target_prefix_params.iter(), + }; + + // Note that in the following loop, the `Concatenate` case could come from a + // regular function signature like: + // + // ```python + // def test[**P](fn: Callable[P, None], /, x: int, *args: P.args, **kwargs: P.kwargs) -> None: ... + // ``` + // + // Here, `fn` is positional-only parameter because of the `/` while `x` is a + // positional-or-keyword parameter. + + loop { + let Some(EitherOrBoth::Both(source_param, target_param)) = + parameters.next() + else { + break; + }; + + match (source_param.kind(), target_param.kind()) { + ( + ParameterKind::PositionalOnly { + default_type: source_default, + .. + } + | ParameterKind::PositionalOrKeyword { + default_type: source_default, + .. + }, + ParameterKind::PositionalOnly { + default_type: other_default, + .. + }, + ) => { + if source_default.is_none() && other_default.is_some() { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + ( + ParameterKind::PositionalOrKeyword { + name: self_name, + default_type: source_default, + }, + ParameterKind::PositionalOrKeyword { + name: other_name, + default_type: other_default, + }, + ) => { + if self_name != other_name { + return self.never(); + } + // The following checks are the same as positional-only parameters. + if source_default.is_none() && other_default.is_some() { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + _ => return self.never(), + } + } + + let (mut source_params, mut target_params) = parameters.into_remaining(); + + // At this point, we should've exhausted at least one of the parameter lists, + // so only one side can have remaining prefix parameters. + if let Some(source_param) = source_params.next() { + let lower = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + source.generic_context, + Parameters::concatenate( + db, + std::iter::once(source_param.clone()) + .chain(source_params.cloned()) + .collect(), + ConcatenateTail::ParamSpec(source_bound_typevar), + ), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_prefix_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + target_bound_typevar, + lower, + Type::object(), + ); + result.intersect(db, self.constraints, param_spec_prefix_matches); + } else if let Some(target_param) = target_params.next() { + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + target.generic_context, + Parameters::concatenate( + db, + std::iter::once(target_param.clone()) + .chain(target_params.cloned()) + .collect(), + ConcatenateTail::ParamSpec(target_bound_typevar), + ), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_prefix_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + source_bound_typevar, + Type::Never, + upper, + ); + result.intersect(db, self.constraints, param_spec_prefix_matches); + } else { + // When the prefixes match exactly, we just relate the remaining tails. + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + source_bound_typevar, + Type::TypeVar(target_bound_typevar), + Type::TypeVar(target_bound_typevar), + ); + result.intersect(db, self.constraints, param_spec_matches); + } + return result; + } + + // self: callable without ParamSpec + // other: `P` + (None, Some(([], target_bound_typevar))) => { let lower = Type::Callable(CallableType::new( db, CallableSignature::single(Signature::new_generic( @@ -1282,6 +1481,273 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return result; } + // self: callable without ParamSpec + // other: `Concatenate[, P]` + (None, Some((target_prefix_params, target_bound_typevar))) => { + // Loop over self parameters and target_prefix_params in a similar manner to the + // above loop + let mut parameters = ParametersZip { + current_source: None, + current_target: None, + source_iter: source.parameters.iter(), + target_iter: target_prefix_params.iter(), + }; + + loop { + let Some(next_parameter) = parameters.next() else { + break; + }; + + match next_parameter { + EitherOrBoth::Left(_) => { + // If the non-Concatenate callable has remaining parameters, they + // should be bound to the `ParamSpec` in other. + break; + } + EitherOrBoth::Right(_) => { + return self.never(); + } + EitherOrBoth::Both(source_param, target_param) => { + match (source_param.kind(), target_param.kind()) { + ( + ParameterKind::PositionalOnly { + default_type: source_default, + .. + } + | ParameterKind::PositionalOrKeyword { + default_type: source_default, + .. + }, + ParameterKind::PositionalOnly { + default_type: target_default, + .. + }, + ) => { + if source_default.is_none() && target_default.is_some() { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + ( + ParameterKind::PositionalOrKeyword { + name: source_name, + default_type: source_default, + }, + ParameterKind::PositionalOrKeyword { + name: target_name, + default_type: target_default, + }, + ) => { + if source_name != target_name { + return self.never(); + } + // The following checks are the same as positional-only parameters. + if source_default.is_none() && target_default.is_some() { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + ( + ParameterKind::Variadic { .. }, + ParameterKind::PositionalOnly { .. } + | ParameterKind::PositionalOrKeyword { .. }, + ) => { + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + + loop { + let Some(target_param) = parameters.peek_target() + else { + break; + }; + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + parameters.next_target(); + } + + break; + } + + _ => return self.never(), + } + } + } + } + + let (source_params, _) = parameters.into_remaining(); + let lower = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + source.generic_context, + Parameters::new(db, source_params.cloned()), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_prefix_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + target_bound_typevar, + lower, + Type::object(), + ); + result.intersect(db, self.constraints, param_spec_prefix_matches); + + return result; + } + + // self: `P` + // other: callable without ParamSpec + (Some(([], source_bound_typevar)), None) => { + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + target.generic_context, + target.parameters.clone(), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + source_bound_typevar, + Type::Never, + upper, + ); + result.intersect(db, self.constraints, param_spec_matches); + return result; + } + + // self: `Concatenate[, P]` + // other: callable without ParamSpec + (Some((source_prefix_params, source_bound_typevar)), None) => { + let mut parameters = ParametersZip { + current_source: None, + current_target: None, + source_iter: source_prefix_params.iter(), + target_iter: target.parameters.iter(), + }; + + if target.parameters.kind() != ParametersKind::Gradual { + loop { + let Some(next_parameter) = parameters.next() else { + break; + }; + + match next_parameter { + EitherOrBoth::Left(_) => { + return self.never(); + } + EitherOrBoth::Right(_) => { + // If the non-Concatenate callable has remaining parameters, they + // should be bound to the `ParamSpec` in self. + break; + } + EitherOrBoth::Both(source_param, target_param) => { + match (source_param.kind(), target_param.kind()) { + ( + ParameterKind::PositionalOnly { + default_type: source_default, + .. + } + | ParameterKind::PositionalOrKeyword { + default_type: source_default, + .. + }, + ParameterKind::PositionalOnly { + default_type: target_default, + .. + }, + ) => { + if source_default.is_none() && target_default.is_some() + { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + ( + ParameterKind::PositionalOrKeyword { + name: source_name, + default_type: source_default, + }, + ParameterKind::PositionalOrKeyword { + name: target_name, + default_type: target_default, + }, + ) => { + if source_name != target_name { + return self.never(); + } + // The following checks are the same as positional-only parameters. + if source_default.is_none() && target_default.is_some() + { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + + _ => return self.never(), + } + } + } + } + } + + let (_, target_params) = parameters.into_remaining(); + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new_generic( + target.generic_context, + Parameters::new(db, target_params.cloned()), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_prefix_matches = ConstraintSet::constrain_typevar( + db, + self.constraints, + source_bound_typevar, + Type::Never, + upper, + ); + result.intersect(db, self.constraints, param_spec_prefix_matches); + + return result; + } + + // Both self and other are callables without ParamSpecs (None, None) => {} } } @@ -1317,6 +1783,174 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { // If either of the parameter lists is gradual (`...`), then it is assignable to and from // any other parameter list, but not a subtype or supertype of any other parameter list. if source.parameters.is_gradual() || target.parameters.is_gradual() { + match (source.parameters.kind(), target.parameters.kind()) { + // Both parameter lists are `Concatenate` with gradual forms. All prefix parameters + // are going to be positional-only. + ( + ParametersKind::Concatenate(ConcatenateTail::Gradual), + ParametersKind::Concatenate(ConcatenateTail::Gradual), + ) => { + let source_prefix_params = + &source.parameters.value[..source.parameters.len().saturating_sub(2)]; + let target_prefix_params = + &target.parameters.value[..target.parameters.len().saturating_sub(2)]; + + for (source_param, target_param) in + source_prefix_params.iter().zip(target_prefix_params.iter()) + { + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + } + + // Self is a `Concatenate` with gradual form while other is a regular non-gradual + // callable + ( + ParametersKind::Concatenate(ConcatenateTail::Gradual), + ParametersKind::Standard, + ) => { + let source_prefix_params = + &source.parameters.value[..source.parameters.len().saturating_sub(2)]; + + for param in source_prefix_params + .iter() + .zip_longest(target.parameters.iter()) + { + match param { + EitherOrBoth::Left(_) => { + // Concatenate (self) has additional positional-only parameters but + // other does not. + return self.never(); + } + EitherOrBoth::Right(_) => { + // Once the left (self) iterator is exhausted, all the remaining + // parameters in other will be consumed by the gradual form of + // `Concatenate`. + break; + } + EitherOrBoth::Both(source_param, target_param) => { + if let ( + ParameterKind::PositionalOnly { .. }, + ParameterKind::PositionalOnly { + default_type: target_default, + .. + }, + ) = (source_param.kind(), target_param.kind()) + { + // `self`'s default is always going to be `None` because it comes + // from the `Concatenate` form which cannot have default value. + if target_default.is_some() { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } else { + return self.never(); + } + } + } + } + } + + // Other is a `Concatenate` with gradual form while self is a regular non-gradual + // callable + ( + ParametersKind::Standard, + ParametersKind::Concatenate(ConcatenateTail::Gradual), + ) => { + let target_prefix_params = + &target.parameters.value[..target.parameters.len().saturating_sub(2)]; + + let mut parameters = ParametersZip { + current_source: None, + current_target: None, + source_iter: source.parameters.iter(), + target_iter: target_prefix_params.iter(), + }; + + loop { + let Some(parameter) = parameters.next() else { + break; + }; + + match parameter { + EitherOrBoth::Left(_) => { + // Once the right (other) iterator is exhausted, all the remaining + // parameters in self will be consumed by the gradual form of + // `Concatenate`. + break; + } + EitherOrBoth::Right(_) => { + // Concatenate (other) has additional positional-only parameters but + // self does not. + return self.never(); + } + EitherOrBoth::Both(source_param, target_param) => { + match source_param.kind() { + ParameterKind::PositionalOnly { + default_type: source_default, + .. + } + | ParameterKind::PositionalOrKeyword { + default_type: source_default, + .. + } => { + if source_default.is_none() + && target_param.default_type().is_some() + { + return self.never(); + } + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + } + ParameterKind::Variadic { .. } => { + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + + loop { + let Some(target_param) = parameters.peek_target() + else { + break; + }; + if !check_types( + target_param.annotated_type(), + source_param.annotated_type(), + ) { + return result; + } + parameters.next_target(); + } + } + _ => { + // self has other parameter kinds but other only has + // positional-only parameters, so they cannot be compatible. + return self.never(); + } + } + } + } + } + } + + _ => {} + } + return match self.relation { TypeRelation::Subtyping | TypeRelation::SubtypingAssuming => self.never(), TypeRelation::Redundancy { .. } => result.intersect( @@ -1626,9 +2260,17 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { } } -// TODO: the spec also allows signatures like `Concatenate[int, ...]` or `Concatenate[int, P]`, -// which have some number of required positional-only parameters followed by a gradual form or a -// `ParamSpec`. Our representation will need some adjustments to represent that. +/// The tail of a `Concatenate[T1, T2, Tn, tail]` form. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] +pub(crate) enum ConcatenateTail<'db> { + /// Represents the `Concatenate[T1, T2, Tn, ...]` form where the prefix parameters are followed + /// by a gradual `*args: Any, **kwargs: Any`. + Gradual, + + /// Represents the `Concatenate[T1, T2, Tn, P]` form where the prefix parameters are followed by + /// a `ParamSpec` type variable. + ParamSpec(BoundTypeVarInstance<'db>), +} /// The kind of parameter list represented. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] @@ -1655,13 +2297,18 @@ pub(crate) enum ParametersKind<'db> { /// union of all possible parameter signatures. Top, - /// Represents a parameter list containing a `ParamSpec` as the only parameter. + /// Represents a parameter list containing a `ParamSpec` as the _only_ parameter. /// /// Note that this is distinct from a parameter list _containing_ a `ParamSpec` which is - /// considered a standard parameter list that just contains a `ParamSpec`. - // TODO: Maybe we should use `find_paramspec_from_args_kwargs` instead of storing the typevar - // here? + /// represented using the `Concatenate` variant. ParamSpec(BoundTypeVarInstance<'db>), + + /// Represents a parameter list containing positional-only parameters followed by either a + /// gradual form (`...`) or a `ParamSpec`. + /// + /// This is used to represent the parameter list of a `Concatenate[T1, T2, Tn, ...]` and + /// `Concatenate[T1, T2, Tn, P]` form. + Concatenate(ConcatenateTail<'db>), } #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] @@ -1674,44 +2321,91 @@ pub(crate) struct Parameters<'db> { impl<'db> Parameters<'db> { /// Create a new parameter list from an iterator of parameters. /// - /// The kind of the parameter list is determined based on the provided parameters. - /// Specifically, if the parameters is made up of `*args` and `**kwargs` only, it checks - /// their annotated types to determine if they represent a gradual form or a `ParamSpec`. + /// The kind of the parameter list is determined based on the provided parameters. Specifically, + /// if the parameter list contains `*args` and `**kwargs`, then it checks their annotated types + /// and the presence of other parameter kinds to determine if they represent a gradual form, a + /// `ParamSpec`, or a `Concatenate` form. pub(crate) fn new( db: &'db dyn Db, parameters: impl IntoIterator>, ) -> Self { - fn new_impl<'db>(db: &'db dyn Db, value: Vec>) -> Parameters<'db> { - let mut kind = ParametersKind::Standard; - if let [p1, p2] = value.as_slice() - && p1.is_variadic() - && p2.is_keyword_variadic() - { - match (p1.annotated_type(), p2.annotated_type()) { - (Type::Dynamic(_), Type::Dynamic(_)) => { + let value: Vec> = parameters.into_iter().collect(); + let mut kind = ParametersKind::Standard; + + let variadic_param = value + .iter() + .find_position(|param| param.is_variadic()) + .map(|(index, param)| (index, param.annotated_type)); + let keyword_variadic_param = value + .iter() + .find_position(|param| param.is_keyword_variadic()) + .map(|(index, param)| (index, param.annotated_type)); + + if let ( + Some((variadic_index, variadic_type)), + Some((keyword_variadic_index, keyword_variadic_type)), + ) = (variadic_param, keyword_variadic_param) + { + let prefix_params = value.get(..variadic_index).unwrap_or(&[]); + let keyword_only_params = value + .get(variadic_index + 1..keyword_variadic_index) + .unwrap_or(&[]); + + match (variadic_type, keyword_variadic_type) { + // > If the input signature in a function definition includes both a `*args` and + // > `**kwargs` parameter and both are typed as Any (explicitly or implicitly + // > because it has no annotation), a type checker should treat this as the + // > equivalent of `...`. Any other parameters in the signature are unaffected and + // > are retained as part of the signature. + // + // https://typing.python.org/en/latest/spec/callables.html#meaning-of-in-callable + (Type::Dynamic(_), Type::Dynamic(_)) => { + if keyword_only_params.is_empty() + && !prefix_params.is_empty() + && prefix_params.iter().all(Parameter::is_positional_only) + { + kind = ParametersKind::Concatenate(ConcatenateTail::Gradual); + } else { kind = ParametersKind::Gradual; } - (Type::TypeVar(args_typevar), Type::TypeVar(kwargs_typevar)) => { - if let (Some(ParamSpecAttrKind::Args), Some(ParamSpecAttrKind::Kwargs)) = ( - args_typevar.paramspec_attr(db), - kwargs_typevar.paramspec_attr(db), + } + + // > A function declared as + // > `def inner(a: A, b: B, *args: P.args, **kwargs: P.kwargs) -> R` + // > has type `Callable[Concatenate[A, B, P], R]`. Placing keyword-only parameters + // > between the `*args` and `**kwargs` is forbidden. + // + // https://typing.python.org/en/latest/spec/generics.html#id5 + (Type::TypeVar(variadic_typevar), Type::TypeVar(keyword_variadic_typevar)) + if keyword_only_params.is_empty() => + { + if let (Some(ParamSpecAttrKind::Args), Some(ParamSpecAttrKind::Kwargs)) = ( + variadic_typevar.paramspec_attr(db), + keyword_variadic_typevar.paramspec_attr(db), + ) { + let typevar = variadic_typevar.without_paramspec_attr(db); + if typevar.is_same_typevar_as( + db, + keyword_variadic_typevar.without_paramspec_attr(db), ) { - let typevar = args_typevar.without_paramspec_attr(db); - if typevar - .is_same_typevar_as(db, kwargs_typevar.without_paramspec_attr(db)) - { + if prefix_params.is_empty() { kind = ParametersKind::ParamSpec(typevar); + } else if prefix_params.iter().all(Parameter::is_positional) { + // TODO: Currently, we accept both positional-only and + // positional-or-keyword parameter but we should raise a warning to + // let users know that these parameters should be positional-only + kind = ParametersKind::Concatenate(ConcatenateTail::ParamSpec( + typevar, + )); } } } - _ => {} } + _ => {} } - Parameters { value, kind } } - let value: Vec> = parameters.into_iter().collect(); - new_impl(db, value) + Parameters { value, kind } } /// Create an empty parameter list. @@ -1730,14 +2424,23 @@ impl<'db> Parameters<'db> { self.kind } + /// Returns `true` if the parameters represent a gradual form using `...` as the only parameter + /// or a `Concatenate` form with `...` as the last argument. pub(crate) const fn is_gradual(&self) -> bool { - matches!(self.kind, ParametersKind::Gradual) + matches!( + self.kind, + ParametersKind::Gradual | ParametersKind::Concatenate(ConcatenateTail::Gradual) + ) } pub(crate) const fn is_top(&self) -> bool { matches!(self.kind, ParametersKind::Top) } + /// Returns the bound `ParamSpec` type variable if the entire parameter list is exactly `P`. + /// + /// For either `P` or `Concatenate[, P]`, use + /// [`Self::as_paramspec_with_prefix`]. pub(crate) const fn as_paramspec(&self) -> Option> { match self.kind { ParametersKind::ParamSpec(bound_typevar) => Some(bound_typevar), @@ -1745,6 +2448,22 @@ impl<'db> Parameters<'db> { } } + /// Returns the prefix parameters and bound `ParamSpec` if this parameter list is either `P` or + /// `Concatenate[, P]`. + /// + /// For the narrower bare-`P` case, use [`Self::as_paramspec`]. + pub(crate) fn as_paramspec_with_prefix<'a>( + &'a self, + ) -> Option<(&'a [Parameter<'db>], BoundTypeVarInstance<'db>)> { + match self.kind { + ParametersKind::ParamSpec(typevar) => Some((&[], typevar)), + ParametersKind::Concatenate(ConcatenateTail::ParamSpec(typevar)) => { + Some((&self.value[..self.value.len().saturating_sub(2)], typevar)) + } + _ => None, + } + } + /// Return todo parameters: (*args: Todo, **kwargs: Todo) pub(crate) fn todo() -> Self { Self { @@ -1789,6 +2508,38 @@ impl<'db> Parameters<'db> { } } + /// Create a parameter list representing a `Concatenate` form with the given prefix parameters + /// and the tail (either gradual or a `ParamSpec`). + /// + /// Internally, this is represented as either: + /// - `(, /, *args: Any, **kwargs: Any)` for the gradual form, or + /// - `(, /, *args: P.args, **kwargs: P.kwargs)` for the `ParamSpec` form. + pub(crate) fn concatenate( + db: &'db dyn Db, + mut prefix_params: Vec>, + concatenate_tail: ConcatenateTail<'db>, + ) -> Self { + let (args_type, kwargs_type) = match concatenate_tail { + ConcatenateTail::Gradual => ( + Type::Dynamic(DynamicType::Any), + Type::Dynamic(DynamicType::Any), + ), + ConcatenateTail::ParamSpec(typevar) => ( + Type::TypeVar(typevar.with_paramspec_attr(db, ParamSpecAttrKind::Args)), + Type::TypeVar(typevar.with_paramspec_attr(db, ParamSpecAttrKind::Kwargs)), + ), + }; + prefix_params.extend([ + Parameter::variadic(Name::new_static("args")).with_annotated_type(args_type), + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(kwargs_type), + ]); + Self { + value: prefix_params, + kind: ParametersKind::Concatenate(concatenate_tail), + } + } + /// Return parameters that represents an unknown list of parameters. /// /// Internally, this is represented as `(*Unknown, **Unknown)` that accepts parameters of type @@ -1839,44 +2590,6 @@ impl<'db> Parameters<'db> { } } - /// Returns the bound `ParamSpec` type variable if the parameters contain a `ParamSpec`. - pub(crate) fn find_paramspec_from_args_kwargs<'a>( - &'a self, - db: &'db dyn Db, - ) -> Option<(&'a [Parameter<'db>], BoundTypeVarInstance<'db>)> { - let [prefix @ .., maybe_args, maybe_kwargs] = self.value.as_slice() else { - return None; - }; - - if !maybe_args.is_variadic() || !maybe_kwargs.is_keyword_variadic() { - return None; - } - - let (Type::TypeVar(args_typevar), Type::TypeVar(kwargs_typevar)) = - (maybe_args.annotated_type(), maybe_kwargs.annotated_type()) - else { - return None; - }; - - if matches!( - ( - args_typevar.paramspec_attr(db), - kwargs_typevar.paramspec_attr(db) - ), - ( - Some(ParamSpecAttrKind::Args), - Some(ParamSpecAttrKind::Kwargs) - ) - ) { - let typevar = args_typevar.without_paramspec_attr(db); - if typevar.is_same_typevar_as(db, kwargs_typevar.without_paramspec_attr(db)) { - return Some((prefix, typevar)); - } - } - - None - } - fn from_parameters( db: &'db dyn Db, definition: Definition<'db>, @@ -2001,7 +2714,10 @@ impl<'db> Parameters<'db> { visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { if let TypeMapping::Materialize(materialization_kind) = type_mapping - && self.kind == ParametersKind::Gradual + && matches!( + self.kind, + ParametersKind::Gradual | ParametersKind::Concatenate(ConcatenateTail::Gradual) + ) { match materialization_kind { MaterializationKind::Bottom => { diff --git a/crates/ty_python_semantic/src/types/special_form.rs b/crates/ty_python_semantic/src/types/special_form.rs index 0063a2ce91608..7a70efb211ad2 100644 --- a/crates/ty_python_semantic/src/types/special_form.rs +++ b/crates/ty_python_semantic/src/types/special_form.rs @@ -728,13 +728,18 @@ impl SpecialFormType { fallback_type: Type::unknown(), }), - Self::Annotated | Self::Concatenate => Err(InvalidTypeExpressionError { + Self::Annotated => Err(InvalidTypeExpressionError { invalid_expressions: smallvec::smallvec_inline![ InvalidTypeExpression::RequiresTwoArguments(self) ], fallback_type: Type::unknown(), }), + Self::Concatenate => Err(InvalidTypeExpressionError { + invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Concatenate], + fallback_type: Type::unknown(), + }), + // We treat `typing.Type` exactly the same as `builtins.type`: SpecialFormType::Type => Ok(KnownClass::Type.to_instance(db)), SpecialFormType::Tuple => Ok(Type::homogeneous_tuple(db, Type::unknown())),