diff --git a/src/wxflow/configuration.py b/src/wxflow/configuration.py index 4a152d3..f06a3f3 100644 --- a/src/wxflow/configuration.py +++ b/src/wxflow/configuration.py @@ -156,21 +156,27 @@ def cast_strdict_as_dtypedict(ctx: Dict[str, str]) -> Dict[str, Any]: def cast_as_dtype(string: str) -> Union[str, int, float, bool, Any]: """ Cast a value into known datatype + Parameters ---------- string: str + Returns ------- - value : str or int or float or datetime + value : str, int, float, bool or datetime; or List of these default: str """ TRUTHS = ['y', 'yes', 't', 'true', '.t.', '.true.'] BOOLS = ['n', 'no', 'f', 'false', '.f.', '.false.'] + TRUTHS BOOLS = [x.upper() for x in BOOLS] + BOOLS + ['Yes', 'No', 'True', 'False'] - def _cast_or_not(type: Any, string: str): + if ',' in string: + # Convert comma-separated list to python list + return [cast_as_dtype(elem.strip()) for elem in string.split(',')] + + def _cast_or_not(to_type: Any, string: str): try: - return type(string) + return to_type(string) except ValueError: return string diff --git a/tests/test_configuration.py b/tests/test_configuration.py index da1f926..1070226 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -30,6 +30,11 @@ export SOME_BOOL4=NO export SOME_BOOL5=.false. export SOME_BOOL6=.F. +export SOME_LIST1="3, 15, -999" +export SOME_LIST2="0.2,3.5,-9999." +export SOME_LIST3="20221225, 202212251845" +export SOME_LIST4="YES, .false., .T." +export SOME_LIST5="0.2, test_str, 15, 20221225, NO" """ file1 = """#!/bin/bash @@ -60,7 +65,12 @@ 'SOME_BOOL3': True, 'SOME_BOOL4': False, 'SOME_BOOL5': False, - 'SOME_BOOL6': False + 'SOME_BOOL6': False, + 'SOME_LIST1': [3, 15, -999], + 'SOME_LIST2': [0.2, 3.5, -9999.], + 'SOME_LIST3': [datetime(2022, 12, 25, 0, 0, 0), datetime(2022, 12, 25, 18, 45, 0)], + 'SOME_LIST4': [True, False, True], + 'SOME_LIST5': [0.2, 'test_str', 15, datetime(2022, 12, 25, 0, 0, 0), False], } file0_dict_set_envvar = file0_dict.copy() @@ -107,6 +117,14 @@ ('20221215T1830Z', datetime(2022, 12, 15, 18, 30, 0)), ] +list_dtypes = [ + ('3, 15, -999', [3, 15, -999]), + ('0.2,3.5,-9999.', [0.2, 3.5, -9999.]), + ('20221215,20221215T1830Z', [datetime(2022, 12, 15, 0, 0, 0), datetime(2022, 12, 15, 18, 30, 0)]), + ('YES, .false., .T.', [True, False, True]), + ('0.2, test_str, 15, 20221225, NO', [0.2, 'test_str', 15, datetime(2022, 12, 25, 0, 0, 0), False]), +] + def evaluate(dtypes): for pair in dtypes: @@ -134,6 +152,10 @@ def test_cast_as_dtype_datetimes(): evaluate(datetime_dtypes) +def test_cast_as_dtype_list(): + evaluate(list_dtypes) + + @pytest.fixture def create_configs(tmp_path):