Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,18 +471,24 @@ def exec_ssh_client_command(
command: str,
get_pty: bool,
environment: dict | None,
timeout: int | None = None,
timeout: int | ArgNotSet | None = NOTSET,
) -> tuple[int, bytes, bytes]:
self.log.info("Running command: %s", command)

if timeout is None:
timeout = self.cmd_timeout # type: ignore[assignment]
cmd_timeout: int | None
if not isinstance(timeout, ArgNotSet):
cmd_timeout = timeout
elif not isinstance(self.cmd_timeout, ArgNotSet):
cmd_timeout = self.cmd_timeout
else:
cmd_timeout = CMD_TIMEOUT
del timeout # Too easy to confuse with "timedout" below.

# set timeout taken as params
stdin, stdout, stderr = ssh_client.exec_command(
command=command,
get_pty=get_pty,
timeout=timeout,
timeout=cmd_timeout,
environment=environment,
)
# get channels
Expand All @@ -505,8 +511,8 @@ def exec_ssh_client_command(

# read from both stdout and stderr
while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready():
readq, _, _ = select([channel], [], [], timeout)
if timeout is not None:
readq, _, _ = select([channel], [], [], cmd_timeout)
if cmd_timeout is not None:
timedout = len(readq) == 0
for recv in readq:
if recv.recv_ready():
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def exec_ssh_client_command(self, ssh_client: SSHClient, command: str):
)
assert self.ssh_hook
return self.ssh_hook.exec_ssh_client_command(
ssh_client, command, environment=self.environment, get_pty=self.get_pty
ssh_client, command, timeout=self.cmd_timeout, environment=self.environment, get_pty=self.get_pty
)

def raise_for_status(self, exit_status: int, stderr: bytes, context=None) -> None:
Expand All @@ -156,7 +156,7 @@ def raise_for_status(self, exit_status: int, stderr: bytes, context=None) -> Non
def run_ssh_client_command(self, ssh_client: SSHClient, command: str, context=None) -> bytes:
assert self.ssh_hook
exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command(
ssh_client, command, environment=self.environment, get_pty=self.get_pty
ssh_client, command, timeout=self.cmd_timeout, environment=self.environment, get_pty=self.get_pty
)
self.raise_for_status(exit_status, agg_stderr, context=context)
return agg_stdout
Expand Down
36 changes: 18 additions & 18 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,24 +974,6 @@ def test_exec_ssh_client_command(self):
)
assert ret == (0, b"airflow\n", b"")

@pytest.mark.flaky(reruns=5)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted as it's basically the same test as test_command_timeout_fail

def test_command_timeout_default(self):
hook = SSHHook(
ssh_conn_id="ssh_default",
conn_timeout=30,
banner_timeout=100,
)

with hook.get_conn() as client:
with pytest.raises(AirflowException):
hook.exec_ssh_client_command(
client,
"sleep 10",
False,
None,
1,
)

@pytest.mark.flaky(reruns=5)
def test_command_timeout_success(self):
hook = SSHHook(
Expand Down Expand Up @@ -1028,6 +1010,24 @@ def test_command_timeout_fail(self):
None,
)

def test_command_timeout_not_set(self):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this illustrates how to set infinite timeout

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like with your changes when cmd_timeout is not set in hook/operator, but set to null in connection extra it still means "infinite timeout", so for me it's ok.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, you can set it to null in connection extra or None when you create the hook. Also i think after we replace the default value from None to ArgNotSet in SSHOperator as suggested by @uranusjr, we should also be able to set it as the operator level.

hook = SSHHook(
ssh_conn_id="ssh_default",
conn_timeout=30,
cmd_timeout=None,
banner_timeout=100,
)

with hook.get_conn() as client:
# sleeping for 20 sec which is longer than default timeout of 10 seconds
# to validate that no timeout is applied
hook.exec_ssh_client_command(
client,
"sleep 20",
environment=False,
get_pty=None,
)

@mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient")
def test_ssh_connection_with_no_host_key_check_true_and_allow_host_key_changes_true(self, ssh_mock):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/ssh/operators/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils.timezone import datetime
from airflow.utils.types import NOTSET
from tests.test_utils.config import conf_vars

TEST_DAG_ID = "unit_tests_ssh_test_op"
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_return_value(self, enable_xcom_pickling, output, expected):
result = task.execute(None)
assert result == expected
self.exec_ssh_client_command.assert_called_with(
mock.ANY, COMMAND, environment={"TEST": "value"}, get_pty=False
mock.ANY, COMMAND, timeout=NOTSET, environment={"TEST": "value"}, get_pty=False
)

@mock.patch("os.environ", {"AIRFLOW_CONN_" + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
Expand Down