diff --git a/docs/apache-airflow-providers-ssh/connections/ssh.rst b/docs/apache-airflow-providers-ssh/connections/ssh.rst index 50c6a75910423..2d13de5cef937 100644 --- a/docs/apache-airflow-providers-ssh/connections/ssh.rst +++ b/docs/apache-airflow-providers-ssh/connections/ssh.rst @@ -57,6 +57,7 @@ Extra (optional) * ``host_key`` - The base64 encoded ssh-rsa public key of the host or "ssh- " (as you would find in the ``known_hosts`` file). Specifying this allows making the connection if and only if the public key of the endpoint matches this value. * ``disabled_algorithms`` - A dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport. * ``ciphers`` - A list of ciphers to use in order of preference. + * ``host_proxy_cmd`` - A proxy command to be executed. Example "extras" field: diff --git a/providers/src/airflow/providers/ssh/hooks/ssh.py b/providers/src/airflow/providers/ssh/hooks/ssh.py index 3502459c64483..f17a5d1c5b863 100644 --- a/providers/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/src/airflow/providers/ssh/hooks/ssh.py @@ -170,6 +170,19 @@ def __init__( if private_key: self.pkey = self._pkey_from_private_key(private_key, passphrase=private_key_passphrase) + if "host_proxy_cmd" in extra_options: + self.host_proxy_cmd = extra_options.get("host_proxy_cmd") + + if "timeout" in extra_options: + warnings.warn( + "Extra option `timeout` is deprecated." + "Please use `conn_timeout` instead." + "The old option `timeout` will be removed in a future version.", + category=AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.timeout = int(extra_options["timeout"]) + if "conn_timeout" in extra_options and self.conn_timeout is None: self.conn_timeout = int(extra_options["conn_timeout"]) @@ -247,8 +260,10 @@ def __init__( with open(user_ssh_config_filename) as config_fd: ssh_conf.parse(config_fd) host_info = ssh_conf.lookup(self.remote_host) - if host_info and host_info.get("proxycommand") and not self.host_proxy_cmd: - self.host_proxy_cmd = host_info["proxycommand"] + # If the proxy command is already set via the extra options, it will not be overwritten""" + if not self.host_proxy_cmd: + if host_info and host_info.get("proxycommand"): + self.host_proxy_cmd = host_info["proxycommand"] if not (self.password or self.key_file): if host_info and host_info.get("identityfile"): diff --git a/providers/tests/ssh/hooks/test_ssh.py b/providers/tests/ssh/hooks/test_ssh.py index e09f2eeee0af7..71fa48f951a93 100644 --- a/providers/tests/ssh/hooks/test_ssh.py +++ b/providers/tests/ssh/hooks/test_ssh.py @@ -755,6 +755,24 @@ def test_ssh_with_extra_ciphers(self, ssh_mock): transport = ssh_mock.return_value.get_transport.return_value assert transport.get_security_options.return_value.ciphers == TEST_CIPHERS + def test_host_proxy_cmd_in_extra(self): + TEST_HOST_PROXY_CMD = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + session = settings.Session() + try: + conn = Connection( + conn_id="ssh_with_proxy_cmd", + host="localhost", + conn_type="ssh", + extra={"host_proxy_cmd": TEST_HOST_PROXY_CMD}, + ) + session.add(conn) + session.flush() + hook = SSHHook(ssh_conn_id=conn.conn_id) + assert hook.host_proxy_cmd == TEST_HOST_PROXY_CMD + finally: + session.delete(conn) + session.commit() + def test_openssh_private_key(self): # Paramiko behaves differently with OpenSSH generated keys to paramiko # generated keys, so we need a test one.