diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 653885b541114..127ee07a60bbd 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -59,7 +59,8 @@ class S3ToRedshiftOperator(BaseOperator): - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. - :param column_list: list of column names to load + :param column_list: list of column names to load source data fields into specific target columns + https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-column-mapping.html#copy-column-list :param copy_options: reference to a list of COPY options :param method: Action to be performed on execution. Available ``APPEND``, ``UPSERT`` and ``REPLACE``. :param upsert_keys: List of fields to use as key on upsert action @@ -204,18 +205,13 @@ def execute(self, context: Context) -> None: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will query destination table.""" - from pathlib import Path - from airflow.providers.amazon.aws.utils.openlineage import ( get_facets_from_redshift_table, - get_identity_column_lineage_facet, ) from airflow.providers.common.compat.openlineage.facet import ( Dataset, - Identifier, LifecycleStateChange, LifecycleStateChangeDatasetFacet, - SymlinksDatasetFacet, ) from airflow.providers.openlineage.extractors import OperatorLineage @@ -235,36 +231,8 @@ def get_openlineage_facets_on_complete(self, task_instance): database = redshift_sql_hook.conn.schema authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority output_dataset_facets = get_facets_from_redshift_table( - redshift_sql_hook, self.table, self.redshift_data_api_kwargs, self.schema - ) - - input_dataset_facets = {} - if not self.column_list: - # If column_list is not specified, then we know that input file matches columns of output table. - input_dataset_facets["schema"] = output_dataset_facets["schema"] - - dataset_name = self.s3_key - if "*" in dataset_name: - # If wildcard ("*") is used in s3 path, we want the name of dataset to be directory name, - # but we create a symlink to the full object path with wildcard. - input_dataset_facets["symlink"] = SymlinksDatasetFacet( - identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", name=dataset_name, type="file")] + redshift_sql_hook, self.table, {}, self.schema ) - dataset_name = Path(dataset_name).parent.as_posix() - if dataset_name == ".": - # blob path does not have leading slash, but we need root dataset name to be "/" - dataset_name = "/" - - input_dataset = Dataset( - namespace=f"s3://{self.s3_bucket}", - name=dataset_name, - facets=input_dataset_facets, - ) - - output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet( - field_names=[field.name for field in output_dataset_facets["schema"].fields], - input_datasets=[input_dataset], - ) if self.method == "REPLACE": output_dataset_facets["lifecycleStateChange"] = LifecycleStateChangeDatasetFacet( @@ -277,4 +245,9 @@ def get_openlineage_facets_on_complete(self, task_instance): facets=output_dataset_facets, ) + input_dataset = Dataset( + namespace=f"s3://{self.s3_bucket}", + name=self.s3_key, + ) + return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset]) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index f554ce8699971..b80c5991626c0 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -26,7 +26,13 @@ from airflow.exceptions import AirflowException from airflow.models.connection import Connection from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator -from airflow.providers.common.compat.openlineage.facet import LifecycleStateChange +from airflow.providers.common.compat.openlineage.facet import ( + DocumentationDatasetFacet, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces @@ -502,8 +508,9 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") @mock.patch("boto3.session.Session") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") def test_get_openlineage_facets_on_complete_default( - self, mock_run, mock_session, mock_connection, mock_hook + self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook ): access_key = "aws_access_key_id" secret_key = "aws_secret_access_key" @@ -515,6 +522,11 @@ def test_get_openlineage_facets_on_complete_default( mock_connection.return_value = mock.MagicMock( schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} ) + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets schema = "schema" table = "table" @@ -531,33 +543,30 @@ def test_get_openlineage_facets_on_complete_default( redshift_conn_id="redshift_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - dag=None, ) op.execute(None) lineage = op.get_openlineage_facets_on_complete(None) - # Hook called two times - on operator execution, and on querying data in redshift to fetch schema - assert mock_run.call_count == 2 + # Hook called only one time - on operator execution - we mocked querying to fetch schema + assert mock_run.call_count == 1 assert len(lineage.inputs) == 1 assert len(lineage.outputs) == 1 assert lineage.inputs[0].name == s3_key + assert lineage.inputs[0].namespace == f"s3://{s3_bucket}" assert lineage.outputs[0].name == f"database.{schema}.{table}" assert lineage.outputs[0].namespace == "redshift://cluster.region:5439" - assert lineage.outputs[0].facets.get("schema") is not None - assert lineage.outputs[0].facets.get("columnLineage") is not None - - assert lineage.inputs[0].facets.get("schema") is not None - # As method was not overwrite, there should be no lifecycleStateChange facet - assert "lifecycleStateChange" not in lineage.outputs[0].facets + assert lineage.outputs[0].facets == mock_facets + assert lineage.inputs[0].facets == {} @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") @mock.patch("boto3.session.Session") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") def test_get_openlineage_facets_on_complete_replace( - self, mock_run, mock_session, mock_connection, mock_hook + self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook ): access_key = "aws_access_key_id" secret_key = "aws_secret_access_key" @@ -569,6 +578,11 @@ def test_get_openlineage_facets_on_complete_replace( mock_connection.return_value = mock.MagicMock( schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} ) + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets schema = "schema" table = "table" @@ -586,59 +600,25 @@ def test_get_openlineage_facets_on_complete_replace( redshift_conn_id="redshift_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - dag=None, ) op.execute(None) lineage = op.get_openlineage_facets_on_complete(None) - assert ( - lineage.outputs[0].facets["lifecycleStateChange"].lifecycleStateChange - == LifecycleStateChange.OVERWRITE - ) - - @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") - @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") - @mock.patch("boto3.session.Session") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") - def test_get_openlineage_facets_on_complete_column_list( - self, mock_run, mock_session, mock_connection, mock_hook - ): - access_key = "aws_access_key_id" - secret_key = "aws_secret_access_key" - mock_session.return_value = Session(access_key, secret_key) - mock_session.return_value.access_key = access_key - mock_session.return_value.secret_key = secret_key - mock_session.return_value.token = None - - mock_connection.return_value = mock.MagicMock( - schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} - ) - - schema = "schema" - table = "table" - s3_bucket = "bucket" - s3_key = "key" - copy_options = "" - - op = S3ToRedshiftOperator( - schema=schema, - table=table, - s3_bucket=s3_bucket, - s3_key=s3_key, - copy_options=copy_options, - column_list=["column1", "column2"], - redshift_conn_id="redshift_conn_id", - aws_conn_id="aws_conn_id", - task_id="task_id", - dag=None, - ) - op.execute(None) - - lineage = op.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0].name == s3_key + assert lineage.inputs[0].namespace == f"s3://{s3_bucket}" + assert lineage.outputs[0].name == f"database.{schema}.{table}" + assert lineage.outputs[0].namespace == "redshift://cluster.region:5439" - assert lineage.outputs[0].facets.get("schema") is not None - assert lineage.inputs[0].facets.get("schema") is None + assert lineage.outputs[0].facets == { + **mock_facets, + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.OVERWRITE + ), + } + assert lineage.inputs[0].facets == {} @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") @@ -648,8 +628,9 @@ def test_get_openlineage_facets_on_complete_column_list( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", new_callable=mock.PropertyMock, ) + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") def test_get_openlineage_facets_on_complete_using_redshift_data_api( - self, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook + self, mock_get_facets, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook ): """ Using the Redshift Data API instead of the SQL-based connection @@ -666,6 +647,11 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api( mock_rs.describe_statement.return_value = {"Status": "FINISHED"} mock_rs_region.return_value = "region" + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets schema = "schema" table = "table" @@ -689,7 +675,7 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api( redshift_conn_id="redshift_conn_id", aws_conn_id="aws_conn_id", task_id="task_id", - dag=None, + method="REPLACE", redshift_data_api_kwargs=dict( database=database, cluster_identifier=cluster_identifier, @@ -705,15 +691,17 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api( assert len(lineage.inputs) == 1 assert len(lineage.outputs) == 1 assert lineage.inputs[0].name == s3_key + assert lineage.inputs[0].namespace == f"s3://{s3_bucket}" assert lineage.outputs[0].name == f"database.{schema}.{table}" assert lineage.outputs[0].namespace == "redshift://cluster.region:5439" - assert lineage.outputs[0].facets.get("schema") is not None - assert lineage.outputs[0].facets.get("columnLineage") is not None - - assert lineage.inputs[0].facets.get("schema") is not None - # As method was not overwrite, there should be no lifecycleStateChange facet - assert "lifecycleStateChange" not in lineage.outputs[0].facets + assert lineage.outputs[0].facets == { + **mock_facets, + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.OVERWRITE + ), + } + assert lineage.inputs[0].facets == {} @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") @@ -724,8 +712,9 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api( "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", new_callable=mock.PropertyMock, ) + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( - self, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook + self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook ): """ Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage. @@ -745,6 +734,11 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( mock_rs.describe_statement.return_value = {"Status": "FINISHED"} mock_rs_region.return_value = "region" + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets schema = "schema" table = "table" @@ -794,13 +788,9 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( op_rs_sql.execute(None) rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None) - assert rs_sql_lineage.inputs == rs_data_lineage.inputs + assert len(rs_sql_lineage.inputs) == 1 assert len(rs_sql_lineage.outputs) == 1 - assert len(rs_data_lineage.outputs) == 1 - assert rs_sql_lineage.outputs[0].facets["schema"] == rs_data_lineage.outputs[0].facets["schema"] - assert ( - rs_sql_lineage.outputs[0].facets["columnLineage"] - == rs_data_lineage.outputs[0].facets["columnLineage"] - ) - assert rs_sql_lineage.outputs[0].name == rs_data_lineage.outputs[0].name - assert rs_sql_lineage.outputs[0].namespace == rs_data_lineage.outputs[0].namespace + assert rs_sql_lineage.inputs == rs_data_lineage.inputs + assert rs_sql_lineage.outputs == rs_data_lineage.outputs + assert rs_sql_lineage.job_facets == rs_data_lineage.job_facets + assert rs_sql_lineage.run_facets == rs_data_lineage.run_facets