diff --git a/hathor/sysctl/protocol.py b/hathor/sysctl/protocol.py index 45c879dc7..c84b10900 100644 --- a/hathor/sysctl/protocol.py +++ b/hathor/sysctl/protocol.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable, List, Optional from pydantic import ValidationError from twisted.protocols.basic import LineReceiver @@ -35,7 +36,11 @@ def lineReceived(self, raw: bytes) -> None: line = raw.decode('utf-8').strip() except UnicodeDecodeError: self.sendError('command is not utf-8 valid') - if line == '!backup': + if line.startswith('!help'): + _, _, path = line.partition(' ') + self.help(path) + return + elif line == '!backup': self.backup() return head, separator, tail = line.partition('=') @@ -89,6 +94,28 @@ def backup(self) -> None: output = f'{key}={self._serialize(value)}' self.sendLine(output.encode('utf-8')) + def help(self, path: str) -> None: + """Show all available commands.""" + if path == '': + self._send_all_commands() + return + try: + cmd = self.root.get_command(path) + except SysctlEntryNotFound: + self.sendError(f'{path} not found') + return + + output: List[str] = [] + output.extend(self._get_method_help('getter', cmd.getter)) + output.append('') + output.extend(self._get_method_help('setter', cmd.setter)) + self.sendLine('\n'.join(output).encode('utf-8')) + + def _send_all_commands(self) -> None: + all_paths = list(self.root.get_all_paths()) + for path in sorted(all_paths): + self.sendLine(path.encode('utf-8')) + def _serialize(self, value: Any) -> str: """Serialize the return of a sysctl getter.""" output: str @@ -107,3 +134,16 @@ def _deserialize(self, value_str: str) -> Any: if len(parts) > 1: return tuple(json.loads(x) for x in parts) return json.loads(value_str) + + def _get_method_help(self, method_name: str, method: Optional[Callable]) -> List[str]: + """Return a list of strings with the help for `method`.""" + if method is None: + return [f'{method_name}: not available'] + + output: List[str] = [] + doc: str = inspect.getdoc(method) or '(no help found)' + signature = inspect.signature(method) + output.append(f'{method_name}{signature}:') + for line in doc.splitlines(): + output.append(f' {line.strip()}') + return output diff --git a/hathor/sysctl/sysctl.py b/hathor/sysctl/sysctl.py index c45f62667..a2f40a778 100644 --- a/hathor/sysctl/sysctl.py +++ b/hathor/sysctl/sysctl.py @@ -102,3 +102,10 @@ def get_all(self, prefix: str = '') -> Iterator[Tuple[str, Any]]: continue value = cmd.getter() yield (self.path_join(prefix, path), value) + + def get_all_paths(self, prefix: str = '') -> Iterator[str]: + """Return all available paths.""" + for path, child in self._children.items(): + yield from child.get_all_paths(self.path_join(prefix, path)) + for path, cmd in self._commands.items(): + yield self.path_join(prefix, path) diff --git a/tests/sysctl/test_sysctl.py b/tests/sysctl/test_sysctl.py index ead97371c..6e0e80ab5 100644 --- a/tests/sysctl/test_sysctl.py +++ b/tests/sysctl/test_sysctl.py @@ -16,11 +16,16 @@ class SysctlTest(unittest.TestCase): def setUp(self) -> None: super().setUp() + getter_max_connections = MagicMock(return_value=3) + getter_max_connections.__doc__ = 'Return the number of maximum connections.' + setter_max_connections = MagicMock() + setter_max_connections.__doc__ = 'Set the number of maximum connections.' + net = Sysctl() net.register( 'max_connections', - MagicMock(return_value=3), # int - MagicMock(), + getter_max_connections, # int + setter_max_connections, ) net.register( 'readonly', @@ -141,6 +146,17 @@ def test_get_all(self) -> None: ('net.readonly', 0.25), }) + def test_get_all_paths(self) -> None: + all_items = set(self.root.get_all_paths()) + self.assertEqual(all_items, { + 'net.max_connections', + 'core.writeonly', + 'core.loglevel', + 'net.rate_limit', + 'net.readonly', + 'ab.bc.cd.useless', + }) + ################## # Protocol: Get ################## @@ -228,3 +244,30 @@ def test_proto_backup(self) -> None: b'net.readonly=0.25', b'', # output ends with a new line (\n) }) + + def test_proto_help(self) -> None: + self.proto.lineReceived(b'!help') + output = self.tr.value() + lines = set(output.split(b'\n')) + self.assertEqual(lines, { + b'net.max_connections', + b'core.writeonly', + b'core.loglevel', + b'net.rate_limit', + b'net.readonly', + b'ab.bc.cd.useless', + b'', # output ends with a new line (\n) + }) + + def test_proto_help_method(self) -> None: + self.proto.lineReceived(b'!help net.max_connections') + output = self.tr.value() + lines = output.split(b'\n') + self.assertEqual(lines, [ + b'getter(*args, **kwargs):', + b' Return the number of maximum connections.', + b'', + b'setter(*args, **kwargs):', + b' Set the number of maximum connections.', + b'' + ])