diff --git a/dirty_equals/_base.py b/dirty_equals/_base.py index 251f216..df7d1ec 100644 --- a/dirty_equals/_base.py +++ b/dirty_equals/_base.py @@ -1,4 +1,6 @@ +import io from abc import ABCMeta +from pprint import PrettyPrinter from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Optional, Protocol, Tuple, TypeVar from ._utils import Omit @@ -131,6 +133,26 @@ def __repr__(self) -> str: # else return something which explains what's going on. return self._repr_ne() + def _pprint_format(self, pprinter: PrettyPrinter, stream: io.StringIO, *args: Any, **kwargs: Any) -> None: + # pytest diffs use pprint to format objects, so we patch pprint to call this method + # for DirtyEquals objects. So this method needs to follow the same pattern as __repr__. + # We check that the protected _format method actually exists + # to be safe and to make linters happy. + if self._was_equal and hasattr(pprinter, '_format'): + pprinter._format(self._other, stream, *args, **kwargs) + else: + stream.write(repr(self)) # i.e. self._repr_ne() (for now) + + +# Patch pprint to call _pprint_format for DirtyEquals objects +# Check that the protected attribute _dispatch exists to be safe and to make linters happy. +# The reason we modify _dispatch rather than _format +# is that pytest sometimes uses a subclass of PrettyPrinter which overrides _format. +if hasattr(PrettyPrinter, '_dispatch'): # pragma: no branch + PrettyPrinter._dispatch[DirtyEquals.__repr__] = lambda pprinter, obj, *args, **kwargs: obj._pprint_format( + pprinter, *args, **kwargs + ) + InstanceOrType: 'TypeAlias' = 'Union[DirtyEquals[Any], DirtyEqualsMeta]' diff --git a/tests/test_base.py b/tests/test_base.py index 3fc519d..7545d2f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,9 +1,10 @@ import platform +import pprint import packaging.version import pytest -from dirty_equals import Contains, IsApprox, IsInt, IsNegative, IsOneOf, IsPositive, IsStr +from dirty_equals import Contains, IsApprox, IsInt, IsList, IsNegative, IsOneOf, IsPositive, IsStr from dirty_equals.version import VERSION @@ -39,8 +40,7 @@ def test_value_eq(): v.value assert 'foo' == v - assert str(v) == "'foo'" - assert repr(v) == "'foo'" + assert repr(v) == str(v) == "'foo'" == pprint.pformat(v) assert v.value == 'foo' @@ -50,8 +50,7 @@ def test_value_ne(): with pytest.raises(AssertionError): assert 1 == v - assert str(v) == 'IsStr()' - assert repr(v) == 'IsStr()' + assert repr(v) == str(v) == 'IsStr()' == pprint.pformat(v) with pytest.raises(AttributeError, match='value is not available until __eq__ has been called'): v.value @@ -110,7 +109,7 @@ def test_repr(): ], ) def test_repr_class(v, v_repr): - assert repr(v) == v_repr + assert repr(v) == str(v) == v_repr == pprint.pformat(v) def test_is_approx_without_init(): @@ -119,11 +118,62 @@ def test_is_approx_without_init(): def test_ne_repr(): v = IsInt - assert repr(v) == 'IsInt' + assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v) assert 'x' != v - assert repr(v) == 'IsInt' + assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v) + + +def test_pprint(): + v = [IsList(length=...), 1, [IsList(length=...), 2], 3, IsInt()] + lorem = ['lorem', 'ipsum', 'dolor', 'sit', 'amet'] * 2 + with pytest.raises(AssertionError): + assert [lorem, 1, [lorem, 2], 3, '4'] == v + + assert repr(v) == (f'[{lorem}, 1, [{lorem}, 2], 3, IsInt()]') + assert pprint.pformat(v) == ( + "[['lorem',\n" + " 'ipsum',\n" + " 'dolor',\n" + " 'sit',\n" + " 'amet',\n" + " 'lorem',\n" + " 'ipsum',\n" + " 'dolor',\n" + " 'sit',\n" + " 'amet'],\n" + ' 1,\n' + " [['lorem',\n" + " 'ipsum',\n" + " 'dolor',\n" + " 'sit',\n" + " 'amet',\n" + " 'lorem',\n" + " 'ipsum',\n" + " 'dolor',\n" + " 'sit',\n" + " 'amet'],\n" + ' 2],\n' + ' 3,\n' + ' IsInt()]' + ) + + +def test_pprint_not_equal(): + v = IsList(*range(30)) # need a big value to trigger pprint + with pytest.raises(AssertionError): + assert [] == v + + assert ( + pprint.pformat(v) + == ( + 'IsList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ' + '15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)' + ) + == repr(v) + == str(v) + ) @pytest.mark.parametrize(