diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index 3ec5b2a987c68..443b5912c5593 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -24,7 +24,6 @@ from botocore.exceptions import ClientError -from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.utils.helpers import prune_dict @@ -253,10 +252,41 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) - @cached_property - def conn(self): - """Get the underlying boto3 EmrServerlessAPIService client (cached)""" - return super().conn + def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): + """ + List all jobs in an intermediate state and cancel them. + Then wait for those jobs to reach a terminal state. + Note: if new jobs are triggered while this operation is ongoing, + it's going to time out and return an error. + """ + paginator = self.conn.get_paginator("list_job_runs") + results_per_response = 50 + iterator = paginator.paginate( + applicationId=application_id, + states=list(self.JOB_INTERMEDIATE_STATES), + PaginationConfig={ + "PageSize": results_per_response, + }, + ) + count = 0 + for r in iterator: + job_ids = [jr["id"] for jr in r["jobRuns"]] + count += len(job_ids) + if len(job_ids) > 0: + self.log.info( + "Cancelling %s pending job(s) for the application %s so that it can be stopped", + len(job_ids), + application_id, + ) + for job_id in job_ids: + self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id) + if count > 0: + self.log.info("now waiting for the %s cancelled job(s) to terminate", count) + self.get_waiter("no_job_running").wait( + applicationId=application_id, + states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})), + WaiterConfig=waiter_config, + ) class EmrContainerHook(AwsBaseHook): diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index f5f23cccf6531..220a4ddea0032 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1025,6 +1025,10 @@ class EmrServerlessStopApplicationOperator(BaseOperator): the application be stopped. Defaults to 5 minutes. :param waiter_check_interval_seconds: Number of seconds between polling the state of the application. Defaults to 30 seconds. + :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. + Otherwise, trying to stop an app with running jobs will return an error. + If you want to wait for the jobs to finish gracefully, use + :class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) @@ -1036,6 +1040,7 @@ def __init__( aws_conn_id: str = "aws_default", waiter_countdown: int = 5 * 60, waiter_check_interval_seconds: int = 30, + force_stop: bool = False, **kwargs, ): self.aws_conn_id = aws_conn_id @@ -1043,6 +1048,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_countdown = waiter_countdown self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.force_stop = force_stop super().__init__(**kwargs) @cached_property @@ -1052,6 +1058,16 @@ def hook(self) -> EmrServerlessHook: def execute(self, context: Context) -> None: self.log.info("Stopping application: %s", self.application_id) + + if self.force_stop: + self.hook.cancel_running_jobs( + self.application_id, + waiter_config={ + "Delay": self.waiter_check_interval_seconds, + "MaxAttempts": self.waiter_countdown / self.waiter_check_interval_seconds, + }, + ) + self.hook.conn.stop_application(applicationId=self.application_id) if self.wait_for_completion: @@ -1088,6 +1104,10 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato the application to be stopped, and then deleted. Defaults to 25 minutes. :param waiter_check_interval_seconds: Number of seconds between polling the state of the application. Defaults to 60 seconds. + :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. + Otherwise, trying to delete an app with running jobs will return an error. + If you want to wait for the jobs to finish gracefully, use + :class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) @@ -1099,6 +1119,7 @@ def __init__( aws_conn_id: str = "aws_default", waiter_countdown: int = 25 * 60, waiter_check_interval_seconds: int = 60, + force_stop: bool = False, **kwargs, ): self.wait_for_delete_completion = wait_for_completion @@ -1110,6 +1131,7 @@ def __init__( aws_conn_id=aws_conn_id, waiter_countdown=waiter_countdown, waiter_check_interval_seconds=waiter_check_interval_seconds, + force_stop=force_stop, **kwargs, ) diff --git a/airflow/providers/amazon/aws/waiters/emr-serverless.json b/airflow/providers/amazon/aws/waiters/emr-serverless.json new file mode 100644 index 0000000000000..a77d07f243687 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/emr-serverless.json @@ -0,0 +1,18 @@ +{ + "version": 2, + "waiters": { + "no_job_running": { + "operation": "ListJobRuns", + "delay": 10, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "length(jobRuns) == `0`", + "expected": true, + "state": "success" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/hooks/test_emr_serverless.py b/tests/providers/amazon/aws/hooks/test_emr_serverless.py index c77cdb77af03b..33726bc62b5dd 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_serverless.py +++ b/tests/providers/amazon/aws/hooks/test_emr_serverless.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest.mock import MagicMock, PropertyMock, patch + from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook task_id = "test_emr_serverless_create_application_operator" @@ -34,3 +36,42 @@ def test_conn_attribute(self): conn = hook.conn conn2 = hook.conn assert conn is conn2 + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs(self, conn_mock: MagicMock): + conn_mock().get_paginator().paginate.return_value = [{"jobRuns": [{"id": "job1"}, {"id": "job2"}]}] + hook = EmrServerlessHook(aws_conn_id="aws_default") + waiter_mock = MagicMock() + hook.get_waiter = waiter_mock + + hook.cancel_running_jobs("app") + + assert conn_mock().cancel_job_run.call_count == 2 + conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job1") + conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job2") + waiter_mock.assert_called_with("no_job_running") + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs_several_calls(self, conn_mock: MagicMock): + conn_mock().get_paginator().paginate.return_value = [ + {"jobRuns": [{"id": "job1"}, {"id": "job2"}]}, + {"jobRuns": [{"id": "job3"}, {"id": "job4"}]}, + ] + hook = EmrServerlessHook(aws_conn_id="aws_default") + waiter_mock = MagicMock() + hook.get_waiter = waiter_mock + + hook.cancel_running_jobs("app") + + assert conn_mock().cancel_job_run.call_count == 4 + waiter_mock.assert_called_once() # we should wait once for all jobs, not once per page + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs_but_no_jobs(self, conn_mock: MagicMock): + conn_mock.return_value.list_job_runs.return_value = {"jobRuns": []} + hook = EmrServerlessHook(aws_conn_id="aws_default") + + hook.cancel_running_jobs("app") + + # nothing very interesting should happen + conn_mock.assert_called_once() diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 698710f0a1c7d..99ae2de04db7b 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -17,7 +17,7 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, PropertyMock from uuid import UUID import pytest @@ -654,3 +654,16 @@ def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock): mock_waiter.assert_not_called() mock_conn.stop_application.assert_called_once() + + @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") + @mock.patch.object(EmrServerlessStopApplicationOperator, "hook", new_callable=PropertyMock) + def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock): + operator = EmrServerlessStopApplicationOperator( + task_id=task_id, application_id="test", force_stop=True + ) + + operator.execute(None) + + mock_hook().cancel_running_jobs.assert_called_once() + mock_hook().conn.stop_application.assert_called_once() + mock_waiter.assert_called_once() diff --git a/tests/system/providers/amazon/aws/example_emr_serverless.py b/tests/system/providers/amazon/aws/example_emr_serverless.py index 8bd20a0dccc3c..0c0a19e17bffb 100644 --- a/tests/system/providers/amazon/aws/example_emr_serverless.py +++ b/tests/system/providers/amazon/aws/example_emr_serverless.py @@ -100,13 +100,15 @@ configuration_overrides=SPARK_CONFIGURATION_OVERRIDES, ) # [END howto_operator_emr_serverless_start_job] - start_job.waiter_check_interval_seconds = 10 + start_job.wait_for_completion = False # [START howto_sensor_emr_serverless_job] wait_for_job = EmrServerlessJobSensor( task_id="wait_for_job", application_id=emr_serverless_app_id, job_run_id=start_job.output, + # the default is to wait for job completion, here we just wait for the job to be running. + target_states={"RUNNING"}, ) # [END howto_sensor_emr_serverless_job] wait_for_job.poke_interval = 10 @@ -115,6 +117,7 @@ stop_app = EmrServerlessStopApplicationOperator( task_id="stop_application", application_id=emr_serverless_app_id, + force_stop=True, ) # [END howto_operator_emr_serverless_stop_application] stop_app.waiter_check_interval_seconds = 1