|
3 | 3 | import typing |
4 | 4 | import warnings |
5 | 5 | from typing import Any |
| 6 | +from typing import Callable |
6 | 7 | from typing import Iterable |
7 | 8 | from typing import List |
8 | 9 | from typing import Mapping |
|
11 | 12 | from typing import Sequence |
12 | 13 | from typing import Set |
13 | 14 | from typing import Tuple |
| 15 | +from typing import TypeVar |
14 | 16 | from typing import Union |
15 | 17 |
|
16 | 18 | import attr |
|
19 | 21 | from ..compat import ascii_escaped |
20 | 22 | from ..compat import NOTSET |
21 | 23 | from ..compat import NotSetType |
| 24 | +from ..compat import overload |
22 | 25 | from ..compat import TYPE_CHECKING |
23 | 26 | from _pytest.config import Config |
24 | 27 | from _pytest.outcomes import fail |
@@ -240,6 +243,12 @@ def combined_with(self, other: "Mark") -> "Mark": |
240 | 243 | ) |
241 | 244 |
|
242 | 245 |
|
| 246 | +# A generic parameter designating an object to which a Mark may |
| 247 | +# be applied -- a test function (callable) or class. |
| 248 | +# Note: a lambda is not allowed, but this can't be represented. |
| 249 | +_Markable = TypeVar("_Markable", bound=Union[Callable[..., object], type]) |
| 250 | + |
| 251 | + |
243 | 252 | @attr.s |
244 | 253 | class MarkDecorator: |
245 | 254 | """A decorator for applying a mark on test functions and classes. |
@@ -311,7 +320,20 @@ def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator": |
311 | 320 | mark = Mark(self.name, args, kwargs) |
312 | 321 | return self.__class__(self.mark.combined_with(mark)) |
313 | 322 |
|
314 | | - def __call__(self, *args: object, **kwargs: object): |
| 323 | + # Type ignored because the overloads overlap with an incompatible |
| 324 | + # return type. Not much we can do about that. Thankfully mypy picks |
| 325 | + # the first match so it works out even if we break the rules. |
| 326 | + @overload |
| 327 | + def __call__(self, arg: _Markable) -> _Markable: # type: ignore[misc] # noqa: F821 |
| 328 | + raise NotImplementedError() |
| 329 | + |
| 330 | + @overload # noqa: F811 |
| 331 | + def __call__( # noqa: F811 |
| 332 | + self, *args: object, **kwargs: object |
| 333 | + ) -> "MarkDecorator": |
| 334 | + raise NotImplementedError() |
| 335 | + |
| 336 | + def __call__(self, *args: object, **kwargs: object): # noqa: F811 |
315 | 337 | """Call the MarkDecorator.""" |
316 | 338 | if args and not kwargs: |
317 | 339 | func = args[0] |
|
0 commit comments