From 7fbb07c7a466bb6732b1350971723f961db57ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 18 Apr 2023 16:31:47 -0700 Subject: [PATCH 1/9] add an option to EMR operator to force an app to stop --- airflow/providers/amazon/aws/hooks/emr.py | 29 +++++++++++++++++-- airflow/providers/amazon/aws/operators/emr.py | 15 ++++++++++ .../amazon/aws/waiters/emr-serverless.json | 18 ++++++++++++ .../amazon/aws/hooks/test_emr_serverless.py | 26 +++++++++++++++++ .../aws/operators/test_emr_serverless.py | 15 +++++++++- .../amazon/aws/example_emr_serverless.py | 5 +++- 6 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 airflow/providers/amazon/aws/waiters/emr-serverless.json diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index dcb993af79660..fb2aa48e6d863 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -24,7 +24,6 @@ from botocore.exceptions import ClientError -from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.utils.waiter import get_state, waiter @@ -254,7 +253,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) - @cached_property + @property def conn(self): """Get the underlying boto3 EmrServerlessAPIService client (cached)""" return super().conn @@ -315,6 +314,32 @@ def get_state(self, response, keys) -> str: ) return get_state(response=response, keys=keys) + def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): + """ + List all jobs in an intermediate state and cancel them. + Then wait for those jobs to reach a terminal state. + + Note: if new jobs are triggered while this operation is ongoing, + it's going to time out and return an error. + """ + r = self.conn.list_job_runs( + applicationId=application_id, maxResults=50, states=list(self.JOB_INTERMEDIATE_STATES) + ) + job_ids = [jr["id"] for jr in r["jobRuns"]] + self.log.info("there are %s job(s) running for app %s", len(job_ids), application_id) + if len(job_ids) > 0: + self.log.warning( + "Cancelling the pending jobs for the application %s so that it can be stopped", application_id + ) + for job_id in job_ids: + self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id) + self.log.info("now waiting for the cancelled jobs to terminate") + self.get_waiter("no_job_running").wait( + applicationId=application_id, + states=["PENDING", "RUNNING", "SCHEDULED", "SUBMITTED", "CANCELLING"], + WaiterConfig=waiter_config, + ) + class EmrContainerHook(AwsBaseHook): """ diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index f5f23cccf6531..8b2b5aadf4622 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1025,6 +1025,9 @@ class EmrServerlessStopApplicationOperator(BaseOperator): the application be stopped. Defaults to 5 minutes. :param waiter_check_interval_seconds: Number of seconds between polling the state of the application. Defaults to 30 seconds. + :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. + Otherwise, trying to stop an app with running jobs will return an error. + If you want to wait for the jobs to finish gracefully, use :ref:`_howto/sensor:EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) @@ -1036,6 +1039,7 @@ def __init__( aws_conn_id: str = "aws_default", waiter_countdown: int = 5 * 60, waiter_check_interval_seconds: int = 30, + force_stop: bool = False, **kwargs, ): self.aws_conn_id = aws_conn_id @@ -1043,6 +1047,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.waiter_countdown = waiter_countdown self.waiter_check_interval_seconds = waiter_check_interval_seconds + self.force_stop = force_stop super().__init__(**kwargs) @cached_property @@ -1052,6 +1057,16 @@ def hook(self) -> EmrServerlessHook: def execute(self, context: Context) -> None: self.log.info("Stopping application: %s", self.application_id) + + if self.force_stop: + self.hook.cancel_running_jobs( + self.application_id, + waiter_config={ + "Delay": self.waiter_check_interval_seconds, + "MaxAttempts": self.waiter_countdown / self.waiter_check_interval_seconds, + }, + ) + self.hook.conn.stop_application(applicationId=self.application_id) if self.wait_for_completion: diff --git a/airflow/providers/amazon/aws/waiters/emr-serverless.json b/airflow/providers/amazon/aws/waiters/emr-serverless.json new file mode 100644 index 0000000000000..a77d07f243687 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/emr-serverless.json @@ -0,0 +1,18 @@ +{ + "version": 2, + "waiters": { + "no_job_running": { + "operation": "ListJobRuns", + "delay": 10, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "length(jobRuns) == `0`", + "expected": true, + "state": "success" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/hooks/test_emr_serverless.py b/tests/providers/amazon/aws/hooks/test_emr_serverless.py index c77cdb77af03b..dc3b119dea9a8 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_serverless.py +++ b/tests/providers/amazon/aws/hooks/test_emr_serverless.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest.mock import MagicMock, PropertyMock, patch + from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook task_id = "test_emr_serverless_create_application_operator" @@ -34,3 +36,27 @@ def test_conn_attribute(self): conn = hook.conn conn2 = hook.conn assert conn is conn2 + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs(self, conn_mock: MagicMock): + conn_mock().list_job_runs.return_value = {"jobRuns": [{"id": "job1"}, {"id": "job2"}]} + hook = EmrServerlessHook(aws_conn_id="aws_default") + waiter_mock = MagicMock() + hook.get_waiter = waiter_mock + + hook.cancel_running_jobs("app") + + assert conn_mock().cancel_job_run.call_count == 2 + conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job1") + conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job2") + waiter_mock.assert_called_with("no_job_running") + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs_but_no_jobs(self, conn_mock: MagicMock): + conn_mock.return_value.list_job_runs.return_value = {"jobRuns": []} + hook = EmrServerlessHook(aws_conn_id="aws_default") + + hook.cancel_running_jobs("app") + + # nothing very interesting should happen + conn_mock.assert_called_once() diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 698710f0a1c7d..99ae2de04db7b 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -17,7 +17,7 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, PropertyMock from uuid import UUID import pytest @@ -654,3 +654,16 @@ def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock): mock_waiter.assert_not_called() mock_conn.stop_application.assert_called_once() + + @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter") + @mock.patch.object(EmrServerlessStopApplicationOperator, "hook", new_callable=PropertyMock) + def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock): + operator = EmrServerlessStopApplicationOperator( + task_id=task_id, application_id="test", force_stop=True + ) + + operator.execute(None) + + mock_hook().cancel_running_jobs.assert_called_once() + mock_hook().conn.stop_application.assert_called_once() + mock_waiter.assert_called_once() diff --git a/tests/system/providers/amazon/aws/example_emr_serverless.py b/tests/system/providers/amazon/aws/example_emr_serverless.py index 8bd20a0dccc3c..3e014392dec43 100644 --- a/tests/system/providers/amazon/aws/example_emr_serverless.py +++ b/tests/system/providers/amazon/aws/example_emr_serverless.py @@ -22,6 +22,7 @@ from airflow.models.baseoperator import chain from airflow.models.dag import DAG +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.operators.emr import ( EmrServerlessCreateApplicationOperator, EmrServerlessDeleteApplicationOperator, @@ -100,13 +101,14 @@ configuration_overrides=SPARK_CONFIGURATION_OVERRIDES, ) # [END howto_operator_emr_serverless_start_job] - start_job.waiter_check_interval_seconds = 10 + start_job.wait_for_completion = False # [START howto_sensor_emr_serverless_job] wait_for_job = EmrServerlessJobSensor( task_id="wait_for_job", application_id=emr_serverless_app_id, job_run_id=start_job.output, + target_states=EmrServerlessHook.JOB_INTERMEDIATE_STATES, ) # [END howto_sensor_emr_serverless_job] wait_for_job.poke_interval = 10 @@ -115,6 +117,7 @@ stop_app = EmrServerlessStopApplicationOperator( task_id="stop_application", application_id=emr_serverless_app_id, + force_stop=True, ) # [END howto_operator_emr_serverless_stop_application] stop_app.waiter_check_interval_seconds = 1 From a7f25a14fa11387eefe2acab9c5efd281ab59a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 18 Apr 2023 16:33:50 -0700 Subject: [PATCH 2/9] add that option to the delete operator as well --- airflow/providers/amazon/aws/operators/emr.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 8b2b5aadf4622..140fe0d92b723 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1103,6 +1103,9 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato the application to be stopped, and then deleted. Defaults to 25 minutes. :param waiter_check_interval_seconds: Number of seconds between polling the state of the application. Defaults to 60 seconds. + :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. + Otherwise, trying to delete an app with running jobs will return an error. + If you want to wait for the jobs to finish gracefully, use :ref:`_howto/sensor:EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) @@ -1114,6 +1117,7 @@ def __init__( aws_conn_id: str = "aws_default", waiter_countdown: int = 25 * 60, waiter_check_interval_seconds: int = 60, + force_stop: bool = False, **kwargs, ): self.wait_for_delete_completion = wait_for_completion @@ -1125,6 +1129,7 @@ def __init__( aws_conn_id=aws_conn_id, waiter_countdown=waiter_countdown, waiter_check_interval_seconds=waiter_check_interval_seconds, + force_stop=force_stop, **kwargs, ) From 8370c8172d8c28ffeec3ca460621b9d0447bfba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 19 Apr 2023 16:16:16 -0700 Subject: [PATCH 3/9] remove conn property shadowing the one from super --- airflow/providers/amazon/aws/hooks/emr.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index fb2aa48e6d863..a97011dde1287 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -253,11 +253,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) - @property - def conn(self): - """Get the underlying boto3 EmrServerlessAPIService client (cached)""" - return super().conn - # This method should be replaced with boto waiters which would implement timeouts and backoff nicely. def waiter( self, From ede8b85e359486e099b15d8475a29a884e13ecb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 19 Apr 2023 16:19:28 -0700 Subject: [PATCH 4/9] better way to pass wanted states --- airflow/providers/amazon/aws/hooks/emr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index a97011dde1287..e8b697dc9c42d 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -331,7 +331,7 @@ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): self.log.info("now waiting for the cancelled jobs to terminate") self.get_waiter("no_job_running").wait( applicationId=application_id, - states=["PENDING", "RUNNING", "SCHEDULED", "SUBMITTED", "CANCELLING"], + states=list(self.JOB_INTERMEDIATE_STATES.union("CANCELLING")), WaiterConfig=waiter_config, ) From d132b705284f5177a2cdc5a5b662d6d3663fef23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= <114772123+vandonr-amz@users.noreply.github.com> Date: Thu, 20 Apr 2023 10:18:22 -0700 Subject: [PATCH 5/9] reduce log level of cancelling job Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com> --- airflow/providers/amazon/aws/hooks/emr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index e8b697dc9c42d..ffc3185adf56e 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -323,7 +323,7 @@ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): job_ids = [jr["id"] for jr in r["jobRuns"]] self.log.info("there are %s job(s) running for app %s", len(job_ids), application_id) if len(job_ids) > 0: - self.log.warning( + self.log.info( "Cancelling the pending jobs for the application %s so that it can be stopped", application_id ) for job_id in job_ids: From 0232d2e0b5e6b489586e29cecf189e4cd0bcc151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 20 Apr 2023 13:03:33 -0700 Subject: [PATCH 6/9] use paginator to get all jobs --- airflow/providers/amazon/aws/hooks/emr.py | 35 ++++++++++++------- .../amazon/aws/hooks/test_emr_serverless.py | 17 ++++++++- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index ffc3185adf56e..305d335bf47ad 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -317,21 +317,32 @@ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): Note: if new jobs are triggered while this operation is ongoing, it's going to time out and return an error. """ - r = self.conn.list_job_runs( - applicationId=application_id, maxResults=50, states=list(self.JOB_INTERMEDIATE_STATES) + paginator = self.conn.get_paginator("list_job_runs") + results_per_response = 50 + iterator = paginator.paginate( + applicationId=application_id, + states=list(self.JOB_INTERMEDIATE_STATES), + PaginationConfig={ + "PageSize": results_per_response, + }, ) - job_ids = [jr["id"] for jr in r["jobRuns"]] - self.log.info("there are %s job(s) running for app %s", len(job_ids), application_id) - if len(job_ids) > 0: - self.log.info( - "Cancelling the pending jobs for the application %s so that it can be stopped", application_id - ) - for job_id in job_ids: - self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id) - self.log.info("now waiting for the cancelled jobs to terminate") + count = 0 + for r in iterator: + job_ids = [jr["id"] for jr in r["jobRuns"]] + count += len(job_ids) + if len(job_ids) > 0: + self.log.info( + "Cancelling %s pending job(s) for the application %s so that it can be stopped", + len(job_ids), + application_id, + ) + for job_id in job_ids: + self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id) + if count > 0: + self.log.info("now waiting for the %s cancelled job(s) to terminate", count) self.get_waiter("no_job_running").wait( applicationId=application_id, - states=list(self.JOB_INTERMEDIATE_STATES.union("CANCELLING")), + states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})), WaiterConfig=waiter_config, ) diff --git a/tests/providers/amazon/aws/hooks/test_emr_serverless.py b/tests/providers/amazon/aws/hooks/test_emr_serverless.py index dc3b119dea9a8..33726bc62b5dd 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_serverless.py +++ b/tests/providers/amazon/aws/hooks/test_emr_serverless.py @@ -39,7 +39,7 @@ def test_conn_attribute(self): @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) def test_cancel_jobs(self, conn_mock: MagicMock): - conn_mock().list_job_runs.return_value = {"jobRuns": [{"id": "job1"}, {"id": "job2"}]} + conn_mock().get_paginator().paginate.return_value = [{"jobRuns": [{"id": "job1"}, {"id": "job2"}]}] hook = EmrServerlessHook(aws_conn_id="aws_default") waiter_mock = MagicMock() hook.get_waiter = waiter_mock @@ -51,6 +51,21 @@ def test_cancel_jobs(self, conn_mock: MagicMock): conn_mock().cancel_job_run.assert_any_call(applicationId="app", jobRunId="job2") waiter_mock.assert_called_with("no_job_running") + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_cancel_jobs_several_calls(self, conn_mock: MagicMock): + conn_mock().get_paginator().paginate.return_value = [ + {"jobRuns": [{"id": "job1"}, {"id": "job2"}]}, + {"jobRuns": [{"id": "job3"}, {"id": "job4"}]}, + ] + hook = EmrServerlessHook(aws_conn_id="aws_default") + waiter_mock = MagicMock() + hook.get_waiter = waiter_mock + + hook.cancel_running_jobs("app") + + assert conn_mock().cancel_job_run.call_count == 4 + waiter_mock.assert_called_once() # we should wait once for all jobs, not once per page + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) def test_cancel_jobs_but_no_jobs(self, conn_mock: MagicMock): conn_mock.return_value.list_job_runs.return_value = {"jobRuns": []} From 4b44efecbac4c7b6c22255a374166bc3291d4b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 20 Apr 2023 13:07:45 -0700 Subject: [PATCH 7/9] fix ref in doc, using class instead --- airflow/providers/amazon/aws/operators/emr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 140fe0d92b723..220a4ddea0032 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1027,7 +1027,8 @@ class EmrServerlessStopApplicationOperator(BaseOperator): Defaults to 30 seconds. :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. Otherwise, trying to stop an app with running jobs will return an error. - If you want to wait for the jobs to finish gracefully, use :ref:`_howto/sensor:EmrServerlessJobSensor` + If you want to wait for the jobs to finish gracefully, use + :class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) @@ -1105,7 +1106,8 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato Defaults to 60 seconds. :param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled. Otherwise, trying to delete an app with running jobs will return an error. - If you want to wait for the jobs to finish gracefully, use :ref:`_howto/sensor:EmrServerlessJobSensor` + If you want to wait for the jobs to finish gracefully, use + :class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor` """ template_fields: Sequence[str] = ("application_id",) From af8eb19a501870fc041454a2e4b92a0899b6e66b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Thu, 20 Apr 2023 16:58:27 -0700 Subject: [PATCH 8/9] add comment on target states --- tests/system/providers/amazon/aws/example_emr_serverless.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/system/providers/amazon/aws/example_emr_serverless.py b/tests/system/providers/amazon/aws/example_emr_serverless.py index 3e014392dec43..0c0a19e17bffb 100644 --- a/tests/system/providers/amazon/aws/example_emr_serverless.py +++ b/tests/system/providers/amazon/aws/example_emr_serverless.py @@ -22,7 +22,6 @@ from airflow.models.baseoperator import chain from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook from airflow.providers.amazon.aws.operators.emr import ( EmrServerlessCreateApplicationOperator, EmrServerlessDeleteApplicationOperator, @@ -108,7 +107,8 @@ task_id="wait_for_job", application_id=emr_serverless_app_id, job_run_id=start_job.output, - target_states=EmrServerlessHook.JOB_INTERMEDIATE_STATES, + # the default is to wait for job completion, here we just wait for the job to be running. + target_states={"RUNNING"}, ) # [END howto_sensor_emr_serverless_job] wait_for_job.poke_interval = 10 From 9ddaafe522efad5bf3f6546ae319c8f42a867449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 21 Apr 2023 15:30:53 -0700 Subject: [PATCH 9/9] bring back method that I mistakenly deleted in merge conflict resolution --- airflow/providers/amazon/aws/hooks/emr.py | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index cfdb68b07c397..443b5912c5593 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -252,6 +252,42 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) + def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}): + """ + List all jobs in an intermediate state and cancel them. + Then wait for those jobs to reach a terminal state. + Note: if new jobs are triggered while this operation is ongoing, + it's going to time out and return an error. + """ + paginator = self.conn.get_paginator("list_job_runs") + results_per_response = 50 + iterator = paginator.paginate( + applicationId=application_id, + states=list(self.JOB_INTERMEDIATE_STATES), + PaginationConfig={ + "PageSize": results_per_response, + }, + ) + count = 0 + for r in iterator: + job_ids = [jr["id"] for jr in r["jobRuns"]] + count += len(job_ids) + if len(job_ids) > 0: + self.log.info( + "Cancelling %s pending job(s) for the application %s so that it can be stopped", + len(job_ids), + application_id, + ) + for job_id in job_ids: + self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id) + if count > 0: + self.log.info("now waiting for the %s cancelled job(s) to terminate", count) + self.get_waiter("no_job_running").wait( + applicationId=application_id, + states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})), + WaiterConfig=waiter_config, + ) + class EmrContainerHook(AwsBaseHook): """