diff --git a/hathor/builder/sysctl_builder.py b/hathor/builder/sysctl_builder.py index 60b2cb0ed..46c189ebd 100644 --- a/hathor/builder/sysctl_builder.py +++ b/hathor/builder/sysctl_builder.py @@ -18,6 +18,7 @@ class SysctlBuilder: """Builder for the sysctl tree.""" + def __init__(self, artifacts: BuildArtifacts) -> None: self.artifacts = artifacts diff --git a/hathor/cli/run_node.py b/hathor/cli/run_node.py index 06fefd4c1..00ab40956 100644 --- a/hathor/cli/run_node.py +++ b/hathor/cli/run_node.py @@ -15,7 +15,7 @@ import os import sys from argparse import SUPPRESS, ArgumentParser, Namespace -from typing import Any, Callable +from typing import Any, Callable, Optional from pydantic import ValidationError from structlog import get_logger @@ -58,6 +58,8 @@ def create_parser(cls) -> ArgumentParser: parser.add_argument('--peer', help='json file with peer info') parser.add_argument('--sysctl', help='Endpoint description (eg: unix:/path/sysctl.sock, tcp:5000:interface:127.0.0.1)') + parser.add_argument('--sysctl-init-file', + help='File path to the sysctl.txt init file (eg: conf/sysctl.txt)') parser.add_argument('--listen', action='append', default=[], help='Address to listen for new connections (eg: tcp:8000)') parser.add_argument('--bootstrap', action='append', help='Address to connect to (eg: tcp:127.0.0.1:8000') @@ -371,10 +373,10 @@ def __init__(self, *, argv=None): self.prepare() self.register_signal_handlers() if self._args.sysctl: - self.init_sysctl(self._args.sysctl) + self.init_sysctl(self._args.sysctl, self._args.sysctl_init_file) - def init_sysctl(self, description: str) -> None: - """Initialize sysctl and listen for connections. + def init_sysctl(self, description: str, sysctl_init_file: Optional[str] = None) -> None: + """Initialize sysctl, listen for connections and apply settings from config file if required. Examples of description: - tcp:5000 @@ -389,12 +391,17 @@ def init_sysctl(self, description: str) -> None: from hathor.builder.sysctl_builder import SysctlBuilder from hathor.sysctl.factory import SysctlFactory + from hathor.sysctl.init_file_loader import SysctlInitFileLoader from hathor.sysctl.runner import SysctlRunner builder = SysctlBuilder(self.artifacts) root = builder.build() - runner = SysctlRunner(root) + + if sysctl_init_file: + init_file_loader = SysctlInitFileLoader(runner, sysctl_init_file) + init_file_loader.load() + factory = SysctlFactory(runner) endpoint = serverFromString(self.reactor, description) endpoint.listen(factory) diff --git a/hathor/sysctl/init_file_loader.py b/hathor/sysctl/init_file_loader.py new file mode 100644 index 000000000..9bf09efb6 --- /dev/null +++ b/hathor/sysctl/init_file_loader.py @@ -0,0 +1,18 @@ +from hathor.sysctl.runner import SysctlRunner + + +class SysctlInitFileLoader: + def __init__(self, runner: SysctlRunner, init_file: str) -> None: + assert runner + assert init_file + + self.runner = runner + self.init_file = init_file + + def load(self) -> None: + """Read the init_file and execute each line as a syctl command in the runner.""" + with open(self.init_file, 'r', encoding='utf-8') as file: + lines = file.readlines() + + for line in lines: + self.runner.run(line.strip()) diff --git a/hathor/sysctl/runner.py b/hathor/sysctl/runner.py index 85850b2db..ef75a21b6 100644 --- a/hathor/sysctl/runner.py +++ b/hathor/sysctl/runner.py @@ -75,6 +75,7 @@ def deserialize(self, value_str: str) -> Any: """Deserialize a value sent by the client.""" if len(value_str) == 0: return () + parts = [x.strip() for x in value_str.split(',')] if len(parts) > 1: return tuple(json.loads(x) for x in parts) diff --git a/tests/cli/test_sysctl_init.py b/tests/cli/test_sysctl_init.py new file mode 100644 index 000000000..a696e2008 --- /dev/null +++ b/tests/cli/test_sysctl_init.py @@ -0,0 +1,200 @@ +import json +import shutil +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +from hathor.builder.sysctl_builder import SysctlBuilder +from hathor.cli.run_node import RunNode +from hathor.sysctl.exception import SysctlEntryNotFound, SysctlRunnerException +from hathor.sysctl.init_file_loader import SysctlInitFileLoader +from hathor.sysctl.runner import SysctlRunner +from tests import unittest + + +class SysctlInitTest(unittest.TestCase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + super().tearDown() + # Removing tmpdir + shutil.rmtree(self.tmp_dir) + + def test_sysctl_builder_fail_with_invalid_property(self): + file_content = [ + 'invalid.property=10', + ] + + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='sysctl_', + delete=False) as sysctl_init_file: + sysctl_init_file.write('\n'.join(file_content).encode()) + sysctl_init_file_path = str(Path(sysctl_init_file.name)) + + # prepare to register only p2p commands + artifacts = Mock(**{ + 'p2p_manager': Mock(), + 'manager.metrics.websocket_factory.return_value': None + }) + + with self.assertRaises(SysctlEntryNotFound) as context: + root = SysctlBuilder(artifacts).build() + runner = SysctlRunner(root) + loader = SysctlInitFileLoader(runner, sysctl_init_file_path) + loader.load() + + # assert message in the caught exception + expected_msg = 'invalid.property' + self.assertEqual(str(context.exception), expected_msg) + + def test_sysctl_builder_fail_with_invalid_value(self): + file_content = [ + 'p2p.rate_limit.global.send_tips=!!tuple [1,2]' + ] + + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='sysctl_', + delete=False) as sysctl_init_file: + sysctl_init_file.write('\n'.join(file_content).encode()) + sysctl_init_file_path = str(Path(sysctl_init_file.name)) + + # prepare to register only p2p commands + artifacts = Mock(**{ + 'p2p_manager': Mock(), + 'manager.metrics.websocket_factory.return_value': None + }) + + with self.assertRaises(SysctlRunnerException) as context: + root = SysctlBuilder(artifacts).build() + runner = SysctlRunner(root) + loader = SysctlInitFileLoader(runner, sysctl_init_file_path) + loader.load() + + # assert message in the caught exception + expected_msg = 'value: wrong format' + self.assertEqual(str(context.exception), expected_msg) + + def test_syctl_init_file_fail_with_empty_or_invalid_file(self): + # prepare to register only p2p commands + artifacts = Mock(**{ + 'p2p_manager': Mock(), + 'manager.metrics.websocket_factory.return_value': None + }) + + with self.assertRaises(AssertionError): + root = SysctlBuilder(artifacts).build() + runner = SysctlRunner(root) + loader = SysctlInitFileLoader(runner, None) + loader.load() + + with self.assertRaises(AssertionError): + root = SysctlBuilder(artifacts).build() + runner = SysctlRunner(root) + loader = SysctlInitFileLoader(runner, "") + loader.load() + + @patch('twisted.internet.endpoints.serverFromString') # avoid open sock + def test_command_option_sysctl_init_file(self, mock_endpoint): + class CustomRunNode(RunNode): + def start_manager(self) -> None: + pass + + def register_signal_handlers(self) -> None: + pass + + expected_sysctl_dict = { + 'p2p.max_enabled_sync': 7, + 'p2p.rate_limit.global.send_tips': (5, 3), + 'p2p.sync_update_interval': 17, + } + + file_content = [ + 'p2p.max_enabled_sync=7', + 'p2p.rate_limit.global.send_tips=5,3', + 'p2p.sync_update_interval=17', + ] + + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='sysctl_', + delete=False) as sysctl_init_file: + sysctl_init_file.write('\n'.join(file_content).encode()) + sysctl_init_file_path = str(Path(sysctl_init_file.name)) + + run_node = CustomRunNode(argv=[ + '--sysctl', 'tcp:8181', + '--sysctl-init-file', sysctl_init_file_path, # relative to src/hathor + '--memory-storage', + ]) + self.assertTrue(run_node is not None) + conn = run_node.manager.connections + + curr_max_enabled_sync = conn.MAX_ENABLED_SYNC + self.assertEqual( + curr_max_enabled_sync, + expected_sysctl_dict['p2p.max_enabled_sync']) + + curr_rate_limit_global_send_tips = conn.rate_limiter.get_limit(conn.GlobalRateLimiter.SEND_TIPS) + self.assertEqual( + curr_rate_limit_global_send_tips.max_hits, + expected_sysctl_dict['p2p.rate_limit.global.send_tips'][0]) + self.assertEqual( + curr_rate_limit_global_send_tips.window_seconds, + expected_sysctl_dict['p2p.rate_limit.global.send_tips'][1]) + + curr_sync_update_interval = conn.lc_sync_update_interval + self.assertEqual( + curr_sync_update_interval, + expected_sysctl_dict['p2p.sync_update_interval']) + + # assert always_enabled_sync when it is set with a file + expected_sysctl_dict = { + 'p2p.always_enable_sync': ['peer-3', 'peer-4'], + } + + file_content = [ + 'peer-3', + 'peer-4', + ] + + # set the always_enabled_sync peers file + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='always_enable_sync_', + delete=False) as always_enabled_peers_file: + always_enabled_peers_file.write('\n'.join(file_content).encode()) + always_enabled_peers_file_path = str(Path(always_enabled_peers_file.name)) + + file_content = [ + f'p2p.always_enable_sync.readtxt={json.dumps(always_enabled_peers_file_path)}' + ] + + # set the sysctl.txt file + with tempfile.NamedTemporaryFile( + dir=self.tmp_dir, + suffix='.txt', + prefix='sysctl_', + delete=False) as sysctl_init_file: + sysctl_init_file.write('\n'.join(file_content).encode()) + sysctl_init_file_path = str(Path(sysctl_init_file.name)) + + run_node = CustomRunNode(argv=[ + '--sysctl', 'tcp:8181', + '--sysctl-init-file', sysctl_init_file_path, # relative to src/hathor + '--memory-storage', + ]) + self.assertTrue(run_node is not None) + conn = run_node.manager.connections + + curr_always_enabled_sync = list(conn.always_enable_sync) + self.assertTrue( + set(curr_always_enabled_sync).issuperset(set(expected_sysctl_dict['p2p.always_enable_sync'])))