diff --git a/hathor/sysctl/runner.py b/hathor/sysctl/runner.py index ef75a21b6..eb948e6f7 100644 --- a/hathor/sysctl/runner.py +++ b/hathor/sysctl/runner.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import re from typing import TYPE_CHECKING, Any from hathor.sysctl.exception import SysctlRunnerException @@ -20,6 +21,22 @@ if TYPE_CHECKING: from hathor.sysctl.sysctl import Sysctl +# Top level: +# - Allow numbers, integers and float +# - Allow string in double quote +# - Allow mix of numbers with strings +# - Allow space before, after and between elements +# - Don't allow empty element like [1,,2]; empty element between 1 and 2 +# - Don't allow comma (,) after last element +# - Don't allow double quote (") inside string +# Regex: +# - \[ - start list +# - ]] - end list +# - \s* - accept space in any quantity +# - \d+ - accept at least 1 number +# - ^" - negate double quote +array_pattern = r'\[(?:\s*(\d+\.\d+|\s*\d+)\s*|\s*"([^"]*)"\s*)(?:,\s*(\d+\.\d+|\s*\d+)\s*|,\s*"([^"]*)"\s*)*\]|\[\]' + class SysctlRunner: """ Encapsulates the Sysctl to decouple it from the SyctlProtocol. @@ -76,6 +93,9 @@ def deserialize(self, value_str: str) -> Any: if len(value_str) == 0: return () + if re.match(array_pattern, value_str): + return list(json.loads(value_str)) + parts = [x.strip() for x in value_str.split(',')] if len(parts) > 1: return tuple(json.loads(x) for x in parts) diff --git a/tests/cli/test_sysctl_init.py b/tests/cli/test_sysctl_init.py index a696e2008..d3241723e 100644 --- a/tests/cli/test_sysctl_init.py +++ b/tests/cli/test_sysctl_init.py @@ -113,12 +113,14 @@ def register_signal_handlers(self) -> None: 'p2p.max_enabled_sync': 7, 'p2p.rate_limit.global.send_tips': (5, 3), 'p2p.sync_update_interval': 17, + 'p2p.always_enable_sync': ['peer-1', 'peer-2'], } file_content = [ 'p2p.max_enabled_sync=7', 'p2p.rate_limit.global.send_tips=5,3', 'p2p.sync_update_interval=17', + 'p2p.always_enable_sync=["peer-1","peer-2"]', ] with tempfile.NamedTemporaryFile( @@ -155,6 +157,10 @@ def register_signal_handlers(self) -> None: curr_sync_update_interval, expected_sysctl_dict['p2p.sync_update_interval']) + curr_always_enabled_sync = list(conn.always_enable_sync) + self.assertTrue( + set(curr_always_enabled_sync).issuperset(set(expected_sysctl_dict['p2p.always_enable_sync']))) + # assert always_enabled_sync when it is set with a file expected_sysctl_dict = { 'p2p.always_enable_sync': ['peer-3', 'peer-4'], @@ -198,3 +204,36 @@ def register_signal_handlers(self) -> None: curr_always_enabled_sync = list(conn.always_enable_sync) self.assertTrue( set(curr_always_enabled_sync).issuperset(set(expected_sysctl_dict['p2p.always_enable_sync']))) + + @patch('twisted.internet.endpoints.serverFromString') # avoid open sock + def test_sysctl_init_file_failing_list_parsing(self, mock_endpoint): + class CustomRunNode(RunNode): + def start_manager(self) -> None: + pass + + def register_signal_handlers(self) -> None: + pass + + test_cases = [ + ['p2p.always_enable_sync=[ "peer-1", "peer-2", ]'], + ['p2p.always_enable_sync=["peer-1",,"peer-2"]'], + ['p2p.always_enable_sync=["peer-1","\"peer-2\""]'], + ['p2p.always_enable_sync=[1,2,]'], + ['p2p.always_enable_sync=[1,,2]'] + ] + + for file_content in test_cases: + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='sysctl_', + delete=False) as sysctl_init_file: + sysctl_init_file.write('\n'.join(file_content).encode()) + sysctl_init_file_path = str(Path(sysctl_init_file.name)) + + with self.assertRaises(SysctlRunnerException): + CustomRunNode(argv=[ + '--sysctl', 'tcp:8181', + '--sysctl-init-file', sysctl_init_file_path, # relative to src/hathor + '--memory-storage', + ])