diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py index 4c13db45b2bf8..6df7139ab3603 100644 --- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py +++ b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py @@ -16,7 +16,7 @@ # under the License. """ -This is an example dag for ECSOperator. +This is an example dag for EcsOperator. The task "hello_world" runs `hello-world` task in `c` cluster. It overrides the command in the `hello-world-container` container. @@ -26,7 +26,7 @@ import os from airflow import DAG -from airflow.providers.amazon.aws.operators.ecs import ECSOperator +from airflow.providers.amazon.aws.operators.ecs import EcsOperator dag = DAG( dag_id="ecs_fargate_dag", @@ -40,7 +40,7 @@ dag.doc_md = __doc__ # [START howto_operator_ecs] -hello_world = ECSOperator( +hello_world = EcsOperator( task_id="hello_world", dag=dag, aws_conn_id="aws_ecs", diff --git a/airflow/providers/amazon/aws/exceptions.py b/airflow/providers/amazon/aws/exceptions.py index d0e5b54a0748e..477075d36c9c6 100644 --- a/airflow/providers/amazon/aws/exceptions.py +++ b/airflow/providers/amazon/aws/exceptions.py @@ -18,12 +18,29 @@ # # Note: Any AirflowException raised is expected to cause the TaskInstance # to be marked in an ERROR state +import warnings -class ECSOperatorError(Exception): +class EcsOperatorError(Exception): """Raise when ECS cannot handle the request.""" def __init__(self, failures: list, message: str): self.failures = failures self.message = message super().__init__(message) + + +class ECSOperatorError(EcsOperatorError): + """ + This class is deprecated. + Please use :class:`airflow.providers.amazon.aws.exceptions.EcsOperatorError`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "This class is deprecated. " + "Please use `airflow.providers.amazon.aws.exceptions.EcsOperatorError`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index f651cc7dd758d..b34337fe9b92c 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -18,6 +18,7 @@ import re import sys import time +import warnings from collections import deque from datetime import datetime, timedelta from logging import Logger @@ -29,7 +30,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, XCom -from airflow.providers.amazon.aws.exceptions import ECSOperatorError +from airflow.providers.amazon.aws.exceptions import EcsOperatorError from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.typing_compat import Protocol, runtime_checkable @@ -38,7 +39,7 @@ def should_retry(exception: Exception): """Check if exception is related to ECS resource quota (CPU, MEM).""" - if isinstance(exception, ECSOperatorError): + if isinstance(exception, EcsOperatorError): return any( quota_reason in failure['reason'] for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU'] @@ -48,10 +49,10 @@ def should_retry(exception: Exception): @runtime_checkable -class ECSProtocol(Protocol): +class EcsProtocol(Protocol): """ A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on - :py:meth:`.ECSOperator.client`. + :py:meth:`.EcsOperator.client`. .. seealso:: @@ -84,7 +85,7 @@ def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: ... -class ECSTaskLogFetcher(Thread): +class EcsTaskLogFetcher(Thread): """ Fetches Cloudwatch log events with specific interval as a thread and sends the log events to the info channel of the provided logger. @@ -151,13 +152,13 @@ def stop(self): self._event.set() -class ECSOperator(BaseOperator): +class EcsOperator(BaseOperator): """ Execute a task on AWS ECS (Elastic Container Service) .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:ECSOperator` + :ref:`howto/operator:EcsOperator` :param task_definition: the task definition name on Elastic Container Service :type task_definition: str @@ -289,17 +290,17 @@ def __init__( self.awslogs_region = region_name self.hook: Optional[AwsBaseHook] = None - self.client: Optional[ECSProtocol] = None + self.client: Optional[EcsProtocol] = None self.arn: Optional[str] = None self.retry_args = quota_retry - self.task_log_fetcher: Optional[ECSTaskLogFetcher] = None + self.task_log_fetcher: Optional[EcsTaskLogFetcher] = None @provide_session def execute(self, context, session=None): self.log.info( 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster ) - self.log.info('ECSOperator overrides: %s', self.overrides) + self.log.info('EcsOperator overrides: %s', self.overrides) self.client = self.get_hook().get_conn() @@ -371,7 +372,7 @@ def _start_task(self, context): failures = response['failures'] if len(failures) > 0: - raise ECSOperatorError(failures, response) + raise EcsOperatorError(failures, response) self.log.info('ECS Task started: %s', response) self.arn = response['tasks'][0]['taskArn'] @@ -430,11 +431,12 @@ def _wait_for_task_ended(self) -> None: def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix - def _get_task_log_fetcher(self) -> ECSTaskLogFetcher: + def _get_task_log_fetcher(self) -> EcsTaskLogFetcher: if not self.awslogs_group: raise ValueError("must specify awslogs_group to fetch task logs") log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}" - return ECSTaskLogFetcher( + + return EcsTaskLogFetcher( aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region, log_group=self.awslogs_group, @@ -509,3 +511,50 @@ def on_kill(self) -> None: cluster=self.cluster, task=self.arn, reason='Task killed by the user' ) self.log.info(response) + + +class ECSOperator(EcsOperator): + """ + This operator is deprecated. + Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsOperator`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "This operator is deprecated. " + "Please use `airflow.providers.amazon.aws.operators.ecs.EcsOperator`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + +class ECSTaskLogFetcher(EcsTaskLogFetcher): + """ + This class is deprecated. + Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "This class is deprecated. " + "Please use `airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + + +class ECSProtocol(EcsProtocol): + """ + This class is deprecated. + Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsProtocol`. + """ + + def __init__(self): + warnings.warn( + "This class is deprecated. " + "Please use `airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst b/docs/apache-airflow-providers-amazon/operators/ecs.rst index e89b8d64e9dc9..d24fb5ec9fe15 100644 --- a/docs/apache-airflow-providers-amazon/operators/ecs.rst +++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst @@ -16,7 +16,7 @@ under the License. -.. _howto/operator:ECSOperator: +.. _howto/operator:EcsOperator: ECS Operator ============ @@ -30,14 +30,14 @@ Using Operator -------------- Use the -:class:`~airflow.providers.amazon.aws.operators.ecs.ECSOperator` +:class:`~airflow.providers.amazon.aws.operators.ecs.EcsOperator` to run a task defined in AWS ECS. In the following example, the task "hello_world" runs ``hello-world`` task in ``c`` cluster. It overrides the command in the ``hello-world-container`` container. -Before using ECSOperator, *cluster* and *task definition* need to be created. +Before using EcsOperator, *cluster* and *task definition* need to be created. .. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 9bd1887fb6c0d..fbc147b4090bf 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -28,8 +28,8 @@ from parameterized import parameterized from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.exceptions import ECSOperatorError -from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSTaskLogFetcher, should_retry +from airflow.providers.amazon.aws.exceptions import EcsOperatorError +from airflow.providers.amazon.aws.operators.ecs import EcsOperator, EcsTaskLogFetcher, should_retry # fmt: off RESPONSE_WITHOUT_FAILURES = { @@ -55,7 +55,7 @@ # fmt: on -class TestECSOperator(unittest.TestCase): +class TestEcsOperator(unittest.TestCase): @mock.patch('airflow.providers.amazon.aws.operators.ecs.AwsBaseHook') def set_up_operator(self, aws_hook_mock, **kwargs): self.aws_hook_mock = aws_hook_mock @@ -77,7 +77,7 @@ def set_up_operator(self, aws_hook_mock, **kwargs): }, 'propagate_tags': 'TASK_DEFINITION', } - self.ecs = ECSOperator(**self.ecs_operator_args, **kwargs) + self.ecs = EcsOperator(**self.ecs_operator_args, **kwargs) self.ecs.get_hook() def setUp(self): @@ -163,8 +163,8 @@ def test_template_fields_overrides(self): ], ] ) - @mock.patch.object(ECSOperator, '_wait_for_task_ended') - @mock.patch.object(ECSOperator, '_check_success_task') + @mock.patch.object(EcsOperator, '_wait_for_task_ended') + @mock.patch.object(EcsOperator, '_check_success_task') def test_execute_without_failures( self, launch_type, @@ -214,7 +214,7 @@ def test_execute_with_failures(self): resp_failures['failures'].append('dummy error') client_mock.run_task.return_value = resp_failures - with pytest.raises(ECSOperatorError): + with pytest.raises(EcsOperatorError): self.ecs.execute(None) self.aws_hook_mock.return_value.get_conn.assert_called_once() @@ -409,15 +409,15 @@ def test_check_success_task_not_raises(self): ['', {'testTagKey': 'testTagValue'}], ] ) - @mock.patch.object(ECSOperator, "_xcom_del") + @mock.patch.object(EcsOperator, "_xcom_del") @mock.patch.object( - ECSOperator, + EcsOperator, "xcom_pull", return_value="arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55", ) - @mock.patch.object(ECSOperator, '_wait_for_task_ended') - @mock.patch.object(ECSOperator, '_check_success_task') - @mock.patch.object(ECSOperator, '_start_task') + @mock.patch.object(EcsOperator, '_wait_for_task_ended') + @mock.patch.object(EcsOperator, '_check_success_task') + @mock.patch.object(EcsOperator, '_start_task') def test_reattach_successful( self, launch_type, tags, start_mock, check_mock, wait_mock, xcom_pull_mock, xcom_del_mock ): @@ -467,11 +467,11 @@ def test_reattach_successful( ['', {'testTagKey': 'testTagValue'}], ] ) - @mock.patch.object(ECSOperator, '_xcom_del') - @mock.patch.object(ECSOperator, '_xcom_set') - @mock.patch.object(ECSOperator, '_try_reattach_task') - @mock.patch.object(ECSOperator, '_wait_for_task_ended') - @mock.patch.object(ECSOperator, '_check_success_task') + @mock.patch.object(EcsOperator, '_xcom_del') + @mock.patch.object(EcsOperator, '_xcom_set') + @mock.patch.object(EcsOperator, '_try_reattach_task') + @mock.patch.object(EcsOperator, '_wait_for_task_ended') + @mock.patch.object(EcsOperator, '_check_success_task') def test_reattach_save_task_arn_xcom( self, launch_type, tags, check_mock, wait_mock, reattach_mock, xcom_set_mock, xcom_del_mock ): @@ -532,18 +532,18 @@ def test_execute_xcom_disabled(self): class TestShouldRetry(unittest.TestCase): def test_return_true_on_valid_reason(self): - self.assertTrue(should_retry(ECSOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo'))) + self.assertTrue(should_retry(EcsOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo'))) def test_return_false_on_invalid_reason(self): - self.assertFalse(should_retry(ECSOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo'))) + self.assertFalse(should_retry(EcsOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo'))) -class TestECSTaskLogFetcher(unittest.TestCase): +class TestEcsTaskLogFetcher(unittest.TestCase): @mock.patch('logging.Logger') def set_up_log_fetcher(self, logger_mock): self.logger_mock = logger_mock - self.log_fetcher = ECSTaskLogFetcher( + self.log_fetcher = EcsTaskLogFetcher( log_group="test_log_group", log_stream_name="test_log_stream_name", fetch_interval=timedelta(milliseconds=1), diff --git a/tests/providers/amazon/aws/operators/test_ecs_system.py b/tests/providers/amazon/aws/operators/test_ecs_system.py index 1a6eec70a1d57..a5460b981a670 100644 --- a/tests/providers/amazon/aws/operators/test_ecs_system.py +++ b/tests/providers/amazon/aws/operators/test_ecs_system.py @@ -23,7 +23,7 @@ @pytest.mark.backend("postgres", "mysql") -class ECSSystemTest(AmazonSystemTest): +class EcsSystemTest(AmazonSystemTest): """ ECS System Test to run and test example ECS dags