diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index f3f7d11a57ea2..53c74f44c5e9b 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -68,6 +68,9 @@ class SSHHook(BaseHook): :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds + :param disabled_algorithms: dictionary mapping algorithm type to an + iterable of algorithm identifiers, which will be disabled for the + lifetime of the transport """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common @@ -112,6 +115,7 @@ def __init__( conn_timeout: Optional[int] = None, keepalive_interval: int = 30, banner_timeout: float = 30.0, + disabled_algorithms: Optional[dict] = None, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -125,6 +129,7 @@ def __init__( self.conn_timeout = conn_timeout self.keepalive_interval = keepalive_interval self.banner_timeout = banner_timeout + self.disabled_algorithms = disabled_algorithms self.host_proxy_cmd = None # Default values, overridable from Connection @@ -197,6 +202,9 @@ def __init__( ): self.look_for_keys = False + if "disabled_algorithms" in extra_options: + self.disabled_algorithms = extra_options.get("disabled_algorithms") + if host_key is not None: if host_key.startswith("ssh-"): key_type, host_key = host_key.split(None)[:2] @@ -313,6 +321,9 @@ def get_conn(self) -> paramiko.SSHClient: if self.key_file: connect_kwargs.update(key_filename=self.key_file) + if self.disabled_algorithms: + connect_kwargs.update(disabled_algorithms=self.disabled_algorithms) + log_before_sleep = lambda retry_state: self.log.info( "Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number ) diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst b/docs/apache-airflow-providers-ssh/connections/ssh.rst index 3cb0e0d8985e5..b91c0854a19ae 100644 --- a/docs/apache-airflow-providers-ssh/connections/ssh.rst +++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst @@ -54,6 +54,7 @@ Extra (optional) * ``allow_host_key_change`` - Set to ``true`` if you want to allow connecting to hosts that has host key changed or when you get 'REMOTE HOST IDENTIFICATION HAS CHANGED' error. This won't protect against Man-In-The-Middle attacks. Other possible solution is to remove the host entry from ``~/.ssh/known_hosts`` file. Default is ``false``. * ``look_for_keys`` - Set to ``false`` if you want to disable searching for discoverable private key files in ``~/.ssh/`` * ``host_key`` - The base64 encoded ssh-rsa public key of the host or "ssh- " (as you would find in the ``known_hosts`` file). Specifying this allows making the connection if and only if the public key of the endpoint matches this value. + * ``disabled_algorithms`` - A dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport. Example "extras" field: @@ -66,6 +67,7 @@ Extra (optional) "look_for_keys": "false", "allow_host_key_change": "false", "host_key": "AAAHD...YDWwq==" + "disabled_algorithms": {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]} } When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index 7362ed918dee8..195823857e7ac 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -76,6 +76,8 @@ def generate_host_key(pkey: paramiko.PKey): PASSPHRASE = ''.join(random.choice(string.ascii_letters) for i in range(10)) TEST_ENCRYPTED_PRIVATE_KEY = generate_key_string(pkey=TEST_PKEY, passphrase=PASSPHRASE) +TEST_DISABLED_ALGORITHMS = {"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]} + class TestSSHHook(unittest.TestCase): CONN_SSH_WITH_NO_EXTRA = 'ssh_with_no_extra' @@ -96,6 +98,7 @@ class TestSSHHook(unittest.TestCase): CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE = ( 'ssh_with_host_key_and_allow_host_key_changes_true' ) + CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS = 'ssh_with_extra_disabled_algorithms' @classmethod def tearDownClass(cls) -> None: @@ -115,6 +118,7 @@ def tearDownClass(cls) -> None: cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, + cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS, ] connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset)) connections.delete(synchronize_session=False) @@ -263,6 +267,14 @@ def setUpClass(cls) -> None: ), ) ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS, + host='localhost', + conn_type='ssh', + extra=json.dumps({"disabled_algorithms": TEST_DISABLED_ALGORITHMS}), + ) + ) @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_with_password(self, ssh_mock): @@ -747,6 +759,28 @@ def test_ssh_connection_with_all_timeout_param_and_extra_combinations( look_for_keys=True, ) + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_with_extra_disabled_algorithms(self, ssh_mock): + hook = SSHHook( + ssh_conn_id=self.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS, + remote_host='remote_host', + port='port', + username='username', + ) + + with hook.get_conn(): + ssh_mock.return_value.connect.assert_called_once_with( + banner_timeout=30.0, + hostname='remote_host', + username='username', + compress=True, + timeout=10, + port='port', + sock=None, + look_for_keys=True, + disabled_algorithms=TEST_DISABLED_ALGORITHMS, + ) + def test_openssh_private_key(self): # Paramiko behaves differently with OpenSSH generated keys to paramiko # generated keys, so we need a test one.