diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index 3b7529bd9abc1..92ca2d6f6dd6f 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -48,6 +48,10 @@ def setup_method(self): self.test_remote_dir = "/tmp/tmp1" self.test_local_filename = 'test_local_file' self.test_remote_filename = 'test_remote_file' + self.test_remote_file_content = ( + b"This is remote file content \n which is also multiline " + b"another line here \n this is last line. EOF" + ) self.test_local_filepath = f'{self.test_dir}/{self.test_local_filename}' # Local Filepath with Intermediate Directory self.test_local_filepath_int_dir = f'{self.test_local_dir}/{self.test_local_filename}' @@ -125,7 +129,7 @@ def test_file_transfer_no_intermediate_dir_error_put(self, create_task_instance_ operation=SFTPOperation.PUT, create_intermediate_dirs=False, ) - with pytest.raises(Exception) as ctx: + with pytest.raises(AirflowException) as ctx: ti2.run() assert 'No such file' in str(ctx.value) @@ -196,20 +200,16 @@ def test_json_file_transfer_put(self, dag_maker): pulled = tis["check_file_task"].xcom_pull(task_ids="check_file_task", key='return_value') assert pulled.strip() == b64encode(test_local_file_content).decode('utf-8') - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_file_transfer_get(self, dag_maker): - test_remote_file_content = ( - "This is remote file content \n which is also multiline " - "another line here \n this is last line. EOF" - ) + @pytest.fixture + def create_remote_file_and_cleanup(self): + with open(self.test_remote_filepath, 'wb') as file: + file.write(self.test_remote_file_content) + yield + os.remove(self.test_remote_filepath) + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_pickle_file_transfer_get(self, dag_maker, create_remote_file_and_cleanup): with dag_maker(dag_id="unit_tests_sftp_op_pickle_file_transfer_get"): - SSHOperator( # Create a test file on remote. - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - ) SFTPOperator( # Get remote file to local. task_id="test_sftp", ssh_hook=self.hook, @@ -222,24 +222,13 @@ def test_pickle_file_transfer_get(self, dag_maker): ti.run() # Test the received content. - with open(self.test_local_filepath) as file: + with open(self.test_local_filepath, 'rb') as file: content_received = file.read() - assert content_received.strip() == test_remote_file_content + assert content_received == self.test_remote_file_content @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_file_transfer_get(self, dag_maker): - test_remote_file_content = ( - "This is remote file content \n which is also multiline " - "another line here \n this is last line. EOF" - ) - + def test_json_file_transfer_get(self, dag_maker, create_remote_file_and_cleanup): with dag_maker(dag_id="unit_tests_sftp_op_json_file_transfer_get"): - SSHOperator( # Create a test file on remote. - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - ) SFTPOperator( # Get remote file to local. task_id="test_sftp", ssh_hook=self.hook, @@ -253,24 +242,13 @@ def test_json_file_transfer_get(self, dag_maker): # Test the received content. content_received = None - with open(self.test_local_filepath) as file: + with open(self.test_local_filepath, 'rb') as file: content_received = file.read() - assert content_received.strip() == test_remote_file_content.encode('utf-8').decode('utf-8') + assert content_received == self.test_remote_file_content @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_no_intermediate_dir_error_get(self, dag_maker): - test_remote_file_content = ( - "This is remote file content \n which is also multiline " - "another line here \n this is last line. EOF" - ) - + def test_file_transfer_no_intermediate_dir_error_get(self, dag_maker, create_remote_file_and_cleanup): with dag_maker(dag_id="unit_tests_sftp_op_file_transfer_no_intermediate_dir_error_get"): - SSHOperator( # Create a test file on remote. - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - ) SFTPOperator( # Try to GET test file from remote. task_id="test_sftp", ssh_hook=self.hook, @@ -279,29 +257,16 @@ def test_file_transfer_no_intermediate_dir_error_get(self, dag_maker): operation=SFTPOperation.GET, ) - ti1, ti2 = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances - ti1.run() - - # This should raise an error with "No such file" as the directory - # does not exist. - with pytest.raises(Exception) as ctx: - ti2.run() - assert 'No such file' in str(ctx.value) + for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances: + # This should raise an error with "No such file" as the directory + # does not exist. + with pytest.raises(AirflowException) as ctx: + ti.run() + assert 'No such file' in str(ctx.value) @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker): - test_remote_file_content = ( - "This is remote file content \n which is also multiline " - "another line here \n this is last line. EOF" - ) - + def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker, create_remote_file_and_cleanup): with dag_maker(dag_id="unit_tests_sftp_op_file_transfer_with_intermediate_dir_error_get"): - SSHOperator( # Create a test file on remote. - task_id="test_create_file", - ssh_hook=self.hook, - command=f"echo '{test_remote_file_content}' > {self.test_remote_filepath}", - do_xcom_push=True, - ) SFTPOperator( # Get remote file to local. task_id="test_sftp", ssh_hook=self.hook, @@ -316,9 +281,9 @@ def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker): # Test the received content. content_received = None - with open(self.test_local_filepath_int_dir) as file: + with open(self.test_local_filepath_int_dir, 'rb') as file: content_received = file.read() - assert content_received.strip() == test_remote_file_content + assert content_received == self.test_remote_file_content @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) def test_arg_checking(self):