Skip to content

Commit

Permalink
feat: implement __eq__ between BaseFlags and flag_values (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sharp-Eyes authored Nov 29, 2024
1 parent bb65a60 commit b92f382
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/1238.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for ``BaseFlags`` to allow comparison with ``flag_values`` and vice versa.
16 changes: 15 additions & 1 deletion disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __init__(self, func: Callable[[Any], int]) -> None:
self.__doc__ = func.__doc__
self._parent: Type[T] = MISSING

def __eq__(self, other: Any) -> bool:
if isinstance(other, flag_value):
return self.flag == other.flag
if isinstance(other, BaseFlags):
return self._parent is other.__class__ and self.flag == other.value
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def __or__(self, other: Union[flag_value[T], T]) -> T:
if isinstance(other, BaseFlags):
if self._parent is not other.__class__:
Expand Down Expand Up @@ -148,7 +158,11 @@ def _from_value(cls, value: int) -> Self:
return self

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.value == other.value
if isinstance(other, self.__class__):
return self.value == other.value
if isinstance(other, flag_value):
return self.__class__ is other._parent and self.value == other.flag
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ def test__eq__(self) -> None:
assert not ins == other
assert ins != other

def test__eq__flag_value(self) -> None:
ins = TestFlags(one=True)
other = TestFlags(one=True, two=True)

assert ins == TestFlags.one
assert TestFlags.one == ins

assert not ins != TestFlags.one
assert ins != TestFlags.two

assert other != TestFlags.one
assert other != TestFlags.two

assert other == TestFlags.three

def test__and__(self) -> None:
ins = TestFlags(one=True, two=True)
other = TestFlags(one=True, two=True)
Expand Down

0 comments on commit b92f382

Please sign in to comment.