-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change: Move StrEnum to pontos.enum and add functions for argparse
Create a dedicated enum module and add functions for using enum with argparse.
- Loading branch information
1 parent
2ed58cb
commit 9498433
Showing
3 changed files
with
91 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-FileCopyrightText: 2024 Greenbone AG | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
from argparse import ArgumentTypeError | ||
from enum import Enum | ||
from typing import Callable, Type, TypeVar, Union | ||
|
||
|
||
class StrEnum(str, Enum): | ||
# Should be replaced by enum.StrEnum when we require Python >= 3.11 | ||
""" | ||
An Enum that provides str like behavior | ||
""" | ||
|
||
def __str__(self) -> str: | ||
return self.value | ||
|
||
|
||
def enum_choice(enum: Type[Enum]) -> list[str]: | ||
""" | ||
Return a sequence of choices for argparse from an enum | ||
""" | ||
return [str(e) for e in enum] | ||
|
||
|
||
def to_choices(enum: Type[Enum]) -> str: | ||
""" | ||
Convert an enum to a comma separated string of choices. For example useful | ||
in help messages for argparse. | ||
""" | ||
return ", ".join([str(t) for t in enum]) | ||
|
||
|
||
T = TypeVar("T", bound=Enum) | ||
|
||
|
||
def enum_type(enum: Type[T]) -> Callable[[Union[str, T]], T]: | ||
""" | ||
Create a argparse type function for converting the string input into an Enum | ||
""" | ||
|
||
def convert(value: Union[str, T]) -> T: | ||
if isinstance(value, str): | ||
try: | ||
return enum(value) | ||
except ValueError: | ||
raise ArgumentTypeError( | ||
f"invalid value {value}. Expected one of {to_choices(enum)}." | ||
) from None | ||
return value | ||
|
||
return convert |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SPDX-FileCopyrightText: 2024 Greenbone AG | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
|
||
import unittest | ||
from argparse import ArgumentTypeError | ||
|
||
from pontos.enum import StrEnum, enum_type | ||
|
||
|
||
class EnumTypeTestCase(unittest.TestCase): | ||
def test_enum_type(self): | ||
class FooEnum(StrEnum): | ||
ALL = "all" | ||
NONE = "none" | ||
|
||
func = enum_type(FooEnum) | ||
|
||
self.assertEqual(func("all"), FooEnum.ALL) | ||
self.assertEqual(func("none"), FooEnum.NONE) | ||
|
||
self.assertEqual(func(FooEnum.ALL), FooEnum.ALL) | ||
self.assertEqual(func(FooEnum.NONE), FooEnum.NONE) | ||
|
||
def test_enum_type_error(self): | ||
class FooEnum(StrEnum): | ||
ALL = "all" | ||
NONE = "none" | ||
|
||
func = enum_type(FooEnum) | ||
|
||
with self.assertRaisesRegex( | ||
ArgumentTypeError, | ||
r"invalid value foo. Expected one of all, none", | ||
): | ||
func("foo") |