Skip to content

Commit 848ab00

Browse files
committed
Type annotate @pytest.mark.foo
1 parent c0af19d commit 848ab00

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/_pytest/mark/structures.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44
import warnings
55
from typing import Any
6+
from typing import Callable
67
from typing import Iterable
78
from typing import List
89
from typing import Mapping
@@ -11,6 +12,7 @@
1112
from typing import Sequence
1213
from typing import Set
1314
from typing import Tuple
15+
from typing import TypeVar
1416
from typing import Union
1517

1618
import attr
@@ -19,6 +21,7 @@
1921
from ..compat import ascii_escaped
2022
from ..compat import NOTSET
2123
from ..compat import NotSetType
24+
from ..compat import overload
2225
from ..compat import TYPE_CHECKING
2326
from _pytest.config import Config
2427
from _pytest.outcomes import fail
@@ -240,6 +243,12 @@ def combined_with(self, other: "Mark") -> "Mark":
240243
)
241244

242245

246+
# A generic parameter designating an object to which a Mark may
247+
# be applied -- a test function (callable) or class.
248+
# Note: a lambda is not allowed, but this can't be represented.
249+
_Markable = TypeVar("_Markable", bound=Union[Callable[..., object], type])
250+
251+
243252
@attr.s
244253
class MarkDecorator:
245254
"""A decorator for applying a mark on test functions and classes.
@@ -311,7 +320,20 @@ def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator":
311320
mark = Mark(self.name, args, kwargs)
312321
return self.__class__(self.mark.combined_with(mark))
313322

314-
def __call__(self, *args: object, **kwargs: object):
323+
# Type ignored because the overloads overlap with an incompatible
324+
# return type. Not much we can do about that. Thankfully mypy picks
325+
# the first match so it works out even if we break the rules.
326+
@overload
327+
def __call__(self, arg: _Markable) -> _Markable: # type: ignore[misc] # noqa: F821
328+
raise NotImplementedError()
329+
330+
@overload # noqa: F811
331+
def __call__( # noqa: F811
332+
self, *args: object, **kwargs: object
333+
) -> "MarkDecorator":
334+
raise NotImplementedError()
335+
336+
def __call__(self, *args: object, **kwargs: object): # noqa: F811
315337
"""Call the MarkDecorator."""
316338
if args and not kwargs:
317339
func = args[0]

0 commit comments

Comments
 (0)