diff --git a/pytest_mock.py b/pytest_mock.py index 295171d..b9afca9 100644 --- a/pytest_mock.py +++ b/pytest_mock.py @@ -137,3 +137,82 @@ def mock(mocker): warnings.warn('"mock" fixture has been deprecated, use "mocker" instead', DeprecationWarning) return mocker + + +_mock_module_patches = [] +_mock_module_originals = {} + + +def assert_wrapper(method, *args, **kwargs): + __tracebackhide__ = True + try: + method(*args, **kwargs) + except AssertionError as e: + raise AssertionError(*e.args) + + +def wrap_assert_not_called(*args, **kwargs): + __tracebackhide__ = True + assert_wrapper(_mock_module_originals["assert_not_called"], + *args, **kwargs) + + +def wrap_assert_called_with(*args, **kwargs): + __tracebackhide__ = True + assert_wrapper(_mock_module_originals["assert_called_with"], + *args, **kwargs) + + +def wrap_assert_called_once_with(*args, **kwargs): + __tracebackhide__ = True + assert_wrapper(_mock_module_originals["assert_called_once_with"], + *args, **kwargs) + + +def wrap_assert_has_calls(*args, **kwargs): + __tracebackhide__ = True + assert_wrapper(_mock_module_originals["assert_has_calls"], + *args, **kwargs) + + +def wrap_assert_any_call(*args, **kwargs): + __tracebackhide__ = True + assert_wrapper(_mock_module_originals["assert_any_call"], + *args, **kwargs) + + +def wrap_assert_methods(): + """ + Wrap assert methods of mock module so we can hide their traceback + """ + wrappers = { + 'assert_not_called': wrap_assert_not_called, + 'assert_called_with': wrap_assert_called_with, + 'assert_called_once_with': wrap_assert_called_once_with, + 'assert_has_calls': wrap_assert_has_calls, + 'assert_any_call': wrap_assert_any_call, + } + for method, wrapper in wrappers.items(): + try: + original = getattr(mock_module.NonCallableMock, method) + except AttributeError: + continue + _mock_module_originals[method] = original + patcher = mock_module.patch.object( + mock_module.NonCallableMock, method, wrapper) + patcher.start() + _mock_module_patches.append(patcher) + + +def unwrap_assert_methods(): + for patcher in _mock_module_patches: + patcher.stop() + _mock_module_patches[:] = [] + + +def pytest_configure(config): + wrap_assert_methods() + + +def pytest_unconfigure(config): + unwrap_assert_methods() diff --git a/test_pytest_mock.py b/test_pytest_mock.py index 23868b8..77581c7 100644 --- a/test_pytest_mock.py +++ b/test_pytest_mock.py @@ -1,6 +1,9 @@ import os import platform +import sys +from contextlib import contextmanager +import py.code import pytest @@ -246,3 +249,69 @@ def bar(arg): assert Foo.bar(arg=10) == 20 Foo.bar.assert_called_once_with(arg=10) spy.assert_called_once_with(arg=10) + + +@contextmanager +def assert_traceback(): + """ + Assert that this file is at the top of the filtered traceback + """ + try: + yield + except AssertionError: + traceback = py.code.ExceptionInfo().traceback + crashentry = traceback.getcrashentry() + assert crashentry.path == __file__ + else: + raise AssertionError("DID NOT RAISE") + + +@pytest.mark.skipif(sys.version_info >= (3, 4) and sys.version_info < (3, 5), + reason="assert_not_called not available in python 3.4") +def test_assert_not_called_wrapper(mocker): + stub = mocker.stub() + stub.assert_not_called() + stub() + with assert_traceback(): + stub.assert_not_called() + + +def test_assert_called_with_wrapper(mocker): + stub = mocker.stub() + stub("foo") + stub.assert_called_with("foo") + with assert_traceback(): + stub.assert_called_with("bar") + + +def test_assert_called_once_with_wrapper(mocker): + stub = mocker.stub() + stub("foo") + stub.assert_called_once_with("foo") + stub("foo") + with assert_traceback(): + stub.assert_called_once_with("foo") + + +def test_assert_any_call_wrapper(mocker): + stub = mocker.stub() + stub("foo") + stub("foo") + stub.assert_any_call("foo") + with assert_traceback(): + stub.assert_any_call("bar") + + +def test_assert_has_calls(mocker): + from pytest_mock import mock_module + stub = mocker.stub() + stub("foo") + stub.assert_has_calls([mock_module.call("foo")]) + with assert_traceback(): + stub.assert_has_calls([mock_module.call("bar")]) + + +def test_dirty(mocker): + stub = mocker.stub() + stub.testa() + assert False