diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py index d4f15fac1bf9d..d1492fe2480f0 100644 --- a/airflow/providers/google/cloud/sensors/bigquery.py +++ b/airflow/providers/google/cloud/sensors/bigquery.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.triggers.bigquery import ( BigQueryTableExistenceTrigger, @@ -141,8 +141,16 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event["status"] == "success": return event["message"] + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "No event received in trigger callback" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) class BigQueryTablePartitionExistenceSensor(BaseSensorOperator): @@ -248,8 +256,17 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event["status"] == "success": return event["message"] + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "No event received in trigger callback" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor): diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py index 34198d2819bea..b4926b3b95f99 100644 --- a/airflow/providers/google/cloud/sensors/bigquery_dts.py +++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py @@ -23,7 +23,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.bigquery_datatransfer_v1 import TransferState -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook from airflow.sensors.base import BaseSensorOperator @@ -140,5 +140,9 @@ def poke(self, context: Context) -> bool: self.log.info("Status of %s run: %s", self.run_id, str(run.state)) if run.state in (TransferState.FAILED, TransferState.CANCELLED): - raise AirflowException(f"Transfer {self.run_id} did not succeed") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Transfer {self.run_id} did not succeed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return run.state in self.expected_statuses diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py b/airflow/providers/google/cloud/sensors/cloud_composer.py index ecd717aa5a152..1873b51d68022 100644 --- a/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger from airflow.sensors.base import BaseSensorOperator @@ -90,5 +90,14 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event.get("operation_done"): return event["operation_done"] + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "No event received in trigger callback" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) diff --git a/airflow/providers/google/cloud/sensors/dataflow.py b/airflow/providers/google/cloud/sensors/dataflow.py index 187b4c00070ab..c9f32588d54a6 100644 --- a/airflow/providers/google/cloud/sensors/dataflow.py +++ b/airflow/providers/google/cloud/sensors/dataflow.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Callable, Sequence -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataflow import ( DEFAULT_DATAFLOW_LOCATION, DataflowHook, @@ -106,7 +106,11 @@ def poke(self, context: Context) -> bool: if job_status in self.expected_statuses: return True elif job_status in DataflowJobStatus.TERMINAL_STATES: - raise AirflowException(f"Job with id '{self.job_id}' is already in terminal state: {job_status}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return False @@ -178,9 +182,11 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - raise AirflowException( - f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) result = self.hook.fetch_job_metrics_by_id( job_id=self.job_id, @@ -257,9 +263,11 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - raise AirflowException( - f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) result = self.hook.fetch_job_messages_by_id( job_id=self.job_id, @@ -336,9 +344,11 @@ def poke(self, context: Context) -> bool: ) job_status = job["currentState"] if job_status in DataflowJobStatus.TERMINAL_STATES: - raise AirflowException( - f"Job with id '{self.job_id}' is already in terminal state: {job_status}" - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) result = self.hook.fetch_job_autoscaling_events_by_id( job_id=self.job_id, diff --git a/airflow/providers/google/cloud/sensors/dataform.py b/airflow/providers/google/cloud/sensors/dataform.py index 965e9c5fe2cde..45c74627a7c52 100644 --- a/airflow/providers/google/cloud/sensors/dataform.py +++ b/airflow/providers/google/cloud/sensors/dataform.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataform import DataformHook from airflow.sensors.base import BaseSensorOperator @@ -95,9 +95,13 @@ def poke(self, context: Context) -> bool: workflow_status = workflow_invocation.state if workflow_status is not None: if self.failure_statuses and workflow_status in self.failure_statuses: - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Workflow Invocation with id '{self.workflow_invocation_id}' " f"state is: {workflow_status}. Terminating sensor..." ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return workflow_status in self.expected_statuses diff --git a/airflow/providers/google/cloud/sensors/datafusion.py b/airflow/providers/google/cloud/sensors/datafusion.py index 8297d60f4403a..b151a6fae7ea0 100644 --- a/airflow/providers/google/cloud/sensors/datafusion.py +++ b/airflow/providers/google/cloud/sensors/datafusion.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Iterable, Sequence -from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook from airflow.sensors.base import BaseSensorOperator @@ -109,15 +109,23 @@ def poke(self, context: Context) -> bool: ) pipeline_status = pipeline_workflow["status"] except AirflowNotFoundException: - raise AirflowException("Specified Pipeline ID was not found.") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "Specified Pipeline ID was not found." + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) except AirflowException: pass # Because the pipeline may not be visible in system yet if pipeline_status is not None: if self.failure_statuses and pipeline_status in self.failure_statuses: - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Pipeline with id '{self.pipeline_id}' state is: {pipeline_status}. " f"Terminating sensor..." ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.debug( "Current status of the pipeline workflow for %s: %s.", self.pipeline_id, pipeline_status diff --git a/airflow/providers/google/cloud/sensors/dataplex.py b/airflow/providers/google/cloud/sensors/dataplex.py index c00373f9476a4..ee0ffc7410d9e 100644 --- a/airflow/providers/google/cloud/sensors/dataplex.py +++ b/airflow/providers/google/cloud/sensors/dataplex.py @@ -24,11 +24,12 @@ from google.api_core.retry import Retry from airflow.utils.context import Context + from google.api_core.exceptions import GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.dataplex_v1.types import DataScanJob -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataplex import ( AirflowDataQualityScanException, AirflowDataQualityScanResultTimeoutException, @@ -116,7 +117,11 @@ def poke(self, context: Context) -> bool: task_status = task.state if task_status == TaskState.DELETING: - raise AirflowException(f"Task is going to be deleted {self.dataplex_task_id}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Task is going to be deleted {self.dataplex_task_id}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.info("Current status of the Dataplex task %s => %s", self.dataplex_task_id, task_status) @@ -196,9 +201,13 @@ def poke(self, context: Context) -> bool: if self.result_timeout: duration = self._duration() if duration > self.result_timeout: - raise AirflowDataQualityScanResultTimeoutException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Timeout: Data Quality scan {self.job_id} is not ready after {self.result_timeout}s" ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowDataQualityScanResultTimeoutException(message) hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -217,22 +226,36 @@ def poke(self, context: Context) -> bool: metadata=self.metadata, ) except GoogleAPICallError as e: - raise AirflowException( - f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}", e - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Error occurred when trying to retrieve Data Quality scan job: {self.data_scan_id}" + if self.soft_fail: + raise AirflowSkipException(message, e) + raise AirflowException(message, e) job_status = job.state self.log.info( "Current status of the Dataplex Data Quality scan job %s => %s", self.job_id, job_status ) if job_status == DataScanJob.State.FAILED: - raise AirflowException(f"Data Quality scan job failed: {self.job_id}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Data Quality scan job failed: {self.job_id}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if job_status == DataScanJob.State.CANCELLED: - raise AirflowException(f"Data Quality scan job cancelled: {self.job_id}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Data Quality scan job cancelled: {self.job_id}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if self.fail_on_dq_failure: if job_status == DataScanJob.State.SUCCEEDED and not job.data_quality_result.passed: - raise AirflowDataQualityScanException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Data Quality job {self.job_id} execution failed due to failure of its scanning " f"rules: {self.data_scan_id}" ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowDataQualityScanException(message) return job_status == DataScanJob.State.SUCCEEDED diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index b3f87a83b53e4..2acd695dba110 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -24,7 +24,7 @@ from google.api_core.exceptions import ServerError from google.cloud.dataproc_v1.types import Batch, JobStatus -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataproc import DataprocHook from airflow.sensors.base import BaseSensorOperator @@ -83,10 +83,14 @@ def poke(self, context: Context) -> bool: duration = self._duration() self.log.info("DURATION RUN: %f", duration) if duration > self.wait_timeout: - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Timeout: dataproc job {self.dataproc_job_id} " f"is not ready after {self.wait_timeout}s" ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err) return False else: @@ -94,13 +98,21 @@ def poke(self, context: Context) -> bool: state = job.status.state if state == JobStatus.State.ERROR: - raise AirflowException(f"Job failed:\n{job}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job failed:\n{job}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: - raise AirflowException(f"Job was cancelled:\n{job}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Job was cancelled:\n{job}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif JobStatus.State.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True @@ -171,12 +183,20 @@ def poke(self, context: Context) -> bool: state = batch.state if state == Batch.State.FAILED: - raise AirflowException("Batch failed") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "Batch failed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif state in { Batch.State.CANCELLED, Batch.State.CANCELLING, }: - raise AirflowException("Batch was cancelled.") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "Batch was cancelled." + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif state == Batch.State.SUCCEEDED: self.log.debug("Batch %s completed successfully.", self.batch_id) return True diff --git a/airflow/providers/google/cloud/sensors/dataproc_metastore.py b/airflow/providers/google/cloud/sensors/dataproc_metastore.py index c50c8f1a8b9ab..ccb222645287d 100644 --- a/airflow/providers/google/cloud/sensors/dataproc_metastore.py +++ b/airflow/providers/google/cloud/sensors/dataproc_metastore.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook from airflow.providers.google.cloud.hooks.gcs import parse_json_from_gcs from airflow.sensors.base import BaseSensorOperator @@ -95,13 +95,21 @@ def poke(self, context: Context) -> bool: self.log.info("Extracting result manifest") manifest: dict = parse_json_from_gcs(gcp_conn_id=self.gcp_conn_id, file_uri=result_manifest_uri) if not (manifest and isinstance(manifest, dict)): - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Failed to extract result manifest. " f"Expected not empty dict, but this was received: {manifest}" ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if manifest.get("status", {}).get("code") != 0: - raise AirflowException(f"Request failed: {manifest.get('message')}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Request failed: {manifest.get('message')}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) # Extract actual query results result_base_uri = result_manifest_uri.rsplit("/", 1)[0] diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index 2fb220ab27726..453bb3bf44000 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -27,7 +27,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.triggers.gcs import ( GCSBlobTrigger, @@ -125,6 +125,9 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> str: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event["status"] == "error": + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) self.log.info("File %s was found in bucket %s.", self.object, self.bucket) return event["message"] @@ -259,8 +262,16 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None "Checking last updated time for object %s in bucket : %s", self.object, self.bucket ) return event["message"] + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "No event received in trigger callback" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator): @@ -347,6 +358,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[ self.log.info("Resuming from trigger and checking status") if event["status"] == "success": return event["matches"] + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) @@ -476,10 +490,14 @@ def is_bucket_updated(self, current_objects: set[str]) -> bool: ) return False - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( "Illegal behavior: objects were deleted in " f"{os.path.join(self.bucket, self.prefix)} between pokes." ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if self.last_activity_time: self.inactivity_seconds = (get_time() - self.last_activity_time).total_seconds() @@ -549,5 +567,13 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None if event: if event["status"] == "success": return event["message"] + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "No event received in trigger callback" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) diff --git a/airflow/providers/google/cloud/sensors/looker.py b/airflow/providers/google/cloud/sensors/looker.py index e75d0fb665ac8..55257346275b7 100644 --- a/airflow/providers/google/cloud/sensors/looker.py +++ b/airflow/providers/google/cloud/sensors/looker.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.looker import JobStatus, LookerHook from airflow.sensors.base import BaseSensorOperator @@ -50,11 +50,14 @@ def __init__( self.hook: LookerHook | None = None def poke(self, context: Context) -> bool: - self.hook = LookerHook(looker_conn_id=self.looker_conn_id) if not self.materialization_id: - raise AirflowException("Invalid `materialization_id`.") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "Invalid `materialization_id`." + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) # materialization_id is templated var pulling output from start task status_dict = self.hook.pdt_build_status(materialization_id=self.materialization_id) @@ -62,17 +65,23 @@ def poke(self, context: Context) -> bool: if status == JobStatus.ERROR.value: msg = status_dict["message"] - raise AirflowException( - f'PDT materialization job failed. Job id: {self.materialization_id}. Message:\n"{msg}"' - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f'PDT materialization job failed. Job id: {self.materialization_id}. Message:\n"{msg}"' + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif status == JobStatus.CANCELLED.value: - raise AirflowException( - f"PDT materialization job was cancelled. Job id: {self.materialization_id}." - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"PDT materialization job was cancelled. Job id: {self.materialization_id}." + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif status == JobStatus.UNKNOWN.value: - raise AirflowException( - f"PDT materialization job has unknown status. Job id: {self.materialization_id}." - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"PDT materialization job has unknown status. Job id: {self.materialization_id}." + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) elif status == JobStatus.DONE.value: self.log.debug( "PDT materialization job completed successfully. Job id: %s.", self.materialization_id diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index 7bd07a08e5059..b4b92889342cf 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -24,7 +24,7 @@ from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.pubsub import PubSubHook from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger from airflow.sensors.base import BaseSensorOperator @@ -174,6 +174,9 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[ self.log.info("Sensor pulls messages: %s", event["message"]) return event["message"] self.log.info("Sensor failed: %s", event["message"]) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) def _default_message_callback( diff --git a/airflow/providers/google/cloud/sensors/workflows.py b/airflow/providers/google/cloud/sensors/workflows.py index 712e328bdd670..7f97fafdbb9f7 100644 --- a/airflow/providers/google/cloud/sensors/workflows.py +++ b/airflow/providers/google/cloud/sensors/workflows.py @@ -21,7 +21,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.workflows.executions_v1beta import Execution -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook from airflow.sensors.base import BaseSensorOperator @@ -100,10 +100,14 @@ def poke(self, context: Context): state = execution.state if state in self.failure_states: - raise AirflowException( + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = ( f"Execution {self.execution_id} for workflow {self.execution_id} " - f"failed and is in `{state}` state", + f"failed and is in `{state}` state" ) + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if state in self.success_states: self.log.info( diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py index 5fe40227c53ae..ec489329fb02d 100644 --- a/tests/providers/google/cloud/sensors/test_bigquery.py +++ b/tests/providers/google/cloud/sensors/test_bigquery.py @@ -20,7 +20,12 @@ import pytest -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowProviderDeprecationWarning, + AirflowSkipException, + TaskDeferred, +) from airflow.providers.google.cloud.sensors.bigquery import ( BigQueryTableExistenceAsyncSensor, BigQueryTableExistencePartitionAsyncSensor, @@ -100,16 +105,20 @@ def test_execute_deferred(self, mock_hook): exc.value.trigger, BigQueryTableExistenceTrigger ), "Trigger is not a BigQueryTableExistenceTrigger" - def test_execute_deferred_failure(self): - """Tests that an AirflowException is raised in case of error event""" + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_execute_deferred_failure(self, soft_fail, expected_exception): + """Tests that an expected exception is raised in case of error event""" task = BigQueryTableExistenceSensor( task_id="task-id", project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, deferrable=True, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) def test_execute_complete(self): @@ -126,15 +135,19 @@ def test_execute_complete(self): task.execute_complete(context={}, event={"status": "success", "message": "Job completed"}) mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) - def test_execute_defered_complete_event_none(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_execute_defered_complete_event_none(self, soft_fail, expected_exception): """Asserts that logging occurs as expected""" task = BigQueryTableExistenceSensor( task_id="task-id", project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event=None) @@ -206,7 +219,10 @@ def test_execute_with_deferrable_mode(self, mock_hook): exc.value.trigger, BigQueryTablePartitionExistenceTrigger ), "Trigger is not a BigQueryTablePartitionExistenceTrigger" - def test_execute_with_deferrable_mode_execute_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_execute_with_deferrable_mode_execute_failure(self, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event""" task = BigQueryTablePartitionExistenceSensor( task_id="test_task_id", @@ -215,11 +231,15 @@ def test_execute_with_deferrable_mode_execute_failure(self): table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, deferrable=True, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) - def test_execute_complete_event_none(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_execute_complete_event_none(self, soft_fail, expected_exception): """Asserts that logging occurs as expected""" task = BigQueryTablePartitionExistenceSensor( task_id="task-id", @@ -228,8 +248,9 @@ def test_execute_complete_event_none(self): table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, deferrable=True, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="No event received in trigger callback"): + with pytest.raises(expected_exception, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_execute_complete(self): @@ -287,16 +308,20 @@ def test_big_query_table_existence_sensor_async(self, mock_hook): exc.value.trigger, BigQueryTableExistenceTrigger ), "Trigger is not a BigQueryTableExistenceTrigger" - def test_big_query_table_existence_sensor_async_execute_failure(self): - """Tests that an AirflowException is raised in case of error event""" + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_big_query_table_existence_sensor_async_execute_failure(self, soft_fail, expected_exception): + """Tests that an expected_exception is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistenceAsyncSensor( task_id="task-id", project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) def test_big_query_table_existence_sensor_async_execute_complete(self): @@ -313,7 +338,10 @@ def test_big_query_table_existence_sensor_async_execute_complete(self): task.execute_complete(context={}, event={"status": "success", "message": "Job completed"}) mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) - def test_big_query_sensor_async_execute_complete_event_none(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_big_query_sensor_async_execute_complete_event_none(self, soft_fail, expected_exception): """Asserts that logging occurs as expected""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistenceAsyncSensor( @@ -321,8 +349,9 @@ def test_big_query_sensor_async_execute_complete_event_none(self): project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event=None) @@ -355,8 +384,13 @@ def test_big_query_table_existence_partition_sensor_async(self, mock_hook): exc.value.trigger, BigQueryTablePartitionExistenceTrigger ), "Trigger is not a BigQueryTablePartitionExistenceTrigger" - def test_big_query_table_existence_partition_sensor_async_execute_failure(self): - """Tests that an AirflowException is raised in case of error event""" + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_big_query_table_existence_partition_sensor_async_execute_failure( + self, soft_fail, expected_exception + ): + """Tests that an expected exception is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistencePartitionAsyncSensor( task_id="test_task_id", @@ -364,11 +398,17 @@ def test_big_query_table_existence_partition_sensor_async_execute_failure(self): dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context={}, event={"status": "error", "message": "test failure message"}) - def test_big_query_table_existence_partition_sensor_async_execute_complete_event_none(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_big_query_table_existence_partition_sensor_async_execute_complete_event_none( + self, soft_fail, expected_exception + ): """Asserts that logging occurs as expected""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = BigQueryTableExistencePartitionAsyncSensor( @@ -377,8 +417,9 @@ def test_big_query_table_existence_partition_sensor_async_execute_complete_event dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID, partition_id=TEST_PARTITION_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="No event received in trigger callback"): + with pytest.raises(expected_exception, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_big_query_table_existence_partition_sensor_async_execute_complete(self): diff --git a/tests/providers/google/cloud/sensors/test_bigtable.py b/tests/providers/google/cloud/sensors/test_bigtable.py index 37bd5eaf8fa06..dea84fb9f9515 100644 --- a/tests/providers/google/cloud/sensors/test_bigtable.py +++ b/tests/providers/google/cloud/sensors/test_bigtable.py @@ -24,7 +24,7 @@ from google.cloud.bigtable.instance import Instance from google.cloud.bigtable.table import ClusterState -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor PROJECT_ID = "test_project_id" @@ -35,6 +35,9 @@ class BigtableWaitForTableReplicationTest: + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @pytest.mark.parametrize( "missing_attribute, project_id, instance_id, table_id", [ @@ -43,8 +46,10 @@ class BigtableWaitForTableReplicationTest: ], ) @mock.patch("airflow.providers.google.cloud.sensors.bigtable.BigtableHook") - def test_empty_attribute(self, missing_attribute, project_id, instance_id, table_id, mock_hook): - with pytest.raises(AirflowException) as ctx: + def test_empty_attribute( + self, missing_attribute, project_id, instance_id, table_id, mock_hook, soft_fail, expected_exception + ): + with pytest.raises(expected_exception) as ctx: BigtableTableReplicationCompletedSensor( project_id=project_id, instance_id=instance_id, @@ -52,6 +57,7 @@ def test_empty_attribute(self, missing_attribute, project_id, instance_id, table task_id="id", gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) err = ctx.value assert str(err) == f"Empty parameter: {missing_attribute}" diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py b/tests/providers/google/cloud/sensors/test_cloud_composer.py index 8062da44b9b9a..f6f3e81a4019b 100644 --- a/tests/providers/google/cloud/sensors/test_cloud_composer.py +++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerEnvironmentSensor from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger @@ -48,15 +48,19 @@ def test_cloud_composer_existence_sensor_async(self): exc.value.trigger, CloudComposerExecutionTrigger ), "Trigger is not a CloudComposerExecutionTrigger" - def test_cloud_composer_existence_sensor_async_execute_failure(self): - """Tests that an AirflowException is raised in case of error event.""" + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_cloud_composer_existence_sensor_async_execute_failure(self, soft_fail, expected_exception): + """Tests that an expected exception is raised in case of error event.""" task = CloudComposerEnvironmentSensor( task_id="task_id", project_id=TEST_PROJECT_ID, region=TEST_REGION, operation_name=TEST_OPERATION_NAME, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="No event received in trigger callback"): + with pytest.raises(expected_exception, match="No event received in trigger callback"): task.execute_complete(context={}, event=None) def test_cloud_composer_existence_sensor_async_execute_complete(self): diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py b/tests/providers/google/cloud/sensors/test_dataflow.py index 36d8840c815f2..d669b2b11190d 100644 --- a/tests/providers/google/cloud/sensors/test_dataflow.py +++ b/tests/providers/google/cloud/sensors/test_dataflow.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.sensors.dataflow import ( DataflowJobAutoScalingEventsSensor, @@ -71,8 +71,11 @@ def test_poke(self, mock_hook, expected_status, current_status, sensor_return): job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception(self, mock_hook): + def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): mock_get_job = mock_hook.return_value.get_job task = DataflowJobStatusSensor( task_id=TEST_TASK_ID, @@ -82,11 +85,12 @@ def test_poke_raise_exception(self, mock_hook): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED} with pytest.raises( - AirflowException, + expected_exception, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_CANCELLED}", ): @@ -182,8 +186,11 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): ) callback.assert_called_once_with(mock_fetch_job_messages_by_id.return_value) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception(self, mock_hook): + def test_poke_raise_exception(self, mock_hook, soft_fail, expected_exception): mock_get_job = mock_hook.return_value.get_job mock_fetch_job_messages_by_id = mock_hook.return_value.fetch_job_messages_by_id callback = mock.MagicMock() @@ -197,11 +204,12 @@ def test_poke_raise_exception(self, mock_hook): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} with pytest.raises( - AirflowException, + expected_exception, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_DONE}", ): @@ -255,8 +263,11 @@ def test_poke(self, mock_hook, job_current_state, fail_on_terminal_state): ) callback.assert_called_once_with(mock_fetch_job_autoscaling_events_by_id.return_value) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook") - def test_poke_raise_exception_on_terminal_state(self, mock_hook): + def test_poke_raise_exception_on_terminal_state(self, mock_hook, soft_fail, expected_exception): mock_get_job = mock_hook.return_value.get_job mock_fetch_job_autoscaling_events_by_id = mock_hook.return_value.fetch_job_autoscaling_events_by_id callback = mock.MagicMock() @@ -270,11 +281,12 @@ def test_poke_raise_exception_on_terminal_state(self, mock_hook): project_id=TEST_PROJECT_ID, gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE} with pytest.raises( - AirflowException, + expected_exception, match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: " f"{DataflowJobStatus.JOB_STATE_DONE}", ): diff --git a/tests/providers/google/cloud/sensors/test_datafusion.py b/tests/providers/google/cloud/sensors/test_datafusion.py index 32dcfbb0508b8..6de24b794376b 100644 --- a/tests/providers/google/cloud/sensors/test_datafusion.py +++ b/tests/providers/google/cloud/sensors/test_datafusion.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowNotFoundException +from airflow.exceptions import AirflowException, AirflowNotFoundException, AirflowSkipException from airflow.providers.google.cloud.hooks.datafusion import PipelineStates from airflow.providers.google.cloud.sensors.datafusion import CloudDataFusionPipelineStateSensor @@ -74,8 +74,11 @@ def test_poke(self, mock_hook, expected_status, current_status, sensor_return): instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook") - def test_assertion(self, mock_hook): + def test_assertion(self, mock_hook, soft_fail, expected_exception): mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} task = CloudDataFusionPipelineStateSensor( @@ -89,17 +92,21 @@ def test_assertion(self, mock_hook): location=LOCATION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) with pytest.raises( - AirflowException, + expected_exception, match=f"Pipeline with id '{PIPELINE_ID}' state is: FAILED. Terminating sensor...", ): mock_hook.return_value.get_pipeline_workflow.return_value = {"status": "FAILED"} task.poke(mock.MagicMock()) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook") - def test_not_found_exception(self, mock_hook): + def test_not_found_exception(self, mock_hook, soft_fail, expected_exception): mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} mock_hook.return_value.get_pipeline_workflow.side_effect = AirflowNotFoundException() @@ -114,10 +121,11 @@ def test_not_found_exception(self, mock_hook): location=LOCATION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) with pytest.raises( - AirflowException, + expected_exception, match="Specified Pipeline ID was not found.", ): task.poke(mock.MagicMock()) diff --git a/tests/providers/google/cloud/sensors/test_dataplex.py b/tests/providers/google/cloud/sensors/test_dataplex.py index 18f5b68b9da6e..20a4de4ff0279 100644 --- a/tests/providers/google/cloud/sensors/test_dataplex.py +++ b/tests/providers/google/cloud/sensors/test_dataplex.py @@ -22,7 +22,7 @@ from google.api_core.gapic_v1.method import DEFAULT from google.cloud.dataplex_v1.types import DataScanJob -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.dataplex import AirflowDataQualityScanResultTimeoutException from airflow.providers.google.cloud.sensors.dataplex import ( DataplexDataQualityJobStatusSensor, @@ -81,8 +81,11 @@ def test_done(self, mock_hook): assert result + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPLEX_HOOK) - def test_deleting(self, mock_hook): + def test_deleting(self, mock_hook, soft_fail, expected_exception): task = self.create_task(TaskState.DELETING) mock_hook.return_value.get_task.return_value = task @@ -95,9 +98,10 @@ def test_deleting(self, mock_hook): api_version=API_VERSION, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="Task is going to be deleted"): + with pytest.raises(expected_exception, match="Task is going to be deleted"): sensor.poke(context={}) mock_hook.return_value.get_task.assert_called_once_with( diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py index 9080705ebd5f0..f123976be9cef 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc.py +++ b/tests/providers/google/cloud/sensors/test_dataproc.py @@ -23,7 +23,7 @@ from google.api_core.exceptions import ServerError from google.cloud.dataproc_v1.types import Batch, JobStatus -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.sensors.dataproc import DataprocBatchSensor, DataprocJobSensor from airflow.version import version as airflow_version @@ -66,8 +66,11 @@ def test_done(self, mock_hook): ) assert ret + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_error(self, mock_hook): + def test_error(self, mock_hook, soft_fail, expected_exception): job = self.create_job(JobStatus.State.ERROR) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -79,9 +82,10 @@ def test_error(self, mock_hook): dataproc_job_id=job_id, gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="Job failed"): + with pytest.raises(expected_exception, match="Job failed"): sensor.poke(context={}) mock_hook.return_value.get_job.assert_called_once_with( @@ -109,8 +113,11 @@ def test_wait(self, mock_hook): ) assert not ret + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_cancelled(self, mock_hook): + def test_cancelled(self, mock_hook, soft_fail, expected_exception): job = self.create_job(JobStatus.State.CANCELLED) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -122,8 +129,9 @@ def test_cancelled(self, mock_hook): dataproc_job_id=job_id, gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="Job was cancelled"): + with pytest.raises(expected_exception, match="Job was cancelled"): sensor.poke(context={}) mock_hook.return_value.get_job.assert_called_once_with( @@ -163,8 +171,11 @@ def test_wait_timeout(self, mock_hook): result = sensor.poke(context={}) assert not result + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_wait_timeout_raise_exception(self, mock_hook): + def test_wait_timeout_raise_exception(self, mock_hook, soft_fail, expected_exception): job_id = "job_id" mock_hook.return_value.get_job.side_effect = ServerError("Job are not ready") @@ -176,12 +187,13 @@ def test_wait_timeout_raise_exception(self, mock_hook): gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, wait_timeout=300, + soft_fail=soft_fail, ) sensor._duration = Mock() sensor._duration.return_value = 301 - with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"): + with pytest.raises(expected_exception, match="Timeout: dataproc job job_id is not ready after 300s"): sensor.poke(context={}) @@ -212,8 +224,11 @@ def test_succeeded(self, mock_hook): ) assert ret + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_cancelled(self, mock_hook): + def test_cancelled(self, mock_hook, soft_fail, expected_exception): batch = self.create_batch(Batch.State.CANCELLED) mock_hook.return_value.get_batch.return_value = batch @@ -224,16 +239,20 @@ def test_cancelled(self, mock_hook): batch_id="batch_id", gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="Batch was cancelled."): + with pytest.raises(expected_exception, match="Batch was cancelled."): sensor.poke(context={}) mock_hook.return_value.get_batch.assert_called_once_with( batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_error(self, mock_hook): + def test_error(self, mock_hook, soft_fail, expected_exception): batch = self.create_batch(Batch.State.FAILED) mock_hook.return_value.get_batch.return_value = batch @@ -244,9 +263,10 @@ def test_error(self, mock_hook): batch_id="batch_id", gcp_conn_id=GCP_CONN_ID, timeout=TIMEOUT, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="Batch failed"): + with pytest.raises(expected_exception, match="Batch failed"): sensor.poke(context={}) mock_hook.return_value.get_batch.assert_called_once_with( diff --git a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py index 117210e2bae3c..435ceac661881 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py +++ b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.sensors.dataproc_metastore import MetastoreHivePartitionSensor DATAPROC_METASTORE_SENSOR_PATH = "airflow.providers.google.cloud.sensors.dataproc_metastore.{}" @@ -106,14 +106,14 @@ def test_poke_positive_manifest( ) assert sensor.poke(context={}) == expected_result + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @pytest.mark.parametrize("empty_manifest", [dict(), list(), tuple(), None, ""]) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook")) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs")) def test_poke_empty_manifest( - self, - mock_parse_json_from_gcs, - mock_hook, - empty_manifest, + self, mock_parse_json_from_gcs, mock_hook, empty_manifest, soft_fail, expected_exception ): mock_parse_json_from_gcs.return_value = empty_manifest @@ -124,18 +124,18 @@ def test_poke_empty_manifest( table=TEST_TABLE, partitions=[PARTITION_1], gcp_conn_id=GCP_CONN_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): sensor.poke(context={}) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook")) @mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs")) - def test_poke_wrong_status( - self, - mock_parse_json_from_gcs, - mock_hook, - ): + def test_poke_wrong_status(self, mock_parse_json_from_gcs, mock_hook, soft_fail, expected_exception): error_message = "Test error message" mock_parse_json_from_gcs.return_value = {"code": 1, "message": error_message} @@ -146,7 +146,8 @@ def test_poke_wrong_status( table=TEST_TABLE, partitions=[PARTITION_1], gcp_conn_id=GCP_CONN_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match=f"Request failed: {error_message}"): + with pytest.raises(expected_exception, match=f"Request failed: {error_message}"): sensor.poke(context={}) diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index 4bb3646152ecb..422cd8f71a39e 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -24,7 +24,12 @@ import pytest from google.cloud.storage.retry import DEFAULT_RETRY -from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowSensorTimeout, TaskDeferred +from airflow.exceptions import ( + AirflowProviderDeprecationWarning, + AirflowSensorTimeout, + AirflowSkipException, + TaskDeferred, +) from airflow.models.dag import DAG, AirflowException from airflow.providers.google.cloud.sensors.gcs import ( GCSObjectExistenceAsyncSensor, @@ -135,7 +140,10 @@ def test_gcs_object_existence_sensor_deferred(self, mock_hook): task.execute(context) assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger" - def test_gcs_object_existence_sensor_deferred_execute_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_gcs_object_existence_sensor_deferred_execute_failure(self, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event when deferrable is set to True""" task = GCSObjectExistenceSensor( task_id="task-id", @@ -143,8 +151,9 @@ def test_gcs_object_existence_sensor_deferred_execute_failure(self): object=TEST_OBJECT, google_cloud_conn_id=TEST_GCP_CONN_ID, deferrable=True, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) def test_gcs_object_existence_sensor_execute_complete(self): @@ -185,7 +194,10 @@ def test_gcs_object_existence_async_sensor(self, mock_hook): task.execute(context) assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger" - def test_gcs_object_existence_async_sensor_execute_failure(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_gcs_object_existence_async_sensor_execute_failure(self, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event""" with pytest.warns(AirflowProviderDeprecationWarning, match=self.depcrecation_message): task = GCSObjectExistenceAsyncSensor( @@ -193,8 +205,9 @@ def test_gcs_object_existence_async_sensor_execute_failure(self): bucket=TEST_BUCKET, object=TEST_OBJECT, google_cloud_conn_id=TEST_GCP_CONN_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) def test_gcs_object_existence_async_sensor_execute_complete(self): @@ -289,10 +302,13 @@ def test_gcs_object_update_async_sensor(self, mock_hook): exc.value.trigger, GCSCheckBlobUpdateTimeTrigger ), "Trigger is not a GCSCheckBlobUpdateTimeTrigger" - def test_gcs_object_update_async_sensor_execute_failure(self, context): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_gcs_object_update_async_sensor_execute_failure(self, context, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): + self.OPERATOR.soft_fail = soft_fail + with pytest.raises(expected_exception): self.OPERATOR.execute_complete( context=context, event={"status": "error", "message": "test failure message"} ) @@ -364,13 +380,21 @@ def test_execute(self, mock_hook): mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX) assert response == generated_messages + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowSensorTimeout), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") - def test_execute_timeout(self, mock_hook): + def test_execute_timeout(self, mock_hook, soft_fail, expected_exception): task = GCSObjectsWithPrefixExistenceSensor( - task_id="task-id", bucket=TEST_BUCKET, prefix=TEST_PREFIX, poke_interval=0, timeout=1 + task_id="task-id", + bucket=TEST_BUCKET, + prefix=TEST_PREFIX, + poke_interval=0, + timeout=1, + soft_fail=soft_fail, ) mock_hook.return_value.list.return_value = [] - with pytest.raises(AirflowSensorTimeout): + with pytest.raises(expected_exception): task.execute(mock.MagicMock) mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX) @@ -410,10 +434,15 @@ def test_gcs_object_with_prefix_existence_async_sensor(self, mock_hook): self.OPERATOR.execute(mock.MagicMock()) assert isinstance(exc.value.trigger, GCSPrefixBlobTrigger), "Trigger is not a GCSPrefixBlobTrigger" - def test_gcs_object_with_prefix_existence_async_sensor_execute_failure(self, context): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_gcs_object_with_prefix_existence_async_sensor_execute_failure( + self, context, soft_fail, expected_exception + ): """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): + self.OPERATOR.soft_fail = soft_fail + with pytest.raises(expected_exception): self.OPERATOR.execute_complete( context=context, event={"status": "error", "message": "test failure message"} ) @@ -461,10 +490,14 @@ def test_get_gcs_hook(self, mock_hook): ) assert mock_hook.return_value == self.sensor.hook + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time", mock_time) - def test_files_deleted_between_pokes_throw_error(self): + def test_files_deleted_between_pokes_throw_error(self, soft_fail, expected_exception): + self.sensor.soft_fail = soft_fail self.sensor.is_bucket_updated({"a", "b"}) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): self.sensor.is_bucket_updated({"a"}) @mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time", mock_time) @@ -549,10 +582,14 @@ def test_gcs_upload_session_complete_async_sensor(self, mock_hook): exc.value.trigger, GCSUploadSessionTrigger ), "Trigger is not a GCSUploadSessionTrigger" - def test_gcs_upload_session_complete_sensor_execute_failure(self, context): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_gcs_upload_session_complete_sensor_execute_failure(self, context, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event""" - with pytest.raises(AirflowException): + self.OPERATOR.soft_fail = soft_fail + with pytest.raises(expected_exception): self.OPERATOR.execute_complete( context=context, event={"status": "error", "message": "test failure message"} ) diff --git a/tests/providers/google/cloud/sensors/test_looker.py b/tests/providers/google/cloud/sensors/test_looker.py index 567340f0864d6..8e352340552a1 100644 --- a/tests/providers/google/cloud/sensors/test_looker.py +++ b/tests/providers/google/cloud/sensors/test_looker.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.hooks.looker import JobStatus from airflow.providers.google.cloud.sensors.looker import LookerCheckPdtBuildSensor @@ -51,8 +51,11 @@ def test_done(self, mock_hook): # assert we got a response assert ret + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(SENSOR_PATH.format("LookerHook")) - def test_error(self, mock_hook): + def test_error(self, mock_hook, soft_fail, expected_exception): mock_hook.return_value.pdt_build_status.return_value = { "status": JobStatus.ERROR.value, "message": "test", @@ -63,9 +66,10 @@ def test_error(self, mock_hook): task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id=TEST_JOB_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="PDT materialization job failed"): + with pytest.raises(expected_exception, match="PDT materialization job failed"): sensor.poke(context={}) # assert hook.pdt_build_status called once @@ -89,8 +93,11 @@ def test_wait(self, mock_hook): # assert we got NO response assert not ret + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(SENSOR_PATH.format("LookerHook")) - def test_cancelled(self, mock_hook): + def test_cancelled(self, mock_hook, soft_fail, expected_exception): mock_hook.return_value.pdt_build_status.return_value = {"status": JobStatus.CANCELLED.value} # run task in mock context @@ -98,22 +105,23 @@ def test_cancelled(self, mock_hook): task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id=TEST_JOB_ID, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException, match="PDT materialization job was cancelled"): + with pytest.raises(expected_exception, match="PDT materialization job was cancelled"): sensor.poke(context={}) # assert hook.pdt_build_status called once mock_hook.return_value.pdt_build_status.assert_called_once_with(materialization_id=TEST_JOB_ID) - def test_empty_materialization_id(self): - + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_empty_materialization_id(self, soft_fail, expected_exception): # run task in mock context sensor = LookerCheckPdtBuildSensor( - task_id=TASK_ID, - looker_conn_id=LOOKER_CONN_ID, - materialization_id="", + task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID, materialization_id="", soft_fail=soft_fail ) - with pytest.raises(AirflowException, match="^Invalid `materialization_id`.$"): + with pytest.raises(expected_exception, match="^Invalid `materialization_id`.$"): sensor.poke(context={}) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index 88fa3e296f8d5..dcecce218e788 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -23,7 +23,7 @@ import pytest from google.cloud.pubsub_v1.types import ReceivedMessage -from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger @@ -98,19 +98,23 @@ def test_execute(self, mock_hook): ) assert generated_dicts == response + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowSensorTimeout), (True, AirflowSkipException)) + ) @mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook") - def test_execute_timeout(self, mock_hook): + def test_execute_timeout(self, mock_hook, soft_fail, expected_exception): operator = PubSubPullSensor( task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, poke_interval=0, timeout=1, + soft_fail=soft_fail, ) mock_hook.return_value.pull.return_value = [] - with pytest.raises(AirflowSensorTimeout): + with pytest.raises(expected_exception): operator.execute({}) mock_hook.return_value.pull.assert_called_once_with( project_id=TEST_PROJECT, @@ -173,7 +177,10 @@ def test_pubsub_pull_sensor_async(self): task.execute(context={}) assert isinstance(exc.value.trigger, PubsubPullTrigger), "Trigger is not a PubsubPullTrigger" - def test_pubsub_pull_sensor_async_execute_should_throw_exception(self): + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_pubsub_pull_sensor_async_execute_should_throw_exception(self, soft_fail, expected_exception): """Tests that an AirflowException is raised in case of error event""" operator = PubSubPullSensor( @@ -182,9 +189,10 @@ def test_pubsub_pull_sensor_async_execute_should_throw_exception(self): project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, deferrable=True, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): operator.execute_complete( context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} ) diff --git a/tests/providers/google/cloud/sensors/test_workflows.py b/tests/providers/google/cloud/sensors/test_workflows.py index 12d66ac62ddef..232c1db0e0fd7 100644 --- a/tests/providers/google/cloud/sensors/test_workflows.py +++ b/tests/providers/google/cloud/sensors/test_workflows.py @@ -21,7 +21,7 @@ import pytest from google.cloud.workflows.executions_v1beta import Execution -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.google.cloud.sensors.workflows import WorkflowExecutionSensor BASE_PATH = "airflow.providers.google.cloud.sensors.workflows.{}" @@ -90,8 +90,11 @@ def test_poke_wait(self, mock_hook): assert result is False + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) @mock.patch(BASE_PATH.format("WorkflowsHook")) - def test_poke_failure(self, mock_hook): + def test_poke_failure(self, mock_hook, soft_fail, expected_exception): mock_hook.return_value.get_execution.return_value = mock.MagicMock(state=Execution.State.FAILED) op = WorkflowExecutionSensor( task_id="test_task", @@ -104,6 +107,7 @@ def test_poke_failure(self, mock_hook): metadata=METADATA, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + soft_fail=soft_fail, ) - with pytest.raises(AirflowException): + with pytest.raises(expected_exception): op.poke({})