From 801d6566cb86d46d0f1ec37d585d8f15ce8acd88 Mon Sep 17 00:00:00 2001 From: Michal Sosnicki Date: Thu, 12 Oct 2023 00:15:45 +0200 Subject: [PATCH 1/2] Cancel operation in on_kill in DataprocInstantiateWorkflowTemplateOperator --- .../google/cloud/operators/dataproc.py | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 3d41fdd1be324..48e66a831a617 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). :param deferrable: Run operator in the deferrable mode. :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status. + :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called """ template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters") @@ -1812,6 +1813,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, + cancel_on_kill: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1830,6 +1832,8 @@ def __init__( self.impersonation_chain = impersonation_chain self.deferrable = deferrable self.polling_interval_seconds = polling_interval_seconds + self.cancel_on_kill = cancel_on_kill + self.operation_name: str | None = None def execute(self, context: Context): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) @@ -1845,24 +1849,26 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - self.workflow_id = operation.operation.name.split("/")[-1] + operation_name = operation.operation.name + self.operation_name = operation_name + workflow_id = operation_name.split("/")[-1] project_id = self.project_id or hook.project_id if project_id: DataprocWorkflowLink.persist( context=context, operator=self, - workflow_id=self.workflow_id, + workflow_id=workflow_id, region=self.region, project_id=project_id, ) - self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) + self.log.info("Template instantiated. Workflow Id : %s", workflow_id) if not self.deferrable: hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) - self.log.info("Workflow %s completed successfully", self.workflow_id) + self.log.info("Workflow %s completed successfully", workflow_id) else: self.defer( trigger=DataprocWorkflowTrigger( - name=operation.operation.name, + name=operation_name, project_id=self.project_id, region=self.region, gcp_conn_id=self.gcp_conn_id, @@ -1884,6 +1890,11 @@ def execute_complete(self, context, event=None) -> None: self.log.info("Workflow %s completed successfully", event["operation_name"]) + def on_kill(self) -> None: + if self.cancel_on_kill and self.operation_name: + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name) + class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator): """Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc. @@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator) account from the list granting this role to the originating account (templated). :param deferrable: Run operator in the deferrable mode. :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status. + :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called """ template_fields: Sequence[str] = ("template", "impersonation_chain") @@ -1946,6 +1958,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, + cancel_on_kill: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -1963,6 +1976,8 @@ def __init__( self.impersonation_chain = impersonation_chain self.deferrable = deferrable self.polling_interval_seconds = polling_interval_seconds + self.cancel_on_kill = cancel_on_kill + self.operation_name: str | None = None def execute(self, context: Context): self.log.info("Instantiating Inline Template") @@ -1977,23 +1992,25 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - self.workflow_id = operation.operation.name.split("/")[-1] + operation_name = operation.operation.name + self.operation_name = operation_name + workflow_id = operation_name.split("/")[-1] if project_id: DataprocWorkflowLink.persist( context=context, operator=self, - workflow_id=self.workflow_id, + workflow_id=workflow_id, region=self.region, project_id=project_id, ) if not self.deferrable: - self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id) + self.log.info("Template instantiated. Workflow Id : %s", workflow_id) operation.result() - self.log.info("Workflow %s completed successfully", self.workflow_id) + self.log.info("Workflow %s completed successfully", workflow_id) else: self.defer( trigger=DataprocWorkflowTrigger( - name=operation.operation.name, + name=operation_name, project_id=self.project_id or hook.project_id, region=self.region, gcp_conn_id=self.gcp_conn_id, @@ -2015,6 +2032,11 @@ def execute_complete(self, context, event=None) -> None: self.log.info("Workflow %s completed successfully", event["operation_name"]) + def on_kill(self) -> None: + if self.cancel_on_kill and self.operation_name: + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name) + class DataprocSubmitJobOperator(GoogleCloudBaseOperator): """Submit a job to a cluster. From f615e63007321fdc76a93cc032e7e01cb49afa0f Mon Sep 17 00:00:00 2001 From: Michal Sosnicki Date: Mon, 16 Oct 2023 03:50:53 +0200 Subject: [PATCH 2/2] Test on_kill method in DataprocInstantiateWorkflowTemplateOperator --- .../google/cloud/operators/test_dataproc.py | 64 ++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 40180d4b47881..02620ccb6c7b2 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -1399,7 +1399,7 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED -class TestDataprocWorkflowTemplateInstantiateOperator: +class TestDataprocInstantiateWorkflowTemplateOperator: @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): version = 6 @@ -1463,6 +1463,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocWorkflowTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_on_kill(self, mock_hook): + operation_name = "operation_name" + mock_hook.return_value.instantiate_workflow_template.return_value.operation.name = operation_name + op = DataprocInstantiateWorkflowTemplateOperator( + task_id=TASK_ID, + template_id=TEMPLATE_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + version=2, + parameters={}, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cancel_on_kill=False, + ) + + op.execute(context=mock.MagicMock()) + + op.on_kill() + mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called() + + op.cancel_on_kill = True + op.on_kill() + mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with( + name=operation_name + ) + @pytest.mark.need_serialized_dag @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1561,6 +1592,37 @@ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocWorkflowTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_on_kill(self, mock_hook): + operation_name = "operation_name" + mock_hook.return_value.instantiate_inline_workflow_template.return_value.operation.name = ( + operation_name + ) + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template={}, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cancel_on_kill=False, + ) + + op.execute(context=mock.MagicMock()) + + op.on_kill() + mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_not_called() + + op.cancel_on_kill = True + op.on_kill() + mock_hook.return_value.get_operations_client.return_value.cancel_operation.assert_called_once_with( + name=operation_name + ) + @pytest.mark.need_serialized_dag @mock.patch(DATAPROC_PATH.format("DataprocHook"))