Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions dirty_equals/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
from abc import ABCMeta
from pprint import PrettyPrinter
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Optional, Protocol, Tuple, TypeVar

from ._utils import Omit
Expand Down Expand Up @@ -131,6 +133,26 @@ def __repr__(self) -> str:
# else return something which explains what's going on.
return self._repr_ne()

def _pprint_format(self, pprinter: PrettyPrinter, stream: io.StringIO, *args: Any, **kwargs: Any) -> None:
# pytest diffs use pprint to format objects, so we patch pprint to call this method
# for DirtyEquals objects. So this method needs to follow the same pattern as __repr__.
# We check that the protected _format method actually exists
# to be safe and to make linters happy.
if self._was_equal and hasattr(pprinter, '_format'):
pprinter._format(self._other, stream, *args, **kwargs)
else:
stream.write(repr(self)) # i.e. self._repr_ne() (for now)


# Patch pprint to call _pprint_format for DirtyEquals objects
# Check that the protected attribute _dispatch exists to be safe and to make linters happy.
# The reason we modify _dispatch rather than _format
# is that pytest sometimes uses a subclass of PrettyPrinter which overrides _format.
if hasattr(PrettyPrinter, '_dispatch'): # pragma: no branch
PrettyPrinter._dispatch[DirtyEquals.__repr__] = lambda pprinter, obj, *args, **kwargs: obj._pprint_format(
pprinter, *args, **kwargs
)


InstanceOrType: 'TypeAlias' = 'Union[DirtyEquals[Any], DirtyEqualsMeta]'

Expand Down
66 changes: 58 additions & 8 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import platform
import pprint

import packaging.version
import pytest

from dirty_equals import Contains, IsApprox, IsInt, IsNegative, IsOneOf, IsPositive, IsStr
from dirty_equals import Contains, IsApprox, IsInt, IsList, IsNegative, IsOneOf, IsPositive, IsStr
from dirty_equals.version import VERSION


Expand Down Expand Up @@ -39,8 +40,7 @@ def test_value_eq():
v.value

assert 'foo' == v
assert str(v) == "'foo'"
assert repr(v) == "'foo'"
assert repr(v) == str(v) == "'foo'" == pprint.pformat(v)
assert v.value == 'foo'


Expand All @@ -50,8 +50,7 @@ def test_value_ne():
with pytest.raises(AssertionError):
assert 1 == v

assert str(v) == 'IsStr()'
assert repr(v) == 'IsStr()'
assert repr(v) == str(v) == 'IsStr()' == pprint.pformat(v)
with pytest.raises(AttributeError, match='value is not available until __eq__ has been called'):
v.value

Expand Down Expand Up @@ -110,7 +109,7 @@ def test_repr():
],
)
def test_repr_class(v, v_repr):
assert repr(v) == v_repr
assert repr(v) == str(v) == v_repr == pprint.pformat(v)


def test_is_approx_without_init():
Expand All @@ -119,11 +118,62 @@ def test_is_approx_without_init():

def test_ne_repr():
v = IsInt
assert repr(v) == 'IsInt'
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)

assert 'x' != v

assert repr(v) == 'IsInt'
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)


def test_pprint():
v = [IsList(length=...), 1, [IsList(length=...), 2], 3, IsInt()]
lorem = ['lorem', 'ipsum', 'dolor', 'sit', 'amet'] * 2
with pytest.raises(AssertionError):
assert [lorem, 1, [lorem, 2], 3, '4'] == v

assert repr(v) == (f'[{lorem}, 1, [{lorem}, 2], 3, IsInt()]')
assert pprint.pformat(v) == (
"[['lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet',\n"
" 'lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet'],\n"
' 1,\n'
" [['lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet',\n"
" 'lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet'],\n"
' 2],\n'
' 3,\n'
' IsInt()]'
)


def test_pprint_not_equal():
v = IsList(*range(30)) # need a big value to trigger pprint
with pytest.raises(AssertionError):
assert [] == v

assert (
pprint.pformat(v)
== (
'IsList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, '
'15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)'
)
== repr(v)
== str(v)
)


@pytest.mark.parametrize(
Expand Down