diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 32ac7cce53a7c..cdb23e248519f 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -34,7 +34,7 @@ Changelog `Apache Airflow providers support policy `_. .. warning:: When deferrable mode was introduced for ``RedshiftDataOperator``, in version 8.17.0, tasks configured with - ``deferrable=True`` and ``wait_for_completion=True`` wouldn't enter the deferred state. Instead, the task would occupy + ``deferrable=True`` and ``wait_for_completion=True`` would not enter the deferred state. Instead, the task would occupy an executor slot until the statement was completed. A workaround may have been to set ``wait_for_completion=False``. In this version, tasks set up with ``wait_for_completion=False`` will not wait anymore, regardless of the value of ``deferrable``. diff --git a/airflow/providers/amazon/aws/operators/glue_databrew.py b/airflow/providers/amazon/aws/operators/glue_databrew.py index 8ea24d49ffcbb..e4d774ae40f32 100644 --- a/airflow/providers/amazon/aws/operators/glue_databrew.py +++ b/airflow/providers/amazon/aws/operators/glue_databrew.py @@ -17,20 +17,22 @@ # under the License. from __future__ import annotations -from functools import cached_property +import warnings from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.models import BaseOperator +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class GlueDataBrewStartJobOperator(BaseOperator): +class GlueDataBrewStartJobOperator(AwsBaseOperator[GlueDataBrewHook]): """ Start an AWS Glue DataBrew job. @@ -47,36 +49,55 @@ class GlueDataBrewStartJobOperator(BaseOperator): :param deferrable: If True, the operator will wait asynchronously for the job to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) - :param delay: Time in seconds to wait between status checks. Default is 30. + :param delay: Time in seconds to wait between status checks. (Deprecated). + :param waiter_delay: Time in seconds to wait between status checks. Default is 30. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 60) :return: dictionary with key run_id and value of the resulting job's run_id. + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ( + aws_hook_class = GlueDataBrewHook + + template_fields: Sequence[str] = aws_template_fields( "job_name", "wait_for_completion", - "delay", - "deferrable", + "waiter_delay", + "waiter_max_attemptsdeferrable", ) def __init__( self, job_name: str, wait_for_completion: bool = True, - delay: int = 30, + delay: int | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), - aws_conn_id: str | None = "aws_default", **kwargs, ): super().__init__(**kwargs) self.job_name = job_name self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable - self.delay = delay - self.aws_conn_id = aws_conn_id - - @cached_property - def hook(self) -> GlueDataBrewHook: - return GlueDataBrewHook(aws_conn_id=self.aws_conn_id) + if delay is not None: + warnings.warn( + "please use `waiter_delay` instead of delay.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.waiter_delay = delay def execute(self, context: Context): job = self.hook.conn.start_job_run(Name=self.job_name) @@ -88,7 +109,14 @@ def execute(self, context: Context): self.log.info("Deferring job %s with run_id %s", self.job_name, run_id) self.defer( trigger=GlueDataBrewJobCompleteTrigger( - aws_conn_id=self.aws_conn_id, job_name=self.job_name, run_id=run_id, delay=self.delay + job_name=self.job_name, + run_id=run_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", ) @@ -97,7 +125,12 @@ def execute(self, context: Context): self.log.info( "Waiting for AWS Glue DataBrew Job: %s. Run Id: %s to complete.", self.job_name, run_id ) - status = self.hook.job_completion(job_name=self.job_name, delay=self.delay, run_id=run_id) + status = self.hook.job_completion( + job_name=self.job_name, + delay=self.waiter_delay, + run_id=run_id, + max_attempts=self.waiter_max_attempts, + ) self.log.info("Glue DataBrew Job: %s status: %s", self.job_name, status) return {"run_id": run_id} @@ -105,6 +138,9 @@ def execute(self, context: Context): def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]: event = validate_execute_complete_event(event) + if event["status"] != "success": + raise AirflowException("Error while running AWS Glue DataBrew job: %s", event) + run_id = event.get("run_id", "") status = event.get("status", "") diff --git a/airflow/providers/amazon/aws/triggers/glue_databrew.py b/airflow/providers/amazon/aws/triggers/glue_databrew.py index 0110817078493..a57dc8a0d0359 100644 --- a/airflow/providers/amazon/aws/triggers/glue_databrew.py +++ b/airflow/providers/amazon/aws/triggers/glue_databrew.py @@ -17,6 +17,9 @@ from __future__ import annotations +import warnings + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger @@ -27,8 +30,10 @@ class GlueDataBrewJobCompleteTrigger(AwsBaseWaiterTrigger): :param job_name: Glue DataBrew job name :param run_id: the ID of the specific run to watch for that job - :param delay: Number of seconds to wait between two checks. Default is 10 seconds. - :param max_attempts: Maximum number of attempts to wait for the job to complete. Default is 60 attempts. + :param delay: Number of seconds to wait between two checks.(Deprecated). + :param waiter_delay: Number of seconds to wait between two checks. Default is 30 seconds. + :param max_attempts: Maximum number of attempts to wait for the job to complete.(Deprecated). + :param waiter_max_attempts: Maximum number of attempts to wait for the job to complete. Default is 60 attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -36,11 +41,27 @@ def __init__( self, job_name: str, run_id: str, - aws_conn_id: str | None, - delay: int = 10, - max_attempts: int = 60, + delay: int | None = None, + max_attempts: int | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 60, + aws_conn_id: str | None = "aws_default", **kwargs, ): + if delay is not None: + warnings.warn( + "please use `waiter_delay` instead of delay.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_delay = delay or waiter_delay + if max_attempts is not None: + warnings.warn( + "please use `waiter_max_attempts` instead of max_attempts.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + waiter_max_attempts = max_attempts or waiter_max_attempts super().__init__( serialized_fields={"job_name": job_name, "run_id": run_id}, waiter_name="job_complete", @@ -50,10 +71,16 @@ def __init__( status_queries=["State"], return_value=run_id, return_key="run_id", - waiter_delay=delay, - waiter_max_attempts=max_attempts, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, + **kwargs, ) def hook(self) -> GlueDataBrewHook: - return GlueDataBrewHook(aws_conn_id=self.aws_conn_id) + return GlueDataBrewHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/docs/apache-airflow-providers-amazon/operators/glue_databrew.rst b/docs/apache-airflow-providers-amazon/operators/glue_databrew.rst index d14f89848653d..be654335ea1d0 100644 --- a/docs/apache-airflow-providers-amazon/operators/glue_databrew.rst +++ b/docs/apache-airflow-providers-amazon/operators/glue_databrew.rst @@ -31,6 +31,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_glue_databrew.py b/tests/providers/amazon/aws/operators/test_glue_databrew.py index 53a323e6f0bd6..0e88d477f3380 100644 --- a/tests/providers/amazon/aws/operators/test_glue_databrew.py +++ b/tests/providers/amazon/aws/operators/test_glue_databrew.py @@ -23,6 +23,7 @@ import pytest from moto import mock_aws +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook from airflow.providers.amazon.aws.operators.glue_databrew import GlueDataBrewStartJobOperator @@ -36,6 +37,30 @@ def hook() -> Generator[GlueDataBrewHook, None, None]: class TestGlueDataBrewOperator: + def test_init(self): + op = GlueDataBrewStartJobOperator( + task_id="task_test", + job_name=JOB_NAME, + aws_conn_id="fake-conn-id", + region_name="eu-central-1", + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, + ) + + assert op.hook.client_type == "databrew" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-central-1" + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = GlueDataBrewStartJobOperator(task_id="fake_task_id", job_name=JOB_NAME) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + @mock.patch.object(GlueDataBrewHook, "conn") @mock.patch.object(GlueDataBrewHook, "get_waiter") def test_start_job_wait_for_completion(self, mock_hook_get_waiter, mock_conn): @@ -57,3 +82,22 @@ def test_start_job_no_wait(self, mock_hook_get_waiter, mock_conn): mock_conn.start_job_run(mock.MagicMock(), return_value=TEST_RUN_ID) operator.execute(None) mock_hook_get_waiter.assert_not_called() + + @mock.patch.object(GlueDataBrewHook, "conn") + @mock.patch.object(GlueDataBrewHook, "get_waiter") + def test_start_job_with_deprecation_parameters(self, mock_hook_get_waiter, mock_conn): + TEST_RUN_ID = "12345" + + with pytest.warns(AirflowProviderDeprecationWarning): + operator = GlueDataBrewStartJobOperator( + task_id="task_test", + job_name=JOB_NAME, + wait_for_completion=False, + aws_conn_id="aws_default", + delay=15, + ) + + mock_conn.start_job_run(mock.MagicMock(), return_value=TEST_RUN_ID) + assert operator.waiter_delay == 15 + operator.execute(None) + mock_hook_get_waiter.assert_not_called() diff --git a/tests/providers/amazon/aws/triggers/test_glue_databrew.py b/tests/providers/amazon/aws/triggers/test_glue_databrew.py index 09137a0c7d6ac..c39892c247f4e 100644 --- a/tests/providers/amazon/aws/triggers/test_glue_databrew.py +++ b/tests/providers/amazon/aws/triggers/test_glue_databrew.py @@ -18,6 +18,7 @@ import pytest +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger TEST_JOB_NAME = "test_job_name" @@ -44,3 +45,24 @@ def test_serialize(self, trigger): assert class_path == class_path2 assert args == args2 + + def test_serialize_with_deprecated_parameters(self, trigger): + with pytest.warns(AirflowProviderDeprecationWarning): + class_path, args = GlueDataBrewJobCompleteTrigger( + aws_conn_id="aws_default", + job_name=TEST_JOB_NAME, + run_id=TEST_JOB_RUN_ID, + delay=1, + max_attempts=1, + ).serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 + assert args.get("waiter_delay") == 1 + assert args.get("waiter_max_attempts") == 1 diff --git a/tests/system/providers/amazon/aws/example_glue_databrew.py b/tests/system/providers/amazon/aws/example_glue_databrew.py index d44ebe383fd63..251c7611b3be3 100644 --- a/tests/system/providers/amazon/aws/example_glue_databrew.py +++ b/tests/system/providers/amazon/aws/example_glue_databrew.py @@ -120,7 +120,7 @@ def delete_job(job_name: str): ) # [START howto_operator_glue_databrew_start] - start_job = GlueDataBrewStartJobOperator(task_id="startjob", job_name=job_name, delay=15) + start_job = GlueDataBrewStartJobOperator(task_id="startjob", job_name=job_name, waiter_delay=15) # [END howto_operator_glue_databrew_start] delete_bucket = S3DeleteBucketOperator(