From 32a04eba891170eb7cdd79d5f5923052fb21eed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 6 Sep 2022 16:51:23 +0200 Subject: [PATCH] Integrate `StrEnum` (#38) --- src/lightning_utilities/core/enums.py | 24 ++++++++++++++++ tests/unittests/core/test_enums.py | 41 +++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/lightning_utilities/core/enums.py create mode 100644 tests/unittests/core/test_enums.py diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py new file mode 100644 index 00000000..7da30bc1 --- /dev/null +++ b/src/lightning_utilities/core/enums.py @@ -0,0 +1,24 @@ +from enum import Enum +from typing import Optional + + +class StrEnum(str, Enum): + """Type of any enumerator with allowed comparison to string invariant to cases.""" + + @classmethod + def from_str(cls, value: str) -> Optional["StrEnum"]: + statuses = cls.__members__.keys() + for st in statuses: + if st.lower() == value.lower(): + return cls[st] + return None + + def __eq__(self, other: object) -> bool: + if isinstance(other, Enum): + other = other.value + return self.value.lower() == str(other).lower() + + def __hash__(self) -> int: + # re-enable hashtable so it can be used as a dict key or in a set + # example: set(LightningEnum) + return hash(self.value.lower()) diff --git a/tests/unittests/core/test_enums.py b/tests/unittests/core/test_enums.py new file mode 100644 index 00000000..499fe13e --- /dev/null +++ b/tests/unittests/core/test_enums.py @@ -0,0 +1,41 @@ +from enum import Enum + +from lightning_utilities.core.enums import StrEnum + + +def test_consistency(): + class MyEnum(StrEnum): + FOO = "FOO" + BAR = "BAR" + BAZ = "BAZ" + NUM = "32" + + # normal equality, case invariant + assert MyEnum.FOO == "FOO" + assert MyEnum.FOO == "foo" + + # int support + assert MyEnum.NUM == 32 + assert MyEnum.NUM in (32, "32") + + # key-based + assert MyEnum.NUM == MyEnum.from_str("num") + + # collections + assert MyEnum.BAZ not in ("FOO", "BAR") + assert MyEnum.BAZ in ("FOO", "BAZ") + assert MyEnum.BAZ in ("baz", "FOO") + assert MyEnum.BAZ not in {"BAR", "FOO"} + # hash cannot be case invariant + assert MyEnum.BAZ not in {"BAZ", "FOO"} + assert MyEnum.BAZ in {"baz", "FOO"} + + +def test_comparison_with_other_enum(): + class MyEnum(StrEnum): + FOO = "FOO" + + class OtherEnum(Enum): + FOO = 123 + + assert not MyEnum.FOO.__eq__(OtherEnum.FOO)