From 60712a8cbdd9cb2fb29d52492928443856a3b869 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ricks?= Date: Mon, 29 Jan 2024 14:52:48 +0100 Subject: [PATCH] Change: Convert GvmConnection into a protocol Improve GvmConnection by being just a protocol and not a specific implementation. This allows for more flexibility. For example the DebugConnection is now also a GvmConnection. --- gvm/connections.py | 38 ++++++++++++------ tests/connections/test_gvm_connection.py | 40 +++++++++++++------ tests/connections/test_ssh_connection.py | 19 +++++---- tests/connections/test_tls_connection.py | 5 +++ .../test_unix_socket_connection.py | 13 +++--- 5 files changed, 76 insertions(+), 39 deletions(-) diff --git a/gvm/connections.py b/gvm/connections.py index e1a1fb38a..cf77cfccd 100644 --- a/gvm/connections.py +++ b/gvm/connections.py @@ -13,9 +13,10 @@ import ssl import sys import time +from abc import ABC, abstractmethod from os import PathLike from pathlib import Path -from typing import Optional, Union +from typing import Optional, Protocol, Union, runtime_checkable import paramiko import paramiko.ssh_exception @@ -41,6 +42,19 @@ Data = Union[str, bytes] +@runtime_checkable +class GvmConnection(Protocol): + def connect(self) -> None: ... + + def disconnect(self) -> None: ... + + def send(self, data: Data) -> None: ... + + def read(self) -> str: ... + + def finish_send(self): ... + + class XmlReader: """ Read a XML command until its closing element @@ -77,7 +91,7 @@ def feed_xml(self, data: Data) -> None: ) from None -class GvmConnection: +class AbstractGvmConnection(ABC): """ Base class for establishing a connection to a remote server daemon. @@ -97,6 +111,7 @@ def _read(self) -> bytes: return self._socket.recv(BUF_SIZE) + @abstractmethod def connect(self) -> None: """Establish a connection to a remote server""" raise NotImplementedError @@ -164,7 +179,7 @@ def finish_send(self): self._socket.shutdown(socketlib.SHUT_WR) -class SSHConnection(GvmConnection): +class SSHConnection(AbstractGvmConnection): """ SSH Class to connect, read and write from GVM via SSH @@ -174,7 +189,7 @@ class SSHConnection(GvmConnection): 127.0.0.1. port: Port of the remote SSH server. Default is port 22. username: Username to use for SSH login. Default is "gmp". - password: Passwort to use for SSH login. Default is "". + password: Password to use for SSH login. Default is "". """ def __init__( @@ -188,8 +203,7 @@ def __init__( known_hosts_file: Optional[Union[str, PathLike]] = None, auto_accept_host: Optional[bool] = None, ) -> None: - super().__init__(timeout=timeout) - + super().__init__(timeout) self.hostname = hostname if hostname is not None else DEFAULT_HOSTNAME self.port = int(port) if port is not None else DEFAULT_SSH_PORT self.username = ( @@ -414,11 +428,11 @@ def connect(self) -> None: def _read(self) -> bytes: return self._stdout.channel.recv(BUF_SIZE) - def send(self, data: Union[bytes, str]) -> int: + def send(self, data: Data) -> None: if isinstance(data, str): - return self._send_all(data.encode()) - - return self._send_all(data) + self._send_all(data.encode()) + else: + self._send_all(data) def finish_send(self) -> None: # shutdown socket for sending. only allow reading data afterwards @@ -439,7 +453,7 @@ def disconnect(self) -> None: del self._socket, self._stdin, self._stdout, self._stderr -class TLSConnection(GvmConnection): +class TLSConnection(AbstractGvmConnection): """ TLS class to connect, read and write from a remote GVM daemon via TLS secured socket. @@ -524,7 +538,7 @@ def disconnect(self): return super().disconnect() -class UnixSocketConnection(GvmConnection): +class UnixSocketConnection(AbstractGvmConnection): """ UNIX-Socket class to connect, read, write from a daemon via direct communicating UNIX-Socket diff --git a/tests/connections/test_gvm_connection.py b/tests/connections/test_gvm_connection.py index fea552073..48b415062 100644 --- a/tests/connections/test_gvm_connection.py +++ b/tests/connections/test_gvm_connection.py @@ -6,7 +6,13 @@ import unittest from unittest.mock import patch -from gvm.connections import DEFAULT_TIMEOUT, GvmConnection, XmlReader +from gvm.connections import ( + DEFAULT_TIMEOUT, + AbstractGvmConnection, + DebugConnection, + GvmConnection, + XmlReader, +) from gvm.errors import GvmError @@ -19,39 +25,49 @@ def test_is_end_xml_false(self): self.assertFalse(false) +class TestConnection(AbstractGvmConnection): + def connect(self) -> None: + pass + + class GvmConnectionTestCase(unittest.TestCase): # pylint: disable=protected-access def test_init_no_args(self): - connection = GvmConnection() + connection = TestConnection() self.check_for_default_values(connection) def test_init_with_none(self): - connection = GvmConnection(timeout=None) + connection = TestConnection(timeout=None) self.check_for_default_values(connection) def check_for_default_values(self, gvm_connection: GvmConnection): self.assertIsNone(gvm_connection._socket) self.assertEqual(gvm_connection._timeout, DEFAULT_TIMEOUT) - def test_connect_not_implemented(self): - connection = GvmConnection() - with self.assertRaises(NotImplementedError): - connection.connect() - - @patch("gvm.connections.GvmConnection._read") + @patch("gvm.connections.AbstractGvmConnection._read") def test_read_no_data(self, _read_mock): _read_mock.return_value = None - connection = GvmConnection() + connection = TestConnection() with self.assertRaises(GvmError, msg="Remote closed the connection"): connection.read() - @patch("gvm.connections.GvmConnection._read") + @patch("gvm.connections.AbstractGvmConnection._read") def test_read_trigger_timeout(self, _read_mock): # mocking the response into two parts, so we run into the timeout # check in the loop _read_mock.side_effect = [b"xyz", b""] - connection = GvmConnection(timeout=0) + connection = TestConnection(timeout=0) with self.assertRaises( GvmError, msg="Timeout while reading the response" ): connection.read() + + def test_is_gvm_connection(self): + connection = TestConnection() + self.assertTrue(isinstance(connection, GvmConnection)) + + +class DebugConnectionTestCase(unittest.TestCase): + def test_is_gvm_connection(self): + connection = DebugConnection(TestConnection()) + self.assertTrue(isinstance(connection, GvmConnection)) diff --git a/tests/connections/test_ssh_connection.py b/tests/connections/test_ssh_connection.py index 029e7922d..75d5fe70b 100644 --- a/tests/connections/test_ssh_connection.py +++ b/tests/connections/test_ssh_connection.py @@ -17,6 +17,7 @@ DEFAULT_SSH_PASSWORD, DEFAULT_SSH_PORT, DEFAULT_SSH_USERNAME, + GvmConnection, SSHConnection, ) from gvm.errors import GvmError @@ -177,7 +178,7 @@ def test_connect_adding_and_save_hostkey(self, input_mock, _print_mock): ) with self.assertLogs("gvm.connections", level="INFO") as cm: - hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file)) ssh_connection._ssh_authentication_input_loop( hostkeys=hostkeys, key=key ) @@ -229,7 +230,7 @@ def test_connect_adding_and_dont_save_hostkey( ) with self.assertLogs("gvm.connections", level="INFO") as cm: - hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file)) ssh_connection._ssh_authentication_input_loop( hostkeys=hostkeys, key=key ) @@ -274,7 +275,7 @@ def test_connect_wrong_input(self, stdout_mock, input_mock): ssh_connection._socket = paramiko.SSHClient() with self.assertLogs("gvm.connections", level="INFO") as cm: - hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file)) ssh_connection._ssh_authentication_input_loop( hostkeys=hostkeys, key=key ) @@ -323,7 +324,7 @@ def test_user_denies_auth(self, input_mock): with self.assertRaises( SystemExit, msg="User denied key. Host key verification failed." ): - hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file)) ssh_connection._ssh_authentication_input_loop( hostkeys=hostkeys, key=key ) @@ -441,8 +442,7 @@ def test_send(self): ) ssh_connection.connect() - req = ssh_connection.send("blah") - self.assertEqual(req, 4) + ssh_connection.send("blah") ssh_connection.disconnect() def test_send_error(self): @@ -473,8 +473,7 @@ def test_send_and_slice(self): ) ssh_connection.connect() - req = ssh_connection.send("blah") - self.assertEqual(req, 4) + ssh_connection.send("blah") stdin.channel.send.assert_called() with self.assertRaises(AssertionError): @@ -495,3 +494,7 @@ def test_read(self): recved = ssh_connection._read() self.assertEqual(recved, b"foo bar baz") ssh_connection.disconnect() + + def test_is_gvm_connection(self): + ssh_connection = SSHConnection(known_hosts_file=self.known_hosts_file) + self.assertTrue(isinstance(ssh_connection, GvmConnection)) diff --git a/tests/connections/test_tls_connection.py b/tests/connections/test_tls_connection.py index f479f16da..7b1cf9dd6 100644 --- a/tests/connections/test_tls_connection.py +++ b/tests/connections/test_tls_connection.py @@ -10,6 +10,7 @@ DEFAULT_GVM_PORT, DEFAULT_HOSTNAME, DEFAULT_TIMEOUT, + GvmConnection, TLSConnection, ) @@ -62,3 +63,7 @@ def test_connect_auth(self): context_mock.load_cert_chain.assert_called_once() context_mock.wrap_socket.assert_called_once() self.assertFalse(context_mock.check_hostname) + + def test_is_gvm_connection(self): + connection = TLSConnection() + self.assertTrue(isinstance(connection, GvmConnection)) diff --git a/tests/connections/test_unix_socket_connection.py b/tests/connections/test_unix_socket_connection.py index 478928edd..4568f78e6 100644 --- a/tests/connections/test_unix_socket_connection.py +++ b/tests/connections/test_unix_socket_connection.py @@ -14,6 +14,7 @@ from gvm.connections import ( DEFAULT_TIMEOUT, DEFAULT_UNIX_SOCKET_PATH, + GvmConnection, UnixSocketConnection, ) from gvm.errors import GvmError @@ -65,8 +66,7 @@ def test_unix_socket_connection_connect_send_bytes_read(self): path=self.socketname, timeout=DEFAULT_TIMEOUT ) connection.connect() - req = connection.send(bytes("", "utf-8")) - self.assertIsNone(req) + connection.send(bytes("", "utf-8")) resp = connection.read() self.assertEqual(resp, '') connection.disconnect() @@ -76,8 +76,7 @@ def test_unix_socket_connection_connect_send_str_read(self): path=self.socketname, timeout=DEFAULT_TIMEOUT ) connection.connect() - req = connection.send("") - self.assertIsNone(req) + connection.send("") resp = connection.read() self.assertEqual(resp, '') connection.disconnect() @@ -120,6 +119,6 @@ def check_default_values(self, connection: UnixSocketConnection): self.assertEqual(connection._timeout, DEFAULT_TIMEOUT) self.assertEqual(connection.path, DEFAULT_UNIX_SOCKET_PATH) - -if __name__ == "__main__": - unittest.main() + def test_is_gvm_connection(self): + connection = UnixSocketConnection() + self.assertTrue(isinstance(connection, GvmConnection))