From 09c5f59549d94acae9af064fc6218eb432933e23 Mon Sep 17 00:00:00 2001 From: Paul Williams Date: Mon, 26 Sep 2022 02:05:29 +0000 Subject: [PATCH 1/2] Add support for list of filepaths to SFTPOperator --- airflow/providers/sftp/operators/sftp.py | 57 +++++++++++++-------- tests/providers/sftp/operators/test_sftp.py | 56 ++++++++++++++++++++ 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 0d18594b818ce..d0f11c61a881b 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -52,8 +52,8 @@ class SFTPOperator(BaseOperator): :param remote_host: remote host to connect (templated) Nullable. If provided, it will replace the `remote_host` which was defined in `sftp_hook`/`ssh_hook` or predefined in the connection of `ssh_conn_id`. - :param local_filepath: local file path to get or put. (templated) - :param remote_filepath: remote file path to get or put. (templated) + :param local_filepath: local file path or list of local file paths to get or put. (templated) + :param remote_filepath: remote file path or list of remote file paths to get or put. (templated) :param operation: specify operation 'get' or 'put', defaults to put :param confirm: specify if the SFTP operation should be confirmed, defaults to True :param create_intermediate_dirs: create missing intermediate directories when @@ -85,8 +85,8 @@ def __init__( sftp_hook: SFTPHook | None = None, ssh_conn_id: str | None = None, remote_host: str | None = None, - local_filepath: str, - remote_filepath: str, + local_filepath: str | list[str], + remote_filepath: str | list[str], operation: str = SFTPOperation.PUT, confirm: bool = True, create_intermediate_dirs: bool = False, @@ -97,12 +97,26 @@ def __init__( self.sftp_hook = sftp_hook self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host - self.local_filepath = local_filepath - self.remote_filepath = remote_filepath self.operation = operation self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs + if isinstance(local_filepath, str): + self.local_filepath = [local_filepath] + else: + self.local_filepath = local_filepath + + if isinstance(remote_filepath, str): + self.remote_filepath = [remote_filepath] + else: + self.remote_filepath = remote_filepath + + if len(self.local_filepath) != len(self.remote_filepath): + raise ValueError( + f'{len(self.local_filepath)} paths in local_filepath ' + f'!= {len(self.remote_filepath)} paths in remote_filepath' + ) + if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT): raise TypeError( f"Unsupported operation value {self.operation}, " @@ -129,7 +143,7 @@ def __init__( ) self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook) - def execute(self, context: Any) -> str | None: + def execute(self, context: Any) -> list[str] | None: file_msg = None try: if self.ssh_conn_id: @@ -152,20 +166,21 @@ def execute(self, context: Any) -> str | None: ) self.sftp_hook.remote_host = self.remote_host - if self.operation.lower() == SFTPOperation.GET: - local_folder = os.path.dirname(self.local_filepath) - if self.create_intermediate_dirs: - Path(local_folder).mkdir(parents=True, exist_ok=True) - file_msg = f"from {self.remote_filepath} to {self.local_filepath}" - self.log.info("Starting to transfer %s", file_msg) - self.sftp_hook.retrieve_file(self.remote_filepath, self.local_filepath) - else: - remote_folder = os.path.dirname(self.remote_filepath) - if self.create_intermediate_dirs: - self.sftp_hook.create_directory(remote_folder) - file_msg = f"from {self.local_filepath} to {self.remote_filepath}" - self.log.info("Starting to transfer file %s", file_msg) - self.sftp_hook.store_file(self.remote_filepath, self.local_filepath, confirm=self.confirm) + for local_filepath, remote_filepath in zip(self.local_filepath, self.remote_filepath): + if self.operation.lower() == SFTPOperation.GET: + local_folder = os.path.dirname(local_filepath) + if self.create_intermediate_dirs: + Path(local_folder).mkdir(parents=True, exist_ok=True) + file_msg = f"from {remote_filepath} to {local_filepath}" + self.log.info("Starting to transfer %s", file_msg) + self.sftp_hook.retrieve_file(remote_filepath, local_filepath) + else: + remote_folder = os.path.dirname(remote_filepath) + if self.create_intermediate_dirs: + self.sftp_hook.create_directory(remote_folder) + file_msg = f"from {local_filepath} to {remote_filepath}" + self.log.info("Starting to transfer file %s", file_msg) + self.sftp_hook.store_file(remote_filepath, local_filepath, confirm=self.confirm) except Exception as e: raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}") diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index 92ca2d6f6dd6f..f2754a53a5b88 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -43,6 +43,9 @@ def setup_method(self): hook = SSHHook(ssh_conn_id='ssh_default') hook.no_host_key_check = True self.hook = hook + sftp_hook = SFTPHook(ssh_conn_id='ssh_default') + sftp_hook.no_host_key_check = True + self.sftp_hook = sftp_hook self.test_dir = "/tmp" self.test_local_dir = "/tmp/tmp2" self.test_remote_dir = "/tmp/tmp1" @@ -386,3 +389,56 @@ def test_arg_checking(self): except Exception: pass assert task_6.sftp_hook.remote_host == 'remotehost' + + def test_unequal_local_remote_file_paths(self): + with pytest.raises(ValueError): + SFTPOperator( + task_id='test_sftp_unequal_paths', + local_filepath='/tmp/test', + remote_filepath=['/tmp/test1', '/tmp/test2'], + ) + + def test_str_filepaths_converted_to_lists(self): + local_filepath = '/tmp/test' + remote_filepath = '/tmp/remotetest' + sftp_op = SFTPOperator( + task_id='test_str_to_list', local_filepath=local_filepath, remote_filepath=remote_filepath + ) + assert sftp_op.local_filepath == [local_filepath] + assert sftp_op.remote_filepath == [remote_filepath] + + @mock.patch('airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file') + def test_multiple_paths_get(self, mock_get): + local_filepath = ['/tmp/ltest1', '/tmp/ltest2'] + remote_filepath = ['/tmp/rtest1', '/tmp/rtest2'] + sftp_op = SFTPOperator( + task_id='test_multiple_paths_get', + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + ) + sftp_op.execute(None) + assert mock_get.call_count == 2 + args0, _ = mock_get.call_args_list[0] + args1, _ = mock_get.call_args_list[1] + assert args0 == (remote_filepath[0], local_filepath[0]) + assert args1 == (remote_filepath[1], local_filepath[1]) + + @mock.patch('airflow.providers.sftp.operators.sftp.SFTPHook.store_file') + def test_multiple_paths_put(self, mock_put): + local_filepath = ['/tmp/ltest1', '/tmp/ltest2'] + remote_filepath = ['/tmp/rtest1', '/tmp/rtest2'] + sftp_op = SFTPOperator( + task_id='test_multiple_paths_get', + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.PUT, + ) + sftp_op.execute(None) + assert mock_put.call_count == 2 + args0, _ = mock_put.call_args_list[0] + args1, _ = mock_put.call_args_list[1] + assert args0 == (remote_filepath[0], local_filepath[0]) + assert args1 == (remote_filepath[1], local_filepath[1]) From 48f12e637ea906b38e6ff26b137dab345985c81c Mon Sep 17 00:00:00 2001 From: Paul Williams Date: Mon, 26 Sep 2022 02:44:34 +0000 Subject: [PATCH 2/2] Return str in execute if local_filepath was passed as str --- airflow/providers/sftp/operators/sftp.py | 6 ++++-- tests/providers/sftp/operators/test_sftp.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index d0f11c61a881b..bf15dc4959c8e 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -101,8 +101,10 @@ def __init__( self.confirm = confirm self.create_intermediate_dirs = create_intermediate_dirs + self.local_filepath_was_str = False if isinstance(local_filepath, str): self.local_filepath = [local_filepath] + self.local_filepath_was_str = True else: self.local_filepath = local_filepath @@ -143,7 +145,7 @@ def __init__( ) self.sftp_hook = SFTPHook(ssh_hook=self.ssh_hook) - def execute(self, context: Any) -> list[str] | None: + def execute(self, context: Any) -> str | list[str] | None: file_msg = None try: if self.ssh_conn_id: @@ -185,4 +187,4 @@ def execute(self, context: Any) -> list[str] | None: except Exception as e: raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}") - return self.local_filepath + return self.local_filepath[0] if self.local_filepath_was_str else self.local_filepath diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index f2754a53a5b88..1ea1838952186 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -442,3 +442,18 @@ def test_multiple_paths_put(self, mock_put): args1, _ = mock_put.call_args_list[1] assert args0 == (remote_filepath[0], local_filepath[0]) assert args1 == (remote_filepath[1], local_filepath[1]) + + @mock.patch('airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file') + def test_return_str_when_local_filepath_was_str(self, mock_get): + local_filepath = '/tmp/ltest1' + remote_filepath = '/tmp/rtest1' + sftp_op = SFTPOperator( + task_id='test_returns_str', + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + ) + return_value = sftp_op.execute(None) + assert isinstance(return_value, str) + assert return_value == local_filepath