diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 09d1679..14a5379 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ Releases ======== +UNRELEASED +------------------- + +* Updated type annotations for ``mocker.patch`` and ``mocker.spy`` (`#364`_). + +.. _#364: https://github.com/pytest-dev/pytest-mock/pull/364 + 3.10.0 (2022-10-05) ------------------- diff --git a/src/pytest_mock/plugin.py b/src/pytest_mock/plugin.py index 1d52555..6699f4b 100644 --- a/src/pytest_mock/plugin.py +++ b/src/pytest_mock/plugin.py @@ -27,10 +27,16 @@ _T = TypeVar("_T") -if sys.version_info[:2] > (3, 7): +if sys.version_info >= (3, 8): AsyncMockType = unittest.mock.AsyncMock + MockType = Union[ + unittest.mock.MagicMock, + unittest.mock.AsyncMock, + unittest.mock.NonCallableMagicMock, + ] else: AsyncMockType = Any + MockType = Union[unittest.mock.MagicMock, unittest.mock.NonCallableMagicMock] class PytestMockWarning(UserWarning): @@ -112,7 +118,7 @@ def stop(self, mock: unittest.mock.MagicMock) -> None: else: raise ValueError("This mock object is not registered") - def spy(self, obj: object, name: str) -> unittest.mock.MagicMock: + def spy(self, obj: object, name: str) -> MockType: """ Create a spy of method. It will run method normally, but it is now possible to use `mock` call features with it, like call count. @@ -205,13 +211,13 @@ def __init__(self, patches_and_mocks, mock_module): def _start_patch( self, mock_func: Any, warn_on_mock_enter: bool, *args: Any, **kwargs: Any - ) -> unittest.mock.MagicMock: + ) -> MockType: """Patches something by calling the given function from the mock module, registering the patch to stop it later and returns the mock object resulting from the mock call. """ p = mock_func(*args, **kwargs) - mocked = p.start() # type: unittest.mock.MagicMock + mocked: MockType = p.start() self.__patches_and_mocks.append((p, mocked)) if hasattr(mocked, "reset_mock"): # check if `mocked` is actually a mock object, as depending on autospec or target @@ -242,7 +248,7 @@ def object( autospec: Optional[object] = None, new_callable: object = None, **kwargs: Any - ) -> unittest.mock.MagicMock: + ) -> MockType: """API to mock.patch.object""" if new is self.DEFAULT: new = self.mock_module.DEFAULT @@ -271,7 +277,7 @@ def context_manager( autospec: Optional[builtins.object] = None, new_callable: builtins.object = None, **kwargs: Any - ) -> unittest.mock.MagicMock: + ) -> MockType: """This is equivalent to mock.patch.object except that the returned mock does not issue a warning when used as a context manager.""" if new is self.DEFAULT: @@ -299,7 +305,7 @@ def multiple( autospec: Optional[builtins.object] = None, new_callable: Optional[builtins.object] = None, **kwargs: Any - ) -> Dict[str, unittest.mock.MagicMock]: + ) -> Dict[str, MockType]: """API to mock.patch.multiple""" return self._start_patch( self.mock_module.patch.multiple, @@ -341,7 +347,7 @@ def __call__( autospec: Optional[builtins.object] = ..., new_callable: None = ..., **kwargs: Any - ) -> unittest.mock.MagicMock: + ) -> MockType: ... @overload